From b075e18c2ae94f2ada20572c05ff35730d2c5968 Mon Sep 17 00:00:00 2001
From: Joel Jakobsson <github@compiler.org>
Date: Mon, 1 Jul 2024 07:17:50 +0200
Subject: [PATCH] Optimize mul_var() for var2ndigits <= 4

---
 src/backend/utils/adt/numeric.c | 457 ++++++++++++++++++++++++++++++++
 src/include/catalog/pg_proc.dat |   3 +
 src/include/utils/numeric.h     |   2 +
 3 files changed, 462 insertions(+)

diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
index 5510a203b0..b74e03eefa 100644
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,9 @@ static void sub_var(const NumericVar *var1, const NumericVar *var2,
 static void mul_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale);
+static void mul_var_patched(const NumericVar *var1, const NumericVar *var2,
+					NumericVar *result,
+					int rscale);
 static void div_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale, bool round);
@@ -3115,6 +3118,130 @@ numeric_mul_opt_error(Numeric num1, Numeric num2, bool *have_error)
 }
 
 
+/*
+ * numeric_mul_patched() -
+ *
+ * This function multiplies two numeric values using the patched algorithm,
+ * designed for efficient handling of large numbers. It's introduced to allow
+ * direct benchmark comparisons with the standard numeric_mul() function.
+ */
+Datum
+numeric_mul_patched(PG_FUNCTION_ARGS)
+{
+	Numeric		num1 = PG_GETARG_NUMERIC(0);
+	Numeric		num2 = PG_GETARG_NUMERIC(1);
+	Numeric		res;
+
+	res = numeric_mul_patched_opt_error(num1, num2, NULL);
+
+	PG_RETURN_NUMERIC(res);
+}
+
+
+/*
+ * numeric_mul_patched_opt_error() -
+ *
+ *	Internal version of numeric_mul_patched().
+ *	If "*have_error" flag is provided, on error it's set to true, NULL returned.
+ *	This is helpful when caller need to handle errors by itself.
+ */
+Numeric
+numeric_mul_patched_opt_error(Numeric num1, Numeric num2, bool *have_error)
+{
+	NumericVar	arg1;
+	NumericVar	arg2;
+	NumericVar	result;
+	Numeric		res;
+
+	/*
+	 * Handle NaN and infinities
+	 */
+	if (NUMERIC_IS_SPECIAL(num1) || NUMERIC_IS_SPECIAL(num2))
+	{
+		if (NUMERIC_IS_NAN(num1) || NUMERIC_IS_NAN(num2))
+			return make_result(&const_nan);
+		if (NUMERIC_IS_PINF(num1))
+		{
+			switch (numeric_sign_internal(num2))
+			{
+				case 0:
+					return make_result(&const_nan); /* Inf * 0 */
+				case 1:
+					return make_result(&const_pinf);
+				case -1:
+					return make_result(&const_ninf);
+			}
+			Assert(false);
+		}
+		if (NUMERIC_IS_NINF(num1))
+		{
+			switch (numeric_sign_internal(num2))
+			{
+				case 0:
+					return make_result(&const_nan); /* -Inf * 0 */
+				case 1:
+					return make_result(&const_ninf);
+				case -1:
+					return make_result(&const_pinf);
+			}
+			Assert(false);
+		}
+		/* by here, num1 must be finite, so num2 is not */
+		if (NUMERIC_IS_PINF(num2))
+		{
+			switch (numeric_sign_internal(num1))
+			{
+				case 0:
+					return make_result(&const_nan); /* 0 * Inf */
+				case 1:
+					return make_result(&const_pinf);
+				case -1:
+					return make_result(&const_ninf);
+			}
+			Assert(false);
+		}
+		Assert(NUMERIC_IS_NINF(num2));
+		switch (numeric_sign_internal(num1))
+		{
+			case 0:
+				return make_result(&const_nan); /* 0 * -Inf */
+			case 1:
+				return make_result(&const_ninf);
+			case -1:
+				return make_result(&const_pinf);
+		}
+		Assert(false);
+	}
+
+	/*
+	 * Unpack the values, let mul_var() compute the result and return it.
+	 * Unlike add_var() and sub_var(), mul_var() will round its result. In the
+	 * case of numeric_mul(), which is invoked for the * operator on numerics,
+	 * we request exact representation for the product (rscale = sum(dscale of
+	 * arg1, dscale of arg2)).  If the exact result has more digits after the
+	 * decimal point than can be stored in a numeric, we round it.  Rounding
+	 * after computing the exact result ensures that the final result is
+	 * correctly rounded (rounding in mul_var() using a truncated product
+	 * would not guarantee this).
+	 */
+	init_var_from_num(num1, &arg1);
+	init_var_from_num(num2, &arg2);
+
+	init_var(&result);
+
+	mul_var_patched(&arg1, &arg2, &result, arg1.dscale + arg2.dscale);
+
+	if (result.dscale > NUMERIC_DSCALE_MAX)
+		round_var(&result, NUMERIC_DSCALE_MAX);
+
+	res = make_result_opt_error(&result, have_error);
+
+	free_var(&result);
+
+	return res;
+}
+
+
 /*
  * numeric_div() -
  *
@@ -8864,6 +8991,336 @@ mul_var(const NumericVar *var1, const NumericVar *var2, NumericVar *result,
 	strip_var(result);
 }
 
+/*
+ * mul_var_patched() -
+ *
+ * Implements patched multiplication for large numbers, introduced
+ * alongside the unchanged original mul_var(). This function is part of
+ * an optimization effort, allowing direct benchmark comparisons with
+ * mul_var(). It selects full or half patched based on input size.
+ * This is a temporary measure before considering its replacement of
+ * mul_var() based on benchmark outcomes.
+ */
+static void
+mul_var_patched(const NumericVar *var1, const NumericVar *var2,
+		NumericVar *result, int rscale)
+{
+	int			res_ndigits;
+	int			res_sign;
+	int			res_weight;
+	int			maxdigits;
+	int		   *dig;
+	int			carry;
+	int			maxdig;
+	int			newdig;
+	int			var1ndigits;
+	int			var2ndigits;
+	NumericDigit *var1digits;
+	NumericDigit *var2digits;
+	NumericDigit *res_digits;
+	int			i,
+				i1,
+				i2;
+
+	/*
+	 * Arrange for var1 to be the shorter of the two numbers.  This improves
+	 * performance because the inner multiplication loop is much simpler than
+	 * the outer loop, so it's better to have a smaller number of iterations
+	 * of the outer loop.  This also reduces the number of times that the
+	 * accumulator array needs to be normalized.
+	 */
+	if (var1->ndigits > var2->ndigits)
+	{
+		const NumericVar *tmp = var1;
+
+		var1 = var2;
+		var2 = tmp;
+	}
+
+	/* copy these values into local vars for speed in inner loop */
+	var1ndigits = var1->ndigits;
+	var2ndigits = var2->ndigits;
+	var1digits = var1->digits;
+	var2digits = var2->digits;
+
+	if (var1ndigits == 0 || var2ndigits == 0)
+	{
+		/* one or both inputs is zero; so is result */
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/* Determine result sign and (maximum possible) weight */
+	if (var1->sign == var2->sign)
+		res_sign = NUMERIC_POS;
+	else
+		res_sign = NUMERIC_NEG;
+	res_weight = var1->weight + var2->weight + 2;
+
+	/*
+	 * Determine the number of result digits to compute.  If the exact result
+	 * would have more than rscale fractional digits, truncate the computation
+	 * with MUL_GUARD_DIGITS guard digits, i.e., ignore input digits that
+	 * would only contribute to the right of that.  (This will give the exact
+	 * rounded-to-rscale answer unless carries out of the ignored positions
+	 * would have propagated through more than MUL_GUARD_DIGITS digits.)
+	 *
+	 * Note: an exact computation could not produce more than var1ndigits +
+	 * var2ndigits digits, but we allocate one extra output digit in case
+	 * rscale-driven rounding produces a carry out of the highest exact digit.
+	 */
+	res_ndigits = var1ndigits + var2ndigits + 1;
+	maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
+		MUL_GUARD_DIGITS;
+	res_ndigits = Min(res_ndigits, maxdigits);
+
+	if (res_ndigits < 3)
+	{
+		/* All input digits will be ignored; so result is zero */
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/*
+	 * If var1 and var2 are just one or two NBASE digits, their product will
+	 * fit in an int64 can be computed directly, which is significantly faster.
+	 */
+	if (var2ndigits <= 2)
+	{
+		int64		product = 0;
+		switch (var1ndigits)
+		{
+			case 1:
+				product = var1digits[0];
+				break;
+			case 2:
+				product = var1digits[0] * NBASE + var1digits[1];
+				break;
+		}
+
+		switch (var2ndigits)
+		{
+			case 1:
+				product *= var2digits[0];
+				break;
+			case 2:
+				product *= var2digits[0] * NBASE + var2digits[1];
+				break;
+		}
+
+		alloc_var(result, res_ndigits);
+		res_digits = result->digits;
+		for (i = res_ndigits - 1; i >= 0; i--)
+		{
+			res_digits[i] = product % NBASE;
+			product /= NBASE;
+		}
+		Assert(product == 0);
+
+		/*
+		 * Finally, round the result to the requested precision.
+		 */
+		result->weight = res_weight;
+		result->sign = res_sign;
+
+		/* Round to target rscale (and set result->dscale) */
+		round_var(result, rscale);
+
+		/* Strip leading and trailing zeroes */
+		strip_var(result);
+
+		return;
+	}
+#ifdef HAVE_INT128
+	/*
+	 * If var1 and var2 are up to four digits, their product will fit in
+	 * an int128 can be computed directly, which is significantly faster.
+	 */
+	if (var2ndigits <= 4)
+	{
+		int128		product = 0;
+
+		switch (var1ndigits)
+		{
+			case 1:
+				product = var1digits[0];
+				break;
+			case 2:
+				product = var1digits[0] * NBASE + var1digits[1];
+				break;
+			case 3:
+				product = ((int128) var1digits[0] * NBASE + var1digits[1])
+						* NBASE + var1digits[2];
+				break;
+			case 4:
+				product = (((int128) var1digits[0] * NBASE + var1digits[1])
+						* NBASE + var1digits[2]) * NBASE + var1digits[3];
+				break;
+		}
+
+		switch (var2ndigits)
+		{
+			case 1:
+				product *= var2digits[0];
+				break;
+			case 2:
+				product *= var2digits[0] * NBASE + var2digits[1];
+				break;
+			case 3:
+				product *= ((int128) var2digits[0] * NBASE + var2digits[1])
+						* NBASE + var2digits[2];
+				break;
+			case 4:
+				product *= (((int128) var2digits[0] * NBASE + var2digits[1])
+						* NBASE + var2digits[2]) * NBASE + var2digits[3];
+				break;
+		}
+
+		alloc_var(result, res_ndigits);
+		res_digits = result->digits;
+		for (i = res_ndigits - 1; i >= 0; i--)
+		{
+			res_digits[i] = product % NBASE;
+			product /= NBASE;
+		}
+		Assert(product == 0);
+
+		/*
+		 * Finally, round the result to the requested precision.
+		 */
+		result->weight = res_weight;
+		result->sign = res_sign;
+
+		/* Round to target rscale (and set result->dscale) */
+		round_var(result, rscale);
+
+		/* Strip leading and trailing zeroes */
+		strip_var(result);
+
+		return;
+	}
+#endif
+
+	/*
+	 * We do the arithmetic in an array "dig[]" of signed int's.  Since
+	 * INT_MAX is noticeably larger than NBASE*NBASE, this gives us headroom
+	 * to avoid normalizing carries immediately.
+	 *
+	 * maxdig tracks the maximum possible value of any dig[] entry; when this
+	 * threatens to exceed INT_MAX, we take the time to propagate carries.
+	 * Furthermore, we need to ensure that overflow doesn't occur during the
+	 * carry propagation passes either.  The carry values could be as much as
+	 * INT_MAX/NBASE, so really we must normalize when digits threaten to
+	 * exceed INT_MAX - INT_MAX/NBASE.
+	 *
+	 * To avoid overflow in maxdig itself, it actually represents the max
+	 * possible value divided by NBASE-1, ie, at the top of the loop it is
+	 * known that no dig[] entry exceeds maxdig * (NBASE-1).
+	 */
+	dig = (int *) palloc0(res_ndigits * sizeof(int));
+	maxdig = 0;
+
+	/*
+	 * The least significant digits of var1 should be ignored if they don't
+	 * contribute directly to the first res_ndigits digits of the result that
+	 * we are computing.
+	 *
+	 * Digit i1 of var1 and digit i2 of var2 are multiplied and added to digit
+	 * i1+i2+2 of the accumulator array, so we need only consider digits of
+	 * var1 for which i1 <= res_ndigits - 3.
+	 */
+	for (i1 = Min(var1ndigits - 1, res_ndigits - 3); i1 >= 0; i1--)
+	{
+		NumericDigit var1digit = var1digits[i1];
+
+		if (var1digit == 0)
+			continue;
+
+		/* Time to normalize? */
+		maxdig += var1digit;
+		if (maxdig > (INT_MAX - INT_MAX / NBASE) / (NBASE - 1))
+		{
+			/* Yes, do it */
+			carry = 0;
+			for (i = res_ndigits - 1; i >= 0; i--)
+			{
+				newdig = dig[i] + carry;
+				if (newdig >= NBASE)
+				{
+					carry = newdig / NBASE;
+					newdig -= carry * NBASE;
+				}
+				else
+					carry = 0;
+				dig[i] = newdig;
+			}
+			Assert(carry == 0);
+			/* Reset maxdig to indicate new worst-case */
+			maxdig = 1 + var1digit;
+		}
+
+		/*
+		 * Add the appropriate multiple of var2 into the accumulator.
+		 *
+		 * As above, digits of var2 can be ignored if they don't contribute,
+		 * so we only include digits for which i1+i2+2 < res_ndigits.
+		 *
+		 * This inner loop is the performance bottleneck for multiplication,
+		 * so we want to keep it simple enough so that it can be
+		 * auto-vectorized.  Accordingly, process the digits left-to-right
+		 * even though schoolbook multiplication would suggest right-to-left.
+		 * Since we aren't propagating carries in this loop, the order does
+		 * not matter.
+		 */
+		{
+			int			i2limit = Min(var2ndigits, res_ndigits - i1 - 2);
+			int		   *dig_i1_2 = &dig[i1 + 2];
+
+			for (i2 = 0; i2 < i2limit; i2++)
+				dig_i1_2[i2] += var1digit * var2digits[i2];
+		}
+	}
+
+	/*
+	 * Now we do a final carry propagation pass to normalize the result, which
+	 * we combine with storing the result digits into the output. Note that
+	 * this is still done at full precision w/guard digits.
+	 */
+	alloc_var(result, res_ndigits);
+	res_digits = result->digits;
+	carry = 0;
+	for (i = res_ndigits - 1; i >= 0; i--)
+	{
+		newdig = dig[i] + carry;
+		if (newdig >= NBASE)
+		{
+			carry = newdig / NBASE;
+			newdig -= carry * NBASE;
+		}
+		else
+			carry = 0;
+		res_digits[i] = newdig;
+	}
+	Assert(carry == 0);
+
+	pfree(dig);
+
+	/*
+	 * Finally, round the result to the requested precision.
+	 */
+	result->weight = res_weight;
+	result->sign = res_sign;
+
+	/* Round to target rscale (and set result->dscale) */
+	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+
+}
+
 
 /*
  * div_var() -
diff --git a/src/include/catalog/pg_proc.dat b/src/include/catalog/pg_proc.dat
index 6a5476d3c4..15dde0f3c3 100644
--- a/src/include/catalog/pg_proc.dat
+++ b/src/include/catalog/pg_proc.dat
@@ -4465,6 +4465,9 @@
 { oid => '1726',
   proname => 'numeric_mul', prorettype => 'numeric',
   proargtypes => 'numeric numeric', prosrc => 'numeric_mul' },
+{ oid => '6347',
+  proname => 'numeric_mul_patched', prorettype => 'numeric',
+  proargtypes => 'numeric numeric', prosrc => 'numeric_mul_patched' },
 { oid => '1727',
   proname => 'numeric_div', prorettype => 'numeric',
   proargtypes => 'numeric numeric', prosrc => 'numeric_div' },
diff --git a/src/include/utils/numeric.h b/src/include/utils/numeric.h
index 43c75c436f..9036c9db50 100644
--- a/src/include/utils/numeric.h
+++ b/src/include/utils/numeric.h
@@ -97,6 +97,8 @@ extern Numeric numeric_sub_opt_error(Numeric num1, Numeric num2,
 									 bool *have_error);
 extern Numeric numeric_mul_opt_error(Numeric num1, Numeric num2,
 									 bool *have_error);
+extern Numeric numeric_mul_patched_opt_error(Numeric num1, Numeric num2,
+									 bool *have_error);
 extern Numeric numeric_div_opt_error(Numeric num1, Numeric num2,
 									 bool *have_error);
 extern Numeric numeric_mod_opt_error(Numeric num1, Numeric num2,
-- 
2.45.1

