This is an automated email from the ASF dual-hosted git repository.

srowen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 33c6d11  [SPARK-20384][SQL] Support value class in nested schema for 
Dataset
33c6d11 is described below

commit 33c6d1168c077630a8f81c1a4e153f862162f257
Author: Mick Jermsurawong <mickjermsuraw...@stripe.com>
AuthorDate: Mon Aug 9 08:47:35 2021 -0500

    [SPARK-20384][SQL] Support value class in nested schema for Dataset
    
    ### What changes were proposed in this pull request?
    
    - This PR revisits https://github.com/apache/spark/pull/22309, and 
[SPARK-20384](https://issues.apache.org/jira/browse/SPARK-20384) solving the 
original problem, but additionally will prevent backward-compat break on schema 
of top-level `AnyVal` value class.
    - Why previous break? We currently support top-level value classes just as 
any other case class; field of the underlying type is present in schema. This 
means any dataframe SQL filtering on this expects the field name to be present. 
The previous PR changes this schema and would result in breaking current usage. 
See test `"schema for case class that is a value class"`. This PR keeps the 
schema.
    - We actually currently support collection of value classes prior to this 
change, but not case class of nested value class. This means the schema of 
these classes shouldn't change to prevent breaking too.
    - However, what we can change, without breaking, is schema of nested value 
class, which will fails due to the compile problem, and thus its schema now 
isn't actually valid. After the change, the schema of this nested value class 
is now flattened
    - With this PR, there's flattening only for nested value class (new), but 
not for top-level and collection classes (existing behavior)
    - This PR revisits https://github.com/apache/spark/pull/27153 by handling 
tuple `Tuple2[AnyVal, AnyVal]` which is a constructor ("nested class") but is a 
generic type, so it should not be flattened behaving similarly to `Seq[AnyVal]`
    
    ### Why are the changes needed?
    
    - Currently, nested value class isn't supported. This is because when the 
generated code treats `anyVal` class in its unwrapped form, but we encode the 
type to be the wrapped case class. This results in compile of generated code
    For example,
    For a given `AnyVal` wrapper and its root-level class container
    ```
    case class IntWrapper(i: Int) extends AnyVal
    case class ComplexValueClassContainer(c: IntWrapper)
    ```
    The problematic part of generated code:
    ```
        private InternalRow If_1(InternalRow i) {
            boolean isNull_42 = i.isNullAt(0);
            // 1) ******** The root-level case class we care
            org.apache.spark.sql.catalyst.encoders.ComplexValueClassContainer 
value_46 = isNull_42 ?
                null : 
((org.apache.spark.sql.catalyst.encoders.ComplexValueClassContainer) i.get(0, 
null));
            if (isNull_42) {
                throw new NullPointerException(((java.lang.String) 
references[5] /* errMsg */ ));
            }
            boolean isNull_39 = true;
            // 2) ******** We specify its member to be unwrapped case class 
extending `AnyVal`
            org.apache.spark.sql.catalyst.encoders.IntWrapper value_43 = null;
            if (!false) {
    
                isNull_39 = false;
                if (!isNull_39) {
                    // 3) ******** ERROR: `c()` compiled however is of type 
`int` and thus we see error
                    value_43 = value_46.c();
                }
            }
    ```
    We get this errror: Assignment conversion not possible from type "int" to 
type "org.apache.spark.sql.catalyst.encoders.IntWrapper"
    ```
    java.util.concurrent.ExecutionException: 
org.codehaus.commons.compiler.CompileException:
    File 'generated.java', Line 159, Column 1: failed to compile: 
org.codehaus.commons.compiler.CompileException: File 'generated.java', Line 
159, Column 1: Assignment conversion not possible from type "int" to type 
"org.apache.spark.sql.catalyst.encoders.IntWrapper"
    ```
    
    From [doc](https://docs.scala-lang.org/overviews/core/value-classes.html) 
on value class: , Given: `class Wrapper(val underlying: Int) extends AnyVal`,
    1) "The type at compile time is `Wrapper`, but at runtime, the 
representation is an `Int`". This implies that when our struct has a field of 
value class, the generated code should support the underlying type during 
runtime execution.
    2) `Wrapper` "must be instantiated... when a value class is used as a type 
argument". This implies that `scala.Tuple[Wrapper, ...], Seq[Wrapper], 
Map[String, Wrapper], Option[Wrapper]` will still contain Wrapper as-is in 
during runtime instead of `Int`.
    
    ### Does this PR introduce _any_ user-facing change?
    
    - Yes, this will allow support for the nested value class.
    
    ### How was this patch tested?
    
    - Added unit tests to illustrate
      - raw schema
      - projection
      - round-trip encode/decode
    
    Closes #33205 from mickjermsurawong-stripe/SPARK-20384-2.
    
    Lead-authored-by: Mick Jermsurawong <mickjermsuraw...@stripe.com>
    Co-authored-by: Emil Ejbyfeldt <eejbyfe...@liveintent.com>
    Signed-off-by: Sean Owen <sro...@gmail.com>
---
 .../spark/sql/catalyst/ScalaReflection.scala       | 33 +++++---
 .../spark/sql/catalyst/ScalaReflectionSuite.scala  | 94 ++++++++++++++++++++++
 .../catalyst/encoders/ExpressionEncoderSuite.scala | 57 +++++++++++++
 .../org/apache/spark/sql/DataFrameSuite.scala      | 52 +++++++++++-
 .../org/apache/spark/sql/test/SQLTestData.scala    |  3 +
 5 files changed, 229 insertions(+), 10 deletions(-)

diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 4de7e5c..bab407b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -956,6 +956,19 @@ trait ScalaReflection extends Logging {
     tag.in(mirror).tpe.dealias
   }
 
+  private def isValueClass(tpe: Type): Boolean = {
+    tpe.typeSymbol.asClass.isDerivedValueClass
+  }
+
+  private def isTypeParameter(tpe: Type): Boolean = {
+    tpe.typeSymbol.isParameter
+  }
+
+  /** Returns the name and type of the underlying parameter of value class 
`tpe`. */
+  private def getUnderlyingTypeOfValueClass(tpe: `Type`): Type = {
+    getConstructorParameters(tpe).head._2
+  }
+
   /**
    * Returns the parameter names and types for the primary constructor of this 
type.
    *
@@ -967,15 +980,17 @@ trait ScalaReflection extends Logging {
     val formalTypeArgs = dealiasedTpe.typeSymbol.asClass.typeParams
     val TypeRef(_, _, actualTypeArgs) = dealiasedTpe
     val params = constructParams(dealiasedTpe)
-    // if there are type variables to fill in, do the substitution 
(SomeClass[T] -> SomeClass[Int])
-    if (actualTypeArgs.nonEmpty) {
-      params.map { p =>
-        p.name.decodedName.toString ->
-          p.typeSignature.substituteTypes(formalTypeArgs, actualTypeArgs)
-      }
-    } else {
-      params.map { p =>
-        p.name.decodedName.toString -> p.typeSignature
+    params.map { p =>
+      val paramTpe = p.typeSignature
+      if (isTypeParameter(paramTpe)) {
+        // if there are type variables to fill in, do the substitution
+        // (SomeClass[T] -> SomeClass[Int])
+        p.name.decodedName.toString -> 
paramTpe.substituteTypes(formalTypeArgs, actualTypeArgs)
+      } else if (isValueClass(paramTpe)) {
+        // Replace value class with underlying type
+        p.name.decodedName.toString -> getUnderlyingTypeOfValueClass(paramTpe)
+      } else {
+        p.name.decodedName.toString -> paramTpe
       }
     }
   }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 164bbd7..d86f986 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -156,8 +156,19 @@ object TraitProductWithNoConstructorCompanion {}
 
 trait TraitProductWithNoConstructorCompanion extends Product1[Int] {}
 
+object TestingValueClass {
+  case class IntWrapper(val i: Int) extends AnyVal
+  case class StrWrapper(s: String) extends AnyVal
+
+  case class ValueClassData(intField: Int,
+                            wrappedInt: IntWrapper, // an int column
+                            strField: String,
+                            wrappedStr: StrWrapper) // a string column
+}
+
 class ScalaReflectionSuite extends SparkFunSuite {
   import org.apache.spark.sql.catalyst.ScalaReflection._
+  import TestingValueClass._
 
   // A helper method used to test `ScalaReflection.serializerForType`.
   private def serializerFor[T: TypeTag]: Expression =
@@ -451,4 +462,87 @@ class ScalaReflectionSuite extends SparkFunSuite {
       StructField("e", StringType, true))))
     assert(deserializerFor[FooClassWithEnum].dataType == 
ObjectType(classOf[FooClassWithEnum]))
   }
+
+  test("schema for case class that is a value class") {
+    val schema = schemaFor[IntWrapper]
+    assert(
+      schema === Schema(StructType(Seq(StructField("i", IntegerType, false))), 
nullable = true))
+  }
+
+  test("SPARK-20384: schema for case class that contains value class fields") {
+    val schema = schemaFor[ValueClassData]
+    assert(
+      schema === Schema(
+        StructType(Seq(
+          StructField("intField", IntegerType, nullable = false),
+          StructField("wrappedInt", IntegerType, nullable = false),
+          StructField("strField", StringType),
+          StructField("wrappedStr", StringType)
+        )),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for array of value class") {
+    val schema = schemaFor[Array[IntWrapper]]
+    assert(
+      schema === Schema(
+        ArrayType(StructType(Seq(StructField("i", IntegerType, false))), 
containsNull = true),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for map of value class") {
+    val schema = schemaFor[Map[IntWrapper, StrWrapper]]
+    assert(
+      schema === Schema(
+        MapType(
+          StructType(Seq(StructField("i", IntegerType, false))),
+          StructType(Seq(StructField("s", StringType))),
+          valueContainsNull = true),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for tuple_2 of value class") {
+    val schema = schemaFor[(IntWrapper, StrWrapper)]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, 
false)))),
+            StructField("_2", StructType(Seq(StructField("s", StringType))))
+          )
+        ),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for tuple_3 of value class") {
+    val schema = schemaFor[(IntWrapper, StrWrapper, StrWrapper)]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, 
false)))),
+            StructField("_2", StructType(Seq(StructField("s", StringType)))),
+            StructField("_3", StructType(Seq(StructField("s", StringType))))
+          )
+        ),
+        nullable = true))
+  }
+
+  test("SPARK-20384: schema for nested tuple of value class") {
+    val schema = schemaFor[(IntWrapper, (StrWrapper, StrWrapper))]
+    assert(
+      schema === Schema(
+        StructType(
+          Seq(
+            StructField("_1", StructType(Seq(StructField("i", IntegerType, 
false)))),
+            StructField("_2", StructType(
+              Seq(
+                StructField("_1", StructType(Seq(StructField("s", 
StringType)))),
+                StructField("_2", StructType(Seq(StructField("s", 
StringType)))))
+              )
+            )
+          )
+        ),
+        nullable = true))
+  }
 }
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index bf4afac..ae5ce60 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -116,6 +116,21 @@ object ReferenceValueClass {
 }
 case class IntAndString(i: Int, s: String)
 
+case class StringWrapper(s: String) extends AnyVal
+case class ValueContainer(
+                           a: Int,
+                           b: StringWrapper) // a string column
+case class IntWrapper(i: Int) extends AnyVal
+case class ComplexValueClassContainer(
+                                       a: Int,
+                                       b: ValueContainer,
+                                       c: IntWrapper)
+case class SeqOfValueClass(s: Seq[StringWrapper])
+case class MapOfValueClassKey(m: Map[IntWrapper, String])
+case class MapOfValueClassValue(m: Map[String, StringWrapper])
+case class OptionOfValueClassValue(o: Option[StringWrapper])
+case class CaseClassWithGeneric[T](generic: T, value: IntWrapper)
+
 class ExpressionEncoderSuite extends CodegenInterpretedPlanTest with 
AnalysisTest {
   OuterScopes.addOuterScope(this)
 
@@ -391,12 +406,54 @@ class ExpressionEncoderSuite extends 
CodegenInterpretedPlanTest with AnalysisTes
     ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
   }
 
+  // test for value classes
   encodeDecodeTest(
     PrimitiveValueClass(42), "primitive value class")
 
   encodeDecodeTest(
     ReferenceValueClass(ReferenceValueClass.Container(1)), "reference value 
class")
 
+  encodeDecodeTest(StringWrapper("a"), "string value class")
+  encodeDecodeTest(ValueContainer(1, StringWrapper("b")), "nested value class")
+  encodeDecodeTest(ValueContainer(1, StringWrapper(null)), "nested value class 
with null")
+  encodeDecodeTest(ComplexValueClassContainer(1, ValueContainer(2, 
StringWrapper("b")),
+    IntWrapper(3)), "complex value class")
+  encodeDecodeTest(
+    Array(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
+    "array of value class")
+  encodeDecodeTest(Array.empty[IntWrapper], "empty array of value class")
+  encodeDecodeTest(
+    Seq(IntWrapper(1), IntWrapper(2), IntWrapper(3)),
+    "seq of value class")
+  encodeDecodeTest(Seq.empty[IntWrapper], "empty seq of value class")
+  encodeDecodeTest(
+    Map(IntWrapper(1) -> StringWrapper("a"), IntWrapper(2) -> 
StringWrapper("b")),
+    "map with value class")
+
+  // test for nested value class collections
+  encodeDecodeTest(
+    MapOfValueClassKey(Map(IntWrapper(1)-> "a")),
+    "case class with map of value class key")
+  encodeDecodeTest(
+    MapOfValueClassValue(Map("a"-> StringWrapper("b"))),
+    "case class with map of value class value")
+  encodeDecodeTest(
+    SeqOfValueClass(Seq(StringWrapper("a"))),
+    "case class with seq of class value")
+  encodeDecodeTest(
+    OptionOfValueClassValue(Some(StringWrapper("a"))),
+    "case class with option of class value")
+  encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2")),
+    "tuple2 of class value")
+  encodeDecodeTest((StringWrapper("a_1"), StringWrapper("a_2"), 
StringWrapper("a_3")),
+    "tuple3 of class value")
+  encodeDecodeTest(((StringWrapper("a_1"), StringWrapper("a_2")), 
StringWrapper("b_2")),
+    "nested tuple._1 of class value")
+  encodeDecodeTest((StringWrapper("a_1"), (StringWrapper("b_1"), 
StringWrapper("b_2"))),
+    "nested tuple._2 of class value")
+  encodeDecodeTest(CaseClassWithGeneric(IntWrapper(1), IntWrapper(2)),
+    "case class with value class in generic parameter")
+
   encodeDecodeTest(Option(31), "option of int")
   encodeDecodeTest(Option.empty[Int], "empty option of int")
   encodeDecodeTest(Option("abc"), "option of string")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2fd5993..f2d0b60 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -46,7 +46,7 @@ import org.apache.spark.sql.expressions.{Aggregator, Window}
 import org.apache.spark.sql.functions._
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.test.{ExamplePoint, ExamplePointUDT, 
SharedSparkSession}
-import org.apache.spark.sql.test.SQLTestData.{DecimalData, TestData2}
+import org.apache.spark.sql.test.SQLTestData.{ArrayStringWrapper, 
ContainerStringWrapper, DecimalData, StringWrapper, TestData2}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.CalendarInterval
 import org.apache.spark.util.Utils
@@ -834,6 +834,56 @@ class DataFrameSuite extends QueryTest
     assert(df.schema.map(_.name) === Seq("key", "valueRenamed", "newCol"))
   }
 
+  test("SPARK-20384: Value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(StringWrapper("a"), StringWrapper("b"), 
StringWrapper("c")))
+      .toDF()
+    val filtered = df.where("s = \"a\"")
+    checkAnswer(filtered, 
spark.sparkContext.parallelize(Seq(StringWrapper("a"))).toDF)
+  }
+
+  test("SPARK-20384: Tuple2 of value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(
+        (StringWrapper("a1"), StringWrapper("a2")),
+        (StringWrapper("b1"), StringWrapper("b2"))))
+      .toDF()
+    val filtered = df.where("_2.s = \"a2\"")
+    checkAnswer(filtered,
+      spark.sparkContext.parallelize(Seq((StringWrapper("a1"), 
StringWrapper("a2")))).toDF)
+  }
+
+  test("SPARK-20384: Tuple3 of value class filter") {
+    val df = spark.sparkContext
+      .parallelize(Seq(
+        (StringWrapper("a1"), StringWrapper("a2"), StringWrapper("a3")),
+        (StringWrapper("b1"), StringWrapper("b2"), StringWrapper("b3"))))
+      .toDF()
+    val filtered = df.where("_3.s = \"a3\"")
+    checkAnswer(filtered,
+      spark.sparkContext.parallelize(
+        Seq((StringWrapper("a1"), StringWrapper("a2"), 
StringWrapper("a3")))).toDF)
+  }
+
+  test("SPARK-20384: Array value class filter") {
+    val ab = ArrayStringWrapper(Seq(StringWrapper("a"), StringWrapper("b")))
+    val cd = ArrayStringWrapper(Seq(StringWrapper("c"), StringWrapper("d")))
+
+    val df = spark.sparkContext.parallelize(Seq(ab, cd)).toDF
+    val filtered = df.where(array_contains(col("wrappers.s"), "b"))
+    checkAnswer(filtered, spark.sparkContext.parallelize(Seq(ab)).toDF)
+  }
+
+  test("SPARK-20384: Nested value class filter") {
+    val a = ContainerStringWrapper(StringWrapper("a"))
+    val b = ContainerStringWrapper(StringWrapper("b"))
+
+    val df = spark.sparkContext.parallelize(Seq(a, b)).toDF
+    // flat value class, `s` field is not in schema
+    val filtered = df.where("wrapper = \"a\"")
+    checkAnswer(filtered, spark.sparkContext.parallelize(Seq(a)).toDF)
+  }
+
   private lazy val person2: DataFrame = Seq(
     ("Bob", 16, 176),
     ("Alice", 32, 164),
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 307c4f3..21064b5 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -450,4 +450,7 @@ private[sql] object SQLTestData {
   case class CourseSales(course: String, year: Int, earnings: Double)
   case class TrainingSales(training: String, sales: CourseSales)
   case class IntervalData(data: CalendarInterval)
+  case class StringWrapper(s: String) extends AnyVal
+  case class ArrayStringWrapper(wrappers: Seq[StringWrapper])
+  case class ContainerStringWrapper(wrapper: StringWrapper)
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to