http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala new file mode 100644 index 0000000..9e1cff0 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala @@ -0,0 +1,292 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution + +import java.util.{List => JList, Map => JMap} + +import scala.collection.JavaConversions._ +import scala.collection.JavaConverters._ + +import net.razorvine.pickle.{Pickler, Unpickler} + +import org.apache.spark.{Accumulator, Logging => SparkLogging} +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.Row +import org.apache.spark.sql.catalyst.plans.logical +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.types._ +import org.apache.spark.unsafe.types.UTF8String + +/** + * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. + */ +private[spark] case class PythonUDF( + name: String, + command: Array[Byte], + envVars: JMap[String, String], + pythonIncludes: JList[String], + pythonExec: String, + pythonVer: String, + broadcastVars: JList[Broadcast[PythonBroadcast]], + accumulator: Accumulator[JList[Array[Byte]]], + dataType: DataType, + children: Seq[Expression]) extends Expression with SparkLogging { + + override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" + + override def nullable: Boolean = true + + override def eval(input: InternalRow): Any = { + throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") + } +} + +/** + * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated + * alone in a batch. + * + * This has the limitation that the input to the Python UDF is not allowed include attributes from + * multiple child operators. + */ +private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transform { + // Skip EvaluatePython nodes. + case plan: EvaluatePython => plan + + case plan: LogicalPlan if plan.resolved => + // Extract any PythonUDFs from the current operator. + val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) + // If there is more than one, we will add another evaluation operator in a subsequent pass. + udfs.find(_.resolved) match { + case Some(udf) => + var evaluation: EvaluatePython = null + + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Check to make sure that the UDF can be evaluated with only the input of this child. + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + if (udf.references.subsetOf(child.outputSet)) { + evaluation = EvaluatePython(udf, child) + evaluation + } else if (udf.references.intersect(child.outputSet).nonEmpty) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + child + } + } + + assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") + + // Trim away the new UDF value if it was only used for filtering or something. + logical.Project( + plan.output, + plan.transformExpressions { + case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute + }.withNewChildren(newChildren)) + + case None => + // If there is no Python UDF that is resolved, skip this round. + plan + } + } + } +} + +object EvaluatePython { + def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = + new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) + + /** + * Helper for converting a Scala object to a java suitable for pyspark serialization. + */ + def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + case (null, _) => null + + case (row: Row, struct: StructType) => + val fields = struct.fields.map(field => field.dataType) + row.toSeq.zip(fields).map { + case (obj, dataType) => toJava(obj, dataType) + }.toArray + + case (seq: Seq[Any], array: ArrayType) => + seq.map(x => toJava(x, array.elementType)).asJava + case (list: JList[_], array: ArrayType) => + list.map(x => toJava(x, array.elementType)).asJava + case (arr, array: ArrayType) if arr.getClass.isArray => + arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) + + case (obj: Map[_, _], mt: MapType) => obj.map { + case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) + }.asJava + + case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) + + case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) + case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) + case (s: UTF8String, StringType) => s.toString + + // Pyrolite can handle Timestamp and Decimal + case (other, _) => other + } + + /** + * Convert Row into Java Array (for pickled into Python) + */ + def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { + // TODO: this is slow! + row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray + } + + // Converts value to the type specified by the data type. + // Because Python does not have data types for TimestampType, FloatType, ShortType, and + // ByteType, we need to explicitly convert values in columns of these data types to the desired + // JVM data types. + def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { + // TODO: We should check nullable + case (null, _) => null + + case (c: java.util.List[_], ArrayType(elementType, _)) => + c.map { e => fromJava(e, elementType)}: Seq[Any] + + case (c, ArrayType(elementType, _)) if c.getClass.isArray => + c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] + + case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { + case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) + }.toMap + + case (c, StructType(fields)) if c.getClass.isArray => + new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { + case (e, f) => fromJava(e, f.dataType) + }) + + case (c: java.util.Calendar, DateType) => + DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) + + case (c: java.util.Calendar, TimestampType) => + c.getTimeInMillis * 10000L + case (t: java.sql.Timestamp, TimestampType) => + DateTimeUtils.fromJavaTimestamp(t) + + case (_, udt: UserDefinedType[_]) => + fromJava(obj, udt.sqlType) + + case (c: Int, ByteType) => c.toByte + case (c: Long, ByteType) => c.toByte + case (c: Int, ShortType) => c.toShort + case (c: Long, ShortType) => c.toShort + case (c: Long, IntegerType) => c.toInt + case (c: Int, LongType) => c.toLong + case (c: Double, FloatType) => c.toFloat + case (c: String, StringType) => UTF8String.fromString(c) + case (c, StringType) => + // If we get here, c is not a string. Call toString on it. + UTF8String.fromString(c.toString) + + case (c, _) => c + } +} + +/** + * :: DeveloperApi :: + * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. + */ +@DeveloperApi +case class EvaluatePython( + udf: PythonUDF, + child: LogicalPlan, + resultAttribute: AttributeReference) + extends logical.UnaryNode { + + def output: Seq[Attribute] = child.output :+ resultAttribute + + // References should not include the produced attribute. + override def references: AttributeSet = udf.references +} + +/** + * :: DeveloperApi :: + * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. + * The input data is zipped with the result of the udf evaluation. + */ +@DeveloperApi +case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) + extends SparkPlan { + + def children: Seq[SparkPlan] = child :: Nil + + protected override def doExecute(): RDD[InternalRow] = { + val childResults = child.execute().map(_.copy()) + + val parent = childResults.mapPartitions { iter => + val pickle = new Pickler + val currentRow = newMutableProjection(udf.children, child.output)() + val fields = udf.children.map(_.dataType) + iter.grouped(1000).map { inputRows => + val toBePickled = inputRows.map { row => + EvaluatePython.rowToArray(currentRow(row), fields) + }.toArray + pickle.dumps(toBePickled) + } + } + + val pyRDD = new PythonRDD( + parent, + udf.command, + udf.envVars, + udf.pythonIncludes, + false, + udf.pythonExec, + udf.pythonVer, + udf.broadcastVars, + udf.accumulator + ).mapPartitions { iter => + val pickle = new Unpickler + iter.flatMap { pickedResult => + val unpickledBatch = pickle.loads(pickedResult) + unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] + } + }.mapPartitions { iter => + val row = new GenericMutableRow(1) + iter.map { result => + row(0) = EvaluatePython.fromJava(result, udf.dataType) + row: InternalRow + } + } + + childResults.zip(pyRDD).mapPartitions { iter => + val joinedRow = new JoinedRow() + iter.map { + case (row, udfResult) => + joinedRow(row, udfResult) + } + } + } +}
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala deleted file mode 100644 index 036f5d2..0000000 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala +++ /dev/null @@ -1,292 +0,0 @@ -/* -* Licensed to the Apache Software Foundation (ASF) under one or more -* contributor license agreements. See the NOTICE file distributed with -* this work for additional information regarding copyright ownership. -* The ASF licenses this file to You under the Apache License, Version 2.0 -* (the "License"); you may not use this file except in compliance with -* the License. You may obtain a copy of the License at -* -* http://www.apache.org/licenses/LICENSE-2.0 -* -* Unless required by applicable law or agreed to in writing, software -* distributed under the License is distributed on an "AS IS" BASIS, -* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -* See the License for the specific language governing permissions and -* limitations under the License. -*/ - -package org.apache.spark.sql.execution - -import java.util.{List => JList, Map => JMap} - -import scala.collection.JavaConversions._ -import scala.collection.JavaConverters._ - -import net.razorvine.pickle.{Pickler, Unpickler} - -import org.apache.spark.{Accumulator, Logging => SparkLogging} -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.{PythonBroadcast, PythonRDD} -import org.apache.spark.broadcast.Broadcast -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.Row -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.catalyst.util.DateTimeUtils -import org.apache.spark.sql.types._ -import org.apache.spark.unsafe.types.UTF8String - -/** - * A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]]. - */ -private[spark] case class PythonUDF( - name: String, - command: Array[Byte], - envVars: JMap[String, String], - pythonIncludes: JList[String], - pythonExec: String, - pythonVer: String, - broadcastVars: JList[Broadcast[PythonBroadcast]], - accumulator: Accumulator[JList[Array[Byte]]], - dataType: DataType, - children: Seq[Expression]) extends Expression with SparkLogging { - - override def toString: String = s"PythonUDF#$name(${children.mkString(",")})" - - override def nullable: Boolean = true - - override def eval(input: InternalRow): Any = { - throw new UnsupportedOperationException("PythonUDFs can not be directly evaluated.") - } -} - -/** - * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated - * alone in a batch. - * - * This has the limitation that the input to the Python UDF is not allowed include attributes from - * multiple child operators. - */ -private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transform { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan - - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(_.collect { case udf: PythonUDF => udf }) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - // Pick the UDF we are going to evaluate (TODO: Support evaluating multiple UDFs at a time) - // If there is more than one, we will add another evaluation operator in a subsequent pass. - udfs.find(_.resolved) match { - case Some(udf) => - var evaluation: EvaluatePython = null - - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Check to make sure that the UDF can be evaluated with only the input of this child. - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - if (udf.references.subsetOf(child.outputSet)) { - evaluation = EvaluatePython(udf, child) - evaluation - } else if (udf.references.intersect(child.outputSet).nonEmpty) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - child - } - } - - assert(evaluation != null, "Unable to evaluate PythonUDF. Missing input attributes.") - - // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute - }.withNewChildren(newChildren)) - - case None => - // If there is no Python UDF that is resolved, skip this round. - plan - } - } - } -} - -object EvaluatePython { - def apply(udf: PythonUDF, child: LogicalPlan): EvaluatePython = - new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)()) - - /** - * Helper for converting a Scala object to a java suitable for pyspark serialization. - */ - def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - case (null, _) => null - - case (row: Row, struct: StructType) => - val fields = struct.fields.map(field => field.dataType) - row.toSeq.zip(fields).map { - case (obj, dataType) => toJava(obj, dataType) - }.toArray - - case (seq: Seq[Any], array: ArrayType) => - seq.map(x => toJava(x, array.elementType)).asJava - case (list: JList[_], array: ArrayType) => - list.map(x => toJava(x, array.elementType)).asJava - case (arr, array: ArrayType) if arr.getClass.isArray => - arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType)) - - case (obj: Map[_, _], mt: MapType) => obj.map { - case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType)) - }.asJava - - case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType) - - case (date: Int, DateType) => DateTimeUtils.toJavaDate(date) - case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t) - case (s: UTF8String, StringType) => s.toString - - // Pyrolite can handle Timestamp and Decimal - case (other, _) => other - } - - /** - * Convert Row into Java Array (for pickled into Python) - */ - def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = { - // TODO: this is slow! - row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray - } - - // Converts value to the type specified by the data type. - // Because Python does not have data types for TimestampType, FloatType, ShortType, and - // ByteType, we need to explicitly convert values in columns of these data types to the desired - // JVM data types. - def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match { - // TODO: We should check nullable - case (null, _) => null - - case (c: java.util.List[_], ArrayType(elementType, _)) => - c.map { e => fromJava(e, elementType)}: Seq[Any] - - case (c, ArrayType(elementType, _)) if c.getClass.isArray => - c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any] - - case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map { - case (key, value) => (fromJava(key, keyType), fromJava(value, valueType)) - }.toMap - - case (c, StructType(fields)) if c.getClass.isArray => - new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map { - case (e, f) => fromJava(e, f.dataType) - }) - - case (c: java.util.Calendar, DateType) => - DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis)) - - case (c: java.util.Calendar, TimestampType) => - c.getTimeInMillis * 10000L - case (t: java.sql.Timestamp, TimestampType) => - DateTimeUtils.fromJavaTimestamp(t) - - case (_, udt: UserDefinedType[_]) => - fromJava(obj, udt.sqlType) - - case (c: Int, ByteType) => c.toByte - case (c: Long, ByteType) => c.toByte - case (c: Int, ShortType) => c.toShort - case (c: Long, ShortType) => c.toShort - case (c: Long, IntegerType) => c.toInt - case (c: Int, LongType) => c.toLong - case (c: Double, FloatType) => c.toFloat - case (c: String, StringType) => UTF8String.fromString(c) - case (c, StringType) => - // If we get here, c is not a string. Call toString on it. - UTF8String.fromString(c.toString) - - case (c, _) => c - } -} - -/** - * :: DeveloperApi :: - * Evaluates a [[PythonUDF]], appending the result to the end of the input tuple. - */ -@DeveloperApi -case class EvaluatePython( - udf: PythonUDF, - child: LogicalPlan, - resultAttribute: AttributeReference) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output :+ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = udf.references -} - -/** - * :: DeveloperApi :: - * Uses PythonRDD to evaluate a [[PythonUDF]], one partition of tuples at a time. - * The input data is zipped with the result of the udf evaluation. - */ -@DeveloperApi -case class BatchPythonEvaluation(udf: PythonUDF, output: Seq[Attribute], child: SparkPlan) - extends SparkPlan { - - def children: Seq[SparkPlan] = child :: Nil - - protected override def doExecute(): RDD[InternalRow] = { - val childResults = child.execute().map(_.copy()) - - val parent = childResults.mapPartitions { iter => - val pickle = new Pickler - val currentRow = newMutableProjection(udf.children, child.output)() - val fields = udf.children.map(_.dataType) - iter.grouped(1000).map { inputRows => - val toBePickled = inputRows.map { row => - EvaluatePython.rowToArray(currentRow(row), fields) - }.toArray - pickle.dumps(toBePickled) - } - } - - val pyRDD = new PythonRDD( - parent, - udf.command, - udf.envVars, - udf.pythonIncludes, - false, - udf.pythonExec, - udf.pythonVer, - udf.broadcastVars, - udf.accumulator - ).mapPartitions { iter => - val pickle = new Unpickler - iter.flatMap { pickedResult => - val unpickledBatch = pickle.loads(pickedResult) - unpickledBatch.asInstanceOf[java.util.ArrayList[Any]] - } - }.mapPartitions { iter => - val row = new GenericMutableRow(1) - iter.map { result => - row(0) = EvaluatePython.fromJava(result, udf.dataType) - row: InternalRow - } - } - - childResults.zip(pyRDD).mapPartitions { iter => - val joinedRow = new JoinedRow() - iter.map { - case (row, udfResult) => - joinedRow(row, udfResult) - } - } - } -} http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/main/scala/org/apache/spark/sql/functions.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala index 5422e06..4d9a019 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/functions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/functions.scala @@ -1509,7 +1509,7 @@ object functions { (0 to 10).map { x => val args = (1 to x).map(i => s"arg$i: Column").mkString(", ") val fTypes = Seq.fill(x + 1)("_").mkString(", ") - val argsInUdf = (1 to x).map(i => s"arg$i.expr").mkString(", ") + val argsInUDF = (1 to x).map(i => s"arg$i.expr").mkString(", ") println(s""" /** * Call a Scala function of ${x} arguments as user-defined function (UDF). This requires @@ -1521,7 +1521,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function$x[$fTypes], returnType: DataType${if (args.length > 0) ", " + args else ""}): Column = { - ScalaUdf(f, returnType, Seq($argsInUdf)) + ScalaUDF(f, returnType, Seq($argsInUDF)) }""") } } @@ -1659,7 +1659,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function0[_], returnType: DataType): Column = { - ScalaUdf(f, returnType, Seq()) + ScalaUDF(f, returnType, Seq()) } /** @@ -1672,7 +1672,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function1[_, _], returnType: DataType, arg1: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr)) } /** @@ -1685,7 +1685,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function2[_, _, _], returnType: DataType, arg1: Column, arg2: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr)) } /** @@ -1698,7 +1698,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function3[_, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr)) } /** @@ -1711,7 +1711,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function4[_, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr)) } /** @@ -1724,7 +1724,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function5[_, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr)) } /** @@ -1737,7 +1737,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function6[_, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr)) } /** @@ -1750,7 +1750,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function7[_, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr)) } /** @@ -1763,7 +1763,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function8[_, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr)) } /** @@ -1776,7 +1776,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function9[_, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr)) } /** @@ -1789,7 +1789,7 @@ object functions { */ @deprecated("Use udf", "1.5.0") def callUDF(f: Function10[_, _, _, _, _, _, _, _, _, _, _], returnType: DataType, arg1: Column, arg2: Column, arg3: Column, arg4: Column, arg5: Column, arg6: Column, arg7: Column, arg8: Column, arg9: Column, arg10: Column): Column = { - ScalaUdf(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) + ScalaUDF(f, returnType, Seq(arg1.expr, arg2.expr, arg3.expr, arg4.expr, arg5.expr, arg6.expr, arg7.expr, arg8.expr, arg9.expr, arg10.expr)) } // scalastyle:on @@ -1802,8 +1802,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUDF("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUDF("simpleUDF", $"value")) * }}} * * @group udf_funcs @@ -1821,8 +1821,8 @@ object functions { * * val df = Seq(("id1", 1), ("id2", 4), ("id3", 5)).toDF("id", "value") * val sqlContext = df.sqlContext - * sqlContext.udf.register("simpleUdf", (v: Int) => v * v) - * df.select($"id", callUdf("simpleUdf", $"value")) + * sqlContext.udf.register("simpleUDF", (v: Int) => v * v) + * df.select($"id", callUdf("simpleUDF", $"value")) * }}} * * @group udf_funcs http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 22c54e4..82dc0e9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -140,9 +140,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index") // we except the id is materialized once - val idUdf = udf(() => UUID.randomUUID().toString) + val idUDF = udf(() => UUID.randomUUID().toString) - val dfWithId = df.withColumn("id", idUdf()) + val dfWithId = df.withColumn("id", idUDF()) // Make a new DataFrame (actually the same reference to the old one) val cached = dfWithId.cache() // Trigger the cache http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala index 8021f91..b91242a 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala @@ -42,7 +42,7 @@ import org.apache.spark.sql.SQLConf.SQLConfEntry._ import org.apache.spark.sql.catalyst.ParserDialect import org.apache.spark.sql.catalyst.analysis._ import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUdfs, SetCommand} +import org.apache.spark.sql.execution.{ExecutedCommand, ExtractPythonUDFs, SetCommand} import org.apache.spark.sql.hive.client._ import org.apache.spark.sql.hive.execution.{DescribeHiveTableCommand, HiveNativeCommand} import org.apache.spark.sql.sources.DataSourceStrategy @@ -381,7 +381,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) { catalog.ParquetConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - ExtractPythonUdfs :: + ExtractPythonUDFs :: ResolveHiveWindowFunction :: sources.PreInsertCastAndRename :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala index 7c46209..2de7a99 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala @@ -1638,7 +1638,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C sys.error(s"Couldn't find function $functionName")) val functionClassName = functionInfo.getFunctionClass.getName - (HiveGenericUdtf( + (HiveGenericUDTF( new HiveFunctionWrapper(functionClassName), children.map(nodeToExpr)), attributes) http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala new file mode 100644 index 0000000..d7827d5 --- /dev/null +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -0,0 +1,598 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive + +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ +import scala.util.Try + +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions +import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory +import org.apache.hadoop.hive.ql.exec._ +import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} +import org.apache.hadoop.hive.ql.udf.generic._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ +import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer +import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper + +import org.apache.spark.Logging +import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.analysis +import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder +import org.apache.spark.sql.catalyst.errors.TreeNodeException +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.hive.HiveShim._ +import org.apache.spark.sql.types._ + + +private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) + extends analysis.FunctionRegistry with HiveInspectors { + + def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) + + override def lookupFunction(name: String, children: Seq[Expression]): Expression = { + Try(underlying.lookupFunction(name, children)).getOrElse { + // We only look it up to see if it exists, but do not include it in the HiveUDF since it is + // not always serializable. + val functionInfo: FunctionInfo = + Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"undefined function $name")) + + val functionClassName = functionInfo.getFunctionClass.getName + + if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveSimpleUDF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDF(new HiveFunctionWrapper(functionClassName), children) + } else if ( + classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveUDAF(new HiveFunctionWrapper(functionClassName), children) + } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { + HiveGenericUDTF(new HiveFunctionWrapper(functionClassName), children) + } else { + sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") + } + } + } + + override def registerFunction(name: String, builder: FunctionBuilder): Unit = + throw new UnsupportedOperationException +} + +private[hive] case class HiveSimpleUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with Logging { + + type UDFType = UDF + + override def deterministic: Boolean = isUDFDeterministic + + override def nullable: Boolean = true + + @transient + lazy val function = funcWrapper.createFunction[UDFType]() + + @transient + protected lazy val method = + function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) + + @transient + protected lazy val arguments = children.map(toInspector).toArray + + @transient + protected lazy val isUDFDeterministic = { + val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + udfType != null && udfType.deterministic() + } + + override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) + + // Create parameter converters + @transient + protected lazy val conversionHelper = new ConversionHelper(method, arguments) + + @transient + lazy val dataType = javaClassToDataType(method.getReturnType) + + @transient + lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( + method.getGenericReturnType(), ObjectInspectorOptions.JAVA) + + @transient + protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) + + override def isThreadSafe: Boolean = false + + // TODO: Finish input output types. + override def eval(input: InternalRow): Any = { + unwrap( + FunctionRegistry.invoke(method, function, conversionHelper + .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), + returnInspector) + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +// Adapter from Catalyst ExpressionResult to Hive DeferredObject +private[hive] class DeferredObjectAdapter(oi: ObjectInspector) + extends DeferredObject with HiveInspectors { + private var func: () => Any = _ + def set(func: () => Any): Unit = { + this.func = func + } + override def prepare(i: Int): Unit = {} + override def get(): AnyRef = wrap(func(), oi) +} + +private[hive] case class HiveGenericUDF(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) + extends Expression with HiveInspectors with Logging { + type UDFType = GenericUDF + + override def deterministic: Boolean = isUDFDeterministic + + override def nullable: Boolean = true + + @transient + lazy val function = funcWrapper.createFunction[UDFType]() + + @transient + protected lazy val argumentInspectors = children.map(toInspector) + + @transient + protected lazy val returnInspector = { + function.initializeAndFoldConstants(argumentInspectors.toArray) + } + + @transient + protected lazy val isUDFDeterministic = { + val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) + (udfType != null && udfType.deterministic()) + } + + override def foldable: Boolean = + isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] + + @transient + protected lazy val deferedObjects = + argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] + + lazy val dataType: DataType = inspectorToDataType(returnInspector) + + override def isThreadSafe: Boolean = false + + override def eval(input: InternalRow): Any = { + returnInspector // Make sure initialized. + + var i = 0 + while (i < children.length) { + val idx = i + deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( + () => { + children(idx).eval(input) + }) + i += 1 + } + unwrap(function.evaluate(deferedObjects), returnInspector) + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +/** + * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. + */ +private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { + def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { + case p: LogicalPlan if !p.childrenResolved => p + + // We are resolving WindowExpressions at here. When we get here, we have already + // replaced those WindowSpecReferences. + case p: LogicalPlan => + p transformExpressions { + case WindowExpression( + UnresolvedWindowFunction(name, children), + windowSpec: WindowSpecDefinition) => + // First, let's find the window function info. + val windowFunctionInfo: WindowFunctionInfo = + Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( + throw new AnalysisException(s"Couldn't find window function $name")) + + // Get the class of this function. + // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use + // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. + val functionClass = windowFunctionInfo.getfInfo().getFunctionClass + val newChildren = + // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit + // input parameters and requires implicit parameters, which + // are expressions in Order By clause. + if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { + if (children.nonEmpty) { + throw new AnalysisException(s"$name does not take input parameters.") + } + windowSpec.orderSpec.map(_.child) + } else { + children + } + + // If the class is UDAF, we need to use UDAFBridge. + val isUDAFBridgeRequired = + if (classOf[UDAF].isAssignableFrom(functionClass)) { + true + } else { + false + } + + // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of + // HiveWindowFunction. + val windowFunction = + HiveWindowFunction( + new HiveFunctionWrapper(functionClass.getName), + windowFunctionInfo.isPivotResult, + isUDAFBridgeRequired, + newChildren) + + // Second, check if the specified window function can accept window definition. + windowSpec.frameSpecification match { + case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => + // This Hive window function does not support user-speficied window frame. + throw new AnalysisException( + s"Window function $name does not take a frame specification.") + case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && + windowFunctionInfo.isPivotResult => + // These two should not be true at the same time when a window frame is defined. + // If so, throw an exception. + throw new AnalysisException(s"Could not handle Hive window function $name because " + + s"it supports both a user specified window frame and pivot result.") + case _ => // OK + } + // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs + // a window frame specification to work. + val newWindowSpec = windowSpec.frameSpecification match { + case UnspecifiedFrame => + val newWindowFrame = + SpecifiedWindowFrame.defaultWindowFrame( + windowSpec.orderSpec.nonEmpty, + windowFunctionInfo.isSupportsWindow) + WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) + case _ => windowSpec + } + + // Finally, we create a WindowExpression with the resolved window function and + // specified window spec. + WindowExpression(windowFunction, newWindowSpec) + } + } +} + +/** + * A [[WindowFunction]] implementation wrapping Hive's window function. + * @param funcWrapper The wrapper for the Hive Window Function. + * @param pivotResult If it is true, the Hive function will return a list of values representing + * the values of the added columns. Otherwise, a single value is returned for + * current row. + * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's + * createFunction is UDAF, we need to use GenericUDAFBridge to wrap + * it as a GenericUDAFResolver2. + * @param children Input parameters. + */ +private[hive] case class HiveWindowFunction( + funcWrapper: HiveFunctionWrapper, + pivotResult: Boolean, + isUDAFBridgeRequired: Boolean, + children: Seq[Expression]) extends WindowFunction + with HiveInspectors { + + // Hive window functions are based on GenericUDAFResolver2. + type UDFType = GenericUDAFResolver2 + + @transient + protected lazy val resolver: GenericUDAFResolver2 = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + } else { + funcWrapper.createFunction[GenericUDAFResolver2]() + } + + @transient + protected lazy val inputInspectors = children.map(toInspector).toArray + + // The GenericUDAFEvaluator used to evaluate the window function. + @transient + protected lazy val evaluator: GenericUDAFEvaluator = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + + // The object inspector of values returned from the Hive window function. + @transient + protected lazy val returnInspector = { + evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) + } + + def dataType: DataType = + if (!pivotResult) { + inspectorToDataType(returnInspector) + } else { + // If pivotResult is true, we should take the element type out as the data type of this + // function. + inspectorToDataType(returnInspector) match { + case ArrayType(dt, _) => dt + case _ => + sys.error( + s"error resolve the data type of window function ${funcWrapper.functionClassName}") + } + } + + def nullable: Boolean = true + + override def eval(input: InternalRow): Any = + throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") + + @transient + lazy val inputProjection = new InterpretedProjection(children) + + @transient + private var hiveEvaluatorBuffer: AggregationBuffer = _ + // Output buffer. + private var outputBuffer: Any = _ + + override def init(): Unit = { + evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) + } + + // Reset the hiveEvaluatorBuffer and outputPosition + override def reset(): Unit = { + // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. + // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. + // However, RowNumberBuffer.init does not really reset this buffer. + hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer + evaluator.reset(hiveEvaluatorBuffer) + } + + override def prepareInputParameters(input: InternalRow): AnyRef = { + wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) + } + // Add input parameters for a single row. + override def update(input: AnyRef): Unit = { + evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) + } + + override def batchUpdate(inputs: Array[AnyRef]): Unit = { + var i = 0 + while (i < inputs.length) { + evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) + i += 1 + } + } + + override def evaluate(): Unit = { + outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) + } + + override def get(index: Int): Any = { + if (!pivotResult) { + // if pivotResult is false, we will get a single value for all rows in the frame. + outputBuffer + } else { + // if pivotResult is true, we will get a Seq having the same size with the size + // of the window frame. At here, we will return the result at the position of + // index in the output buffer. + outputBuffer.asInstanceOf[Seq[Any]].get(index) + } + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + override def newInstance: WindowFunction = + new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) +} + +private[hive] case class HiveGenericUDAF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors { + + type UDFType = AbstractGenericUDAFResolver + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() + + @transient + protected lazy val objectInspector = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this) +} + +/** It is used as a wrapper for the hive functions which uses UDAF interface */ +private[hive] case class HiveUDAF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) extends AggregateExpression + with HiveInspectors { + + type UDFType = UDAF + + @transient + protected lazy val resolver: AbstractGenericUDAFResolver = + new GenericUDAFBridge(funcWrapper.createFunction()) + + @transient + protected lazy val objectInspector = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) + resolver.getEvaluator(parameterInfo) + .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) + } + + @transient + protected lazy val inspectors = children.map(toInspector) + + def dataType: DataType = inspectorToDataType(objectInspector) + + def nullable: Boolean = true + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } + + def newInstance(): HiveUDAFFunction = new HiveUDAFFunction(funcWrapper, children, this, true) +} + +/** + * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a + * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow + * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning + * dependent operations like calls to `close()` before producing output will not operate the same as + * in Hive. However, in practice this should not affect compatibility for most sane UDTFs + * (e.g. explode or GenericUDTFParseUrlTuple). + * + * Operators that require maintaining state in between input rows should instead be implemented as + * user defined aggregations, which have clean semantics even in a partitioned execution. + */ +private[hive] case class HiveGenericUDTF( + funcWrapper: HiveFunctionWrapper, + children: Seq[Expression]) + extends Generator with HiveInspectors { + + @transient + protected lazy val function: GenericUDTF = { + val fun: GenericUDTF = funcWrapper.createFunction() + fun.setCollector(collector) + fun + } + + @transient + protected lazy val inputInspectors = children.map(toInspector) + + @transient + protected lazy val outputInspector = function.initialize(inputInspectors.toArray) + + @transient + protected lazy val udtInput = new Array[AnyRef](children.length) + + @transient + protected lazy val collector = new UDTFCollector + + lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { + field => (inspectorToDataType(field.getFieldObjectInspector), true) + } + + override def eval(input: InternalRow): TraversableOnce[InternalRow] = { + outputInspector // Make sure initialized. + + val inputProjection = new InterpretedProjection(children) + + function.process(wrap(inputProjection(input), inputInspectors, udtInput)) + collector.collectRows() + } + + protected class UDTFCollector extends Collector { + var collected = new ArrayBuffer[InternalRow] + + override def collect(input: java.lang.Object) { + // We need to clone the input here because implementations of + // GenericUDTF reuse the same object. Luckily they are always an array, so + // it is easy to clone. + collected += unwrap(input, outputInspector).asInstanceOf[InternalRow] + } + + def collectRows(): Seq[InternalRow] = { + val toCollect = collected + collected = new ArrayBuffer[InternalRow] + toCollect + } + } + + override def terminate(): TraversableOnce[InternalRow] = { + outputInspector // Make sure initialized. + function.close() + collector.collectRows() + } + + override def toString: String = { + s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" + } +} + +private[hive] case class HiveUDAFFunction( + funcWrapper: HiveFunctionWrapper, + exprs: Seq[Expression], + base: AggregateExpression, + isUDAFBridgeRequired: Boolean = false) + extends AggregateFunction + with HiveInspectors { + + def this() = this(null, null, null) + + private val resolver = + if (isUDAFBridgeRequired) { + new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) + } else { + funcWrapper.createFunction[AbstractGenericUDAFResolver]() + } + + private val inspectors = exprs.map(toInspector).toArray + + private val function = { + val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) + resolver.getEvaluator(parameterInfo) + } + + private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) + + private val buffer = + function.getNewAggregationBuffer + + override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) + + @transient + val inputProjection = new InterpretedProjection(exprs) + + @transient + protected lazy val cached = new Array[AnyRef](exprs.length) + + def update(input: InternalRow): Unit = { + val inputs = inputProjection(input) + function.iterate(buffer, wrap(inputs, inspectors, cached)) + } +} + http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala deleted file mode 100644 index 4986b1e..0000000 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala +++ /dev/null @@ -1,598 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive - -import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ -import scala.util.Try - -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ConstantObjectInspector} -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory.ObjectInspectorOptions -import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory -import org.apache.hadoop.hive.ql.exec._ -import org.apache.hadoop.hive.ql.udf.{UDFType => HiveUDFType} -import org.apache.hadoop.hive.ql.udf.generic._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF._ -import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator.AggregationBuffer -import org.apache.hadoop.hive.ql.udf.generic.GenericUDFUtils.ConversionHelper - -import org.apache.spark.Logging -import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.catalyst.analysis -import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder -import org.apache.spark.sql.catalyst.errors.TreeNodeException -import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical._ -import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.hive.HiveShim._ -import org.apache.spark.sql.types._ - - -private[hive] class HiveFunctionRegistry(underlying: analysis.FunctionRegistry) - extends analysis.FunctionRegistry with HiveInspectors { - - def getFunctionInfo(name: String): FunctionInfo = FunctionRegistry.getFunctionInfo(name) - - override def lookupFunction(name: String, children: Seq[Expression]): Expression = { - Try(underlying.lookupFunction(name, children)).getOrElse { - // We only look it up to see if it exists, but do not include it in the HiveUDF since it is - // not always serializable. - val functionInfo: FunctionInfo = - Option(FunctionRegistry.getFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"undefined function $name")) - - val functionClassName = functionInfo.getFunctionClass.getName - - if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveSimpleUdf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdf(new HiveFunctionWrapper(functionClassName), children) - } else if ( - classOf[AbstractGenericUDAFResolver].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveUdaf(new HiveFunctionWrapper(functionClassName), children) - } else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) { - HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children) - } else { - sys.error(s"No handler for udf ${functionInfo.getFunctionClass}") - } - } - } - - override def registerFunction(name: String, builder: FunctionBuilder): Unit = - throw new UnsupportedOperationException -} - -private[hive] case class HiveSimpleUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { - - type UDFType = UDF - - override def deterministic: Boolean = isUDFDeterministic - - override def nullable: Boolean = true - - @transient - lazy val function = funcWrapper.createFunction[UDFType]() - - @transient - protected lazy val method = - function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo)) - - @transient - protected lazy val arguments = children.map(toInspector).toArray - - @transient - protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - udfType != null && udfType.deterministic() - } - - override def foldable: Boolean = isUDFDeterministic && children.forall(_.foldable) - - // Create parameter converters - @transient - protected lazy val conversionHelper = new ConversionHelper(method, arguments) - - @transient - lazy val dataType = javaClassToDataType(method.getReturnType) - - @transient - lazy val returnInspector = ObjectInspectorFactory.getReflectionObjectInspector( - method.getGenericReturnType(), ObjectInspectorOptions.JAVA) - - @transient - protected lazy val cached: Array[AnyRef] = new Array[AnyRef](children.length) - - override def isThreadSafe: Boolean = false - - // TODO: Finish input output types. - override def eval(input: InternalRow): Any = { - unwrap( - FunctionRegistry.invoke(method, function, conversionHelper - .convertIfNecessary(wrap(children.map(c => c.eval(input)), arguments, cached): _*): _*), - returnInspector) - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } -} - -// Adapter from Catalyst ExpressionResult to Hive DeferredObject -private[hive] class DeferredObjectAdapter(oi: ObjectInspector) - extends DeferredObject with HiveInspectors { - private var func: () => Any = _ - def set(func: () => Any): Unit = { - this.func = func - } - override def prepare(i: Int): Unit = {} - override def get(): AnyRef = wrap(func(), oi) -} - -private[hive] case class HiveGenericUdf(funcWrapper: HiveFunctionWrapper, children: Seq[Expression]) - extends Expression with HiveInspectors with Logging { - type UDFType = GenericUDF - - override def deterministic: Boolean = isUDFDeterministic - - override def nullable: Boolean = true - - @transient - lazy val function = funcWrapper.createFunction[UDFType]() - - @transient - protected lazy val argumentInspectors = children.map(toInspector) - - @transient - protected lazy val returnInspector = { - function.initializeAndFoldConstants(argumentInspectors.toArray) - } - - @transient - protected lazy val isUDFDeterministic = { - val udfType = function.getClass().getAnnotation(classOf[HiveUDFType]) - (udfType != null && udfType.deterministic()) - } - - override def foldable: Boolean = - isUDFDeterministic && returnInspector.isInstanceOf[ConstantObjectInspector] - - @transient - protected lazy val deferedObjects = - argumentInspectors.map(new DeferredObjectAdapter(_)).toArray[DeferredObject] - - lazy val dataType: DataType = inspectorToDataType(returnInspector) - - override def isThreadSafe: Boolean = false - - override def eval(input: InternalRow): Any = { - returnInspector // Make sure initialized. - - var i = 0 - while (i < children.length) { - val idx = i - deferedObjects(i).asInstanceOf[DeferredObjectAdapter].set( - () => { - children(idx).eval(input) - }) - i += 1 - } - unwrap(function.evaluate(deferedObjects), returnInspector) - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } -} - -/** - * Resolves [[UnresolvedWindowFunction]] to [[HiveWindowFunction]]. - */ -private[spark] object ResolveHiveWindowFunction extends Rule[LogicalPlan] { - def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { - case p: LogicalPlan if !p.childrenResolved => p - - // We are resolving WindowExpressions at here. When we get here, we have already - // replaced those WindowSpecReferences. - case p: LogicalPlan => - p transformExpressions { - case WindowExpression( - UnresolvedWindowFunction(name, children), - windowSpec: WindowSpecDefinition) => - // First, let's find the window function info. - val windowFunctionInfo: WindowFunctionInfo = - Option(FunctionRegistry.getWindowFunctionInfo(name.toLowerCase)).getOrElse( - throw new AnalysisException(s"Couldn't find window function $name")) - - // Get the class of this function. - // In Hive 0.12, there is no windowFunctionInfo.getFunctionClass. So, we use - // windowFunctionInfo.getfInfo().getFunctionClass for both Hive 0.13 and Hive 0.13.1. - val functionClass = windowFunctionInfo.getfInfo().getFunctionClass - val newChildren = - // Rank(), DENSE_RANK(), CUME_DIST(), and PERCENT_RANK() do not take explicit - // input parameters and requires implicit parameters, which - // are expressions in Order By clause. - if (classOf[GenericUDAFRank].isAssignableFrom(functionClass)) { - if (children.nonEmpty) { - throw new AnalysisException(s"$name does not take input parameters.") - } - windowSpec.orderSpec.map(_.child) - } else { - children - } - - // If the class is UDAF, we need to use UDAFBridge. - val isUDAFBridgeRequired = - if (classOf[UDAF].isAssignableFrom(functionClass)) { - true - } else { - false - } - - // Create the HiveWindowFunction. For the meaning of isPivotResult, see the doc of - // HiveWindowFunction. - val windowFunction = - HiveWindowFunction( - new HiveFunctionWrapper(functionClass.getName), - windowFunctionInfo.isPivotResult, - isUDAFBridgeRequired, - newChildren) - - // Second, check if the specified window function can accept window definition. - windowSpec.frameSpecification match { - case frame: SpecifiedWindowFrame if !windowFunctionInfo.isSupportsWindow => - // This Hive window function does not support user-speficied window frame. - throw new AnalysisException( - s"Window function $name does not take a frame specification.") - case frame: SpecifiedWindowFrame if windowFunctionInfo.isSupportsWindow && - windowFunctionInfo.isPivotResult => - // These two should not be true at the same time when a window frame is defined. - // If so, throw an exception. - throw new AnalysisException(s"Could not handle Hive window function $name because " + - s"it supports both a user specified window frame and pivot result.") - case _ => // OK - } - // Resolve those UnspecifiedWindowFrame because the physical Window operator still needs - // a window frame specification to work. - val newWindowSpec = windowSpec.frameSpecification match { - case UnspecifiedFrame => - val newWindowFrame = - SpecifiedWindowFrame.defaultWindowFrame( - windowSpec.orderSpec.nonEmpty, - windowFunctionInfo.isSupportsWindow) - WindowSpecDefinition(windowSpec.partitionSpec, windowSpec.orderSpec, newWindowFrame) - case _ => windowSpec - } - - // Finally, we create a WindowExpression with the resolved window function and - // specified window spec. - WindowExpression(windowFunction, newWindowSpec) - } - } -} - -/** - * A [[WindowFunction]] implementation wrapping Hive's window function. - * @param funcWrapper The wrapper for the Hive Window Function. - * @param pivotResult If it is true, the Hive function will return a list of values representing - * the values of the added columns. Otherwise, a single value is returned for - * current row. - * @param isUDAFBridgeRequired If it is true, the function returned by functionWrapper's - * createFunction is UDAF, we need to use GenericUDAFBridge to wrap - * it as a GenericUDAFResolver2. - * @param children Input parameters. - */ -private[hive] case class HiveWindowFunction( - funcWrapper: HiveFunctionWrapper, - pivotResult: Boolean, - isUDAFBridgeRequired: Boolean, - children: Seq[Expression]) extends WindowFunction - with HiveInspectors { - - // Hive window functions are based on GenericUDAFResolver2. - type UDFType = GenericUDAFResolver2 - - @transient - protected lazy val resolver: GenericUDAFResolver2 = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[GenericUDAFResolver2]() - } - - @transient - protected lazy val inputInspectors = children.map(toInspector).toArray - - // The GenericUDAFEvaluator used to evaluate the window function. - @transient - protected lazy val evaluator: GenericUDAFEvaluator = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inputInspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - // The object inspector of values returned from the Hive window function. - @transient - protected lazy val returnInspector = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - def dataType: DataType = - if (!pivotResult) { - inspectorToDataType(returnInspector) - } else { - // If pivotResult is true, we should take the element type out as the data type of this - // function. - inspectorToDataType(returnInspector) match { - case ArrayType(dt, _) => dt - case _ => - sys.error( - s"error resolve the data type of window function ${funcWrapper.functionClassName}") - } - } - - def nullable: Boolean = true - - override def eval(input: InternalRow): Any = - throw new TreeNodeException(this, s"No function to evaluate expression. type: ${this.nodeName}") - - @transient - lazy val inputProjection = new InterpretedProjection(children) - - @transient - private var hiveEvaluatorBuffer: AggregationBuffer = _ - // Output buffer. - private var outputBuffer: Any = _ - - override def init(): Unit = { - evaluator.init(GenericUDAFEvaluator.Mode.COMPLETE, inputInspectors) - } - - // Reset the hiveEvaluatorBuffer and outputPosition - override def reset(): Unit = { - // We create a new aggregation buffer to workaround the bug in GenericUDAFRowNumber. - // Basically, GenericUDAFRowNumberEvaluator.reset calls RowNumberBuffer.init. - // However, RowNumberBuffer.init does not really reset this buffer. - hiveEvaluatorBuffer = evaluator.getNewAggregationBuffer - evaluator.reset(hiveEvaluatorBuffer) - } - - override def prepareInputParameters(input: InternalRow): AnyRef = { - wrap(inputProjection(input), inputInspectors, new Array[AnyRef](children.length)) - } - // Add input parameters for a single row. - override def update(input: AnyRef): Unit = { - evaluator.iterate(hiveEvaluatorBuffer, input.asInstanceOf[Array[AnyRef]]) - } - - override def batchUpdate(inputs: Array[AnyRef]): Unit = { - var i = 0 - while (i < inputs.length) { - evaluator.iterate(hiveEvaluatorBuffer, inputs(i).asInstanceOf[Array[AnyRef]]) - i += 1 - } - } - - override def evaluate(): Unit = { - outputBuffer = unwrap(evaluator.evaluate(hiveEvaluatorBuffer), returnInspector) - } - - override def get(index: Int): Any = { - if (!pivotResult) { - // if pivotResult is false, we will get a single value for all rows in the frame. - outputBuffer - } else { - // if pivotResult is true, we will get a Seq having the same size with the size - // of the window frame. At here, we will return the result at the position of - // index in the output buffer. - outputBuffer.asInstanceOf[Seq[Any]].get(index) - } - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - override def newInstance: WindowFunction = - new HiveWindowFunction(funcWrapper, pivotResult, isUDAFBridgeRequired, children) -} - -private[hive] case class HiveGenericUdaf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression - with HiveInspectors { - - type UDFType = AbstractGenericUDAFResolver - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = funcWrapper.createFunction() - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this) -} - -/** It is used as a wrapper for the hive functions which uses UDAF interface */ -private[hive] case class HiveUdaf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) extends AggregateExpression - with HiveInspectors { - - type UDFType = UDAF - - @transient - protected lazy val resolver: AbstractGenericUDAFResolver = - new GenericUDAFBridge(funcWrapper.createFunction()) - - @transient - protected lazy val objectInspector = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors.toArray, false, false) - resolver.getEvaluator(parameterInfo) - .init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors.toArray) - } - - @transient - protected lazy val inspectors = children.map(toInspector) - - def dataType: DataType = inspectorToDataType(objectInspector) - - def nullable: Boolean = true - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } - - def newInstance(): HiveUdafFunction = new HiveUdafFunction(funcWrapper, children, this, true) -} - -/** - * Converts a Hive Generic User Defined Table Generating Function (UDTF) to a - * [[catalyst.expressions.Generator Generator]]. Note that the semantics of Generators do not allow - * Generators to maintain state in between input rows. Thus UDTFs that rely on partitioning - * dependent operations like calls to `close()` before producing output will not operate the same as - * in Hive. However, in practice this should not affect compatibility for most sane UDTFs - * (e.g. explode or GenericUDTFParseUrlTuple). - * - * Operators that require maintaining state in between input rows should instead be implemented as - * user defined aggregations, which have clean semantics even in a partitioned execution. - */ -private[hive] case class HiveGenericUdtf( - funcWrapper: HiveFunctionWrapper, - children: Seq[Expression]) - extends Generator with HiveInspectors { - - @transient - protected lazy val function: GenericUDTF = { - val fun: GenericUDTF = funcWrapper.createFunction() - fun.setCollector(collector) - fun - } - - @transient - protected lazy val inputInspectors = children.map(toInspector) - - @transient - protected lazy val outputInspector = function.initialize(inputInspectors.toArray) - - @transient - protected lazy val udtInput = new Array[AnyRef](children.length) - - @transient - protected lazy val collector = new UDTFCollector - - lazy val elementTypes = outputInspector.getAllStructFieldRefs.map { - field => (inspectorToDataType(field.getFieldObjectInspector), true) - } - - override def eval(input: InternalRow): TraversableOnce[InternalRow] = { - outputInspector // Make sure initialized. - - val inputProjection = new InterpretedProjection(children) - - function.process(wrap(inputProjection(input), inputInspectors, udtInput)) - collector.collectRows() - } - - protected class UDTFCollector extends Collector { - var collected = new ArrayBuffer[InternalRow] - - override def collect(input: java.lang.Object) { - // We need to clone the input here because implementations of - // GenericUDTF reuse the same object. Luckily they are always an array, so - // it is easy to clone. - collected += unwrap(input, outputInspector).asInstanceOf[InternalRow] - } - - def collectRows(): Seq[InternalRow] = { - val toCollect = collected - collected = new ArrayBuffer[InternalRow] - toCollect - } - } - - override def terminate(): TraversableOnce[InternalRow] = { - outputInspector // Make sure initialized. - function.close() - collector.collectRows() - } - - override def toString: String = { - s"$nodeName#${funcWrapper.functionClassName}(${children.mkString(",")})" - } -} - -private[hive] case class HiveUdafFunction( - funcWrapper: HiveFunctionWrapper, - exprs: Seq[Expression], - base: AggregateExpression, - isUDAFBridgeRequired: Boolean = false) - extends AggregateFunction - with HiveInspectors { - - def this() = this(null, null, null) - - private val resolver = - if (isUDAFBridgeRequired) { - new GenericUDAFBridge(funcWrapper.createFunction[UDAF]()) - } else { - funcWrapper.createFunction[AbstractGenericUDAFResolver]() - } - - private val inspectors = exprs.map(toInspector).toArray - - private val function = { - val parameterInfo = new SimpleGenericUDAFParameterInfo(inspectors, false, false) - resolver.getEvaluator(parameterInfo) - } - - private val returnInspector = function.init(GenericUDAFEvaluator.Mode.COMPLETE, inspectors) - - private val buffer = - function.getNewAggregationBuffer - - override def eval(input: InternalRow): Any = unwrap(function.evaluate(buffer), returnInspector) - - @transient - val inputProjection = new InterpretedProjection(exprs) - - @transient - protected lazy val cached = new Array[AnyRef](exprs.length) - - def update(input: InternalRow): Unit = { - val inputs = inputProjection(input) - function.iterate(buffer, wrap(inputs, inspectors, cached)) - } -} - http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala index ea325cc..7978fda 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/test/TestHive.scala @@ -391,7 +391,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { * Records the UDFs present when the server starts, so we can delete ones that are created by * tests. */ - protected val originalUdfs: JavaSet[String] = FunctionRegistry.getFunctionNames + protected val originalUDFs: JavaSet[String] = FunctionRegistry.getFunctionNames /** * Resets the test instance by deleting any tables that have been created. @@ -410,7 +410,7 @@ class TestHiveContext(sc: SparkContext) extends HiveContext(sc) { catalog.client.reset() catalog.unregisterAllTables() - FunctionRegistry.getFunctionNames.filterNot(originalUdfs.contains(_)).foreach { udfName => + FunctionRegistry.getFunctionNames.filterNot(originalUDFs.contains(_)).foreach { udfName => FunctionRegistry.unregisterTemporaryUDF(udfName) } http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/resources/data/files/testUDF/part-00000 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/data/files/testUDF/part-00000 b/sql/hive/src/test/resources/data/files/testUDF/part-00000 new file mode 100755 index 0000000..240a5c1 Binary files /dev/null and b/sql/hive/src/test/resources/data/files/testUDF/part-00000 differ http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/resources/data/files/testUdf/part-00000 ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/resources/data/files/testUdf/part-00000 b/sql/hive/src/test/resources/data/files/testUdf/part-00000 deleted file mode 100755 index 240a5c1..0000000 Binary files a/sql/hive/src/test/resources/data/files/testUdf/part-00000 and /dev/null differ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org