On Tue, 2 Jul 2024 at 21:10, Joel Jacobson <[email protected]> wrote:
>
> I found the bug in the case 3 code,
> and it turns out the same type of bug also exists in the case 2 code:
>
> case 2:
> newdig = (int) var1digits[1] *
> var2digits[res_ndigits - 4];
>
> The problem here is that res_ndigits could become less than 4,
Yes. It can't be less than 3 though (per an earlier test), so the case
2 code was correct.
I've been hacking on this a bit and trying to tidy it up. Firstly, I
moved it to a separate function, because it was starting to look messy
having so much extra code in mul_var(). Then I added a bunch more
comments to explain what's going on, and the limits of the various
variables. Note that most of the boundary checks are actually
unnecessary -- in particular all the ones in or after the main loop,
provided you pull out the first 2 result digits from the main loop in
the 3-digit case. That does seem to work very well, but...
I wasn't entirely happy with how messy that code is getting, so I
tried a different approach. Similar to div_var_int(), I tried writing
a mul_var_int() function instead. This can be used for 1 and 2 digit
factors, and we could add a similar mul_var_int64() function on
platforms with 128-bit integers. The code looks quite a lot neater, so
it's probably less likely to contain bugs (though I have just written
it in a hurry,so it might still have bugs). In testing, it seemed to
give a decent speedup, but perhaps a little less than before. But
that's to be balanced against having more maintainable code, and also
a function that might be useful elsewhere in numeric.c.
Anyway, here are both patches for comparison. I'll stop hacking for a
while and let you see what you make of these.
Regards,
Dean
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 5510a20..81600b3
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va
static void mul_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale);
+static void mul_var_small(const NumericVar *var1, const NumericVar *var2,
+ NumericVar *result, int rscale);
static void div_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale, bool round);
@@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu
var1digits = var1->digits;
var2digits = var2->digits;
- if (var1ndigits == 0 || var2ndigits == 0)
+ if (var1ndigits == 0)
{
/* one or both inputs is zero; so is result */
zero_var(result);
@@ -8715,6 +8717,16 @@ mul_var(const NumericVar *var1, const Nu
return;
}
+ /*
+ * If var1 has 3 digits or fewer, delegate to mul_var_small() which uses a
+ * faster short multiplication algorithm.
+ */
+ if (var1ndigits <= 3)
+ {
+ mul_var_small(var1, var2, result, rscale);
+ return;
+ }
+
/* Determine result sign and (maximum possible) weight */
if (var1->sign == var2->sign)
res_sign = NUMERIC_POS;
@@ -8858,6 +8870,188 @@ mul_var(const NumericVar *var1, const Nu
result->sign = res_sign;
/* Round to target rscale (and set result->dscale) */
+ round_var(result, rscale);
+
+ /* Strip leading and trailing zeroes */
+ strip_var(result);
+}
+
+
+/*
+ * mul_var_small() -
+ *
+ * This has the same API as mul_var, but it assumes that var1 has no more
+ * than 3 digits and var2 has at least as many digits as var1. For variables
+ * satisfying these conditions, the product can be computed more quickly than
+ * the general algorithm used in mul_var.
+ */
+static void
+mul_var_small(const NumericVar *var1, const NumericVar *var2,
+ NumericVar *result, int rscale)
+{
+ int var1ndigits = var1->ndigits;
+ int var2ndigits = var2->ndigits;
+ NumericDigit *var1digits = var1->digits;
+ NumericDigit *var2digits = var2->digits;
+ int res_sign;
+ int res_weight;
+ int res_ndigits;
+ int maxdigits;
+ NumericDigit *res_buf;
+ NumericDigit *res_digits;
+ int carry;
+ int term;
+
+ /* Check preconditions */
+ Assert(var1ndigits <= 3);
+ Assert(var2ndigits >= var1ndigits);
+
+ /* Determine result sign and (maximum possible) weight */
+ if (var1->sign == var2->sign)
+ res_sign = NUMERIC_POS;
+ else
+ res_sign = NUMERIC_NEG;
+ res_weight = var1->weight + var2->weight + 2;
+
+ /* Determine the number of result digits to compute - see mul_var() */
+ res_ndigits = var1ndigits + var2ndigits + 1;
+ maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
+ MUL_GUARD_DIGITS;
+ res_ndigits = Min(res_ndigits, maxdigits);
+
+ if (res_ndigits < 3)
+ {
+ /* All input digits will be ignored; so result is zero */
+ zero_var(result);
+ result->dscale = rscale;
+ return;
+ }
+
+ /* Allocate result digit array */
+ res_buf = digitbuf_alloc(res_ndigits);
+ res_buf[0] = 0; /* spare digit for later rounding */
+ res_digits = res_buf + 1;
+
+ /*
+ * Compute the result digits in reverse, in one pass, propagating the
+ * carry up as we go.
+ *
+ * This computes res_digits[res_ndigits - 2], ... res_digits[0] by summing
+ * the products var1digits[i1] * var2digits[i2] for which i1 + i2 + 1 is
+ * the result index.
+ */
+ switch (var1ndigits)
+ {
+ case 1:
+ /* ---------
+ * 1-digit case:
+ * var1ndigits = 1
+ * var2ndigits >= 1
+ * 3 <= res_ndigits <= var2ndigits + 2
+ * ----------
+ */
+ carry = 0;
+ for (int i = res_ndigits - 3; i >= 0; i--)
+ {
+ term = (int) var1digits[0] * var2digits[i] + carry;
+ res_digits[i + 1] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+ }
+ res_digits[0] = (NumericDigit) carry;
+ break;
+
+ case 2:
+ /* ---------
+ * 2-digit case:
+ * var1ndigits = 2
+ * var2ndigits >= 2
+ * 3 <= res_ndigits <= var2ndigits + 3
+ * ----------
+ */
+ /* last result digit and carry */
+ term = 0;
+ if (res_ndigits - 3 < var2ndigits)
+ term += (int) var1digits[0] * var2digits[res_ndigits - 3];
+ if (res_ndigits > 3)
+ term += (int) var1digits[1] * var2digits[res_ndigits - 4];
+ res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+
+ /* remaining digits, except for the first two */
+ for (int i = res_ndigits - 4; i >= 1; i--)
+ {
+ term = (int) var1digits[0] * var2digits[i] +
+ (int) var1digits[1] * var2digits[i - 1] + carry;
+ res_digits[i + 1] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+ }
+
+ /* first two digits */
+ term = (int) var1digits[0] * var2digits[0] + carry;
+ res_digits[1] = (NumericDigit) (term % NBASE);
+ res_digits[0] = (NumericDigit) (term / NBASE);
+ break;
+
+ case 3:
+ /* ---------
+ * 3-digit case:
+ * var1ndigits = 3
+ * var2ndigits >= 3
+ * 3 <= res_ndigits <= var2ndigits + 4
+ * ----------
+ */
+ /* last result digit and carry */
+ term = 0;
+ if (res_ndigits - 3 < var2ndigits)
+ term += (int) var1digits[0] * var2digits[res_ndigits - 3];
+ if (res_ndigits > 3 && res_ndigits - 4 < var2ndigits)
+ term += (int) var1digits[1] * var2digits[res_ndigits - 4];
+ if (res_ndigits > 4)
+ term += (int) var1digits[2] * var2digits[res_ndigits - 5];
+ res_digits[res_ndigits - 2] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+
+ /* penultimate result digit */
+ term = carry;
+ if (res_ndigits > 3 && res_ndigits - 4 < var2ndigits)
+ term += (int) var1digits[0] * var2digits[res_ndigits - 4];
+ if (res_ndigits > 4)
+ term += (int) var1digits[1] * var2digits[res_ndigits - 5];
+ if (res_ndigits > 5)
+ term += (int) var1digits[2] * var2digits[res_ndigits - 6];
+ res_digits[res_ndigits - 3] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+
+ /* remaining digits, except for the first three */
+ for (int i = res_ndigits - 5; i >= 2; i--)
+ {
+ term = (int) var1digits[0] * var2digits[i] +
+ (int) var1digits[1] * var2digits[i - 1] +
+ (int) var1digits[2] * var2digits[i - 2] + carry;
+ res_digits[i + 1] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+ }
+
+ /* first three digits */
+ term = (int) var1digits[0] * var2digits[1] +
+ (int) var1digits[1] * var2digits[0] + carry;
+ res_digits[2] = (NumericDigit) (term % NBASE);
+ carry = term / NBASE;
+ term = (int) var1digits[0] * var2digits[0] + carry;
+ res_digits[1] = (NumericDigit) (term % NBASE);
+ res_digits[0] = (NumericDigit) (term / NBASE);
+ break;
+ }
+
+ /* Store the product in result (minus extra rounding digit) */
+ digitbuf_free(result->buf);
+ result->ndigits = res_ndigits - 1;
+ result->buf = res_buf;
+ result->digits = res_digits;
+ result->weight = res_weight - 1;
+ result->sign = res_sign;
+
+ /* Round to target rscale (and set result->dscale) */
round_var(result, rscale);
/* Strip leading and trailing zeroes */
diff --git a/src/backend/utils/adt/numeric.c b/src/backend/utils/adt/numeric.c
new file mode 100644
index 5510a20..9e50ea7
--- a/src/backend/utils/adt/numeric.c
+++ b/src/backend/utils/adt/numeric.c
@@ -551,6 +551,8 @@ static void sub_var(const NumericVar *va
static void mul_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale);
+static void mul_var_int(const NumericVar *var, int ival, int ival_weight,
+ NumericVar *result, int rscale);
static void div_var(const NumericVar *var1, const NumericVar *var2,
NumericVar *result,
int rscale, bool round);
@@ -8707,7 +8709,7 @@ mul_var(const NumericVar *var1, const Nu
var1digits = var1->digits;
var2digits = var2->digits;
- if (var1ndigits == 0 || var2ndigits == 0)
+ if (var1ndigits == 0)
{
/* one or both inputs is zero; so is result */
zero_var(result);
@@ -8715,6 +8717,31 @@ mul_var(const NumericVar *var1, const Nu
return;
}
+ /*
+ * If var1 has just one or two digits, delegate to mul_var_int(), which
+ * uses a faster direct multiplication algorithm.
+ *
+ * TODO: Similarly, on platforms with 128-bit integers ...
+ */
+ if (var1ndigits <= 2)
+ {
+ int ifactor;
+ int ifactor_weight;
+
+ ifactor = var1->digits[0];
+ ifactor_weight = var1->weight;
+ if (var1ndigits == 2)
+ {
+ ifactor = ifactor * NBASE + var1->digits[1];
+ ifactor_weight--;
+ }
+ if (var1->sign == NUMERIC_NEG)
+ ifactor = -ifactor;
+
+ mul_var_int(var2, ifactor, ifactor_weight, result, rscale);
+ return;
+ }
+
/* Determine result sign and (maximum possible) weight */
if (var1->sign == var2->sign)
res_sign = NUMERIC_POS;
@@ -8857,6 +8884,123 @@ mul_var(const NumericVar *var1, const Nu
result->weight = res_weight;
result->sign = res_sign;
+ /* Round to target rscale (and set result->dscale) */
+ round_var(result, rscale);
+
+ /* Strip leading and trailing zeroes */
+ strip_var(result);
+}
+
+
+/*
+ * mul_var_int() -
+ *
+ * Multiply a numeric variable by a 32-bit integer with the specified weight.
+ * The product var * ival * NBASE^ival_weight is stored in result.
+ */
+static void
+mul_var_int(const NumericVar *var, int ival, int ival_weight,
+ NumericVar *result, int rscale)
+{
+ NumericDigit *var_digits = var->digits;
+ int var_ndigits = var->ndigits;
+ int res_sign;
+ int res_weight;
+ int res_ndigits;
+ int maxdigits;
+ NumericDigit *res_buf;
+ NumericDigit *res_digits;
+ uint32 factor;
+ uint32 carry;
+
+ if (ival == 0 || var_ndigits == 0)
+ {
+ zero_var(result);
+ result->dscale = rscale;
+ return;
+ }
+
+ /*
+ * Determine the result sign, (maximum possible) weight and number of
+ * digits to calculate. The weight figured here is correct if the emitted
+ * product has no leading zero digits; otherwise strip_var() will fix
+ * things up.
+ */
+ if (var->sign == NUMERIC_POS)
+ res_sign = ival > 0 ? NUMERIC_POS : NUMERIC_NEG;
+ else
+ res_sign = ival > 0 ? NUMERIC_NEG : NUMERIC_POS;
+ res_weight = var->weight + ival_weight + 3;
+ /* The number of accurate result digits we need to produce: */
+ res_ndigits = var_ndigits + 3;
+ maxdigits = res_weight + 1 + (rscale + DEC_DIGITS - 1) / DEC_DIGITS +
+ MUL_GUARD_DIGITS;
+ res_ndigits = Min(res_ndigits, maxdigits);
+
+ if (res_ndigits < 3)
+ {
+ /* All input digits will be ignored; so result is zero */
+ zero_var(result);
+ result->dscale = rscale;
+ return;
+ }
+
+ res_buf = digitbuf_alloc(res_ndigits + 1);
+ res_buf[0] = 0; /* spare digit for later rounding */
+ res_digits = res_buf + 1;
+
+ /*
+ * Now compute the product digits by procssing the input digits in reverse
+ * and propagating the carry up as we go.
+ *
+ * In this algorithm, the carry from one digit to the next is at most
+ * factor - 1, and product is at most factor * NBASE - 1, and so it needs
+ * to be a 64-bit integer if this exceeds UINT_MAX.
+ */
+ factor = abs(ival);
+ carry = 0;
+
+ if (factor <= UINT_MAX / NBASE)
+ {
+ /* product cannot overflow 32 bits */
+ uint32 product;
+
+ for (int i = res_ndigits - 4; i >= 0; i--)
+ {
+ product = factor * var_digits[i] + carry;
+ res_digits[i + 3] = (NumericDigit) (product % NBASE);
+ carry = product / NBASE;
+ }
+ res_digits[2] = (NumericDigit) (carry % NBASE);
+ carry = carry / NBASE;
+ res_digits[1] = (NumericDigit) (carry % NBASE);
+ res_digits[0] = (NumericDigit) (carry / NBASE);
+ }
+ else
+ {
+ /* product may exceed 32 bits */
+ uint64 product;
+
+ for (int i = res_ndigits - 4; i >= 0; i--)
+ {
+ product = (uint64) factor * var_digits[i] + carry;
+ res_digits[i + 3] = (NumericDigit) (product % NBASE);
+ carry = (uint32) (product / NBASE);
+ }
+ res_digits[2] = (NumericDigit) (carry % NBASE);
+ carry = carry / NBASE;
+ res_digits[1] = (NumericDigit) (carry % NBASE);
+ res_digits[0] = (NumericDigit) (carry / NBASE);
+ }
+
+ /* Store the product in result */
+ digitbuf_free(result->buf);
+ result->ndigits = res_ndigits;
+ result->buf = res_buf;
+ result->digits = res_digits;
+ result->weight = res_weight;
+ result->sign = res_sign;
+
/* Round to target rscale (and set result->dscale) */
round_var(result, rscale);