From 7132c08b48dc207ea5cd4e94b6be4f893527ec45 Mon Sep 17 00:00:00 2001
From: David Rowley <dgrowley@gmail.com>
Date: Wed, 16 Oct 2024 13:38:02 +1300
Subject: [PATCH v1] Don't store intermediate hash values in
 ExprState->resvalue

adf97c156 made it so ExprStates could support hashing and changed Hash
Join to use that instead of manually extracting Datums from tuples and
hashing them 1 column at a time.

When hashing multiple columns or expressions, the code added in that
commit stored the intermediate hash value in the ExprState's resvalue
field.  That was a mistake as steps may be injected into the ExprState
between each hashing step that look at or overwrite the stored
intermediate hash value.  EEOP_PARAM_SET is an example of such a step.

Here we fix this by adding a new dedicated field for storing
intermediate hash values and adjust the code so that all apart from the
final hashing step store their result in the intermediate field.

Reported-by: Andres Freund
---
 src/backend/executor/execExpr.c       | 35 ++++++++++++++++--
 src/backend/executor/execExprInterp.c | 16 ++++----
 src/backend/jit/llvm/llvmjit_expr.c   | 16 ++++++--
 src/include/executor/execExpr.h       |  1 +
 src/test/regress/expected/join.out    | 53 +++++++++++++++++++++++++++
 src/test/regress/sql/join.sql         | 19 ++++++++++
 6 files changed, 125 insertions(+), 15 deletions(-)

diff --git a/src/backend/executor/execExpr.c b/src/backend/executor/execExpr.c
index c8077aa57b..a343d0bc6a 100644
--- a/src/backend/executor/execExpr.c
+++ b/src/backend/executor/execExpr.c
@@ -3996,6 +3996,7 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 {
 	ExprState  *state = makeNode(ExprState);
 	ExprEvalStep scratch = {0};
+	NullableDatum *iresult = NULL;
 	List	   *adjust_jumps = NIL;
 	ListCell   *lc;
 	ListCell   *lc2;
@@ -4009,6 +4010,14 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 	/* Insert setup steps as needed. */
 	ExecCreateExprSetupSteps(state, (Node *) hash_exprs);
 
+	/*
+	 * When hashing more than 1 expression or if we have an init value, we
+	 * need somewhere to store the intermediate hash value so that it's
+	 * available to be combined with the result of subsequent hashing.
+	 */
+	if (list_length(hash_exprs) > 1 || init_value != 0)
+		iresult = palloc(sizeof(NullableDatum));
+
 	if (init_value == 0)
 	{
 		/*
@@ -4024,8 +4033,8 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 		/* Set up operation to set the initial value. */
 		scratch.opcode = EEOP_HASHDATUM_SET_INITVAL;
 		scratch.d.hashdatum_initvalue.init_value = UInt32GetDatum(init_value);
-		scratch.resvalue = &state->resvalue;
-		scratch.resnull = &state->resnull;
+		scratch.resvalue = &iresult->value;
+		scratch.resnull = &iresult->isnull;
 
 		ExprEvalPushStep(state, &scratch);
 
@@ -4063,8 +4072,26 @@ ExecBuildHash32Expr(TupleDesc desc, const TupleTableSlotOps *ops,
 						&fcinfo->args[0].value,
 						&fcinfo->args[0].isnull);
 
-		scratch.resvalue = &state->resvalue;
-		scratch.resnull = &state->resnull;
+		if (i == list_length(hash_exprs) - 1)
+		{
+			/* the result for hashing the final expr is stored in the state */
+			scratch.resvalue = &state->resvalue;
+			scratch.resnull = &state->resnull;
+		}
+		else
+		{
+			Assert(iresult != NULL);
+
+			/* intermediate values are stored in an intermediate result */
+			scratch.resvalue = &iresult->value;
+			scratch.resnull = &iresult->isnull;
+		}
+
+		/*
+		 * NEXT32 opcodes need to look at the intermediate result.  We might
+		 * as well just set this for all ops.  FIRSTs won't look at it.
+		 */
+		scratch.d.hashdatum.iresult = iresult;
 
 		/* Initialize function call parameter structure too */
 		InitFunctionCallInfoData(*fcinfo, finfo, 1, inputcollid, NULL, NULL);
diff --git a/src/backend/executor/execExprInterp.c b/src/backend/executor/execExprInterp.c
index 9fd988cc99..6a7f18f6de 100644
--- a/src/backend/executor/execExprInterp.c
+++ b/src/backend/executor/execExprInterp.c
@@ -1600,10 +1600,11 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 		EEO_CASE(EEOP_HASHDATUM_NEXT32)
 		{
 			FunctionCallInfo fcinfo = op->d.hashdatum.fcinfo_data;
-			uint32		existing_hash = DatumGetUInt32(*op->resvalue);
+			uint32		existinghash;
 
+			existinghash = DatumGetUInt32(op->d.hashdatum.iresult->value);
 			/* combine successive hash values by rotating */
-			existing_hash = pg_rotate_left32(existing_hash, 1);
+			existinghash = pg_rotate_left32(existinghash, 1);
 
 			/* leave the hash value alone on NULL inputs */
 			if (!fcinfo->args[0].isnull)
@@ -1612,10 +1613,10 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 
 				/* execute hash func and combine with previous hash value */
 				hashvalue = DatumGetUInt32(op->d.hashdatum.fn_addr(fcinfo));
-				existing_hash = existing_hash ^ hashvalue;
+				existinghash = existinghash ^ hashvalue;
 			}
 
-			*op->resvalue = UInt32GetDatum(existing_hash);
+			*op->resvalue = UInt32GetDatum(existinghash);
 			*op->resnull = false;
 
 			EEO_NEXT();
@@ -1638,15 +1639,16 @@ ExecInterpExpr(ExprState *state, ExprContext *econtext, bool *isnull)
 			}
 			else
 			{
-				uint32		existing_hash = DatumGetUInt32(*op->resvalue);
+				uint32		existinghash;
 				uint32		hashvalue;
 
+				existinghash = DatumGetUInt32(op->d.hashdatum.iresult->value);
 				/* combine successive hash values by rotating */
-				existing_hash = pg_rotate_left32(existing_hash, 1);
+				existinghash = pg_rotate_left32(existinghash, 1);
 
 				/* execute hash func and combine with previous hash value */
 				hashvalue = DatumGetUInt32(op->d.hashdatum.fn_addr(fcinfo));
-				*op->resvalue = UInt32GetDatum(existing_hash ^ hashvalue);
+				*op->resvalue = UInt32GetDatum(existinghash ^ hashvalue);
 				*op->resnull = false;
 			}
 
diff --git a/src/backend/jit/llvm/llvmjit_expr.c b/src/backend/jit/llvm/llvmjit_expr.c
index 48ccdb942a..d03027331f 100644
--- a/src/backend/jit/llvm/llvmjit_expr.c
+++ b/src/backend/jit/llvm/llvmjit_expr.c
@@ -1940,12 +1940,16 @@ llvm_compile_expr(ExprState *state)
 					{
 						LLVMValueRef v_tmp1;
 						LLVMValueRef v_tmp2;
+						LLVMValueRef tmp;
+
+						tmp = l_ptr_const(&op->d.hashdatum.iresult->value,
+										  l_ptr(TypeSizeT));
 
 						/*
 						 * Fetch the previously hashed value from where the
-						 * EEOP_HASHDATUM_FIRST operation stored it.
+						 * previous hash operation stored it.
 						 */
-						v_prevhash = l_load(b, TypeSizeT, v_resvaluep,
+						v_prevhash = l_load(b, TypeSizeT, tmp,
 											"prevhash");
 
 						/*
@@ -2062,12 +2066,16 @@ llvm_compile_expr(ExprState *state)
 					{
 						LLVMValueRef v_tmp1;
 						LLVMValueRef v_tmp2;
+						LLVMValueRef tmp;
+
+						tmp = l_ptr_const(&op->d.hashdatum.iresult->value,
+										  l_ptr(TypeSizeT));
 
 						/*
 						 * Fetch the previously hashed value from where the
-						 * EEOP_HASHDATUM_FIRST_STRICT operation stored it.
+						 * previous hash operation stored it.
 						 */
-						v_prevhash = l_load(b, TypeSizeT, v_resvaluep,
+						v_prevhash = l_load(b, TypeSizeT, tmp,
 											"prevhash");
 
 						/*
diff --git a/src/include/executor/execExpr.h b/src/include/executor/execExpr.h
index eec0aa699e..cd97dfa062 100644
--- a/src/include/executor/execExpr.h
+++ b/src/include/executor/execExpr.h
@@ -580,6 +580,7 @@ typedef struct ExprEvalStep
 			/* faster to access without additional indirection: */
 			PGFunction	fn_addr;	/* actual call address */
 			int			jumpdone;	/* jump here on null */
+			NullableDatum *iresult; /* intermediate hash result */
 		}			hashdatum;
 
 		/* for EEOP_CONVERT_ROWTYPE */
diff --git a/src/test/regress/expected/join.out b/src/test/regress/expected/join.out
index 756c2e2496..23132c846a 100644
--- a/src/test/regress/expected/join.out
+++ b/src/test/regress/expected/join.out
@@ -2358,6 +2358,59 @@ where b.f1 = t.thousand and a.f1 = b.f1 and (a.f1+b.f1+999) = t.tenthous;
 ----+----+----------+----------
 (0 rows)
 
+-- test hash joins with multiple hash keys and subplans
+-- first ensure we get a hash join with multiple hash keys
+explain (costs off)
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+                                             QUERY PLAN                                             
+----------------------------------------------------------------------------------------------------
+ Sort
+   Sort Key: t1.unique1
+   ->  Hash Join
+         Hash Cond: ((t1.two = t2.two) AND (t1.unique1 = (SubPlan 2)))
+         ->  Bitmap Heap Scan on tenk1 t1
+               Recheck Cond: (unique1 < 10)
+               ->  Bitmap Index Scan on tenk1_unique1
+                     Index Cond: (unique1 < 10)
+         ->  Hash
+               ->  Bitmap Heap Scan on tenk1 t2
+                     Recheck Cond: (unique1 < 10)
+                     ->  Bitmap Index Scan on tenk1_unique1
+                           Index Cond: (unique1 < 10)
+               SubPlan 2
+                 ->  Result
+                       InitPlan 1
+                         ->  Limit
+                               ->  Index Only Scan using tenk1_unique1 on tenk1
+                                     Index Cond: ((unique1 IS NOT NULL) AND (unique1 = t2.unique1))
+(19 rows)
+
+-- ensure we get the expected result
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+ unique1 | unique1 
+---------+---------
+       0 |       0
+       1 |       1
+       2 |       2
+       3 |       3
+       4 |       4
+       5 |       5
+       6 |       6
+       7 |       7
+       8 |       8
+       9 |       9
+(10 rows)
+
 --
 -- checks for correct handling of quals in multiway outer joins
 --
diff --git a/src/test/regress/sql/join.sql b/src/test/regress/sql/join.sql
index 0c65e5af4b..7e958fbc33 100644
--- a/src/test/regress/sql/join.sql
+++ b/src/test/regress/sql/join.sql
@@ -441,6 +441,25 @@ select a.f1, b.f1, t.thousand, t.tenthous from
   (select sum(f1) as f1 from int4_tbl i4b) b
 where b.f1 = t.thousand and a.f1 = b.f1 and (a.f1+b.f1+999) = t.tenthous;
 
+-- test hash joins with multiple hash keys and subplans
+
+-- first ensure we get a hash join with multiple hash keys
+explain (costs off)
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+
+-- ensure we get the expected result
+select t1.unique1,t2.unique1 from tenk1 t1
+inner join tenk1 t2 on t1.two = t2.two
+  and t1.unique1 = (select min(unique1) from tenk1
+                    where t2.unique1=unique1)
+where t1.unique1 < 10 and t2.unique1 < 10
+order by t1.unique1;
+
 --
 -- checks for correct handling of quals in multiway outer joins
 --
-- 
2.40.1.windows.1

