Repository: spark Updated Branches: refs/heads/master c8ae887ef -> 931da5c8a
http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala new file mode 100644 index 0000000..56b0bef --- /dev/null +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUDFSuite.scala @@ -0,0 +1,261 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.hive.execution + +import java.io.{DataInput, DataOutput} +import java.util +import java.util.Properties + +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} +import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject +import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory +import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} +import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} +import org.apache.hadoop.io.Writable +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.hive.test.TestHive + +import org.apache.spark.util.Utils + +import scala.collection.JavaConversions._ + +case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) + +// Case classes for the custom UDF's. +case class IntegerCaseClass(i: Int) +case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) +case class StringCaseClass(s: String) +case class ListStringCaseClass(l: Seq[String]) + +/** + * A test suite for Hive custom UDFs. + */ +class HiveUDFSuite extends QueryTest { + + import TestHive.{udf, sql} + import TestHive.implicits._ + + test("spark sql udf test that returns a struct") { + udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) + assert(sql( + """ + |SELECT getStruct(1).f1, + | getStruct(1).f2, + | getStruct(1).f3, + | getStruct(1).f4, + | getStruct(1).f5 FROM src LIMIT 1 + """.stripMargin).head() === Row(1, 2, 3, 4, 5)) + } + + test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { + checkAnswer( + sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), + Row(8) + ) + } + + test("hive struct udf") { + sql( + """ + |CREATE EXTERNAL TABLE hiveUDFTestTable ( + | pair STRUCT<id: INT, value: INT> + |) + |PARTITIONED BY (partition STRING) + |ROW FORMAT SERDE '%s' + |STORED AS SEQUENCEFILE + """. + stripMargin.format(classOf[PairSerDe].getName)) + + val location = Utils.getSparkClassLoader.getResource("data/files/testUDF").getFile + sql(s""" + ALTER TABLE hiveUDFTestTable + ADD IF NOT EXISTS PARTITION(partition='testUDF') + LOCATION '$location'""") + + sql(s"CREATE TEMPORARY FUNCTION testUDF AS '${classOf[PairUDF].getName}'") + sql("SELECT testUDF(pair) FROM hiveUDFTestTable") + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDF") + } + + test("SPARK-6409 UDAFAverage test") { + sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") + checkAnswer( + sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), + Seq(Row(1.0, 260.182))) + sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") + TestHive.reset() + } + + test("SPARK-2693 udaf aggregates test") { + checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src").collect().toSeq) + + checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), + sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) + } + + test("Generic UDAF aggregates") { + checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), + sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) + + checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), + sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) + } + + test("UDFIntegerToString") { + val testData = TestHive.sparkContext.parallelize( + IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() + testData.registerTempTable("integerTable") + + val udfName = classOf[UDFIntegerToString].getName + sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") + checkAnswer( + sql("SELECT testUDFIntegerToString(i) FROM integerTable"), + Seq(Row("1"), Row("2"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") + + TestHive.reset() + } + + test("UDFListListInt") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() + testData.registerTempTable("listListIntTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") + checkAnswer( + sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), + Seq(Row(0), Row(2), Row(13))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") + + TestHive.reset() + } + + test("UDFListString") { + val testData = TestHive.sparkContext.parallelize( + ListStringCaseClass(Seq("a", "b", "c")) :: + ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() + testData.registerTempTable("listStringTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") + checkAnswer( + sql("SELECT testUDFListString(l) FROM listStringTable"), + Seq(Row("a,b,c"), Row("d,e"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") + + TestHive.reset() + } + + test("UDFStringString") { + val testData = TestHive.sparkContext.parallelize( + StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() + testData.registerTempTable("stringTable") + + sql(s"CREATE TEMPORARY FUNCTION testStringStringUDF AS '${classOf[UDFStringString].getName}'") + checkAnswer( + sql("SELECT testStringStringUDF(\"hello\", s) FROM stringTable"), + Seq(Row("hello world"), Row("hello goodbye"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUDF") + + TestHive.reset() + } + + test("UDFTwoListList") { + val testData = TestHive.sparkContext.parallelize( + ListListIntCaseClass(Nil) :: + ListListIntCaseClass(Seq((1, 2, 3))) :: + ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: + Nil).toDF() + testData.registerTempTable("TwoListTable") + + sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") + checkAnswer( + sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), + Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) + sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") + + TestHive.reset() + } +} + +class TestPair(x: Int, y: Int) extends Writable with Serializable { + def this() = this(0, 0) + var entry: (Int, Int) = (x, y) + + override def write(output: DataOutput): Unit = { + output.writeInt(entry._1) + output.writeInt(entry._2) + } + + override def readFields(input: DataInput): Unit = { + val x = input.readInt() + val y = input.readInt() + entry = (x, y) + } +} + +class PairSerDe extends AbstractSerDe { + override def initialize(p1: Configuration, p2: Properties): Unit = {} + + override def getObjectInspector: ObjectInspector = { + ObjectInspectorFactory + .getStandardStructObjectInspector( + Seq("pair"), + Seq(ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector)) + )) + } + + override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] + + override def getSerDeStats: SerDeStats = null + + override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null + + override def deserialize(value: Writable): AnyRef = { + val pair = value.asInstanceOf[TestPair] + + val row = new util.ArrayList[util.ArrayList[AnyRef]] + row.add(new util.ArrayList[AnyRef](2)) + row(0).add(Integer.valueOf(pair.entry._1)) + row(0).add(Integer.valueOf(pair.entry._2)) + + row + } +} + +class PairUDF extends GenericUDF { + override def initialize(p1: Array[ObjectInspector]): ObjectInspector = + ObjectInspectorFactory.getStandardStructObjectInspector( + Seq("id", "value"), + Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, + PrimitiveObjectInspectorFactory.javaIntObjectInspector) + ) + + override def evaluate(args: Array[DeferredObject]): AnyRef = { + println("Type = %s".format(args(0).getClass.getName)) + Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) + } + + override def getDisplayString(p1: Array[String]): String = "" +} http://git-wip-us.apache.org/repos/asf/spark/blob/931da5c8/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala ---------------------------------------------------------------------- diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala deleted file mode 100644 index ce59858..0000000 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveUdfSuite.scala +++ /dev/null @@ -1,261 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.hive.execution - -import java.io.{DataInput, DataOutput} -import java.util -import java.util.Properties - -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.hive.ql.udf.generic.{GenericUDAFAverage, GenericUDF} -import org.apache.hadoop.hive.ql.udf.generic.GenericUDF.DeferredObject -import org.apache.hadoop.hive.serde2.objectinspector.primitive.PrimitiveObjectInspectorFactory -import org.apache.hadoop.hive.serde2.objectinspector.{ObjectInspector, ObjectInspectorFactory} -import org.apache.hadoop.hive.serde2.{AbstractSerDe, SerDeStats} -import org.apache.hadoop.io.Writable -import org.apache.spark.sql.{QueryTest, Row} -import org.apache.spark.sql.hive.test.TestHive - -import org.apache.spark.util.Utils - -import scala.collection.JavaConversions._ - -case class Fields(f1: Int, f2: Int, f3: Int, f4: Int, f5: Int) - -// Case classes for the custom UDF's. -case class IntegerCaseClass(i: Int) -case class ListListIntCaseClass(lli: Seq[(Int, Int, Int)]) -case class StringCaseClass(s: String) -case class ListStringCaseClass(l: Seq[String]) - -/** - * A test suite for Hive custom UDFs. - */ -class HiveUdfSuite extends QueryTest { - - import TestHive.{udf, sql} - import TestHive.implicits._ - - test("spark sql udf test that returns a struct") { - udf.register("getStruct", (_: Int) => Fields(1, 2, 3, 4, 5)) - assert(sql( - """ - |SELECT getStruct(1).f1, - | getStruct(1).f2, - | getStruct(1).f3, - | getStruct(1).f4, - | getStruct(1).f5 FROM src LIMIT 1 - """.stripMargin).head() === Row(1, 2, 3, 4, 5)) - } - - test("SPARK-4785 When called with arguments referring column fields, PMOD throws NPE") { - checkAnswer( - sql("SELECT PMOD(CAST(key as INT), 10) FROM src LIMIT 1"), - Row(8) - ) - } - - test("hive struct udf") { - sql( - """ - |CREATE EXTERNAL TABLE hiveUdfTestTable ( - | pair STRUCT<id: INT, value: INT> - |) - |PARTITIONED BY (partition STRING) - |ROW FORMAT SERDE '%s' - |STORED AS SEQUENCEFILE - """. - stripMargin.format(classOf[PairSerDe].getName)) - - val location = Utils.getSparkClassLoader.getResource("data/files/testUdf").getFile - sql(s""" - ALTER TABLE hiveUdfTestTable - ADD IF NOT EXISTS PARTITION(partition='testUdf') - LOCATION '$location'""") - - sql(s"CREATE TEMPORARY FUNCTION testUdf AS '${classOf[PairUdf].getName}'") - sql("SELECT testUdf(pair) FROM hiveUdfTestTable") - sql("DROP TEMPORARY FUNCTION IF EXISTS testUdf") - } - - test("SPARK-6409 UDAFAverage test") { - sql(s"CREATE TEMPORARY FUNCTION test_avg AS '${classOf[GenericUDAFAverage].getName}'") - checkAnswer( - sql("SELECT test_avg(1), test_avg(substr(value,5)) FROM src"), - Seq(Row(1.0, 260.182))) - sql("DROP TEMPORARY FUNCTION IF EXISTS test_avg") - TestHive.reset() - } - - test("SPARK-2693 udaf aggregates test") { - checkAnswer(sql("SELECT percentile(key, 1) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src").collect().toSeq) - - checkAnswer(sql("SELECT percentile(key, array(1, 1)) FROM src LIMIT 1"), - sql("SELECT array(max(key), max(key)) FROM src").collect().toSeq) - } - - test("Generic UDAF aggregates") { - checkAnswer(sql("SELECT ceiling(percentile_approx(key, 0.99999)) FROM src LIMIT 1"), - sql("SELECT max(key) FROM src LIMIT 1").collect().toSeq) - - checkAnswer(sql("SELECT percentile_approx(100.0, array(0.9, 0.9)) FROM src LIMIT 1"), - sql("SELECT array(100, 100) FROM src LIMIT 1").collect().toSeq) - } - - test("UDFIntegerToString") { - val testData = TestHive.sparkContext.parallelize( - IntegerCaseClass(1) :: IntegerCaseClass(2) :: Nil).toDF() - testData.registerTempTable("integerTable") - - val udfName = classOf[UDFIntegerToString].getName - sql(s"CREATE TEMPORARY FUNCTION testUDFIntegerToString AS '$udfName'") - checkAnswer( - sql("SELECT testUDFIntegerToString(i) FROM integerTable"), - Seq(Row("1"), Row("2"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFIntegerToString") - - TestHive.reset() - } - - test("UDFListListInt") { - val testData = TestHive.sparkContext.parallelize( - ListListIntCaseClass(Nil) :: - ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: Nil).toDF() - testData.registerTempTable("listListIntTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFListListInt AS '${classOf[UDFListListInt].getName}'") - checkAnswer( - sql("SELECT testUDFListListInt(lli) FROM listListIntTable"), - Seq(Row(0), Row(2), Row(13))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListListInt") - - TestHive.reset() - } - - test("UDFListString") { - val testData = TestHive.sparkContext.parallelize( - ListStringCaseClass(Seq("a", "b", "c")) :: - ListStringCaseClass(Seq("d", "e")) :: Nil).toDF() - testData.registerTempTable("listStringTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFListString AS '${classOf[UDFListString].getName}'") - checkAnswer( - sql("SELECT testUDFListString(l) FROM listStringTable"), - Seq(Row("a,b,c"), Row("d,e"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFListString") - - TestHive.reset() - } - - test("UDFStringString") { - val testData = TestHive.sparkContext.parallelize( - StringCaseClass("world") :: StringCaseClass("goodbye") :: Nil).toDF() - testData.registerTempTable("stringTable") - - sql(s"CREATE TEMPORARY FUNCTION testStringStringUdf AS '${classOf[UDFStringString].getName}'") - checkAnswer( - sql("SELECT testStringStringUdf(\"hello\", s) FROM stringTable"), - Seq(Row("hello world"), Row("hello goodbye"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testStringStringUdf") - - TestHive.reset() - } - - test("UDFTwoListList") { - val testData = TestHive.sparkContext.parallelize( - ListListIntCaseClass(Nil) :: - ListListIntCaseClass(Seq((1, 2, 3))) :: - ListListIntCaseClass(Seq((4, 5, 6), (7, 8, 9))) :: - Nil).toDF() - testData.registerTempTable("TwoListTable") - - sql(s"CREATE TEMPORARY FUNCTION testUDFTwoListList AS '${classOf[UDFTwoListList].getName}'") - checkAnswer( - sql("SELECT testUDFTwoListList(lli, lli) FROM TwoListTable"), - Seq(Row("0, 0"), Row("2, 2"), Row("13, 13"))) - sql("DROP TEMPORARY FUNCTION IF EXISTS testUDFTwoListList") - - TestHive.reset() - } -} - -class TestPair(x: Int, y: Int) extends Writable with Serializable { - def this() = this(0, 0) - var entry: (Int, Int) = (x, y) - - override def write(output: DataOutput): Unit = { - output.writeInt(entry._1) - output.writeInt(entry._2) - } - - override def readFields(input: DataInput): Unit = { - val x = input.readInt() - val y = input.readInt() - entry = (x, y) - } -} - -class PairSerDe extends AbstractSerDe { - override def initialize(p1: Configuration, p2: Properties): Unit = {} - - override def getObjectInspector: ObjectInspector = { - ObjectInspectorFactory - .getStandardStructObjectInspector( - Seq("pair"), - Seq(ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector)) - )) - } - - override def getSerializedClass: Class[_ <: Writable] = classOf[TestPair] - - override def getSerDeStats: SerDeStats = null - - override def serialize(p1: scala.Any, p2: ObjectInspector): Writable = null - - override def deserialize(value: Writable): AnyRef = { - val pair = value.asInstanceOf[TestPair] - - val row = new util.ArrayList[util.ArrayList[AnyRef]] - row.add(new util.ArrayList[AnyRef](2)) - row(0).add(Integer.valueOf(pair.entry._1)) - row(0).add(Integer.valueOf(pair.entry._2)) - - row - } -} - -class PairUdf extends GenericUDF { - override def initialize(p1: Array[ObjectInspector]): ObjectInspector = - ObjectInspectorFactory.getStandardStructObjectInspector( - Seq("id", "value"), - Seq(PrimitiveObjectInspectorFactory.javaIntObjectInspector, - PrimitiveObjectInspectorFactory.javaIntObjectInspector) - ) - - override def evaluate(args: Array[DeferredObject]): AnyRef = { - println("Type = %s".format(args(0).getClass.getName)) - Integer.valueOf(args(0).get.asInstanceOf[TestPair].entry._2) - } - - override def getDisplayString(p1: Array[String]): String = "" -} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org