Repository: spark Updated Branches: refs/heads/master 8d0d2a65e -> f33d55046
[SPARK-3891][SQL] Add array support to percentile, percentile_approx and constant inspectors support Supported passing array to percentile and percentile_approx UDAFs To support percentile_approx, constant inspectors are supported for GenericUDAF Constant folding support added to CreateArray expression Avoided constant udf expression re-evaluation Author: Venkata Ramana G <ramana.gollamudihuawei.com> Author: Venkata Ramana Gollamudi <ramana.gollam...@huawei.com> Closes #2802 from gvramana/percentile_array_support and squashes the following commits: a0182e5 [Venkata Ramana Gollamudi] fixed review comment a18f917 [Venkata Ramana Gollamudi] avoid constant udf expression re-evaluation - fixes failure due to return iterator and value type mismatch c46db0f [Venkata Ramana Gollamudi] Removed TestHive reset 4d39105 [Venkata Ramana Gollamudi] Unified inspector creation, style check fixes f37fd69 [Venkata Ramana Gollamudi] Fixed review comments 47f6365 [Venkata Ramana Gollamudi] fixed test cb7c61e [Venkata Ramana Gollamudi] Supported ConstantInspector for UDAF Fixed HiveUdaf wrap object issue. 7f94aff [Venkata Ramana Gollamudi] Added foldable support to CreateArray Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f33d5504 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f33d5504 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f33d5504 Branch: refs/heads/master Commit: f33d55046427b8594fd19bda5fd2214eeeab1a95 Parents: 8d0d2a6 Author: Venkata Ramana Gollamudi <ramana.gollam...@huawei.com> Authored: Wed Dec 17 15:41:35 2014 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Wed Dec 17 15:41:35 2014 -0800 ---------------------------------------------------------------------- .../sql/catalyst/expressions/complexTypes.scala | 4 ++- .../org/apache/spark/sql/hive/hiveUdfs.scala | 35 ++++++++++++++------ .../spark/sql/hive/execution/HiveUdfSuite.scala | 13 +++++++- 3 files changed, 40 insertions(+), 12 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f33d5504/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala index b12821d..9aec601 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/complexTypes.scala @@ -113,7 +113,9 @@ case class GetField(child: Expression, fieldName: String) extends UnaryExpressio */ case class CreateArray(children: Seq[Expression]) extends Expression { override type EvaluatedType = Any - + + override def foldable = !children.exists(!_.foldable) + lazy val childTypes = children.map(_.dataType).distinct override lazy val resolved = http://git-wip-us.apache.org/repos/asf/spark/blob/f33d5504/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala index ed2e96d..93b6ef9 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala @@ -159,6 +159,11 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] @transient + protected def constantReturnValue = unwrap( + returnInspector.asInstanceOf[ConstantObjectInspector].getWritableConstantValue(), + returnInspector) + + @transient protected lazy val deferedObjects = argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] @@ -166,6 +171,8 @@ private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, childr override def eval(input: Row): Any = { returnInspector // Make sure initialized. + if(foldable) return constantReturnValue + var i = 0 while (i < children.length) { val idx = i @@ -193,12 +200,13 @@ private[hive] case class HiveGenericUdaf( @transient protected lazy val objectInspector = { - resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) } @transient - protected lazy val inspectors = children.map(_.dataType).map(toInspector) + protected lazy val inspectors = children.map(toInspector) def dataType: DataType = inspectorToDataType(objectInspector) @@ -223,12 +231,13 @@ private[hive] case class HiveUdaf( @transient protected lazy val objectInspector = { - resolver.getEvaluator(children.map(_.dataType.toTypeInfo).toArray) + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) } @transient - protected lazy val inspectors = children.map(_.dataType).map(toInspector) + protected lazy val inspectors = children.map(toInspector) def dataType: DataType = inspectorToDataType(objectInspector) @@ -261,7 +270,7 @@ private[hive] case class HiveGenericUdtf( protected lazy val function: GenericUDTF = funcWrapper.createFunction() @transient - protected lazy val inputInspectors = children.map(_.dataType).map(toInspector) + protected lazy val inputInspectors = children.map(toInspector) @transient protected lazy val outputInspector = function.initialize(inputInspectors.toArray) @@ -334,10 +343,13 @@ private[hive] case class HiveUdafFunction( } else { funcWrapper.createFunction[AbstractGenericUDAFResolver]() } - - private val inspectors = exprs.map(_.dataType).map(toInspector).toArray - - private val function = resolver.getEvaluator(exprs.map(_.dataType.toTypeInfo).toArray) + + private val inspectors = exprs.map(toInspector).toArray + + private val function = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) + resolver.getEvaluator(parameterInfo) + } private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) @@ -350,9 +362,12 @@ private[hive] case class HiveUdafFunction( @transient val inputProjection = new InterpretedProjection(exprs) + @transient + protected lazy val cached = new Array[AnyRef](exprs.length) + def update(input: Row): Unit = { val inputs = inputProjection(input).asInstanceOf[Seq[AnyRef]].toArray - function.iterate(buffer, inputs) + function.iterate(buffer, wrap(inputs, inspectors, cached)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/f33d5504/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala index 5fcaf67..5fc8d8d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala @@ -92,10 +92,21 @@ class HiveUdfSuite extends QueryTest { } test("SPARK-2693 udaf aggregates test") { - checkAnswer(sql("SELECT percentile(key,1) FROM src LIMIT 1"), + checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), sql("SELECT max(key) FROM src").collect().toSeq) + + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) } + test("Generic UDAF aggregates") { + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) + + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) + } + test("UDFIntegerToString") { val testData = TestHive.sparkContext.parallelize( IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org