From ebc3c795ca42a66d9d0debabd50c97ed90a568b4 Mon Sep 17 00:00:00 2001
From: David Rowley <dgrowley@gmail.com>
Date: Sat, 24 Apr 2021 21:21:33 +1200
Subject: [PATCH v2] Speedup NOT IN() with a set of Consts

Similar to 50e17ad28, which allowed hash tables to be used for IN clauses
with a set of constants. Here we add the same feature for NOT IN clauses.

Much of the code is shared with the IN implementation, we mostly just need
to check if the negator operator for a !useOr ScalarArrayOpExpr is
hashable.  Only some small changes are required in the executor to ensure
we pay attention to useOr so that we correctly negate the return value in
the correct place.
---
 src/backend/executor/execExpr.c           |  1 +
 src/backend/executor/execExprInterp.c     | 18 ++++-
 src/backend/optimizer/util/clauses.c      | 76 +++++++++++++++-----
 src/include/executor/execExpr.h           |  1 +
 src/test/regress/expected/expressions.out | 84 +++++++++++++++++++++++
 src/test/regress/sql/expressions.sql      | 30 ++++++++
 6 files changed, 192 insertions(+), 18 deletions(-)

diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c
index 77c9d785d9..1a74609045 100644
--- a/src/backend/executor/execExpr.c
+++ b/src/backend/executor/execExpr.c
@@ -1220,6 +1220,7 @@ ExecInitExprRec(Expr *node, ExprState *state,
 
 					/* And perform the operation */
 					scratch.opcode = EEOP_HASHED_SCALARARRAYOP;
+					scratch.d.hashedscalararrayop.useOr = opexpr->useOr;
 					scratch.d.hashedscalararrayop.finfo = finfo;
 					scratch.d.hashedscalararrayop.fcinfo_data = fcinfo;
 					scratch.d.hashedscalararrayop.fn_addr = finfo->fn_addr;
diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c
index 094e22d392..7a65680a01 100644
--- a/src/backend/executor/execExprInterp.c
+++ b/src/backend/executor/execExprInterp.c
@@ -3481,6 +3481,7 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
 {
 	ScalarArrayOpExprHashTable *elements_tab = op->d.hashedscalararrayop.elements_tab;
 	FunctionCallInfo fcinfo = op->d.hashedscalararrayop.fcinfo_data;
+	bool		useOr = op->d.hashedscalararrayop.useOr;
 	bool		strictfunc = op->d.hashedscalararrayop.finfo->fn_strict;
 	Datum		scalar = fcinfo->args[0].value;
 	bool		scalar_isnull = fcinfo->args[0].isnull;
@@ -3584,7 +3585,12 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
 	/* Check the hash to see if we have a match. */
 	hashfound = NULL != saophash_lookup(elements_tab->hashtab, scalar);
 
-	result = BoolGetDatum(hashfound);
+	/* useOr == true means an IN clause, useOr == false is NOT IN */
+	if (useOr)
+		result = BoolGetDatum(hashfound);
+	else
+		result = BoolGetDatum(!hashfound);
+
 	resultnull = false;
 
 	/*
@@ -3593,7 +3599,7 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
 	 * hashtable, but instead marked if we found any when building the table
 	 * in has_nulls.
 	 */
-	if (!DatumGetBool(result) && op->d.hashedscalararrayop.has_nulls)
+	if (!hashfound && op->d.hashedscalararrayop.has_nulls)
 	{
 		if (strictfunc)
 		{
@@ -3621,6 +3627,14 @@ ExecEvalHashedScalarArrayOp(ExprState *state, ExprEvalStep *op, ExprContext *eco
 
 			result = op->d.hashedscalararrayop.fn_addr(fcinfo);
 			resultnull = fcinfo->isnull;
+
+			/*
+			 * When doing NOT IN the function call we did above is the negator
+			 * of the NOT IN function, so we must reverse the result of the
+			 * function.
+			 */
+			if (!useOr)
+				result = !result;
 		}
 	}
 
diff --git a/src/backend/optimizer/util/clauses.c b/src/backend/optimizer/util/clauses.c
index d9ad4efc5e..4f5907a3b5 100644
--- a/src/backend/optimizer/util/clauses.c
+++ b/src/backend/optimizer/util/clauses.c
@@ -2137,27 +2137,71 @@ convert_saop_to_hashed_saop_walker(Node *node, void *context)
 		Oid			lefthashfunc;
 		Oid			righthashfunc;
 
-		if (saop->useOr && arrayarg && IsA(arrayarg, Const) &&
-			!((Const *) arrayarg)->constisnull &&
-			get_op_hash_functions(saop->opno, &lefthashfunc, &righthashfunc) &&
-			lefthashfunc == righthashfunc)
+		if (arrayarg && IsA(arrayarg, Const) &&
+			!((Const *) arrayarg)->constisnull)
 		{
-			Datum		arrdatum = ((Const *) arrayarg)->constvalue;
-			ArrayType  *arr = (ArrayType *) DatumGetPointer(arrdatum);
-			int			nitems;
+			if (saop->useOr)
+			{
+				if (get_op_hash_functions(saop->opno, &lefthashfunc, &righthashfunc) &&
+					lefthashfunc == righthashfunc)
+				{
+					Datum		arrdatum = ((Const *) arrayarg)->constvalue;
+					ArrayType  *arr = (ArrayType *) DatumGetPointer(arrdatum);
+					int			nitems;
 
-			/*
-			 * Only fill in the hash functions if the array looks large enough
-			 * for it to be worth hashing instead of doing a linear search.
-			 */
-			nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr));
+					/*
+					 * Only fill in the hash functions if the array looks large enough
+					 * for it to be worth hashing instead of doing a linear search.
+					 */
+					nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr));
 
-			if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP)
+					if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP)
+					{
+						/* Looks good. Fill in the hash functions */
+						saop->hashfuncid = lefthashfunc;
+					}
+					return true;
+				}
+			}
+			else /* !saop->useOr */
 			{
-				/* Looks good. Fill in the hash functions */
-				saop->hashfuncid = lefthashfunc;
+				Oid		negator = get_negator(saop->opno);
+
+				/*
+				 * Check if this is a NOT IN using an operator whose negator
+				 * is hashable.  If so we can still build a hash table and
+				 * just ensure the lookup items are not in the hash table.
+				 */
+				if (OidIsValid(negator) &&
+					get_op_hash_functions(negator, &lefthashfunc, &righthashfunc) &&
+					lefthashfunc == righthashfunc)
+				{
+					Datum		arrdatum = ((Const *) arrayarg)->constvalue;
+					ArrayType  *arr = (ArrayType *) DatumGetPointer(arrdatum);
+					int			nitems;
+
+					/*
+					 * Only fill in the hash functions if the array looks large enough
+					 * for it to be worth hashing instead of doing a linear search.
+					 */
+					nitems = ArrayGetNItems(ARR_NDIM(arr), ARR_DIMS(arr));
+
+					if (nitems >= MIN_ARRAY_SIZE_FOR_HASHED_SAOP)
+					{
+						/* Looks good. Fill in the hash functions */
+						saop->hashfuncid = lefthashfunc;
+
+						/*
+						 * XXX is it safe enough just to set the opfuncid to
+						 * the negator's function?  We need to leave the opno
+						 * in place so that EXPLAIN shows the correct
+						 * operator.
+						 */
+						saop->opfuncid = get_opcode(negator);
+					}
+					return true;
+				}
 			}
-			return true;
 		}
 	}
 
diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h
index 785600d04d..c68668a7a0 100644
--- a/src/include/executor/execExpr.h
+++ b/src/include/executor/execExpr.h
@@ -574,6 +574,7 @@ typedef struct ExprEvalStep
 		struct
 		{
 			bool		has_nulls;
+			bool		useOr;	/* use OR or AND semantics? */
 			struct ScalarArrayOpExprHashTable *elements_tab;
 			FmgrInfo   *finfo;	/* function's lookup data */
 			FunctionCallInfo fcinfo_data;	/* arguments etc */
diff --git a/src/test/regress/expected/expressions.out b/src/test/regress/expected/expressions.out
index 5944dfd5e1..84159cb21f 100644
--- a/src/test/regress/expected/expressions.out
+++ b/src/test/regress/expected/expressions.out
@@ -216,6 +216,55 @@ select return_text_input('a') in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', '
  t
 (1 row)
 
+-- NOT IN
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1);
+ ?column? 
+----------
+ f
+(1 row)
+
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 0);
+ ?column? 
+----------
+ t
+(1 row)
+
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 2, null);
+ ?column? 
+----------
+ 
+(1 row)
+
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null);
+ ?column? 
+----------
+ f
+(1 row)
+
+select return_int_input(1) not in (null, null, null, null, null, null, null, null, null, null, null);
+ ?column? 
+----------
+ 
+(1 row)
+
+select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1);
+ ?column? 
+----------
+ 
+(1 row)
+
+select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, null);
+ ?column? 
+----------
+ 
+(1 row)
+
+select return_text_input('a') not in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j');
+ ?column? 
+----------
+ f
+(1 row)
+
 rollback;
 -- Test with non-strict equality function.
 -- We need to create our own type for this.
@@ -242,6 +291,11 @@ begin
   end if;
 end;
 $$ language plpgsql immutable;
+create function myintne(myint, myint) returns bool as $$
+begin
+  return not myinteq($1, $2);
+end;
+$$ language plpgsql immutable;
 create operator = (
   leftarg    = myint,
   rightarg   = myint,
@@ -252,6 +306,16 @@ create operator = (
   join       = eqjoinsel,
   merges
 );
+create operator <> (
+  leftarg    = myint,
+  rightarg   = myint,
+  commutator = <>,
+  negator    = =,
+  procedure  = myintne,
+  restrict   = eqsel,
+  join       = eqjoinsel,
+  merges
+);
 create operator class myint_ops
 default for type myint using hash as
   operator    1   =  (myint, myint),
@@ -266,6 +330,16 @@ select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6
  
 (2 rows)
 
+select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
+ a 
+---
+(0 rows)
+
+select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
+ a 
+---
+(0 rows)
+
 -- ensure the result matched with the non-hashed version.  We simply remove
 -- some array elements so that we don't reach the hashing threshold.
 select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
@@ -275,4 +349,14 @@ select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,
  
 (2 rows)
 
+select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
+ a 
+---
+(0 rows)
+
+select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null);
+ a 
+---
+(0 rows)
+
 rollback;
diff --git a/src/test/regress/sql/expressions.sql b/src/test/regress/sql/expressions.sql
index b3fd1b5ecb..bf30f41505 100644
--- a/src/test/regress/sql/expressions.sql
+++ b/src/test/regress/sql/expressions.sql
@@ -93,6 +93,15 @@ select return_int_input(1) in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null);
 select return_int_input(null::int) in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1);
 select return_int_input(null::int) in (10, 9, 2, 8, 3, 7, 4, 6, 5, null);
 select return_text_input('a') in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j');
+-- NOT IN
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1);
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 0);
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 2, null);
+select return_int_input(1) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1, null);
+select return_int_input(1) not in (null, null, null, null, null, null, null, null, null, null, null);
+select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, 1);
+select return_int_input(null::int) not in (10, 9, 2, 8, 3, 7, 4, 6, 5, null);
+select return_text_input('a') not in ('a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j');
 
 rollback;
 
@@ -124,6 +133,12 @@ begin
 end;
 $$ language plpgsql immutable;
 
+create function myintne(myint, myint) returns bool as $$
+begin
+  return not myinteq($1, $2);
+end;
+$$ language plpgsql immutable;
+
 create operator = (
   leftarg    = myint,
   rightarg   = myint,
@@ -135,6 +150,17 @@ create operator = (
   merges
 );
 
+create operator <> (
+  leftarg    = myint,
+  rightarg   = myint,
+  commutator = <>,
+  negator    = =,
+  procedure  = myintne,
+  restrict   = eqsel,
+  join       = eqjoinsel,
+  merges
+);
+
 create operator class myint_ops
 default for type myint using hash as
   operator    1   =  (myint, myint),
@@ -145,8 +171,12 @@ insert into inttest values(1::myint),(null);
 
 -- try an array with enough elements to cause hashing
 select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
+select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
+select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint,6::myint,7::myint,8::myint,9::myint, null);
 -- ensure the result matched with the non-hashed version.  We simply remove
 -- some array elements so that we don't reach the hashing threshold.
 select * from inttest where a in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
+select * from inttest where a not in (1::myint,2::myint,3::myint,4::myint,5::myint, null);
+select * from inttest where a not in (0::myint,2::myint,3::myint,4::myint,5::myint, null);
 
 rollback;
-- 
2.21.0.windows.1

