Repository: spark
Updated Branches:
  refs/heads/master 8fcbda9c9 -> 6327ea570


[SPARK-21255][SQL] simplify encoder for java enum

## What changes were proposed in this pull request?

This is a follow-up for https://github.com/apache/spark/pull/18488, to simplify 
the code.

The major change is, we should map java enum to string type, instead of a 
struct type with a single string field.

## How was this patch tested?

existing tests

Author: Wenchen Fan <wenc...@databricks.com>

Closes #19066 from cloud-fan/fix.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6327ea57
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6327ea57
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6327ea57

Branch: refs/heads/master
Commit: 6327ea570bf542983081c5d1d3ee7e6123365c8f
Parents: 8fcbda9
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Aug 29 09:15:59 2017 -0700
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Tue Aug 29 09:15:59 2017 -0700

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 46 ++++++--------------
 .../catalyst/encoders/ExpressionEncoder.scala   | 14 +-----
 .../catalyst/expressions/objects/objects.scala  |  4 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 24 ++++------
 4 files changed, 25 insertions(+), 63 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6327ea57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 33f6ce0..3ecc137 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -32,7 +32,6 @@ import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, 
GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
-import org.apache.spark.util.Utils
 
 /**
  * Type-inference utilities for POJOs and Java collections.
@@ -120,8 +119,7 @@ object JavaTypeInference {
         (MapType(keyDataType, valueDataType, nullable), true)
 
       case other if other.isEnum =>
-        (StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
-          StringType, nullable = false))), true)
+        (StringType, true)
 
       case other =>
         if (seenTypeSet.contains(other)) {
@@ -310,9 +308,12 @@ object JavaTypeInference {
           returnNullable = false)
 
       case other if other.isEnum =>
-        StaticInvoke(JavaTypeInference.getClass, ObjectType(other), 
"deserializeEnumName",
-          expressions.Literal.create(other.getEnumConstants.apply(0), 
ObjectType(other))
-            :: getPath :: Nil)
+        StaticInvoke(
+          other,
+          ObjectType(other),
+          "valueOf",
+          Invoke(getPath, "toString", ObjectType(classOf[String]), 
returnNullable = false) :: Nil,
+          returnNullable = false)
 
       case other =>
         val properties = getJavaBeanReadableAndWritableProperties(other)
@@ -356,30 +357,6 @@ object JavaTypeInference {
     }
   }
 
-  /** Returns a mapping from enum value to int for given enum type */
-  def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
-    assert(enum.isEnum)
-    inputObject: T =>
-      UTF8String.fromString(inputObject.name())
-  }
-
-  /** Returns value index for given enum type and value */
-  def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): 
UTF8String = {
-    
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
-  }
-
-  /** Returns a mapping from int to enum value for given enum type */
-  def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
-    assert(enum.isEnum)
-    value: InternalRow =>
-      Enum.valueOf(enum, value.getUTF8String(0).toString)
-  }
-
-  /** Returns enum value for given enum type and value index */
-  def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: 
InternalRow): T = {
-    enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
-  }
-
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): 
Expression = {
 
     def toCatalystArray(input: Expression, elementType: TypeToken[_]): 
Expression = {
@@ -465,9 +442,12 @@ object JavaTypeInference {
           )
 
         case other if other.isEnum =>
-          CreateNamedStruct(expressions.Literal("enum") ::
-          StaticInvoke(JavaTypeInference.getClass, StringType, 
"serializeEnumName",
-          expressions.Literal.create(other.getName, StringType) :: inputObject 
:: Nil) :: Nil)
+          StaticInvoke(
+            classOf[UTF8String],
+            StringType,
+            "fromString",
+            Invoke(inputObject, "name", ObjectType(classOf[String]), 
returnNullable = false) :: Nil,
+            returnNullable = false)
 
         case other =>
           val properties = getJavaBeanReadableAndWritableProperties(other)

http://git-wip-us.apache.org/repos/asf/spark/blob/6327ea57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 9ed5e12..efc2882 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, 
DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, 
StringType, StructField, StructType}
+import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, 
StructType}
 import org.apache.spark.util.Utils
 
 /**
@@ -81,19 +81,9 @@ object ExpressionEncoder {
       ClassTag[T](cls))
   }
 
-  def javaEnumSchema[T](beanClass: Class[T]): DataType = {
-    StructType(Seq(StructField("enum",
-      StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable 
= false))),
-      nullable = false)))
-  }
-
   // TODO: improve error message for java bean encoder.
   def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
-    val schema = if (beanClass.isEnum) {
-      javaEnumSchema(beanClass)
-    } else {
-      JavaTypeInference.inferDataType(beanClass)._1
-    }
+    val schema = JavaTypeInference.inferDataType(beanClass)._1
     assert(schema.isInstanceOf[StructType])
 
     val serializer = JavaTypeInference.serializerFor(beanClass)

http://git-wip-us.apache.org/repos/asf/spark/blob/6327ea57/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 7c466fe..9b28a18 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -154,13 +154,13 @@ case class StaticInvoke(
     val evaluate = if (returnNullable) {
       if (ctx.defaultValue(dataType) == "null") {
         s"""
-          ${ev.value} = (($javaType) ($callFunc));
+          ${ev.value} = $callFunc;
           ${ev.isNull} = ${ev.value} == null;
         """
       } else {
         val boxedResult = ctx.freshName("boxedResult")
         s"""
-          ${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
+          ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
           ${ev.isNull} = $boxedResult == null;
           if (!${ev.isNull}) {
             ${ev.value} = $boxedResult;

http://git-wip-us.apache.org/repos/asf/spark/blob/6327ea57/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a344746..3e57403 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1283,13 +1283,13 @@ public class JavaDatasetSuite implements Serializable {
     ds.collectAsList();
   }
 
-  public enum EnumBean {
+  public enum MyEnum {
     A("www.elgoog.com"),
     B("www.google.com");
 
     private String url;
 
-    EnumBean(String url) {
+    MyEnum(String url) {
       this.url = url;
     }
 
@@ -1302,16 +1302,8 @@ public class JavaDatasetSuite implements Serializable {
     }
   }
 
-  @Test
-  public void testEnum() {
-    List<EnumBean> data = Arrays.asList(EnumBean.B);
-    Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
-    Dataset<EnumBean> ds = spark.createDataset(data, encoder);
-    Assert.assertEquals(ds.collectAsList(), data);
-  }
-
   public static class BeanWithEnum {
-    EnumBean enumField;
+    MyEnum enumField;
     String regularField;
 
     public String getRegularField() {
@@ -1322,15 +1314,15 @@ public class JavaDatasetSuite implements Serializable {
       this.regularField = regularField;
     }
 
-    public EnumBean getEnumField() {
+    public MyEnum getEnumField() {
       return enumField;
     }
 
-    public void setEnumField(EnumBean field) {
+    public void setEnumField(MyEnum field) {
       this.enumField = field;
     }
 
-    public BeanWithEnum(EnumBean enumField, String regularField) {
+    public BeanWithEnum(MyEnum enumField, String regularField) {
       this.enumField = enumField;
       this.regularField = regularField;
     }
@@ -1353,8 +1345,8 @@ public class JavaDatasetSuite implements Serializable {
 
   @Test
   public void testBeanWithEnum() {
-    List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira 
avenue"),
-            new BeanWithEnum(EnumBean.B, "flower boulevard"));
+    List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(MyEnum.A, "mira 
avenue"),
+            new BeanWithEnum(MyEnum.B, "flower boulevard"));
     Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
     Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
     Assert.assertEquals(ds.collectAsList(), data);


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to