On Wed, 3 Jul 2024 at 21:45, Joel Jacobson <j...@compiler.org> wrote:
>
> > On Wed, Jul 3, 2024, at 20:57, Dean Rasheed wrote:
> >> I wouldn't expect it to ever be off by more than 1
> >
> > OK, so then the cases I found where it was off by 2 for the mul_var_int() 
> > patch
> > are unexpected?
>
> Sorry, I meant off by 2 for the mul_var_small() patch, these cases that I 
> found:
>

Yeah, so that was another bug in mul_var_small(). If rscale is made
small enough, the result index for the digits computed before the main
loop overlaps the ones after, so it would overwrite digits already
computed.

Of course, that's fairly easy to fix, but at this point I think the
better solution is to only use mul_var_small() when an exact product
is requested. We would have to do that for mul_var_int() anyway,
because of its accuracy issues discussed earlier. I think this is a
reasonable thing to do because only functions like ln_var() and
exp_var() will ask mul_var() for a reduced-rscale result, and those
functions are likely to be dominated by computations involving larger
numbers, for which this patch wouldn't help anyway. Also those
functions are probably less widely used.

If we make that decision, a lot of the complexity in mul_var_small()
goes away, including all the conditional array accesses, making it
simpler and more efficient. v6 patch attached.

I also updated the mul_var_int() patch so that it is also only invoked
when an exact product is requested, and I noticed a couple of other
minor optimisations that could be made. Then I decided to try
implementing mul_var_int64(). This gives a pretty decent speedup for
3-digit inputs, but unfortunately it is much slower for 4-digit inputs
(for which most values will go through the 128-bit code path). I'm
attaching that too, just for information, but it's clearly not going
to be acceptable as-is.

Running your benchmark queries, I got these results:

SELECT SUM(var1*var2) FROM bench_mul_var_var1ndigits_1;
Time: 4520.874 ms (00:04.521)  -- HEAD
Time: 3937.536 ms (00:03.938)  -- v5-mul_var_int.patch
Time: 3919.495 ms (00:03.919)  -- v5-mul_var_small.patch
Time: 3916.964 ms (00:03.917)  -- v6-mul_var_int64.patch
Time: 3811.118 ms (00:03.811)  -- v6-mul_var_small.patch

SELECT SUM(var1*var2) FROM bench_mul_var_var1ndigits_2;
Time: 4762.528 ms (00:04.763)  -- HEAD
Time: 4075.546 ms (00:04.076)  -- v5-mul_var_int.patch
Time: 4055.180 ms (00:04.055)  -- v5-mul_var_small.patch
Time: 4037.866 ms (00:04.038)  -- v6-mul_var_int64.patch
Time: 4018.488 ms (00:04.018)  -- v6-mul_var_small.patch

SELECT SUM(var1*var2) FROM bench_mul_var_var1ndigits_3;
Time: 5387.514 ms (00:05.388)  -- HEAD
Time: 5350.736 ms (00:05.351)  -- v5-mul_var_int.patch
Time: 4648.449 ms (00:04.648)  -- v5-mul_var_small.patch
Time: 4655.204 ms (00:04.655)  -- v6-mul_var_int64.patch
Time: 4645.962 ms (00:04.646)  -- v6-mul_var_small.patch

SELECT SUM(var1*var2) FROM bench_mul_var_var1ndigits_4;
Time: 5617.150 ms (00:05.617)  -- HEAD
Time: 5505.913 ms (00:05.506)  -- v5-mul_var_int.patch
Time: 5486.441 ms (00:05.486)  -- v5-mul_var_small.patch
Time: 8203.081 ms (00:08.203)  -- v6-mul_var_int64.patch
Time: 5598.909 ms (00:05.599)  -- v6-mul_var_small.patch

So v6-mul_var_int64 improves on v5-mul_var_int in the 3-digit case,
but is terrible in the 4-digit case. None of the other patches touch
the 4-digit case, but it might be interesting to try mul_var_small()
with 4 digits.

Regards,
Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 5510a20..bae07f2
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va
 static void mul_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale);
+static void mul_var_small(const NumericVar *var1, const NumericVar *var2,
+						  NumericVar *result, int rscale);
 static void div_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale, bool round);
@@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu
 	var1digits = var1->digits;
 	var2digits = var2->digits;
 
-	if (var1ndigits == 0 || var2ndigits == 0)
+	if (var1ndigits == 0)
 	{
 		/* one or both inputs is zero; so is result */
 		zero_var(result);
@@ -8715,6 +8717,17 @@ mul_var(const NumericVar *var1, const Nu
 		return;
 	}
 
+	/*
+	 * If var1 has 3 digits or fewer, and we are computing the exact result,
+	 * with no rounding, delegate to mul_var_small() which uses a faster short
+	 * multiplication algorithm.
+	 */
+	if (var1ndigits <= 3 && rscale == var1->dscale + var2->dscale)
+	{
+		mul_var_small(var1, var2, result, rscale);
+		return;
+	}
+
 	/* Determine result sign and (maximum possible) weight */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
@@ -8858,6 +8871,168 @@ mul_var(const NumericVar *var1, const Nu
 	result->sign = res_sign;
 
 	/* Round to target rscale (and set result->dscale) */
+	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+}
+
+
+/*
+ * mul_var_small() -
+ *
+ *	This has the same API as mul_var, but it assumes that var1 has no more
+ *	than 3 digits and var2 has at least as many digits as var1.  For variables
+ *	satisfying these conditions, the product can be computed more quickly than
+ *	the general algorithm used in mul_var.
+ */
+static void
+mul_var_small(const NumericVar *var1, const NumericVar *var2,
+			  NumericVar *result, int rscale)
+{
+	int			var1ndigits = var1->ndigits;
+	int			var2ndigits = var2->ndigits;
+	NumericDigit *var1digits = var1->digits;
+	NumericDigit *var2digits = var2->digits;
+	int			res_sign;
+	int			res_weight;
+	int			res_ndigits;
+	NumericDigit *res_buf;
+	NumericDigit *res_digits;
+	uint32		carry;
+	uint32		term;
+
+	/* Check preconditions */
+	Assert(var1ndigits <= 3);
+	Assert(var2ndigits >= var1ndigits);
+
+	/* Determine result sign and (maximum possible) weight */
+	if (var1->sign == var2->sign)
+		res_sign = NUMERIC_POS;
+	else
+		res_sign = NUMERIC_NEG;
+	res_weight = var1->weight + var2->weight + 2;
+
+	/* Determine the number of result digits to compute - cf. mul_var() */
+	res_ndigits = var1ndigits + var2ndigits + 1;
+
+	if (res_ndigits < 3)
+	{
+		/* All input digits will be ignored; so result is zero */
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/* Allocate result digit array */
+	res_buf = digitbuf_alloc(res_ndigits);
+	res_buf[0] = 0;				/* spare digit for later rounding */
+	res_digits = res_buf + 1;
+
+	/*
+	 * Compute the result digits in reverse, in one pass, propagating the
+	 * carry up as we go.
+	 *
+	 * This computes res_digits[res_ndigits - 2], ... res_digits[0] by summing
+	 * the products var1digits[i1] * var2digits[i2] for which i1 + i2 + 1 is
+	 * the result index.
+	 */
+	switch (var1ndigits)
+	{
+		case 1:
+			/* ---------
+			 * 1-digit case:
+			 *		var1ndigits = 1
+			 *		var2ndigits >= 1
+			 *		res_ndigits = var2ndigits + 2
+			 * ----------
+			 */
+			carry = 0;
+			for (int i = res_ndigits - 3; i >= 0; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+			res_digits[0] = (NumericDigit) carry;
+			break;
+
+		case 2:
+			/* ---------
+			 * 2-digit case:
+			 *		var1ndigits = 2
+			 *		var2ndigits >= 2
+			 *		res_ndigits = var2ndigits + 3
+			 * ----------
+			 */
+			/* last result digit and carry */
+			term = (uint32) var1digits[1] * var2digits[res_ndigits - 4];
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first two */
+			for (int i = res_ndigits - 4; i >= 1; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] +
+					(uint32) var1digits[1] * var2digits[i - 1] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+
+			/* first two digits */
+			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			res_digits[1] = (NumericDigit) (term % NBASE);
+			res_digits[0] = (NumericDigit) (term / NBASE);
+			break;
+
+		case 3:
+			/* ---------
+			 * 3-digit case:
+			 *		var1ndigits = 3
+			 *		var2ndigits >= 3
+			 *		res_ndigits = var2ndigits + 4
+			 * ----------
+			 */
+			/* last two result digits */
+			term = (uint32) var1digits[2] * var2digits[res_ndigits - 5];
+			res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			term = (uint32) var1digits[1] * var2digits[res_ndigits - 5] +
+				(uint32) var1digits[2] * var2digits[res_ndigits - 6] + carry;
+			res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+
+			/* remaining digits, except for the first three */
+			for (int i = res_ndigits - 5; i >= 2; i--)
+			{
+				term = (uint32) var1digits[0] * var2digits[i] +
+					(uint32) var1digits[1] * var2digits[i - 1] +
+					(uint32) var1digits[2] * var2digits[i - 2] + carry;
+				res_digits[i + 1] = (NumericDigit) (term % NBASE);
+				carry = term / NBASE;
+			}
+
+			/* first three digits */
+			term = (uint32) var1digits[0] * var2digits[1] +
+				(uint32) var1digits[1] * var2digits[0] + carry;
+			res_digits[2] = (NumericDigit) (term % NBASE);
+			carry = term / NBASE;
+			term = (uint32) var1digits[0] * var2digits[0] + carry;
+			res_digits[1] = (NumericDigit) (term % NBASE);
+			res_digits[0] = (NumericDigit) (term / NBASE);
+			break;
+	}
+
+	/* Store the product in result (minus extra rounding digit) */
+	digitbuf_free(result->buf);
+	result->ndigits = res_ndigits - 1;
+	result->buf = res_buf;
+	result->digits = res_digits;
+	result->weight = res_weight - 1;
+	result->sign = res_sign;
+
+	/* Round to target rscale (and set result->dscale) */
 	round_var(result, rscale);
 
 	/* Strip leading and trailing zeroes */
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 5510a20..96456d8
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,12 @@ static void sub_var(const NumericVar *va
 static void mul_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale);
+static void mul_var_int(const NumericVar *var, int ival, int ival_weight,
+						NumericVar *result, int rscale);
+#ifdef HAVE_INT128
+static void mul_var_int64(const NumericVar *var, int64 ival, int ival_weight,
+						  NumericVar *result, int rscale);
+#endif
 static void div_var(const NumericVar *var1, const NumericVar *var2,
 					NumericVar *result,
 					int rscale, bool round);
@@ -8707,7 +8713,7 @@ mul_var(const NumericVar *var1, const Nu
 	var1digits = var1->digits;
 	var2digits = var2->digits;
 
-	if (var1ndigits == 0 || var2ndigits == 0)
+	if (var1ndigits == 0)
 	{
 		/* one or both inputs is zero; so is result */
 		zero_var(result);
@@ -8715,6 +8721,54 @@ mul_var(const NumericVar *var1, const Nu
 		return;
 	}
 
+	/*
+	 * If var1 has just one or two digits, and we are computing the exact
+	 * result, with no rounding, delegate to mul_var_int(), which uses a
+	 * faster direct multiplication algorithm.
+	 *
+	 * Similarly, on platforms with 128-bit integer support, delegate to
+	 * mul_var_int64() if var1 has three or four digits, and we are computing
+	 * the exact result, with no rounding.
+	 */
+	if (var1ndigits <= 2 && rscale == var1->dscale + var2->dscale)
+	{
+		int			ifactor;
+		int			ifactor_weight;
+
+		ifactor = var1->digits[0];
+		ifactor_weight = var1->weight;
+		if (var1ndigits == 2)
+		{
+			ifactor = ifactor * NBASE + var1->digits[1];
+			ifactor_weight--;
+		}
+		if (var1->sign == NUMERIC_NEG)
+			ifactor = -ifactor;
+
+		mul_var_int(var2, ifactor, ifactor_weight, result, rscale);
+		return;
+	}
+#ifdef HAVE_INT128
+	if (var1ndigits <= 4 && rscale == var1->dscale + var2->dscale)
+	{
+		int64		ifactor;
+		int			ifactor_weight;
+
+		ifactor = var1->digits[0];
+		ifactor_weight = var1->weight;
+		for (i = 1; i < var1ndigits; i++)
+		{
+			ifactor = ifactor * NBASE + var1->digits[i];
+			ifactor_weight--;
+		}
+		if (var1->sign == NUMERIC_NEG)
+			ifactor = -ifactor;
+
+		mul_var_int64(var2, ifactor, ifactor_weight, result, rscale);
+		return;
+	}
+#endif
+
 	/* Determine result sign and (maximum possible) weight */
 	if (var1->sign == var2->sign)
 		res_sign = NUMERIC_POS;
@@ -8866,6 +8920,233 @@ mul_var(const NumericVar *var1, const Nu
 
 
 /*
+ * mul_var_int() -
+ *
+ *	Multiply a numeric variable by a 32-bit integer with the specified weight.
+ *	The product var * ival * NBASE^ival_weight is stored in result.
+ */
+static void
+mul_var_int(const NumericVar *var, int ival, int ival_weight,
+			NumericVar *result, int rscale)
+{
+	NumericDigit *var_digits = var->digits;
+	int			var_ndigits = var->ndigits;
+	int			res_sign;
+	int			res_weight;
+	int			res_ndigits;
+	NumericDigit *res_buf;
+	NumericDigit *res_digits;
+	uint32		factor;
+	uint32		carry;
+
+	if (ival == 0 || var_ndigits == 0)
+	{
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/*
+	 * Determine the result sign, (maximum possible) weight and number of
+	 * digits to calculate.  The weight figured here is correct if the emitted
+	 * product has no leading zero digits; otherwise strip_var() will fix
+	 * things up.
+	 */
+	if (var->sign == NUMERIC_POS)
+		res_sign = ival > 0 ? NUMERIC_POS : NUMERIC_NEG;
+	else
+		res_sign = ival > 0 ? NUMERIC_NEG : NUMERIC_POS;
+	res_weight = var->weight + ival_weight + 3;
+	/* The number of accurate result digits we need to produce: */
+	res_ndigits = var_ndigits + 3;
+
+	res_buf = digitbuf_alloc(res_ndigits + 1);
+	res_buf[0] = 0;				/* spare digit for later rounding */
+	res_digits = res_buf + 1;
+
+	/*
+	 * Now compute the product digits by processing the input digits in
+	 * reverse and propagating the carry up as we go.
+	 *
+	 * In this algorithm, the carry from one digit to the next is at most
+	 * factor - 1, and product is at most factor * NBASE - 1, and so it needs
+	 * to be a 64-bit integer if this exceeds UINT_MAX.
+	 */
+	factor = abs(ival);
+	carry = 0;
+
+	if (factor <= UINT_MAX / NBASE)
+	{
+		/* product cannot overflow 32 bits */
+		uint32		product;
+
+		for (int i = res_ndigits - 4; i >= 0; i--)
+		{
+			product = factor * var_digits[i] + carry;
+			res_digits[i + 3] = (NumericDigit) (product % NBASE);
+			carry = product / NBASE;
+		}
+		/* note: carry < UINT_MAX / NBASE in this branch */
+		res_digits[2] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		Assert(carry < NBASE);
+		res_digits[1] = (NumericDigit) carry;
+		res_digits[0] = 0;
+	}
+	else
+	{
+		/* product may exceed 32 bits */
+		uint64		product;
+
+		for (int i = res_ndigits - 4; i >= 0; i--)
+		{
+			product = (uint64) factor * var_digits[i] + carry;
+			res_digits[i + 3] = (NumericDigit) (product % NBASE);
+			carry = (uint32) (product / NBASE);
+		}
+		res_digits[2] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[1] = (NumericDigit) (carry % NBASE);
+		res_digits[0] = (NumericDigit) (carry / NBASE);
+	}
+
+	/* Store the product in result */
+	digitbuf_free(result->buf);
+	result->ndigits = res_ndigits;
+	result->buf = res_buf;
+	result->digits = res_digits;
+	result->weight = res_weight;
+	result->sign = res_sign;
+
+	/* Round to target rscale (and set result->dscale) */
+	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+}
+
+
+#ifdef HAVE_INT128
+/*
+ * mul_var_int64() -
+ *
+ *	Multiply a numeric variable by a 64-bit integer with the specified weight.
+ *	The product var * ival * NBASE^ival_weight is stored in result.
+ *
+ *	This duplicates the logic in mul_var_int(), so any changes made there
+ *	should be made here too.
+ */
+static void
+mul_var_int64(const NumericVar *var, int64 ival, int ival_weight,
+			  NumericVar *result, int rscale)
+{
+	NumericDigit *var_digits = var->digits;
+	int			var_ndigits = var->ndigits;
+	int			res_sign;
+	int			res_weight;
+	int			res_ndigits;
+	NumericDigit *res_buf;
+	NumericDigit *res_digits;
+	uint64		factor;
+	uint64		carry;
+
+	if (ival == 0 || var_ndigits == 0)
+	{
+		zero_var(result);
+		result->dscale = rscale;
+		return;
+	}
+
+	/*
+	 * Determine the result sign, (maximum possible) weight and number of
+	 * digits to calculate.  The weight figured here is correct if the emitted
+	 * product has no leading zero digits; otherwise strip_var() will fix
+	 * things up.
+	 */
+	if (var->sign == NUMERIC_POS)
+		res_sign = ival > 0 ? NUMERIC_POS : NUMERIC_NEG;
+	else
+		res_sign = ival > 0 ? NUMERIC_NEG : NUMERIC_POS;
+	res_weight = var->weight + ival_weight + 5;
+	/* The number of accurate result digits we need to produce: */
+	res_ndigits = var_ndigits + 5;
+
+	res_buf = digitbuf_alloc(res_ndigits + 1);
+	res_buf[0] = 0;				/* spare digit for later rounding */
+	res_digits = res_buf + 1;
+
+	/*
+	 * Now compute the product digits by processing the input digits in
+	 * reverse and propagating the carry up as we go.
+	 *
+	 * In this algorithm, the carry from one digit to the next is at most
+	 * factor - 1, and product is at most factor * NBASE - 1, and so it needs
+	 * to be a 128-bit integer if this exceeds PG_UINT64_MAX.
+	 */
+	factor = i64abs(ival);
+	carry = 0;
+
+	if (factor <= PG_UINT64_MAX / NBASE)
+	{
+		/* product cannot overflow 64 bits */
+		uint64		product;
+
+		for (int i = res_ndigits - 6; i >= 0; i--)
+		{
+			product = factor * var_digits[i] + carry;
+			res_digits[i + 5] = (NumericDigit) (product % NBASE);
+			carry = product / NBASE;
+		}
+		/* note: carry < PG_UINT64_MAX / NBASE in this branch */
+		res_digits[4] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[3] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[2] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		Assert(carry < NBASE);
+		res_digits[1] = (NumericDigit) carry;
+		res_digits[0] = 0;
+	}
+	else
+	{
+		/* product may exceed 64 bits */
+		uint128		product;
+
+		for (int i = res_ndigits - 6; i >= 0; i--)
+		{
+			product = (uint128) factor * var_digits[i] + carry;
+			res_digits[i + 5] = (NumericDigit) (product % NBASE);
+			carry = (uint64) (product / NBASE);
+		}
+		res_digits[4] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[3] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[2] = (NumericDigit) (carry % NBASE);
+		carry = carry / NBASE;
+		res_digits[1] = (NumericDigit) (carry % NBASE);
+		res_digits[0] = (NumericDigit) (carry / NBASE);
+	}
+
+	/* Store the product in result */
+	digitbuf_free(result->buf);
+	result->ndigits = res_ndigits;
+	result->buf = res_buf;
+	result->digits = res_digits;
+	result->weight = res_weight;
+	result->sign = res_sign;
+
+	/* Round to target rscale (and set result->dscale) */
+	round_var(result, rscale);
+
+	/* Strip leading and trailing zeroes */
+	strip_var(result);
+}
+#endif
+
+
+/*
  * div_var() -
  *
  *	Division on variable level. Quotient of var1 / var2 is stored in result.

Reply via email to