This is an automated email from the ASF dual-hosted git repository. dongjoon pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new 667c0a9dbbe0 [SPARK-46791][SQL] Support Java Set in JavaTypeInference 667c0a9dbbe0 is described below commit 667c0a9dbbe045c73842a345c1b3897b155564d4 Author: Liang-Chi Hsieh <vii...@gmail.com> AuthorDate: Mon Jan 22 02:13:12 2024 -0800 [SPARK-46791][SQL] Support Java Set in JavaTypeInference ### What changes were proposed in this pull request? This patch adds the support of Java `Set` as bean field in `JavaTypeInference`. ### Why are the changes needed? Scala `Set` (`scala.collection.Set`) is supported in `ScalaReflection` so users can encode Scala `Set` in Dataset. But Java `Set` is not supported in bean encoder (i.e., `JavaTypeInference`). This feature inconsistency makes Java users cannot use `Set` like Scala users do. ### Does this PR introduce _any_ user-facing change? Yes. Java `Set` is supported to be part of Java bean when encoding with bean encoder. ### How was this patch tested? Added tests. ### Was this patch authored or co-authored using generative AI tooling? No Closes #44828 from viirya/java_set. Authored-by: Liang-Chi Hsieh <vii...@gmail.com> Signed-off-by: Dongjoon Hyun <dh...@apple.com> --- .../spark/sql/catalyst/JavaTypeInference.scala | 6 ++- .../sql/catalyst/expressions/objects/objects.scala | 50 ++++++++++++++++++++++ .../sql/catalyst/JavaTypeInferenceSuite.scala | 26 +++++++++-- .../expressions/ObjectExpressionsSuite.scala | 5 ++- .../org/apache/spark/sql/JavaDatasetSuite.java | 45 +++++++++++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 9 ++++ 6 files changed, 136 insertions(+), 5 deletions(-) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala index a945cb720b01..f85e96da2be1 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.beans.{Introspector, PropertyDescriptor} import java.lang.reflect.{ParameterizedType, Type, TypeVariable} -import java.util.{List => JList, Map => JMap} +import java.util.{List => JList, Map => JMap, Set => JSet} import javax.annotation.Nonnull import scala.jdk.CollectionConverters._ @@ -112,6 +112,10 @@ object JavaTypeInference { val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false) + case c: Class[_] if classOf[JSet[_]].isAssignableFrom(c) => + val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) + IterableEncoder(ClassTag(c), element, element.nullable, lenientSerialization = false) + case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) => val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet, typeVariables) val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet, typeVariables) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala index bae2922cf921..a684ca18435e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala @@ -907,6 +907,8 @@ case class MapObjects private( _.asInstanceOf[Array[_]].toImmutableArraySeq case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) => _.asInstanceOf[java.util.List[_]].asScala.toSeq + case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) => + _.asInstanceOf[java.util.Set[_]].asScala.toSeq case ObjectType(cls) if cls == classOf[Object] => (inputCollection) => { if (inputCollection.getClass.isArray) { @@ -982,6 +984,34 @@ case class MapObjects private( builder } } + case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) => + // Java set + if (cls == classOf[java.util.Set[_]] || cls == classOf[java.util.AbstractSet[_]]) { + // Specifying non concrete implementations of `java.util.Set` + executeFuncOnCollection(_).toSet.asJava + } else { + val constructors = cls.getConstructors() + val intParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 1 && constructor.getParameterTypes()(0) == classOf[Int] + } + val noParamConstructor = constructors.find { constructor => + constructor.getParameterCount == 0 + } + + val constructor = intParamConstructor.map { intConstructor => + (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object]) + }.getOrElse { + (_: Int) => noParamConstructor.get.newInstance() + } + + // Specifying concrete implementations of `java.util.Set` + (inputs) => { + val results = executeFuncOnCollection(inputs) + val builder = constructor(inputs.length).asInstanceOf[java.util.Set[Any]] + results.foreach(builder.add(_)) + builder + } + } case None => // array x => new GenericArrayData(executeFuncOnCollection(x).toArray) @@ -1067,6 +1097,13 @@ case class MapObjects private( s"java.util.Iterator $it = ${genInputData.value}.iterator();", s"$it.next()" ) + case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) => + val it = ctx.freshName("it") + ( + s"${genInputData.value}.size()", + s"java.util.Iterator $it = ${genInputData.value}.iterator();", + s"$it.next()" + ) case ArrayType(et, _) => ( s"${genInputData.value}.numElements()", @@ -1158,6 +1195,19 @@ case class MapObjects private( (genValue: String) => s"$builder.add($genValue);", s"$builder;" ) + case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) => + // Java set + val builder = ctx.freshName("collectionBuilder") + ( + if (cls == classOf[java.util.Set[_]] || cls == classOf[java.util.AbstractSet[_]]) { + s"${cls.getName} $builder = new java.util.HashSet($dataLength);" + } else { + val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => dataLength).getOrElse("") + s"${cls.getName} $builder = new ${cls.getName}($param);" + }, + (genValue: String) => s"$builder.add($genValue);", + s"$builder;" + ) case _ => // array ( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala index bef0cf8831eb..c785c71428ca 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst import java.math.BigInteger -import java.util.{LinkedList, List => JList, Map => JMap} +import java.util.{HashSet, LinkedList, List => JList, Map => JMap, Set => JSet} import scala.beans.{BeanProperty, BooleanBeanProperty} import scala.reflect.{classTag, ClassTag} @@ -37,6 +37,8 @@ class GenericCollectionBean { @BeanProperty var listOfListOfStrings: JList[JList[String]] = _ @BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _ @BeanProperty var linkedListOfStrings: LinkedList[String] = _ + @BeanProperty var hashSetOfString: HashSet[String] = _ + @BeanProperty var setOfSetOfStrings: JSet[JSet[String]] = _ } class LeafBean { @@ -139,9 +141,16 @@ class JavaTypeInferenceSuite extends SparkFunSuite { assert(schema === expected) } - test("resolve type parameters for map and list") { + test("resolve type parameters for map, list and set") { val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean]) val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), Seq( + encoderField( + "hashSetOfString", + IterableEncoder( + ClassTag(classOf[HashSet[_]]), + StringEncoder, + containsNull = true, + lenientSerialization = false)), encoderField( "linkedListOfStrings", IterableEncoder( @@ -166,7 +175,18 @@ class JavaTypeInferenceSuite extends SparkFunSuite { ClassTag(classOf[JMap[_, _]]), StringEncoder, expectedDummyBeanEncoder, - valueContainsNull = true)))) + valueContainsNull = true)), + encoderField( + "setOfSetOfStrings", + IterableEncoder( + ClassTag(classOf[JSet[_]]), + IterableEncoder( + ClassTag(classOf[JSet[_]]), + StringEncoder, + containsNull = true, + lenientSerialization = false), + containsNull = true, + lenientSerialization = false)))) assert(encoder === expected) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala index 538a7600b02a..7f58516cf4eb 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala @@ -362,6 +362,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { assert(result.asInstanceOf[ArrayData].array.toSeq == expected) case l if classOf[java.util.List[_]].isAssignableFrom(l) => assert(result.asInstanceOf[java.util.List[_]].asScala == expected) + case s if classOf[java.util.Set[_]].isAssignableFrom(s) => + assert(result.asInstanceOf[java.util.Set[_]].asScala == expected.toSet) case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) => assert(result == mutable.ArraySeq.make[Int](expected.toArray)) case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) => @@ -379,7 +381,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { classOf[Seq[Int]], classOf[scala.collection.Set[Int]], classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]], classOf[java.util.AbstractSequentialList[Int]], classOf[java.util.Vector[Int]], - classOf[java.util.Stack[Int]], null) + classOf[java.util.Stack[Int]], null, + classOf[java.util.Set[Int]]) val list = new java.util.ArrayList[Int]() list.add(1) diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java index 254c6df28209..bd776300bd5e 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java @@ -111,6 +111,26 @@ public class JavaDatasetSuite implements Serializable { Assertions.assertEquals(ds.schema(), ds2.schema()); } + @Test + public void testBeanWithSet() { + BeanWithSet bean = new BeanWithSet(); + Set<Long> fields = asSet(1L, 2L, 3L); + bean.setFields(fields); + List<BeanWithSet> objects = Collections.singletonList(bean); + + Dataset<BeanWithSet> ds = spark.createDataset(objects, Encoders.bean(BeanWithSet.class)); + Dataset<Row> df = ds.toDF(); + + Dataset<BeanWithSet> mapped = + df.map((MapFunction<Row, BeanWithSet>) row -> { + BeanWithSet obj = new BeanWithSet(); + obj.setFields(new HashSet<>(row.<Long>getList(row.fieldIndex("fields")))); + return obj; + }, Encoders.bean(BeanWithSet.class)); + + Assertions.assertEquals(objects, mapped.collectAsList()); + } + @Test public void testCommonOperation() { List<String> data = Arrays.asList("hello", "world"); @@ -1989,6 +2009,31 @@ public class JavaDatasetSuite implements Serializable { Assertions.assertEquals(expected, df.collectAsList()); } + public static class BeanWithSet implements Serializable { + private Set<Long> fields; + + public Set<Long> getFields() { + return fields; + } + + public void setFields(Set<Long> fields) { + this.fields = fields; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + BeanWithSet that = (BeanWithSet) o; + return Objects.equal(fields, that.fields); + } + + @Override + public int hashCode() { + return Objects.hashCode(fields); + } + } + public static class SpecificListsBean implements Serializable { private ArrayList<Integer> arrayList; private LinkedList<Integer> linkedList; diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index cd28c60d83c7..f0f48026a4a0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql import java.io.{Externalizable, ObjectInput, ObjectOutput} import java.sql.{Date, Timestamp} +import scala.collection.immutable.HashSet import scala.reflect.ClassTag import scala.util.Random @@ -2706,6 +2707,12 @@ class DatasetSuite extends QueryTest assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length == 2) } } + + test("SPARK-46791: Dataset with set field") { + val ds = Seq(WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo"))).toDS() + checkDataset(ds.map(t => t), + WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo"))) + } } class DatasetLargeResultCollectingSuite extends QueryTest @@ -2759,6 +2766,8 @@ case class WithImmutableMap(id: String, map_test: scala.collection.immutable.Map case class WithMap(id: String, map_test: scala.collection.Map[Long, String]) case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]]) +case class WithSet(id: Int, values: Set[String]) + case class Generic[T](id: T, value: Double) case class OtherTuple(_1: String, _2: Int) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org