This is an automated email from the ASF dual-hosted git repository. wenchen pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new 30e49c2 [SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _] and UnresolvedMapObjects 30e49c2 is described below commit 30e49c23249213a70266583566db93e113a9e23e Author: Erik Erlandson <eerla...@redhat.com> AuthorDate: Thu Jul 9 08:42:20 2020 +0000 [SPARK-32159][SQL] Fix integration between Aggregator[Array[_], _, _] and UnresolvedMapObjects Context: The fix for SPARK-27296 introduced by #25024 allows `Aggregator` objects to appear in queries. This works fine for aggregators with atomic input types, e.g. `Aggregator[Double, _, _]`. However it can cause a null pointer exception if the input type is `Array[_]`. This was historically considered an ignorable case for serialization of `UnresolvedMapObjects`, but the new ScalaAggregator class causes these expressions to be serialized over to executors because the resolve-and-bind is being deferred. ### What changes were proposed in this pull request? A new rule `ResolveEncodersInScalaAgg` that performs the resolution of the expressions contained in the encoders so that properly resolved expressions are serialized over to executors. ### Why are the changes needed? Applying an aggregator of the form `Aggregator[Array[_], _, _]` using `functions.udaf()` currently causes a null pointer error in Catalyst. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? A unit test has been added that does aggregation with array types for input, buffer, and output. I have done additional testing with my own custom aggregators in the spark REPL. Closes #28983 from erikerlandson/fix-spark-32159. Authored-by: Erik Erlandson <eerla...@redhat.com> Signed-off-by: Wenchen Fan <wenc...@databricks.com> (cherry picked from commit 1cb5bfc47a2b4fff824433f8cecabfbac7e050b6) Signed-off-by: Wenchen Fan <wenc...@databricks.com> --- .../sql/catalyst/expressions/objects/objects.scala | 7 ++++ .../spark/sql/execution/aggregate/udaf.scala | 27 ++++++++++--- .../sql/expressions/UserDefinedFunction.scala | 3 +- .../sql/internal/BaseSessionStateBuilder.scala | 2 + .../spark/sql/hive/HiveSessionStateBuilder.scala | 2 + .../spark/sql/hive/execution/UDAQuerySuite.scala | 47 +++++++++++++++------- 6 files changed, 68 insertions(+), 20 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index d5de95c..ab2f66b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -678,6 +678,13 @@ object MapObjects { elementType: DataType, elementNullable: Boolean = true, customCollectionCls: Option[Class[_]] = None): MapObjects = { + // UnresolvedMapObjects does not serialize its 'function' field. + // If an array expression or array Encoder is not correctly resolved before + // serialization, this exception condition may occur. + require(function != null, + "MapObjects applied with a null function. " + + "Likely cause is failure to resolve an array expression or encoder. " + + "(See UnresolvedMapObjects)") val loopVar = LambdaVariable("MapObject", elementType, elementNullable) MapObjects(loopVar, function(loopVar), inputData, customCollectionCls) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index 544b90a..44bc9c2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -27,6 +27,8 @@ import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete} import org.apache.spark.sql.catalyst.expressions.aggregate.{ImperativeAggregate, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.{GenerateMutableProjection, GenerateSafeProjection} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.expressions.{Aggregator, MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types._ @@ -458,7 +460,8 @@ case class ScalaUDAF( case class ScalaAggregator[IN, BUF, OUT]( children: Seq[Expression], agg: Aggregator[IN, BUF, OUT], - inputEncoderNR: ExpressionEncoder[IN], + inputEncoder: ExpressionEncoder[IN], + bufferEncoder: ExpressionEncoder[BUF], nullable: Boolean = true, isDeterministic: Boolean = true, mutableAggBufferOffset: Int = 0, @@ -469,9 +472,8 @@ case class ScalaAggregator[IN, BUF, OUT]( with ImplicitCastInputTypes with Logging { - private[this] lazy val inputDeserializer = inputEncoderNR.resolveAndBind().createDeserializer() - private[this] lazy val bufferEncoder = - agg.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]].resolveAndBind() + // input and buffer encoders are resolved by ResolveEncodersInScalaAgg + private[this] lazy val inputDeserializer = inputEncoder.createDeserializer() private[this] lazy val bufferSerializer = bufferEncoder.createSerializer() private[this] lazy val bufferDeserializer = bufferEncoder.createDeserializer() private[this] lazy val outputEncoder = agg.outputEncoder.asInstanceOf[ExpressionEncoder[OUT]] @@ -479,7 +481,7 @@ case class ScalaAggregator[IN, BUF, OUT]( def dataType: DataType = outputEncoder.objSerializer.dataType - def inputTypes: Seq[DataType] = inputEncoderNR.schema.map(_.dataType) + def inputTypes: Seq[DataType] = inputEncoder.schema.map(_.dataType) override lazy val deterministic: Boolean = isDeterministic @@ -517,3 +519,18 @@ case class ScalaAggregator[IN, BUF, OUT]( override def nodeName: String = agg.getClass.getSimpleName } + +/** + * An extension rule to resolve encoder expressions from a [[ScalaAggregator]] + */ +object ResolveEncodersInScalaAgg extends Rule[LogicalPlan] { + override def apply(plan: LogicalPlan): LogicalPlan = plan.resolveOperatorsUp { + case p if !p.resolved => p + case p => p.transformExpressionsUp { + case agg: ScalaAggregator[_, _, _] => + agg.copy( + inputEncoder = agg.inputEncoder.resolveAndBind(), + bufferEncoder = agg.bufferEncoder.resolveAndBind()) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala index 2ef6e3d..6a20a46 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/expressions/UserDefinedFunction.scala @@ -150,7 +150,8 @@ private[sql] case class UserDefinedAggregator[IN, BUF, OUT]( // This is also used by udf.register(...) when it detects a UserDefinedAggregator def scalaAggregator(exprs: Seq[Expression]): ScalaAggregator[IN, BUF, OUT] = { val iEncoder = inputEncoder.asInstanceOf[ExpressionEncoder[IN]] - ScalaAggregator(exprs, aggregator, iEncoder, nullable, deterministic) + val bEncoder = aggregator.bufferEncoder.asInstanceOf[ExpressionEncoder[BUF]] + ScalaAggregator(exprs, aggregator, iEncoder, bEncoder, nullable, deterministic) } override def withName(name: String): UserDefinedAggregator[IN, BUF, OUT] = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala index 3bbdbb0..4ae12f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/BaseSessionStateBuilder.scala @@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.connector.catalog.CatalogManager import org.apache.spark.sql.execution.{ColumnarRule, QueryExecution, SparkOptimizer, SparkPlanner, SparkSqlParser} +import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ @@ -175,6 +176,7 @@ abstract class BaseSessionStateBuilder( new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: + ResolveEncodersInScalaAgg +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala index 6472675..e256107 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionStateBuilder.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.optimizer.Optimizer import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution.{SparkOptimizer, SparkPlanner} +import org.apache.spark.sql.execution.aggregate.ResolveEncodersInScalaAgg import org.apache.spark.sql.execution.analysis.DetectAmbiguousSelfJoin import org.apache.spark.sql.execution.command.CommandCheck import org.apache.spark.sql.execution.datasources._ @@ -76,6 +77,7 @@ class HiveSessionStateBuilder(session: SparkSession, parentState: Option[Session new FindDataSourceTable(session) +: new ResolveSQLOnFile(session) +: new FallBackFileSourceV2(session) +: + ResolveEncodersInScalaAgg +: new ResolveSessionCatalog( catalogManager, conf, catalog.isTempView, catalog.isTempFunction) +: customResolutionRules diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala index e6856a5..1f1a556 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/UDAQuerySuite.scala @@ -119,6 +119,27 @@ object CountSerDeAgg extends Aggregator[Int, CountSerDeSQL, CountSerDeSQL] { def outputEncoder: Encoder[CountSerDeSQL] = ExpressionEncoder[CountSerDeSQL]() } +object ArrayDataAgg extends Aggregator[Array[Double], Array[Double], Array[Double]] { + def zero: Array[Double] = Array(0.0, 0.0, 0.0) + def reduce(s: Array[Double], array: Array[Double]): Array[Double] = { + require(s.length == array.length) + for ( j <- 0 until s.length ) { + s(j) += array(j) + } + s + } + def merge(s1: Array[Double], s2: Array[Double]): Array[Double] = { + require(s1.length == s2.length) + for ( j <- 0 until s1.length ) { + s1(j) += s2(j) + } + s1 + } + def finish(s: Array[Double]): Array[Double] = s + def bufferEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] + def outputEncoder: Encoder[Array[Double]] = ExpressionEncoder[Array[Double]] +} + abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -156,20 +177,11 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi (3, null, null)).toDF("key", "value1", "value2") data2.write.saveAsTable("agg2") - val data3 = Seq[(Seq[Integer], Integer, Integer)]( - (Seq[Integer](1, 1), 10, -10), - (Seq[Integer](null), -60, 60), - (Seq[Integer](1, 1), 30, -30), - (Seq[Integer](1), 30, 30), - (Seq[Integer](2), 1, 1), - (null, -10, 10), - (Seq[Integer](2, 3), -1, null), - (Seq[Integer](2, 3), 1, 1), - (Seq[Integer](2, 3, 4), null, 1), - (Seq[Integer](null), 100, -10), - (Seq[Integer](3), null, 3), - (null, null, null), - (Seq[Integer](3), null, null)).toDF("key", "value1", "value2") + val data3 = Seq[(Seq[Double], Int)]( + (Seq(1.0, 2.0, 3.0), 0), + (Seq(4.0, 5.0, 6.0), 0), + (Seq(7.0, 8.0, 9.0), 0) + ).toDF("data", "dummy") data3.write.saveAsTable("agg3") val data4 = Seq[Boolean](true, false, true).toDF("boolvalues") @@ -184,6 +196,7 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi spark.udf.register("mydoublesum", udaf(MyDoubleSumAgg)) spark.udf.register("mydoubleavg", udaf(MyDoubleAvgAgg)) spark.udf.register("longProductSum", udaf(LongProductSumAgg)) + spark.udf.register("arraysum", udaf(ArrayDataAgg)) } override def afterAll(): Unit = { @@ -354,6 +367,12 @@ abstract class UDAQuerySuite extends QueryTest with SQLTestUtils with TestHiveSi Row(3, 0, null, 1, 3, 0, 0, 0, null, 1, 3, 0, 2, 2) :: Nil) } + test("SPARK-32159: array encoders should be resolved in analyzer") { + checkAnswer( + spark.sql("SELECT arraysum(data) FROM agg3"), + Row(Seq(12.0, 15.0, 18.0)) :: Nil) + } + test("verify aggregator ser/de behavior") { val data = sparkContext.parallelize((1 to 100).toSeq, 3).toDF("value1") val agg = udaf(CountSerDeAgg) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org