This is an automated email from the ASF dual-hosted git repository. rameshkumar pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/hive.git
The following commit(s) were added to refs/heads/master by this push: new 10805bc997d HIVE-26683: Sum windowing function returns wrong value when all nulls. (#3800) 10805bc997d is described below commit 10805bc997d7cd136b85fca9200cf165ffe2eae5 Author: scarlin-cloudera <55709772+scarlin-cloud...@users.noreply.github.com> AuthorDate: Mon Dec 5 08:58:15 2022 -0800 HIVE-26683: Sum windowing function returns wrong value when all nulls. (#3800) * HIVE-26683: Sum windowing function returns wrong value when all nulls. The sum windowing function is returning an incorrect value when all the "following" rows are null. The correct value for sum when all the rows are null is "null". A new member variable had to be added to track for nulls. It uses the same algorithm that is used for sums. The sums are tracked by keeping a running sum across all the rows and subtracting off the running sum outside the window. Likewise, we keep track of a running non null row count for the current row and subtract the non null row count of the row that is leaving the window. * empty --- .../hadoop/hive/ql/udf/generic/GenericUDAFSum.java | 106 +++++++++++------- .../clientpositive/windowing_sum_following_null.q | 30 +++++ .../llap/windowing_sum_following_null.q.out | 124 +++++++++++++++++++++ 3 files changed, 220 insertions(+), 40 deletions(-) diff --git a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java index 6ce8734e8f0..40c7a7d7b5e 100644 --- a/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java +++ b/ql/src/java/org/apache/hadoop/hive/ql/udf/generic/GenericUDAFSum.java @@ -139,9 +139,17 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { */ public static abstract class GenericUDAFSumEvaluator<ResultType extends Writable> extends GenericUDAFEvaluator { static abstract class SumAgg<T> extends AbstractAggregationBuffer { - boolean empty; T sum; HashSet<ObjectInspectorObject> uniqueObjects; // Unique rows. + // HIVE-26683: Tracks the number of non null rows. If all the rows are null, then the sum of + // them is null. The count is needed for tracking in windowing frames. Windowing frames + // keep a running count of the sum and subtract off entries as the window moves. In order + // to process nulls within this same framework, we track the number of non null rows and + // also subtract off the number of entries as the window moves. If the current running count + // of non null rows is <n> and the number of non null rows in the entry leaving the window + // is also <n> then we know all the entries within the window are null and can return null + // for the sum. + long nonNullCount; } protected PrimitiveObjectInspector inputOI; @@ -267,9 +275,9 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public void reset(AggregationBuffer agg) throws HiveException { SumAgg<HiveDecimalWritable> bdAgg = (SumAgg<HiveDecimalWritable>) agg; - bdAgg.empty = true; bdAgg.sum = new HiveDecimalWritable(0); bdAgg.uniqueObjects = null; + bdAgg.nonNullCount = 0; } boolean warned = false; @@ -279,7 +287,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { assert (parameters.length == 1); try { if (isEligibleValue((SumHiveDecimalWritableAgg) agg, parameters[0])) { - ((SumHiveDecimalWritableAgg)agg).empty = false; + ((SumHiveDecimalWritableAgg)agg).nonNullCount++; ((SumHiveDecimalWritableAgg)agg).sum.mutateAdd( PrimitiveObjectInspectorUtils.getHiveDecimal(parameters[0], inputOI)); } @@ -303,12 +311,12 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { return; } - myagg.empty = false; if (isWindowingDistinct()) { throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial"); } else { // If partial is NULL, then there was an overflow and myagg.sum will be marked as not set. myagg.sum.mutateAdd(PrimitiveObjectInspectorUtils.getHiveDecimal(partial, inputOI)); + myagg.nonNullCount++; } } } @@ -316,7 +324,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public Object terminate(AggregationBuffer agg) throws HiveException { SumHiveDecimalWritableAgg myagg = (SumHiveDecimalWritableAgg) agg; - if (myagg.empty || myagg.sum == null || !myagg.sum.isSet()) { + if (myagg.nonNullCount == 0 || myagg.sum == null || !myagg.sum.isSet()) { return null; } DecimalTypeInfo decimalTypeInfo = (DecimalTypeInfo)outputOI.getTypeInfo(); @@ -337,29 +345,35 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { return null; } - return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, HiveDecimal>( + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, Object[]>( this, wFrameDef) { @Override protected HiveDecimalWritable getNextResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, HiveDecimal>.SumAvgStreamingState ss) + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumHiveDecimalWritableAgg myagg = (SumHiveDecimalWritableAgg) ss.wrappedBuf; - HiveDecimal r = myagg.empty ? null : myagg.sum.getHiveDecimal(); - HiveDecimal d = ss.retrieveNextIntermediateValue(); - if (d != null ) { + long nonNullCount = myagg.nonNullCount; + HiveDecimal r = nonNullCount == 0 ? null : myagg.sum.getHiveDecimal(); + Object[] o = ss.retrieveNextIntermediateValue(); + if (o != null) { + HiveDecimal d = (HiveDecimal) o[0]; r = r == null ? null : r.subtract(d); + // nonNullCount keeps track of the running count of non null rows. If the number of + // non null rows dropping out of the window frame is the same as the current number + // of non null rows, then the sum should be returned as null. + nonNullCount = nonNullCount - ((Long) o[1]); } - return r == null ? null : new HiveDecimalWritable(r); + return nonNullCount == 0 ? null : new HiveDecimalWritable(r); } @Override - protected HiveDecimal getCurrentIntermediateResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, HiveDecimal>.SumAvgStreamingState ss) + protected Object[] getCurrentIntermediateResult( + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<HiveDecimalWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumHiveDecimalWritableAgg myagg = (SumHiveDecimalWritableAgg) ss.wrappedBuf; - return myagg.empty ? null : myagg.sum.getHiveDecimal(); + return myagg.nonNullCount == 0 ? null : new Object[] { myagg.sum.getHiveDecimal(), myagg.nonNullCount}; } }; @@ -413,9 +427,9 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public void reset(AggregationBuffer agg) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) agg; - myagg.empty = true; myagg.sum = 0.0; myagg.uniqueObjects = null; + myagg.nonNullCount = 0; } boolean warned = false; @@ -425,7 +439,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { assert (parameters.length == 1); try { if (isEligibleValue((SumDoubleAgg) agg, parameters[0])) { - ((SumDoubleAgg)agg).empty = false; + ((SumDoubleAgg)agg).nonNullCount++; ((SumDoubleAgg)agg).sum += PrimitiveObjectInspectorUtils.getDouble(parameters[0], inputOI); } } catch (NumberFormatException e) { @@ -444,11 +458,11 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { SumDoubleAgg myagg = (SumDoubleAgg) agg; - myagg.empty = false; if (isWindowingDistinct()) { throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial"); } else { myagg.sum += PrimitiveObjectInspectorUtils.getDouble(partial, inputOI); + myagg.nonNullCount++; } } } @@ -456,7 +470,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public Object terminate(AggregationBuffer agg) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) agg; - if (myagg.empty) { + if (myagg.nonNullCount == 0) { return null; } result.set(myagg.sum); @@ -470,29 +484,35 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { return null; } - return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Double>(this, + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Object[]>(this, wFrameDef) { @Override protected DoubleWritable getNextResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Double>.SumAvgStreamingState ss) + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) ss.wrappedBuf; - Double r = myagg.empty ? null : myagg.sum; - Double d = ss.retrieveNextIntermediateValue(); - if (d != null) { + long nonNullCount = myagg.nonNullCount; + Double r = nonNullCount == 0 ? null : myagg.sum; + Object[] o = ss.retrieveNextIntermediateValue(); + if (o != null) { + Double d = (Double) o[0]; r = r == null ? null : r - d; + // nonNullCount keeps track of the running count of non null rows. If the number of + // non null rows dropping out of the window frame is the same as the current number + // of non null rows, then the sum should be returned as null. + nonNullCount = nonNullCount - ((Long) o[1]); } - return r == null ? null : new DoubleWritable(r); + return nonNullCount == 0 ? null : new DoubleWritable(r); } @Override - protected Double getCurrentIntermediateResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Double>.SumAvgStreamingState ss) + protected Object[] getCurrentIntermediateResult( + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<DoubleWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumDoubleAgg myagg = (SumDoubleAgg) ss.wrappedBuf; - return myagg.empty ? null : myagg.sum; + return myagg.nonNullCount == 0 ? null : new Object[] { myagg.sum, myagg.nonNullCount}; } }; @@ -545,9 +565,9 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public void reset(AggregationBuffer agg) throws HiveException { SumLongAgg myagg = (SumLongAgg) agg; - myagg.empty = true; myagg.sum = 0L; myagg.uniqueObjects = null; + myagg.nonNullCount = 0; } private boolean warned = false; @@ -557,7 +577,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { assert (parameters.length == 1); try { if (isEligibleValue((SumLongAgg) agg, parameters[0])) { - ((SumLongAgg)agg).empty = false; + ((SumLongAgg)agg).nonNullCount++; ((SumLongAgg)agg).sum += PrimitiveObjectInspectorUtils.getLong(parameters[0], inputOI); } } catch (NumberFormatException e) { @@ -573,11 +593,11 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { public void merge(AggregationBuffer agg, Object partial) throws HiveException { if (partial != null) { SumLongAgg myagg = (SumLongAgg) agg; - myagg.empty = false; if (isWindowingDistinct()) { throw new HiveException("Distinct windowing UDAF doesn't support merge and terminatePartial"); } else { myagg.sum += PrimitiveObjectInspectorUtils.getLong(partial, inputOI); + myagg.nonNullCount++; } } } @@ -585,7 +605,7 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { @Override public Object terminate(AggregationBuffer agg) throws HiveException { SumLongAgg myagg = (SumLongAgg) agg; - if (myagg.empty) { + if (myagg.nonNullCount == 0) { return null; } result.set(myagg.sum); @@ -599,29 +619,35 @@ public class GenericUDAFSum extends AbstractGenericUDAFResolver { return null; } - return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Long>(this, + return new GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Object[]>(this, wFrameDef) { @Override protected LongWritable getNextResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Long>.SumAvgStreamingState ss) + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumLongAgg myagg = (SumLongAgg) ss.wrappedBuf; - Long r = myagg.empty ? null : myagg.sum; - Long d = ss.retrieveNextIntermediateValue(); - if (d != null) { + long nonNullCount = myagg.nonNullCount; + Long r = nonNullCount == 0 ? null : myagg.sum; + Object[] o = ss.retrieveNextIntermediateValue(); + if (o != null) { + Long d = (Long) o[0]; r = r == null ? null : r - d; + // nonNullCount keeps track of the running count of non null rows. If the number of + // non null rows dropping out of the window frame is the same as the current number + // of non null rows, then the sum should be returned as null. + nonNullCount = nonNullCount - ((Long) o[1]); } - return r == null ? null : new LongWritable(r); + return nonNullCount == 0 ? null : new LongWritable(r); } @Override - protected Long getCurrentIntermediateResult( - org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Long>.SumAvgStreamingState ss) + protected Object[] getCurrentIntermediateResult( + org.apache.hadoop.hive.ql.udf.generic.GenericUDAFStreamingEvaluator.SumAvgEnhancer<LongWritable, Object[]>.SumAvgStreamingState ss) throws HiveException { SumLongAgg myagg = (SumLongAgg) ss.wrappedBuf; - return myagg.empty ? null : myagg.sum; + return myagg.nonNullCount == 0 ? null : new Object[] { myagg.sum, myagg.nonNullCount}; } }; } diff --git a/ql/src/test/queries/clientpositive/windowing_sum_following_null.q b/ql/src/test/queries/clientpositive/windowing_sum_following_null.q new file mode 100644 index 00000000000..501f969cfcc --- /dev/null +++ b/ql/src/test/queries/clientpositive/windowing_sum_following_null.q @@ -0,0 +1,30 @@ + +create table sum_window_test_small (id int, tinyint_col tinyint, double_col double, decimal_col decimal(12,2)); +insert into sum_window_test_small values (3, 17, 17.1, 17.1), (4, 14, 14.1, 14.1), (6, 18, 18.1, 18.1), + (7, 19, 19.1, 19.1), (8,NULL, NULL, NULL), (10, NULL, NULL, NULL), (11, 22, 22.0, 22.1); +select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id; + +-- check if it works with a null at the end +insert into sum_window_test_small values (12,NULL, NULL, NULL); + +select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id; + +-- check if it works with two nulls at the end +insert into sum_window_test_small values (13,NULL, NULL, NULL); + +select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id; diff --git a/ql/src/test/results/clientpositive/llap/windowing_sum_following_null.q.out b/ql/src/test/results/clientpositive/llap/windowing_sum_following_null.q.out new file mode 100644 index 00000000000..68534468f8f --- /dev/null +++ b/ql/src/test/results/clientpositive/llap/windowing_sum_following_null.q.out @@ -0,0 +1,124 @@ +PREHOOK: query: create table sum_window_test_small (id int, tinyint_col tinyint, double_col double, decimal_col decimal(12,2)) +PREHOOK: type: CREATETABLE +PREHOOK: Output: database:default +PREHOOK: Output: default@sum_window_test_small +POSTHOOK: query: create table sum_window_test_small (id int, tinyint_col tinyint, double_col double, decimal_col decimal(12,2)) +POSTHOOK: type: CREATETABLE +POSTHOOK: Output: database:default +POSTHOOK: Output: default@sum_window_test_small +PREHOOK: query: insert into sum_window_test_small values (3, 17, 17.1, 17.1), (4, 14, 14.1, 14.1), (6, 18, 18.1, 18.1), + (7, 19, 19.1, 19.1), (8,NULL, NULL, NULL), (10, NULL, NULL, NULL), (11, 22, 22.0, 22.1) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@sum_window_test_small +POSTHOOK: query: insert into sum_window_test_small values (3, 17, 17.1, 17.1), (4, 14, 14.1, 14.1), (6, 18, 18.1, 18.1), + (7, 19, 19.1, 19.1), (8,NULL, NULL, NULL), (10, NULL, NULL, NULL), (11, 22, 22.0, 22.1) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@sum_window_test_small +POSTHOOK: Lineage: sum_window_test_small.decimal_col SCRIPT [] +POSTHOOK: Lineage: sum_window_test_small.double_col SCRIPT [] +POSTHOOK: Lineage: sum_window_test_small.id SCRIPT [] +POSTHOOK: Lineage: sum_window_test_small.tinyint_col SCRIPT [] +PREHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +PREHOOK: type: QUERY +PREHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +POSTHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +3 17 32 32.2 32.20 +4 14 37 37.2 37.20 +6 18 19 19.1 19.10 +7 19 NULL NULL NULL +8 NULL 22 22.0 22.10 +10 NULL 22 22.0 22.10 +11 22 NULL NULL NULL +PREHOOK: query: insert into sum_window_test_small values (12,NULL, NULL, NULL) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@sum_window_test_small +POSTHOOK: query: insert into sum_window_test_small values (12,NULL, NULL, NULL) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@sum_window_test_small +POSTHOOK: Lineage: sum_window_test_small.decimal_col EXPRESSION [] +POSTHOOK: Lineage: sum_window_test_small.double_col EXPRESSION [] +POSTHOOK: Lineage: sum_window_test_small.id SCRIPT [] +POSTHOOK: Lineage: sum_window_test_small.tinyint_col EXPRESSION [] +PREHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +PREHOOK: type: QUERY +PREHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +POSTHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +3 17 32 32.2 32.20 +4 14 37 37.2 37.20 +6 18 19 19.1 19.10 +7 19 NULL NULL NULL +8 NULL 22 22.0 22.10 +10 NULL 22 22.0 22.10 +11 22 NULL NULL NULL +12 NULL NULL NULL NULL +PREHOOK: query: insert into sum_window_test_small values (13,NULL, NULL, NULL) +PREHOOK: type: QUERY +PREHOOK: Input: _dummy_database@_dummy_table +PREHOOK: Output: default@sum_window_test_small +POSTHOOK: query: insert into sum_window_test_small values (13,NULL, NULL, NULL) +POSTHOOK: type: QUERY +POSTHOOK: Input: _dummy_database@_dummy_table +POSTHOOK: Output: default@sum_window_test_small +POSTHOOK: Lineage: sum_window_test_small.decimal_col EXPRESSION [] +POSTHOOK: Lineage: sum_window_test_small.double_col EXPRESSION [] +POSTHOOK: Lineage: sum_window_test_small.id SCRIPT [] +POSTHOOK: Lineage: sum_window_test_small.tinyint_col EXPRESSION [] +PREHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +PREHOOK: type: QUERY +PREHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +POSTHOOK: query: select id, +tinyint_col, +sum(tinyint_col) over (order by id nulls last rows between 1 following and 2 following), +sum(double_col) over (order by id nulls last rows between 1 following and 2 following), +sum(decimal_col) over (order by id nulls last rows between 1 following and 2 following) +from sum_window_test_small order by id +POSTHOOK: type: QUERY +POSTHOOK: Input: default@sum_window_test_small +#### A masked pattern was here #### +3 17 32 32.2 32.20 +4 14 37 37.2 37.20 +6 18 19 19.1 19.10 +7 19 NULL NULL NULL +8 NULL 22 22.0 22.10 +10 NULL 22 22.0 22.10 +11 22 NULL NULL NULL +12 NULL NULL NULL NULL +13 NULL NULL NULL NULL