This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new 30e49c2  [SPARK-32159][SQL] Fix integration between 
Aggregator[Array[_], _, _] and UnresolvedMapObjects
30e49c2 is described below

commit 30e49c23249213a70266583566db93e113a9e23e
Author: Erik Erlandson <eerla...@redhat.com>
AuthorDate: Thu Jul 9 08:42:20 2020 +0000

    [SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _] and 
UnresolvedMapObjects
    
    Context: The fix for SPARK-27296 introduced by #25024 allows `Aggregator` 
objects to appear in queries. This works fine for aggregators with atomic input 
types, e.g. `Aggregator[Double, _, _]`.
    
    However it can cause a null pointer exception if the input type is 
`Array[_]`.  This was historically considered an ignorable case for 
serialization of `UnresolvedMapObjects`, but the new ScalaAggregator class 
causes these expressions to be serialized over to executors because the 
resolve-and-bind is being deferred.
    
    ### What changes were proposed in this pull request?
    A new rule `ResolveEncodersInScalaAgg` that performs the resolution of the 
expressions contained in the encoders so that properly resolved expressions are 
serialized over to executors.
    
    ### Why are the changes needed?
    Applying an aggregator of the form `Aggregator[Array[_], _, _]` using 
`functions.udaf()` currently causes a null pointer error in Catalyst.
    
    ### Does this PR introduce _any_ user-facing change?
    No.
    
    ### How was this patch tested?
    A unit test has been added that does aggregation with array types for 
input, buffer, and output. I have done additional testing with my own custom 
aggregators in the spark REPL.
    
    Closes #28983 from erikerlandson/fix-spark-32159.
    
    Authored-by: Erik Erlandson <eerla...@redhat.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
    (cherry picked from commit 1cb5bfc47a2b4fff824433f8cecabfbac7e050b6)
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../sql/catalyst/expressions/objects/objects.scala |  7 ++++
 .../spark/sql/execution/aggregate/udaf.scala       | 27 ++++++++++---
 .../sql/expressions/UserDefinedFunction.scala      |  3 +-
 .../sql/internal/BaseSessionStateBuilder.scala     |  2 +
 .../spark/sql/hive/HiveSessionStateBuilder.scala   |  2 +
 .../spark/sql/hive/execution/UDAQuerySuite.scala   | 47 +++++++++++++++-------
 6 files changed, 68 insertions(+), 20 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d5de95c..ab2f66b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -678,6 +678,13 @@ object MapObjects {
       elementType: DataType,
       elementNullable: Boolean = true,
       customCollectionCls: Option[Class[_]] = None): MapObjects = {
+    // UnresolvedMapObjects does not serialize its 'function' field.
+    // If an array expression or array Encoder is not correctly resolved before
+    // serialization, this exception condition may occur.
+    require(function != null,
+      "MapObjects applied with a null function. " +
+      "Likely cause is failure to resolve an array expression or encoder. " +
+      "(See UnresolvedMapObjects)")
     val loopVar = LambdaVariable("MapObject", elementType, elementNullable)
     MapObjects(loopVar, function(loopVar), inputData, customCollectionCls)
   }
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
index 544b90a..44bc9c2 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala
@@ -27,6 +27,8 @@ import 
org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, 
Complete}
 import 
org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, 
TypedImperativeAggregate}
 import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, 
GenerateSafeProjection}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, 
UserDefinedAggregateFunction}
 import org.apache.spark.sql.types._
 
@@ -458,7 +460,8 @@ case class ScalaUDAF(
 case class ScalaAggregator[IN, BUF, OUT](
     children: Seq[Expression],
     agg: Aggregator[IN, BUF, OUT],
-    inputEncoderNR: ExpressionEncoder[IN],
+    inputEncoder: ExpressionEncoder[IN],
+    bufferEncoder: ExpressionEncoder[BUF],
     nullable: Boolean = true,
     isDeterministic: Boolean = true,
     mutableAggBufferOffset: Int = 0,
@@ -469,9 +472,8 @@ case class ScalaAggregator[IN, BUF, OUT](
   with ImplicitCastInputTypes
   with Logging {
 
-  private[this] lazy val inputDeserializer = 
inputEncoderNR.resolveAndBind().createDeserializer()
-  private[this] lazy val bufferEncoder =
-    agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind()
+  // input and buffer encoders are resolved by ResolveEncodersInScalaAgg
+  private[this] lazy val inputDeserializer = inputEncoder.createDeserializer()
   private[this] lazy val bufferSerializer = bufferEncoder.createSerializer()
   private[this] lazy val bufferDeserializer = 
bufferEncoder.createDeserializer()
   private[this] lazy val outputEncoder = 
agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]]
@@ -479,7 +481,7 @@ case class ScalaAggregator[IN, BUF, OUT](
 
   def dataType: DataType = outputEncoder.objSerializer.dataType
 
-  def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType)
+  def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType)
 
   override lazy val deterministic: Boolean = isDeterministic
 
@@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT](
 
   override def nodeName: String = agg.getClass.getSimpleName
 }
+
+/**
+ * An extension rule to resolve encoder expressions from a [[ScalaAggregator]]
+ */
+object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] {
+  override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp 
{
+    case p if !p.resolved => p
+    case p => p.transformExpressionsUp {
+      case agg: ScalaAggregator[_, _, _] =>
+        agg.copy(
+          inputEncoder = agg.inputEncoder.resolveAndBind(),
+          bufferEncoder = agg.bufferEncoder.resolveAndBind())
+    }
+  }
+}
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
index 2ef6e3d..6a20a46 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala
@@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT](
   // This is also used by udf.register(...) when it detects a 
UserDefinedAggregator
   def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = 
{
     val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]]
-    ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic)
+    val bEncoder = 
aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]]
+    ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, 
deterministic)
   }
 
   override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = {
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
index 3bbdbb0..4ae12f8 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.connector.catalog.CatalogManager
 import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, 
SparkOptimizer, SparkPlanner, SparkSqlParser}
+import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
 import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
 import org.apache.spark.sql.execution.command.CommandCheck
 import org.apache.spark.sql.execution.datasources._
@@ -175,6 +176,7 @@ abstract class BaseSessionStateBuilder(
       new FindDataSourceTable(session) +:
         new ResolveSQLOnFile(session) +:
         new FallBackFileSourceV2(session) +:
+        ResolveEncodersInScalaAgg +:
         new ResolveSessionCatalog(
           catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
         customResolutionRules
diff --git 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
index 6472675..e256107 100644
--- 
a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
+++ 
b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala
@@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.catalyst.rules.Rule
 import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner}
+import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg
 import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin
 import org.apache.spark.sql.execution.command.CommandCheck
 import org.apache.spark.sql.execution.datasources._
@@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, 
parentState: Option[Session
         new FindDataSourceTable(session) +:
         new ResolveSQLOnFile(session) +:
         new FallBackFileSourceV2(session) +:
+        ResolveEncodersInScalaAgg +:
         new ResolveSessionCatalog(
           catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +:
         customResolutionRules
diff --git 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
index e6856a5..1f1a556 100644
--- 
a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
+++ 
b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala
@@ -119,6 +119,27 @@ object CountSerDeAgg extends Aggregator[Int, 
CountSerDeSQL, CountSerDeSQL] {
   def outputEncoder: Encoder[CountSerDeSQL] = 
ExpressionEncoder[CountSerDeSQL]()
 }
 
+object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], 
Array[Double]] {
+  def zero: Array[Double] = Array(0.0, 0.0, 0.0)
+  def reduce(s: Array[Double], array: Array[Double]): Array[Double] = {
+    require(s.length == array.length)
+    for ( j <- 0 until s.length ) {
+      s(j) += array(j)
+    }
+    s
+  }
+  def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = {
+    require(s1.length == s2.length)
+    for ( j <- 0 until s1.length ) {
+      s1(j) += s2(j)
+    }
+    s1
+  }
+  def finish(s: Array[Double]): Array[Double] = s
+  def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
+  def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]]
+}
+
 abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with 
TestHiveSingleton {
   import testImplicits._
 
@@ -156,20 +177,11 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
       (3, null, null)).toDF("key", "value1", "value2")
     data2.write.saveAsTable("agg2")
 
-    val data3 = Seq[(Seq[Integer], Integer, Integer)](
-      (Seq[Integer](1, 1), 10, -10),
-      (Seq[Integer](null), -60, 60),
-      (Seq[Integer](1, 1), 30, -30),
-      (Seq[Integer](1), 30, 30),
-      (Seq[Integer](2), 1, 1),
-      (null, -10, 10),
-      (Seq[Integer](2, 3), -1, null),
-      (Seq[Integer](2, 3), 1, 1),
-      (Seq[Integer](2, 3, 4), null, 1),
-      (Seq[Integer](null), 100, -10),
-      (Seq[Integer](3), null, 3),
-      (null, null, null),
-      (Seq[Integer](3), null, null)).toDF("key", "value1", "value2")
+    val data3 = Seq[(Seq[Double], Int)](
+      (Seq(1.0, 2.0, 3.0), 0),
+      (Seq(4.0, 5.0, 6.0), 0),
+      (Seq(7.0, 8.0, 9.0), 0)
+    ).toDF("data", "dummy")
     data3.write.saveAsTable("agg3")
 
     val data4 = Seq[Boolean](true, false, true).toDF("boolvalues")
@@ -184,6 +196,7 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
     spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg))
     spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg))
     spark.udf.register("longProductSum", udaf(LongProductSumAgg))
+    spark.udf.register("arraysum", udaf(ArrayDataAgg))
   }
 
   override def afterAll(): Unit = {
@@ -354,6 +367,12 @@ abstract class UDAQuerySuite extends QueryTest with 
SQLTestUtils with TestHiveSi
         Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil)
   }
 
+  test("SPARK-32159: array encoders should be resolved in analyzer") {
+    checkAnswer(
+      spark.sql("SELECT arraysum(data) FROM agg3"),
+      Row(Seq(12.0, 15.0, 18.0)) :: Nil)
+  }
+
   test("verify aggregator ser/de behavior") {
     val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1")
     val agg = udaf(CountSerDeAgg)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to