This is an automated email from the ASF dual-hosted git repository.

dongjoon pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 667c0a9dbbe0 [SPARK-46791][SQL] Support Java Set in JavaTypeInference
667c0a9dbbe0 is described below

commit 667c0a9dbbe045c73842a345c1b3897b155564d4
Author: Liang-Chi Hsieh <vii...@gmail.com>
AuthorDate: Mon Jan 22 02:13:12 2024 -0800

    [SPARK-46791][SQL] Support Java Set in JavaTypeInference
    
    ### What changes were proposed in this pull request?
    
    This patch adds the support of Java `Set` as bean field in 
`JavaTypeInference`.
    
    ### Why are the changes needed?
    
    Scala `Set` (`scala.collection.Set`) is supported in `ScalaReflection` so 
users can encode Scala `Set` in Dataset. But Java `Set` is not supported in 
bean encoder (i.e., `JavaTypeInference`). This feature inconsistency makes Java 
users cannot use `Set` like Scala users do.
    
    ### Does this PR introduce _any_ user-facing change?
    
    Yes. Java `Set` is supported to be part of Java bean when encoding with 
bean encoder.
    
    ### How was this patch tested?
    
    Added tests.
    
    ### Was this patch authored or co-authored using generative AI tooling?
    
    No
    
    Closes #44828 from viirya/java_set.
    
    Authored-by: Liang-Chi Hsieh <vii...@gmail.com>
    Signed-off-by: Dongjoon Hyun <dh...@apple.com>
---
 .../spark/sql/catalyst/JavaTypeInference.scala     |  6 ++-
 .../sql/catalyst/expressions/objects/objects.scala | 50 ++++++++++++++++++++++
 .../sql/catalyst/JavaTypeInferenceSuite.scala      | 26 +++++++++--
 .../expressions/ObjectExpressionsSuite.scala       |  5 ++-
 .../org/apache/spark/sql/JavaDatasetSuite.java     | 45 +++++++++++++++++++
 .../scala/org/apache/spark/sql/DatasetSuite.scala  |  9 ++++
 6 files changed, 136 insertions(+), 5 deletions(-)

diff --git 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index a945cb720b01..f85e96da2be1 100644
--- 
a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -18,7 +18,7 @@ package org.apache.spark.sql.catalyst
 
 import java.beans.{Introspector, PropertyDescriptor}
 import java.lang.reflect.{ParameterizedType, Type, TypeVariable}
-import java.util.{List => JList, Map => JMap}
+import java.util.{List => JList, Map => JMap, Set => JSet}
 import javax.annotation.Nonnull
 
 import scala.jdk.CollectionConverters._
@@ -112,6 +112,10 @@ object JavaTypeInference {
       val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, 
typeVariables)
       IterableEncoder(ClassTag(c), element, element.nullable, 
lenientSerialization = false)
 
+    case c: Class[_] if classOf[JSet[_]].isAssignableFrom(c) =>
+      val element = encoderFor(c.getTypeParameters.array(0), seenTypeSet, 
typeVariables)
+      IterableEncoder(ClassTag(c), element, element.nullable, 
lenientSerialization = false)
+
     case c: Class[_] if classOf[JMap[_, _]].isAssignableFrom(c) =>
       val keyEncoder = encoderFor(c.getTypeParameters.array(0), seenTypeSet, 
typeVariables)
       val valueEncoder = encoderFor(c.getTypeParameters.array(1), seenTypeSet, 
typeVariables)
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index bae2922cf921..a684ca18435e 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -907,6 +907,8 @@ case class MapObjects private(
       _.asInstanceOf[Array[_]].toImmutableArraySeq
     case ObjectType(cls) if classOf[java.util.List[_]].isAssignableFrom(cls) =>
       _.asInstanceOf[java.util.List[_]].asScala.toSeq
+    case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+      _.asInstanceOf[java.util.Set[_]].asScala.toSeq
     case ObjectType(cls) if cls == classOf[Object] =>
       (inputCollection) => {
         if (inputCollection.getClass.isArray) {
@@ -982,6 +984,34 @@ case class MapObjects private(
           builder
         }
       }
+    case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+      // Java set
+      if (cls == classOf[java.util.Set[_]] || cls == 
classOf[java.util.AbstractSet[_]]) {
+        // Specifying non concrete implementations of `java.util.Set`
+        executeFuncOnCollection(_).toSet.asJava
+      } else {
+        val constructors = cls.getConstructors()
+        val intParamConstructor = constructors.find { constructor =>
+          constructor.getParameterCount == 1 && 
constructor.getParameterTypes()(0) == classOf[Int]
+        }
+        val noParamConstructor = constructors.find { constructor =>
+          constructor.getParameterCount == 0
+        }
+
+        val constructor = intParamConstructor.map { intConstructor =>
+          (len: Int) => intConstructor.newInstance(len.asInstanceOf[Object])
+        }.getOrElse {
+          (_: Int) => noParamConstructor.get.newInstance()
+        }
+
+        // Specifying concrete implementations of `java.util.Set`
+        (inputs) => {
+          val results = executeFuncOnCollection(inputs)
+          val builder = 
constructor(inputs.length).asInstanceOf[java.util.Set[Any]]
+          results.foreach(builder.add(_))
+          builder
+        }
+      }
     case None =>
       // array
       x => new GenericArrayData(executeFuncOnCollection(x).toArray)
@@ -1067,6 +1097,13 @@ case class MapObjects private(
           s"java.util.Iterator $it = ${genInputData.value}.iterator();",
           s"$it.next()"
         )
+      case ObjectType(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) 
=>
+        val it = ctx.freshName("it")
+        (
+          s"${genInputData.value}.size()",
+          s"java.util.Iterator $it = ${genInputData.value}.iterator();",
+          s"$it.next()"
+        )
       case ArrayType(et, _) =>
         (
           s"${genInputData.value}.numElements()",
@@ -1158,6 +1195,19 @@ case class MapObjects private(
             (genValue: String) => s"$builder.add($genValue);",
             s"$builder;"
           )
+        case Some(cls) if classOf[java.util.Set[_]].isAssignableFrom(cls) =>
+          // Java set
+          val builder = ctx.freshName("collectionBuilder")
+          (
+            if (cls == classOf[java.util.Set[_]] || cls == 
classOf[java.util.AbstractSet[_]]) {
+              s"${cls.getName} $builder = new java.util.HashSet($dataLength);"
+            } else {
+              val param = Try(cls.getConstructor(Integer.TYPE)).map(_ => 
dataLength).getOrElse("")
+              s"${cls.getName} $builder = new ${cls.getName}($param);"
+            },
+            (genValue: String) => s"$builder.add($genValue);",
+            s"$builder;"
+          )
         case _ =>
           // array
           (
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
index bef0cf8831eb..c785c71428ca 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/JavaTypeInferenceSuite.scala
@@ -18,7 +18,7 @@
 package org.apache.spark.sql.catalyst
 
 import java.math.BigInteger
-import java.util.{LinkedList, List => JList, Map => JMap}
+import java.util.{HashSet, LinkedList, List => JList, Map => JMap, Set => JSet}
 
 import scala.beans.{BeanProperty, BooleanBeanProperty}
 import scala.reflect.{classTag, ClassTag}
@@ -37,6 +37,8 @@ class GenericCollectionBean {
   @BeanProperty var listOfListOfStrings: JList[JList[String]] = _
   @BeanProperty var mapOfDummyBeans: JMap[String, DummyBean] = _
   @BeanProperty var linkedListOfStrings: LinkedList[String] = _
+  @BeanProperty var hashSetOfString: HashSet[String] = _
+  @BeanProperty var setOfSetOfStrings: JSet[JSet[String]] = _
 }
 
 class LeafBean {
@@ -139,9 +141,16 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
     assert(schema === expected)
   }
 
-  test("resolve type parameters for map and list") {
+  test("resolve type parameters for map, list and set") {
     val encoder = JavaTypeInference.encoderFor(classOf[GenericCollectionBean])
     val expected = JavaBeanEncoder(ClassTag(classOf[GenericCollectionBean]), 
Seq(
+      encoderField(
+        "hashSetOfString",
+        IterableEncoder(
+          ClassTag(classOf[HashSet[_]]),
+          StringEncoder,
+          containsNull = true,
+          lenientSerialization = false)),
       encoderField(
         "linkedListOfStrings",
         IterableEncoder(
@@ -166,7 +175,18 @@ class JavaTypeInferenceSuite extends SparkFunSuite {
           ClassTag(classOf[JMap[_, _]]),
           StringEncoder,
           expectedDummyBeanEncoder,
-          valueContainsNull = true))))
+          valueContainsNull = true)),
+      encoderField(
+        "setOfSetOfStrings",
+        IterableEncoder(
+          ClassTag(classOf[JSet[_]]),
+          IterableEncoder(
+            ClassTag(classOf[JSet[_]]),
+            StringEncoder,
+            containsNull = true,
+            lenientSerialization = false),
+          containsNull = true,
+          lenientSerialization = false))))
     assert(encoder === expected)
   }
 
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
index 538a7600b02a..7f58516cf4eb 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ObjectExpressionsSuite.scala
@@ -362,6 +362,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
           assert(result.asInstanceOf[ArrayData].array.toSeq == expected)
         case l if classOf[java.util.List[_]].isAssignableFrom(l) =>
           assert(result.asInstanceOf[java.util.List[_]].asScala == expected)
+        case s if classOf[java.util.Set[_]].isAssignableFrom(s) =>
+          assert(result.asInstanceOf[java.util.Set[_]].asScala == 
expected.toSet)
         case a if classOf[mutable.ArraySeq[Int]].isAssignableFrom(a) =>
           assert(result == mutable.ArraySeq.make[Int](expected.toArray))
         case a if classOf[immutable.ArraySeq[Int]].isAssignableFrom(a) =>
@@ -379,7 +381,8 @@ class ObjectExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
       classOf[Seq[Int]], classOf[scala.collection.Set[Int]],
       classOf[java.util.List[Int]], classOf[java.util.AbstractList[Int]],
       classOf[java.util.AbstractSequentialList[Int]], 
classOf[java.util.Vector[Int]],
-      classOf[java.util.Stack[Int]], null)
+      classOf[java.util.Stack[Int]], null,
+      classOf[java.util.Set[Int]])
 
     val list = new java.util.ArrayList[Int]()
     list.add(1)
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 254c6df28209..bd776300bd5e 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -111,6 +111,26 @@ public class JavaDatasetSuite implements Serializable {
     Assertions.assertEquals(ds.schema(), ds2.schema());
   }
 
+  @Test
+  public void testBeanWithSet() {
+    BeanWithSet bean = new BeanWithSet();
+    Set<Long> fields = asSet(1L, 2L, 3L);
+    bean.setFields(fields);
+    List<BeanWithSet> objects = Collections.singletonList(bean);
+
+    Dataset<BeanWithSet> ds = spark.createDataset(objects, 
Encoders.bean(BeanWithSet.class));
+    Dataset<Row> df = ds.toDF();
+
+    Dataset<BeanWithSet> mapped =
+            df.map((MapFunction<Row, BeanWithSet>) row -> {
+              BeanWithSet obj = new BeanWithSet();
+              obj.setFields(new 
HashSet<>(row.<Long>getList(row.fieldIndex("fields"))));
+              return obj;
+            }, Encoders.bean(BeanWithSet.class));
+
+    Assertions.assertEquals(objects, mapped.collectAsList());
+  }
+
   @Test
   public void testCommonOperation() {
     List<String> data = Arrays.asList("hello", "world");
@@ -1989,6 +2009,31 @@ public class JavaDatasetSuite implements Serializable {
     Assertions.assertEquals(expected, df.collectAsList());
   }
 
+  public static class BeanWithSet implements Serializable {
+    private Set<Long> fields;
+
+    public Set<Long> getFields() {
+      return fields;
+    }
+
+    public void setFields(Set<Long> fields) {
+      this.fields = fields;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) return true;
+      if (o == null || getClass() != o.getClass()) return false;
+      BeanWithSet that = (BeanWithSet) o;
+      return Objects.equal(fields, that.fields);
+    }
+
+    @Override
+    public int hashCode() {
+      return Objects.hashCode(fields);
+    }
+  }
+
   public static class SpecificListsBean implements Serializable {
     private ArrayList<Integer> arrayList;
     private LinkedList<Integer> linkedList;
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index cd28c60d83c7..f0f48026a4a0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
 import java.io.{Externalizable, ObjectInput, ObjectOutput}
 import java.sql.{Date, Timestamp}
 
+import scala.collection.immutable.HashSet
 import scala.reflect.ClassTag
 import scala.util.Random
 
@@ -2706,6 +2707,12 @@ class DatasetSuite extends QueryTest
       
assert(exception.context.head.asInstanceOf[DataFrameQueryContext].stackTrace.length
 == 2)
     }
   }
+
+  test("SPARK-46791: Dataset with set field") {
+    val ds = Seq(WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", 
"zoo"))).toDS()
+    checkDataset(ds.map(t => t),
+      WithSet(0, HashSet("foo", "bar")), WithSet(1, HashSet("bar", "zoo")))
+  }
 }
 
 class DatasetLargeResultCollectingSuite extends QueryTest
@@ -2759,6 +2766,8 @@ case class WithImmutableMap(id: String, map_test: 
scala.collection.immutable.Map
 case class WithMap(id: String, map_test: scala.collection.Map[Long, String])
 case class WithMapInOption(m: Option[scala.collection.Map[Int, Int]])
 
+case class WithSet(id: Int, values: Set[String])
+
 case class Generic[T](id: T, value: Double)
 
 case class OtherTuple(_1: String, _2: Int)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to