diff --git a/src/backend/executor/execPartition.c b/src/backend/executor/execPartition.c
index 33513ff1d1..5b74733804 100644
--- a/src/backend/executor/execPartition.c
+++ b/src/backend/executor/execPartition.c
@@ -1463,19 +1463,29 @@ ExecSetupPartitionPruneState(PlanState *planstate, List *partitionpruneinfo)
 
 		partkey = RelationGetPartitionKey(rel);
 		partdesc = RelationGetPartitionDesc(rel);
+		n_steps = list_length(pinfo->pruning_steps);
 
 		context->strategy = partkey->strategy;
 		context->partnatts = partnatts = partkey->partnatts;
-		context->partopfamily = partkey->partopfamily;
-		context->partopcintype = partkey->partopcintype;
-		context->partcollation = partkey->partcollation;
-		context->partsupfunc = partkey->partsupfunc;
+
+		context->partopfamily = (Oid *) palloc(sizeof(Oid) * partnatts);
+		memcpy(context->partopfamily, partkey->partopfamily, sizeof(Oid) * partnatts);
+
+		context->partopcintype = (Oid *) palloc(sizeof(Oid) * partnatts);
+		memcpy(context->partopcintype, partkey->partopcintype, sizeof(Oid) * partnatts);
+
+		context->partcollation = (Oid *) palloc(sizeof(Oid) * partnatts);
+		memcpy(context->partcollation, partkey->partcollation, sizeof(Oid) * partnatts);
+
+		context->stepcmpfuncs = (FmgrInfo *) palloc0(sizeof(FmgrInfo) *
+													 n_steps *
+													 partnatts);
+
 		context->nparts = pinfo->nparts;
 		context->boundinfo = partition_bounds_copy(partdesc->boundinfo, partkey);
 		context->planstate = planstate;
 
 		/* Initialize expression state for each expression we need */
-		n_steps = list_length(pinfo->pruning_steps);
 		context->exprstates = (ExprState **)
 			palloc0(sizeof(ExprState *) * n_steps * partnatts);
 		foreach(lc2, pinfo->pruning_steps)
diff --git a/src/backend/partitioning/partprune.c b/src/backend/partitioning/partprune.c
index 856bdd3a14..44be58f7ec 100644
--- a/src/backend/partitioning/partprune.c
+++ b/src/backend/partitioning/partprune.c
@@ -441,7 +441,9 @@ prune_append_rel_partitions(RelOptInfo *rel)
 	context.partopfamily = rel->part_scheme->partopfamily;
 	context.partopcintype = rel->part_scheme->partopcintype;
 	context.partcollation = rel->part_scheme->partcollation;
-	context.partsupfunc = rel->part_scheme->partsupfunc;
+	context.stepcmpfuncs = (FmgrInfo *) palloc0(sizeof(FmgrInfo) *
+												context.partnatts *
+												list_length(pruning_steps));
 	context.nparts = rel->nparts;
 	context.boundinfo = rel->boundinfo;
 
@@ -2809,7 +2811,8 @@ perform_pruning_base_step(PartitionPruneContext *context,
 	int			keyno,
 				nvalues;
 	Datum		values[PARTITION_MAX_KEYS];
-	FmgrInfo	partsupfunc[PARTITION_MAX_KEYS];
+	FmgrInfo	*partsupfunc;
+	int			stateidx;
 
 	/*
 	 * There better be the same number of expressions and compare functions.
@@ -2844,7 +2847,6 @@ perform_pruning_base_step(PartitionPruneContext *context,
 		if (lc1 != NULL)
 		{
 			Expr	   *expr;
-			int			stateidx;
 			Datum		datum;
 			bool		isnull;
 
@@ -2873,19 +2875,12 @@ perform_pruning_base_step(PartitionPruneContext *context,
 					return result;
 				}
 
-				/*
-				 * If we're going to need a different comparison function than
-				 * the one cached in the PartitionKey, we'll need to look up
-				 * the FmgrInfo.
-				 */
 				cmpfn = lfirst_oid(lc2);
 				Assert(OidIsValid(cmpfn));
-				if (cmpfn != context->partsupfunc[keyno].fn_oid)
-					fmgr_info(cmpfn, &partsupfunc[keyno]);
-				else
-					fmgr_info_copy(&partsupfunc[keyno],
-								   &context->partsupfunc[keyno],
-								   CurrentMemoryContext);
+
+				/* Check if we've cached the FmgrInfo yet */
+				if (!OidIsValid(context->stepcmpfuncs[stateidx].fn_oid))
+					fmgr_info(cmpfn, &context->stepcmpfuncs[stateidx]);
 
 				values[keyno] = datum;
 				nvalues++;
@@ -2896,6 +2891,14 @@ perform_pruning_base_step(PartitionPruneContext *context,
 		}
 	}
 
+	/*
+	 * Determine the stateidx for the 0th key and point the partsupfunc to
+	 * that element. This provides the correct array segment for the
+	 * strategy matching function below.
+	 */
+	stateidx = PruneCxtStateIdx(context->partnatts, opstep->step.step_id, 0);
+	partsupfunc = &context->stepcmpfuncs[stateidx];
+
 	switch (context->strategy)
 	{
 		case PARTITION_STRATEGY_HASH:
diff --git a/src/include/partitioning/partprune.h b/src/include/partitioning/partprune.h
index e3b3bfb7c1..b275c9375e 100644
--- a/src/include/partitioning/partprune.h
+++ b/src/include/partitioning/partprune.h
@@ -18,51 +18,56 @@
 #include "nodes/relation.h"
 
 
-/*
+/*-----------------------
  * PartitionPruneContext
+ *		Stores information to allow partition pruning on a single partitioned
+ *		table.
  *
- * Information about a partitioned table needed to perform partition pruning.
+ * strategy			Partition strategy, e.g. LIST, RANGE, HASH.
+ * partnatts		Number of attributes and exprs that make up the partition
+ *					key.
+ * partopfamily		Array of partnatts elements storing the opfamily of the
+ *					corresponding partition key element.
+ * partopcintype	Array of partnatts elements storing the Oid of opclass
+ *					if the corresponding partition key element.
+ * partcollation	Array of partnatts elements storing the collation of the
+ *					corresponding partition key element.
+ * stepcmpfuncs		An array to store FmrgInfo for each pruning step partition
+ *					key pair. The array should be indexed by PruneCtxStateIdx.
+ * nparts			Number of partitions belonging to this partitioned table.
+ * boundinfo		PartitionBoundInfo for the partitioned table.
+ * planstate		Holds the executor's planstate when being called during
+ *					execution, or NULL when being called from the planner.
+ * exprstates		Array of ExprStates, indexed as per PruneCtxStateIdx; one
+ *					for each partkey in each pruning step.  Allocated if
+ *					planstate is non-NULL, otherwise NULL.
+ * exprhasexecparam	Array of bools, each true if corresponding 'exprstate'
+ *					expression contains any PARAM_EXEC Params.  (Can be NULL
+ *					if planstate is NULL.)
+ * evalexecparams	True if it's safe to evaluate PARAM_EXEC Params.
+ *-----------------------
  */
 typedef struct PartitionPruneContext
 {
-	/* Partition key information */
 	char		strategy;
 	int			partnatts;
 	Oid		   *partopfamily;
 	Oid		   *partopcintype;
 	Oid		   *partcollation;
-	FmgrInfo   *partsupfunc;
-
-	/* Number of partitions */
+	FmgrInfo   *stepcmpfuncs;
 	int			nparts;
-
-	/* Partition boundary info */
 	PartitionBoundInfo boundinfo;
-
-	/*
-	 * This will be set when the context is used from the executor, to allow
-	 * Params to be evaluated.
-	 */
 	PlanState  *planstate;
-
-	/*
-	 * Array of ExprStates, indexed as per PruneCtxStateIdx; one for each
-	 * partkey in each pruning step.  Allocated if planstate is non-NULL,
-	 * otherwise NULL.
-	 */
 	ExprState **exprstates;
-
-	/*
-	 * Similar array of flags, each true if corresponding 'exprstate'
-	 * expression contains any PARAM_EXEC Params.  (Can be NULL if planstate
-	 * is NULL.)
-	 */
 	bool	   *exprhasexecparam;
-
-	/* true if it's safe to evaluate PARAM_EXEC Params */
 	bool		evalexecparams;
 } PartitionPruneContext;
 
+/*
+ * Determine a unique index into a 2-dimentional array based on the 3 inputs,
+ * where partnatts is the maximum possible value for keyno.  Consecutive
+ * keynos are consecutive array elements.
+ */
 #define PruneCxtStateIdx(partnatts, step_id, keyno) \
 	((partnatts) * (step_id) + (keyno))
 
