Repository: spark Updated Branches: refs/heads/master 26ac085de -> 48e44b24a
[SPARK-21204][SQL] Add support for Scala Set collection types in serialization ## What changes were proposed in this pull request? Currently we can't produce a `Dataset` containing `Set` in SparkSQL. This PR tries to support serialization/deserialization of `Set`. Because there's no corresponding internal data type in SparkSQL for a `Set`, the most proper choice for serializing a set should be an array. ## How was this patch tested? Added unit tests. Author: Liang-Chi Hsieh <vii...@gmail.com> Closes #18416 from viirya/SPARK-21204. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/48e44b24 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/48e44b24 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/48e44b24 Branch: refs/heads/master Commit: 48e44b24a7663142176102ac4c6bf4242f103804 Parents: 26ac085 Author: Liang-Chi Hsieh <vii...@gmail.com> Authored: Fri Jul 7 01:07:45 2017 +0800 Committer: Wenchen Fan <wenc...@databricks.com> Committed: Fri Jul 7 01:07:45 2017 +0800 ---------------------------------------------------------------------- .../spark/sql/catalyst/ScalaReflection.scala | 28 ++++++++++++++++-- .../catalyst/expressions/objects/objects.scala | 5 ++-- .../org/apache/spark/sql/SQLImplicits.scala | 10 +++++++ .../spark/sql/DataFrameAggregateSuite.scala | 10 +++++++ .../spark/sql/DatasetPrimitiveSuite.scala | 31 ++++++++++++++++++++ 5 files changed, 79 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala index 814f2c1..4d5401f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala @@ -309,7 +309,10 @@ object ScalaReflection extends ScalaReflection { Invoke(arrayData, primitiveMethod, arrayCls, returnNullable = false) } - case t if t <:< localTypeOf[Seq[_]] => + // We serialize a `Set` to Catalyst array. When we deserialize a Catalyst array + // to a `Set`, if there are duplicated elements, the elements will be de-duplicated. + case t if t <:< localTypeOf[Seq[_]] || + t <:< localTypeOf[scala.collection.Set[_]] => val TypeRef(_, _, Seq(elementType)) = t val Schema(dataType, elementNullable) = schemaFor(elementType) val className = getClassNameFromType(elementType) @@ -327,8 +330,10 @@ object ScalaReflection extends ScalaReflection { } val companion = t.normalize.typeSymbol.companionSymbol.typeSignature - val cls = companion.declaration(newTermName("newBuilder")) match { - case NoSymbol => classOf[Seq[_]] + val cls = companion.member(newTermName("newBuilder")) match { + case NoSymbol if t <:< localTypeOf[Seq[_]] => classOf[Seq[_]] + case NoSymbol if t <:< localTypeOf[scala.collection.Set[_]] => + classOf[scala.collection.Set[_]] case _ => mirror.runtimeClass(t.typeSymbol.asClass) } UnresolvedMapObjects(mapFunction, getPath, Some(cls)) @@ -502,6 +507,19 @@ object ScalaReflection extends ScalaReflection { serializerFor(_, valueType, valuePath, seenTypeSet), valueNullable = !valueType.typeSymbol.asClass.isPrimitive) + case t if t <:< localTypeOf[scala.collection.Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + + // There's no corresponding Catalyst type for `Set`, we serialize a `Set` to Catalyst array. + // Note that the property of `Set` is only kept when manipulating the data as domain object. + val newInput = + Invoke( + inputObject, + "toSeq", + ObjectType(classOf[Seq[_]])) + + toCatalystArray(newInput, elementType) + case t if t <:< localTypeOf[String] => StaticInvoke( classOf[UTF8String], @@ -713,6 +731,10 @@ object ScalaReflection extends ScalaReflection { val Schema(valueDataType, valueNullable) = schemaFor(valueType) Schema(MapType(schemaFor(keyType).dataType, valueDataType, valueContainsNull = valueNullable), nullable = true) + case t if t <:< localTypeOf[Set[_]] => + val TypeRef(_, _, Seq(elementType)) = t + val Schema(dataType, nullable) = schemaFor(elementType) + Schema(ArrayType(dataType, containsNull = nullable), nullable = true) case t if t <:< localTypeOf[String] => Schema(StringType, nullable = true) case t if t <:< localTypeOf[java.sql.Timestamp] => Schema(TimestampType, nullable = true) case t if t <:< localTypeOf[java.sql.Date] => Schema(DateType, nullable = true) http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index 24c06d8..9b28a18 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -627,8 +627,9 @@ case class MapObjects private( val (initCollection, addElement, getResult): (String, String => String, String) = customCollectionCls match { - case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) => - // Scala sequence + case Some(cls) if classOf[Seq[_]].isAssignableFrom(cls) || + classOf[scala.collection.Set[_]].isAssignableFrom(cls) => + // Scala sequence or set val getBuilder = s"${cls.getName}$$.MODULE$$.newBuilder()" val builder = ctx.freshName("collectionBuilder") ( http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala index 86574e2..05db292 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala @@ -171,6 +171,16 @@ abstract class SQLImplicits extends LowPrioritySQLImplicits { /** @since 2.3.0 */ implicit def newMapEncoder[T <: Map[_, _] : TypeTag]: Encoder[T] = ExpressionEncoder() + /** + * Notice that we serialize `Set` to Catalyst array. The set property is only kept when + * manipulating the domain objects. The serialization format doesn't keep the set property. + * When we have a Catalyst array which contains duplicated elements and convert it to + * `Dataset[Set[T]]` by using the encoder, the elements will be de-duplicated. + * + * @since 2.3.0 + */ + implicit def newSetEncoder[T <: Set[_] : TypeTag]: Encoder[T] = ExpressionEncoder() + // Arrays /** @since 1.6.1 */ http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 5db354d..b52d50b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -460,6 +460,16 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { df.select(collect_set($"a"), collect_set($"b")), Seq(Row(Seq(1, 2, 3), Seq(2, 4))) ) + + checkDataset( + df.select(collect_set($"a").as("aSet")).as[Set[Int]], + Set(1, 2, 3)) + checkDataset( + df.select(collect_set($"b").as("bSet")).as[Set[Int]], + Set(2, 4)) + checkDataset( + df.select(collect_set($"a"), collect_set($"b")).as[(Set[Int], Set[Int])], + Seq(Set(1, 2, 3) -> Set(2, 4)): _*) } test("collect functions structs") { http://git-wip-us.apache.org/repos/asf/spark/blob/48e44b24/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala index a6847dc..f62f9e2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetPrimitiveSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql +import scala.collection.immutable.{HashSet => HSet} import scala.collection.immutable.Queue import scala.collection.mutable.{LinkedHashMap => LHMap} import scala.collection.mutable.ArrayBuffer @@ -342,6 +343,31 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { LHMapClass(LHMap(1 -> 2)) -> LHMap("test" -> MapClass(Map(3 -> 4)))) } + test("arbitrary sets") { + checkDataset(Seq(Set(1, 2, 3, 4)).toDS(), Set(1, 2, 3, 4)) + checkDataset(Seq(Set(1.toLong, 2.toLong)).toDS(), Set(1.toLong, 2.toLong)) + checkDataset(Seq(Set(1.toDouble, 2.toDouble)).toDS(), Set(1.toDouble, 2.toDouble)) + checkDataset(Seq(Set(1.toFloat, 2.toFloat)).toDS(), Set(1.toFloat, 2.toFloat)) + checkDataset(Seq(Set(1.toByte, 2.toByte)).toDS(), Set(1.toByte, 2.toByte)) + checkDataset(Seq(Set(1.toShort, 2.toShort)).toDS(), Set(1.toShort, 2.toShort)) + checkDataset(Seq(Set(true, false)).toDS(), Set(true, false)) + checkDataset(Seq(Set("test1", "test2")).toDS(), Set("test1", "test2")) + checkDataset(Seq(Set(Tuple1(1), Tuple1(2))).toDS(), Set(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(HSet(1, 2)).toDS(), HSet(1, 2)) + checkDataset(Seq(HSet(1.toLong, 2.toLong)).toDS(), HSet(1.toLong, 2.toLong)) + checkDataset(Seq(HSet(1.toDouble, 2.toDouble)).toDS(), HSet(1.toDouble, 2.toDouble)) + checkDataset(Seq(HSet(1.toFloat, 2.toFloat)).toDS(), HSet(1.toFloat, 2.toFloat)) + checkDataset(Seq(HSet(1.toByte, 2.toByte)).toDS(), HSet(1.toByte, 2.toByte)) + checkDataset(Seq(HSet(1.toShort, 2.toShort)).toDS(), HSet(1.toShort, 2.toShort)) + checkDataset(Seq(HSet(true, false)).toDS(), HSet(true, false)) + checkDataset(Seq(HSet("test1", "test2")).toDS(), HSet("test1", "test2")) + checkDataset(Seq(HSet(Tuple1(1), Tuple1(2))).toDS(), HSet(Tuple1(1), Tuple1(2))) + + checkDataset(Seq(Seq(Some(1), None), Seq(Some(2))).toDF("c").as[Set[Integer]], + Seq(Set[Integer](1, null), Set[Integer](2)): _*) + } + test("nested sequences") { checkDataset(Seq(Seq(Seq(1))).toDS(), Seq(Seq(1))) checkDataset(Seq(List(Queue(1))).toDS(), List(Queue(1))) @@ -352,6 +378,11 @@ class DatasetPrimitiveSuite extends QueryTest with SharedSQLContext { checkDataset(Seq(LHMap(Map(1 -> 2) -> 3)).toDS(), LHMap(Map(1 -> 2) -> 3)) } + test("nested set") { + checkDataset(Seq(Set(HSet(1, 2), HSet(3, 4))).toDS(), Set(HSet(1, 2), HSet(3, 4))) + checkDataset(Seq(HSet(Set(1, 2), Set(3, 4))).toDS(), HSet(Set(1, 2), Set(3, 4))) + } + test("package objects") { import packageobject._ checkDataset(Seq(PackageClass(1)).toDS(), PackageClass(1)) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org