This is an automated email from the ASF dual-hosted git repository. yangjie01 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 5021638ee14 [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs 5021638ee14 is described below commit 5021638ee14758b92309942a1bcaed2b6554f810 Author: Zhen Li <zhenli...@users.noreply.github.com> AuthorDate: Wed Jun 7 14:30:42 2023 +0800 [SPARK-43717][CONNECT] Scala client reduce agg cannot handle null partitions for scala primitive inputs ### What changes were proposed in this pull request? Scala client fails with NPE when running the following reduce agg: ``` spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10 ``` The reason is because the `range` will produce null partitions and the Reduce encoder will not be able to set the default value correctly for partitions that contains Scala primitives. In the example, we expect 0 but receive null. This causes the codegen wrongly assumes the input is nullable and generates wrong code. ### Why are the changes needed? Bug fix ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Unit and Scala Client E2E tests. Closes #41264 from zhenlineo/fix-agg-null. Authored-by: Zhen Li <zhenli...@users.noreply.github.com> Signed-off-by: yangjie01 <yangji...@baidu.com> --- .../spark/sql/UserDefinedFunctionE2ETestSuite.scala | 20 ++++++++++++++++---- .../spark/sql/expressions/ReduceAggregator.scala | 13 ++++++++++++- .../sql/expressions/ReduceAggregatorSuite.scala | 10 ++++++++-- 3 files changed, 36 insertions(+), 7 deletions(-) diff --git a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala index b5bbee67803..ca1bcf3fe67 100644 --- a/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala +++ b/connector/connect/client/jvm/src/test/scala/org/apache/spark/sql/UserDefinedFunctionE2ETestSuite.scala @@ -198,18 +198,30 @@ class UserDefinedFunctionE2ETestSuite extends RemoteSparkSession { assert(sum.get() == 0) // The value is not 45 } - test("Dataset reduce") { + test("Dataset reduce without null partition inputs") { val session: SparkSession = spark import session.implicits._ - assert(spark.range(10).map(_ + 1).reduce(_ + _) == 55) + assert(spark.range(0, 10, 1, 5).map(_ + 1).reduce(_ + _) == 55) } - test("Dataset reduce - java") { + test("Dataset reduce with null partition inputs") { + val session: SparkSession = spark + import session.implicits._ + assert(spark.range(0, 10, 1, 16).map(_ + 1).reduce(_ + _) == 55) + } + + test("Dataset reduce with null partition inputs - java to scala long type") { + val session: SparkSession = spark + import session.implicits._ + assert(spark.range(0, 5, 1, 10).as[Long].reduce(_ + _) == 10) + } + + test("Dataset reduce with null partition inputs - java") { val session: SparkSession = spark import session.implicits._ assert( spark - .range(10) + .range(0, 10, 1, 16) .map(_ + 1) .reduce(new ReduceFunction[Long] { override def call(v1: Long, v2: Long): Long = v1 + v2 diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala index 41306cd0a99..e897fdfe008 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/ReduceAggregator.scala @@ -32,7 +32,18 @@ private[sql] class ReduceAggregator[T: Encoder](func: (T, T) => T) @transient private val encoder = implicitly[Encoder[T]] - override def zero: (Boolean, T) = (false, null.asInstanceOf[T]) + private val _zero = encoder.clsTag.runtimeClass match { + case java.lang.Boolean.TYPE => false + case java.lang.Byte.TYPE => 0.toByte + case java.lang.Short.TYPE => 0.toShort + case java.lang.Integer.TYPE => 0 + case java.lang.Long.TYPE => 0L + case java.lang.Float.TYPE => 0f + case java.lang.Double.TYPE => 0d + case _ => null + } + + override def zero: (Boolean, T) = (false, _zero.asInstanceOf[T]) override def bufferEncoder: Encoder[(Boolean, T)] = ExpressionEncoder.tuple( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala index f65dcdf119c..c1071373287 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/expressions/ReduceAggregatorSuite.scala @@ -24,10 +24,16 @@ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder class ReduceAggregatorSuite extends SparkFunSuite { test("zero value") { - val encoder: ExpressionEncoder[Int] = ExpressionEncoder() val func = (v1: Int, v2: Int) => v1 + v2 val aggregator: ReduceAggregator[Int] = new ReduceAggregator(func)(Encoders.scalaInt) - assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, Int)]) + assert(aggregator.zero == (false, 0)) + } + + test("zero value boxed null") { + val func = (v1: java.lang.Integer, v2: java.lang.Integer) => + (v1 + v2).asInstanceOf[java.lang.Integer] + val aggregator: ReduceAggregator[java.lang.Integer] = new ReduceAggregator(func)(Encoders.INT) + assert(aggregator.zero == (false, null).asInstanceOf[(Boolean, java.lang.Integer)]) } test("reduce, merge and finish") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org