DRILL-1019: Handle multilpication overflow for decimal38 data type
Project: http://git-wip-us.apache.org/repos/asf/incubator-drill/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-drill/commit/20605067 Tree: http://git-wip-us.apache.org/repos/asf/incubator-drill/tree/20605067 Diff: http://git-wip-us.apache.org/repos/asf/incubator-drill/diff/20605067 Branch: refs/heads/master Commit: 2060506781a6b5f9c2dc5ce244ffb532fed556a8 Parents: da61823 Author: Mehant Baid <[email protected]> Authored: Wed Jun 18 02:13:51 2014 -0700 Committer: Jacques Nadeau <[email protected]> Committed: Fri Jun 20 10:56:16 2014 -0700 ---------------------------------------------------------------------- .../templates/Decimal/DecimalFunctions.java | 63 ++++++++++++++++++-- .../drill/jdbc/test/TestFunctionsQuery.java | 17 +++++- 2 files changed, 73 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/20605067/exec/java-exec/src/main/codegen/templates/Decimal/DecimalFunctions.java ---------------------------------------------------------------------- diff --git a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalFunctions.java b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalFunctions.java index 3f5b5cd..864c461 100644 --- a/exec/java-exec/src/main/codegen/templates/Decimal/DecimalFunctions.java +++ b/exec/java-exec/src/main/codegen/templates/Decimal/DecimalFunctions.java @@ -392,7 +392,6 @@ public class ${type.name}Functions { } public void eval() { - if (outputPrecision == Integer.MIN_VALUE) { org.apache.drill.common.util.DecimalScalePrecisionMulFunction resultScalePrec = new org.apache.drill.common.util.DecimalScalePrecisionMulFunction((int) left.precision, (int) left.scale, (int) right.precision, (int) right.scale); @@ -437,10 +436,6 @@ public class ${type.name}Functions { int resultIntegerSize = leftIntegerSize + rightIntegerSize; int resultScaleSize = org.apache.drill.common.util.DecimalUtility.roundUp(left.scale + right.scale); - if ((resultIntegerSize + resultScaleSize) > result.nDecimalDigits) { - throw new org.apache.drill.common.exceptions.DrillRuntimeException("Cannot fit multiplication result in the given Decimal type"); - } - int leftSize = left.nDecimalDigits - 1; int rightSize = right.nDecimalDigits - 1; @@ -475,6 +470,63 @@ public class ${type.name}Functions { resultIndex--; } + /* We have computed the result of the multiplication; check if we need to + * round a portion of the fractional part + */ + resultScaleSize = org.apache.drill.common.util.DecimalUtility.roundUp(result.scale); + + if (result.scale < (left.scale + right.scale)) { + /* The scale of the output data type is lesser than the scale + * we obtained as a result of multiplication, we need to round + * a chunk of the fractional part + */ + int lastScaleIndex = currentIndex + resultIntegerSize + resultScaleSize - 1; + + // Compute the power of 10 necessary to chop of the fractional part + int scaleFactor = (int) (org.apache.drill.common.util.DecimalUtility.getPowerOfTen( + org.apache.drill.common.util.DecimalUtility.MAX_DIGITS - (result.scale % org.apache.drill.common.util.DecimalUtility.MAX_DIGITS))); + + // compute the power of 10 necessary to find if we need to round up + int roundFactor = (int) (org.apache.drill.common.util.DecimalUtility.getPowerOfTen( + org.apache.drill.common.util.DecimalUtility.MAX_DIGITS - ((result.scale + 1) % org.apache.drill.common.util.DecimalUtility.MAX_DIGITS))); + + // index of rounding digit + int roundIndex = currentIndex + resultIntegerSize + org.apache.drill.common.util.DecimalUtility.roundUp(result.scale + 1) - 1; + + // Check the first chopped digit to see if we need to round up + int carry = ((tempResult[roundIndex] / roundFactor) % 10) > 4 ? 1 : 0; + + // Adjust the carry so that it gets added to the correct digit + carry *= scaleFactor; + + // Chop the unwanted fractional part + tempResult[lastScaleIndex] /= scaleFactor; + tempResult[lastScaleIndex] *= scaleFactor; + + // propogate the carry + while (carry > 0 && lastScaleIndex >= 0) { + int tempSum = tempResult[lastScaleIndex] + carry; + if (tempSum >= org.apache.drill.common.util.DecimalUtility.DIGITS_BASE) { + tempResult[lastScaleIndex] = (tempSum % org.apache.drill.common.util.DecimalUtility.DIGITS_BASE); + carry = (int) (tempSum / org.apache.drill.common.util.DecimalUtility.DIGITS_BASE); + } else { + tempResult[lastScaleIndex] = tempSum; + carry = 0; + } + lastScaleIndex--; + } + + // check if carry has increased integer digit + if ((lastScaleIndex + 1) < currentIndex) { + resultIntegerSize++; + currentIndex = lastScaleIndex + 1; + } + } + + if (resultIntegerSize > result.nDecimalDigits) { + throw new org.apache.drill.common.exceptions.DrillRuntimeException("Cannot fit multiplication result in the given decimal type"); + } + int outputIndex = result.nDecimalDigits - 1; for (int i = (currentIndex + resultIntegerSize + resultScaleSize - 1); i >= currentIndex; i--) { @@ -485,7 +537,6 @@ public class ${type.name}Functions { while(outputIndex >= 0) { result.setInteger(outputIndex--, 0); } - result.setSign(left.getSign() != right.getSign()); } } http://git-wip-us.apache.org/repos/asf/incubator-drill/blob/20605067/exec/jdbc/src/test/java/org/apache/drill/jdbc/test/TestFunctionsQuery.java ---------------------------------------------------------------------- diff --git a/exec/jdbc/src/test/java/org/apache/drill/jdbc/test/TestFunctionsQuery.java b/exec/jdbc/src/test/java/org/apache/drill/jdbc/test/TestFunctionsQuery.java index 64bdf6d..8660579 100644 --- a/exec/jdbc/src/test/java/org/apache/drill/jdbc/test/TestFunctionsQuery.java +++ b/exec/jdbc/src/test/java/org/apache/drill/jdbc/test/TestFunctionsQuery.java @@ -562,7 +562,7 @@ public class TestFunctionsQuery { @Test public void testDecimal18Decimal38Comparison() throws Exception { - String query = "select cast('999999999.999999999' as decimal(18, 9)) = cast('999999999.999999999' as decimal(38, 18)) as CMP " + + String query = "select cast('-999999999.999999999' as decimal(18, 9)) = cast('-999999999.999999999' as decimal(38, 18)) as CMP " + "from cp.`employee.json` where employee_id = 1"; JdbcAssert.withNoDefaultSchema() @@ -570,4 +570,19 @@ public class TestFunctionsQuery { .returns( "CMP=true\n"); } + + @Test + public void testDecimalMultiplicationOverflowHandling() throws Exception { + String query = "select cast('1' as decimal(9, 5)) * cast ('999999999999999999999999999.999999999' as decimal(38, 9)) as DEC38_1, " + + "cast('1000000000000000001.000000000000000000' as decimal(38, 18)) * cast('0.999999999999999999' as decimal(38, 18)) as DEC38_2, " + + "cast('3' as decimal(9, 8)) * cast ('333333333.3333333333333333333' as decimal(38, 19)) as DEC38_3 " + + "from cp.`employee.json` where employee_id = 1"; + + JdbcAssert.withNoDefaultSchema() + .sql(query) + .returns( + "DEC38_1=1000000000000000000000000000.00000; " + + "DEC38_2=1000000000000000000; " + + "DEC38_3=1000000000.000000000000000000\n"); + } }
