From f029ec8ef882c3a2e1383ae82adba7d1f0576784 Mon Sep 17 00:00:00 2001
From: "dgrowley@gmail.com" <dgrowley@gmail.com>
Date: Tue, 13 Apr 2021 00:10:43 +1200
Subject: [PATCH v2] Cleanup some aggregate code in the executor

Here we alter the code that calls build_pertrans_for_aggref() so that the
function no longer needs to special-case whether it's dealing with an
aggtransfn or an aggcombinefn.  This allows us to reuse the
build_aggregate_transfn_expr() function and just get rid of the
build_aggregate_combinefn_expr() completely.

All of the special case code that was in build_pertrans_for_aggref() has
been moved up to the calling functions.

This saves about a dozen lines of code in nodeAgg.c and a few dozen more
in parse_agg.c

Also, rename a few variables in nodeAgg.c to try to make it more clear
that we're working with either a aggtransfn or an aggcombinefn.  Some of
the old names would have you believe that we were always working with an
aggtransfn.
---
 src/backend/executor/nodeAgg.c | 234 ++++++++++++++++-----------------
 src/backend/parser/parse_agg.c |  34 +----
 src/include/parser/parse_agg.h |   5 -
 3 files changed, 118 insertions(+), 155 deletions(-)

diff --git a/src/backend/executor/nodeAgg.c b/src/backend/executor/nodeAgg.c
index 8440a76fbd..ee842fa710 100644
--- a/src/backend/executor/nodeAgg.c
+++ b/src/backend/executor/nodeAgg.c
@@ -461,10 +461,11 @@ static void hashagg_tapeinfo_release(HashTapeInfo *tapeinfo, int tapenum);
 static Datum GetAggInitVal(Datum textInitVal, Oid transtype);
 static void build_pertrans_for_aggref(AggStatePerTrans pertrans,
 									  AggState *aggstate, EState *estate,
-									  Aggref *aggref, Oid aggtransfn, Oid aggtranstype,
-									  Oid aggserialfn, Oid aggdeserialfn,
-									  Datum initValue, bool initValueIsNull,
-									  Oid *inputTypes, int numArguments);
+									  Aggref *aggref, Oid transfn_oid,
+									  Oid aggtranstype, Oid aggserialfn,
+									  Oid aggdeserialfn, Datum initValue,
+									  bool initValueIsNull, Oid *inputTypes,
+									  int numArguments);
 
 
 /*
@@ -3724,8 +3725,8 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		Aggref	   *aggref = lfirst(l);
 		AggStatePerAgg peragg;
 		AggStatePerTrans pertrans;
-		Oid			inputTypes[FUNC_MAX_ARGS];
-		int			numArguments;
+		Oid			aggTransFnInputTypes[FUNC_MAX_ARGS];
+		int			numAggTransFnArgs;
 		int			numDirectArgs;
 		HeapTuple	aggTuple;
 		Form_pg_aggregate aggform;
@@ -3859,14 +3860,15 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		 * could be different from the agg's declared input types, when the
 		 * agg accepts ANY or a polymorphic type.
 		 */
-		numArguments = get_aggregate_argtypes(aggref, inputTypes);
+		numAggTransFnArgs = get_aggregate_argtypes(aggref,
+												   aggTransFnInputTypes);
 
 		/* Count the "direct" arguments, if any */
 		numDirectArgs = list_length(aggref->aggdirectargs);
 
 		/* Detect how many arguments to pass to the finalfn */
 		if (aggform->aggfinalextra)
-			peragg->numFinalArgs = numArguments + 1;
+			peragg->numFinalArgs = numAggTransFnArgs + 1;
 		else
 			peragg->numFinalArgs = numDirectArgs + 1;
 
@@ -3880,7 +3882,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 		 */
 		if (OidIsValid(finalfn_oid))
 		{
-			build_aggregate_finalfn_expr(inputTypes,
+			build_aggregate_finalfn_expr(aggTransFnInputTypes,
 										 peragg->numFinalArgs,
 										 aggtranstype,
 										 aggref->aggtype,
@@ -3911,7 +3913,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 			/*
 			 * If this aggregation is performing state combines, then instead
 			 * of using the transition function, we'll use the combine
-			 * function
+			 * function.
 			 */
 			if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
 			{
@@ -3924,8 +3926,7 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 			else
 				transfn_oid = aggform->aggtransfn;
 
-			aclresult = pg_proc_aclcheck(transfn_oid, aggOwner,
-										 ACL_EXECUTE);
+			aclresult = pg_proc_aclcheck(transfn_oid, aggOwner, ACL_EXECUTE);
 			if (aclresult != ACLCHECK_OK)
 				aclcheck_error(aclresult, OBJECT_FUNCTION,
 							   get_func_name(transfn_oid));
@@ -3943,11 +3944,72 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
 			else
 				initValue = GetAggInitVal(textInitVal, aggtranstype);
 
-			build_pertrans_for_aggref(pertrans, aggstate, estate,
-									  aggref, transfn_oid, aggtranstype,
-									  serialfn_oid, deserialfn_oid,
-									  initValue, initValueIsNull,
-									  inputTypes, numArguments);
+			if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
+			{
+				Oid			combineFnInputTypes[] = { aggtranstype,
+													  aggtranstype };
+
+				/*
+				 * When combining there's only one input, the to-be-combined
+				 * added transition value from below (this node's transition
+				 * value is counted separately).
+				 */
+				pertrans->numTransInputs = 1;
+
+				/* aggcombinedfn always has two arguments of aggtranstype */
+				build_pertrans_for_aggref(pertrans, aggstate, estate,
+										  aggref, transfn_oid, aggtranstype,
+										  serialfn_oid, deserialfn_oid,
+										  initValue, initValueIsNull,
+										  combineFnInputTypes, 2);
+
+				/*
+				 * Ensure that a combine function to combine INTERNAL states
+				 * is not strict. This should have been checked during CREATE
+				 * AGGREGATE, but the strict property could have been changed
+				 * since then.
+				 */
+				if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
+					ereport(ERROR,
+					(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+						errmsg("combine function with transition type %s must not be declared STRICT",
+							format_type_be(aggtranstype))));
+			}
+			else
+			{
+				/* Detect how many arguments to pass to the transfn */
+				if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
+					pertrans->numTransInputs = list_length(aggref->args);
+				else
+					pertrans->numTransInputs = numAggTransFnArgs;
+
+				build_pertrans_for_aggref(pertrans, aggstate, estate,
+										  aggref, transfn_oid, aggtranstype,
+										  serialfn_oid, deserialfn_oid,
+										  initValue, initValueIsNull,
+										  aggTransFnInputTypes,
+										  numAggTransFnArgs);
+
+				/*
+				 * If the transfn is strict and the initval is NULL, make sure
+				 * input type and transtype are the same (or at least
+				 * binary-compatible), so that it's OK to use the first
+				 * aggregated input value as the initial transValue.  This
+				 * should have been checked at agg definition time, but we
+				 * must check again in case the transfn's strictness property
+				 * has been changed.
+				 */
+				if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
+				{
+					if (numAggTransFnArgs <= numDirectArgs ||
+						!IsBinaryCoercible(aggTransFnInputTypes[numDirectArgs],
+										   aggtranstype))
+						ereport(ERROR,
+						(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
+							errmsg("aggregate %u needs to have compatible input type and transition type",
+								aggref->aggfnoid)));
+				}
+			}
 		}
 		else
 			pertrans->aggshared = true;
@@ -4039,20 +4101,24 @@ ExecInitAgg(Agg *node, EState *estate, int eflags)
  * Build the state needed to calculate a state value for an aggregate.
  *
  * This initializes all the fields in 'pertrans'. 'aggref' is the aggregate
- * to initialize the state for. 'aggtransfn', 'aggtranstype', and the rest
+ * to initialize the state for. 'transfn_oid', 'aggtranstype', and the rest
  * of the arguments could be calculated from 'aggref', but the caller has
  * calculated them already, so might as well pass them.
+ *
+ * 'transfn_oid' may be either the Oid of the aggtransfn or the aggcombinefn.
  */
 static void
 build_pertrans_for_aggref(AggStatePerTrans pertrans,
 						  AggState *aggstate, EState *estate,
 						  Aggref *aggref,
-						  Oid aggtransfn, Oid aggtranstype,
+						  Oid transfn_oid, Oid aggtranstype,
 						  Oid aggserialfn, Oid aggdeserialfn,
 						  Datum initValue, bool initValueIsNull,
 						  Oid *inputTypes, int numArguments)
 {
 	int			numGroupingSets = Max(aggstate->maxsets, 1);
+	Expr	   *transfnexpr;
+	int			numTransArgs;
 	Expr	   *serialfnexpr = NULL;
 	Expr	   *deserialfnexpr = NULL;
 	ListCell   *lc;
@@ -4067,7 +4133,7 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 	pertrans->aggref = aggref;
 	pertrans->aggshared = false;
 	pertrans->aggCollation = aggref->inputcollid;
-	pertrans->transfn_oid = aggtransfn;
+	pertrans->transfn_oid = transfn_oid;
 	pertrans->serialfn_oid = aggserialfn;
 	pertrans->deserialfn_oid = aggdeserialfn;
 	pertrans->initValue = initValue;
@@ -4081,111 +4147,34 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 
 	pertrans->aggtranstype = aggtranstype;
 
+	/* account for the current transition state */
+	numTransArgs = pertrans->numTransInputs + 1;
+
 	/*
-	 * When combining states, we have no use at all for the aggregate
-	 * function's transfn. Instead we use the combinefn.  In this case, the
-	 * transfn and transfn_oid fields of pertrans refer to the combine
-	 * function rather than the transition function.
+	 * Set up infrastructure for calling the transfn.  Note that invtrans is
+	 * not needed here.
 	 */
-	if (DO_AGGSPLIT_COMBINE(aggstate->aggsplit))
-	{
-		Expr	   *combinefnexpr;
-		size_t		numTransArgs;
-
-		/*
-		 * When combining there's only one input, the to-be-combined added
-		 * transition value from below (this node's transition value is
-		 * counted separately).
-		 */
-		pertrans->numTransInputs = 1;
-
-		/* account for the current transition state */
-		numTransArgs = pertrans->numTransInputs + 1;
-
-		build_aggregate_combinefn_expr(aggtranstype,
-									   aggref->inputcollid,
-									   aggtransfn,
-									   &combinefnexpr);
-		fmgr_info(aggtransfn, &pertrans->transfn);
-		fmgr_info_set_expr((Node *) combinefnexpr, &pertrans->transfn);
-
-		pertrans->transfn_fcinfo =
-			(FunctionCallInfo) palloc(SizeForFunctionCallInfo(2));
-		InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
-								 &pertrans->transfn,
-								 numTransArgs,
-								 pertrans->aggCollation,
-								 (void *) aggstate, NULL);
-
-		/*
-		 * Ensure that a combine function to combine INTERNAL states is not
-		 * strict. This should have been checked during CREATE AGGREGATE, but
-		 * the strict property could have been changed since then.
-		 */
-		if (pertrans->transfn.fn_strict && aggtranstype == INTERNALOID)
-			ereport(ERROR,
-					(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-					 errmsg("combine function with transition type %s must not be declared STRICT",
-							format_type_be(aggtranstype))));
-	}
-	else
-	{
-		Expr	   *transfnexpr;
-		size_t		numTransArgs;
-
-		/* Detect how many arguments to pass to the transfn */
-		if (AGGKIND_IS_ORDERED_SET(aggref->aggkind))
-			pertrans->numTransInputs = numInputs;
-		else
-			pertrans->numTransInputs = numArguments;
+	build_aggregate_transfn_expr(inputTypes,
+								 numArguments,
+								 numDirectArgs,
+								 aggref->aggvariadic,
+								 aggtranstype,
+								 aggref->inputcollid,
+								 transfn_oid,
+								 InvalidOid,
+								 &transfnexpr,
+								 NULL);
 
-		/* account for the current transition state */
-		numTransArgs = pertrans->numTransInputs + 1;
+	fmgr_info(transfn_oid, &pertrans->transfn);
+	fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
 
-		/*
-		 * Set up infrastructure for calling the transfn.  Note that
-		 * invtransfn is not needed here.
-		 */
-		build_aggregate_transfn_expr(inputTypes,
-									 numArguments,
-									 numDirectArgs,
-									 aggref->aggvariadic,
-									 aggtranstype,
-									 aggref->inputcollid,
-									 aggtransfn,
-									 InvalidOid,
-									 &transfnexpr,
-									 NULL);
-		fmgr_info(aggtransfn, &pertrans->transfn);
-		fmgr_info_set_expr((Node *) transfnexpr, &pertrans->transfn);
-
-		pertrans->transfn_fcinfo =
-			(FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
-		InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
-								 &pertrans->transfn,
-								 numTransArgs,
-								 pertrans->aggCollation,
-								 (void *) aggstate, NULL);
-
-		/*
-		 * If the transfn is strict and the initval is NULL, make sure input
-		 * type and transtype are the same (or at least binary-compatible), so
-		 * that it's OK to use the first aggregated input value as the initial
-		 * transValue.  This should have been checked at agg definition time,
-		 * but we must check again in case the transfn's strictness property
-		 * has been changed.
-		 */
-		if (pertrans->transfn.fn_strict && pertrans->initValueIsNull)
-		{
-			if (numArguments <= numDirectArgs ||
-				!IsBinaryCoercible(inputTypes[numDirectArgs],
-								   aggtranstype))
-				ereport(ERROR,
-						(errcode(ERRCODE_INVALID_FUNCTION_DEFINITION),
-						 errmsg("aggregate %u needs to have compatible input type and transition type",
-								aggref->aggfnoid)));
-		}
-	}
+	pertrans->transfn_fcinfo =
+		(FunctionCallInfo) palloc(SizeForFunctionCallInfo(numTransArgs));
+	InitFunctionCallInfoData(*pertrans->transfn_fcinfo,
+							 &pertrans->transfn,
+							 numTransArgs,
+							 pertrans->aggCollation,
+							 (void *) aggstate, NULL);
 
 	/* get info about the state value's datatype */
 	get_typlenbyval(aggtranstype,
@@ -4276,6 +4265,9 @@ build_pertrans_for_aggref(AggStatePerTrans pertrans,
 		 */
 		Assert(aggstate->aggstrategy != AGG_HASHED && aggstate->aggstrategy != AGG_MIXED);
 
+		/* ORDER BY aggregates are not supported with partial aggregation */
+		Assert(!DO_AGGSPLIT_COMBINE(aggstate->aggsplit));
+
 		/* If we have only one input, we need its len/byval info. */
 		if (numInputs == 1)
 		{
diff --git a/src/backend/parser/parse_agg.c b/src/backend/parser/parse_agg.c
index 9562ffcf3e..0719537f60 100644
--- a/src/backend/parser/parse_agg.c
+++ b/src/backend/parser/parse_agg.c
@@ -1959,6 +1959,11 @@ resolve_aggregate_transtype(Oid aggfuncid,
  * latter may be InvalidOid, however if invtransfn_oid is set then
  * transfn_oid must also be set.
  *
+ * transfn_oid may also be passed as the aggcombinefn when the *transfnexpr is
+ * to be used for a combine aggregate phase.  We expect invtransfn_oid to be
+ * InvalidOid in this case since there is no such thing as an inverse
+ * combinefn.
+ *
  * Pointers to the constructed trees are returned into *transfnexpr,
  * *invtransfnexpr. If there is no invtransfn, the respective pointer is set
  * to NULL.  Since use of the invtransfn is optional, NULL may be passed for
@@ -2021,35 +2026,6 @@ build_aggregate_transfn_expr(Oid *agg_input_types,
 	}
 }
 
-/*
- * Like build_aggregate_transfn_expr, but creates an expression tree for the
- * combine function of an aggregate, rather than the transition function.
- */
-void
-build_aggregate_combinefn_expr(Oid agg_state_type,
-							   Oid agg_input_collation,
-							   Oid combinefn_oid,
-							   Expr **combinefnexpr)
-{
-	Node	   *argp;
-	List	   *args;
-	FuncExpr   *fexpr;
-
-	/* combinefn takes two arguments of the aggregate state type */
-	argp = make_agg_arg(agg_state_type, agg_input_collation);
-
-	args = list_make2(argp, argp);
-
-	fexpr = makeFuncExpr(combinefn_oid,
-						 agg_state_type,
-						 args,
-						 InvalidOid,
-						 agg_input_collation,
-						 COERCE_EXPLICIT_CALL);
-	/* combinefn is currently never treated as variadic */
-	*combinefnexpr = (Expr *) fexpr;
-}
-
 /*
  * Like build_aggregate_transfn_expr, but creates an expression tree for the
  * serialization function of an aggregate.
diff --git a/src/include/parser/parse_agg.h b/src/include/parser/parse_agg.h
index 4dea01752a..bffbb82df6 100644
--- a/src/include/parser/parse_agg.h
+++ b/src/include/parser/parse_agg.h
@@ -46,11 +46,6 @@ extern void build_aggregate_transfn_expr(Oid *agg_input_types,
 										 Expr **transfnexpr,
 										 Expr **invtransfnexpr);
 
-extern void build_aggregate_combinefn_expr(Oid agg_state_type,
-										   Oid agg_input_collation,
-										   Oid combinefn_oid,
-										   Expr **combinefnexpr);
-
 extern void build_aggregate_serialfn_expr(Oid serialfn_oid,
 										  Expr **serialfnexpr);
 
-- 
2.27.0

