Repository: spark
Updated Branches:
  refs/heads/master f3676d639 -> 7d16776d2


[SPARK-21255][SQL][WIP] Fixed NPE when creating encoder for enum

## What changes were proposed in this pull request?

Fixed NPE when creating encoder for enum.

When you try to create an encoder for Enum type (or bean with enum property) 
via Encoders.bean(...), it fails with NullPointerException at TypeToken:495.
I did a little research and it turns out, that in JavaTypeInference following 
code
```
  def getJavaBeanReadableProperties(beanClass: Class[_]): 
Array[PropertyDescriptor] = {
    val beanInfo = Introspector.getBeanInfo(beanClass)
    beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
      .filter(_.getReadMethod != null)
  }
```
filters out properties named "class", because we wouldn't want to serialize 
that. But enum types have another property of type Class named 
"declaringClass", which we are trying to inspect recursively. Eventually we try 
to inspect ClassLoader class, which has property "defaultAssertionStatus" with 
no read method, which leads to NPE at TypeToken:495.

I added property name "declaringClass" to filtering to resolve this.

## How was this patch tested?
Unit test in JavaDatasetSuite which creates an encoder for enum

Author: mike <mike...@gmail.com>
Author: Mikhail Sveshnikov <mike...@gmail.com>

Closes #18488 from mike0sv/enum-support.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7d16776d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7d16776d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7d16776d

Branch: refs/heads/master
Commit: 7d16776d28da5bcf656f0d8556b15ed3a5edca44
Parents: f3676d6
Author: mike <mike...@gmail.com>
Authored: Fri Aug 25 07:22:34 2017 +0100
Committer: Sean Owen <so...@cloudera.com>
Committed: Fri Aug 25 07:22:34 2017 +0100

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  | 40 ++++++++++
 .../catalyst/encoders/ExpressionEncoder.scala   | 14 +++-
 .../catalyst/expressions/objects/objects.scala  |  4 +-
 .../org/apache/spark/sql/JavaDatasetSuite.java  | 77 ++++++++++++++++++++
 4 files changed, 131 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7d16776d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 21363d3..33f6ce0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -32,6 +32,7 @@ import org.apache.spark.sql.catalyst.expressions.objects._
 import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, DateTimeUtils, 
GenericArrayData}
 import org.apache.spark.sql.types._
 import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.util.Utils
 
 /**
  * Type-inference utilities for POJOs and Java collections.
@@ -118,6 +119,10 @@ object JavaTypeInference {
         val (valueDataType, nullable) = inferDataType(valueType, seenTypeSet)
         (MapType(keyDataType, valueDataType, nullable), true)
 
+      case other if other.isEnum =>
+        (StructType(Seq(StructField(typeToken.getRawType.getSimpleName,
+          StringType, nullable = false))), true)
+
       case other =>
         if (seenTypeSet.contains(other)) {
           throw new UnsupportedOperationException(
@@ -140,6 +145,7 @@ object JavaTypeInference {
   def getJavaBeanReadableProperties(beanClass: Class[_]): 
Array[PropertyDescriptor] = {
     val beanInfo = Introspector.getBeanInfo(beanClass)
     beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+      .filterNot(_.getName == "declaringClass")
       .filter(_.getReadMethod != null)
   }
 
@@ -303,6 +309,11 @@ object JavaTypeInference {
           keyData :: valueData :: Nil,
           returnNullable = false)
 
+      case other if other.isEnum =>
+        StaticInvoke(JavaTypeInference.getClass, ObjectType(other), 
"deserializeEnumName",
+          expressions.Literal.create(other.getEnumConstants.apply(0), 
ObjectType(other))
+            :: getPath :: Nil)
+
       case other =>
         val properties = getJavaBeanReadableAndWritableProperties(other)
         val setters = properties.map { p =>
@@ -345,6 +356,30 @@ object JavaTypeInference {
     }
   }
 
+  /** Returns a mapping from enum value to int for given enum type */
+  def enumSerializer[T <: Enum[T]](enum: Class[T]): T => UTF8String = {
+    assert(enum.isEnum)
+    inputObject: T =>
+      UTF8String.fromString(inputObject.name())
+  }
+
+  /** Returns value index for given enum type and value */
+  def serializeEnumName[T <: Enum[T]](enum: UTF8String, inputObject: T): 
UTF8String = {
+    
enumSerializer(Utils.classForName(enum.toString).asInstanceOf[Class[T]])(inputObject)
+  }
+
+  /** Returns a mapping from int to enum value for given enum type */
+  def enumDeserializer[T <: Enum[T]](enum: Class[T]): InternalRow => T = {
+    assert(enum.isEnum)
+    value: InternalRow =>
+      Enum.valueOf(enum, value.getUTF8String(0).toString)
+  }
+
+  /** Returns enum value for given enum type and value index */
+  def deserializeEnumName[T <: Enum[T]](typeDummy: T, inputObject: 
InternalRow): T = {
+    enumDeserializer(typeDummy.getClass.asInstanceOf[Class[T]])(inputObject)
+  }
+
   private def serializerFor(inputObject: Expression, typeToken: TypeToken[_]): 
Expression = {
 
     def toCatalystArray(input: Expression, elementType: TypeToken[_]): 
Expression = {
@@ -429,6 +464,11 @@ object JavaTypeInference {
             valueNullable = true
           )
 
+        case other if other.isEnum =>
+          CreateNamedStruct(expressions.Literal("enum") ::
+          StaticInvoke(JavaTypeInference.getClass, StringType, 
"serializeEnumName",
+          expressions.Literal.create(other.getName, StringType) :: inputObject 
:: Nil) :: Nil)
+
         case other =>
           val properties = getJavaBeanReadableAndWritableProperties(other)
           val nonNullOutput = CreateNamedStruct(properties.flatMap { p =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7d16776d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index efc2882..9ed5e12 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -28,7 +28,7 @@ import 
org.apache.spark.sql.catalyst.expressions.codegen.{GenerateSafeProjection
 import org.apache.spark.sql.catalyst.expressions.objects.{AssertNotNull, 
Invoke, NewInstance}
 import org.apache.spark.sql.catalyst.optimizer.SimplifyCasts
 import org.apache.spark.sql.catalyst.plans.logical.{CatalystSerde, 
DeserializeToObject, LocalRelation}
-import org.apache.spark.sql.types.{BooleanType, ObjectType, StructField, 
StructType}
+import org.apache.spark.sql.types.{BooleanType, DataType, ObjectType, 
StringType, StructField, StructType}
 import org.apache.spark.util.Utils
 
 /**
@@ -81,9 +81,19 @@ object ExpressionEncoder {
       ClassTag[T](cls))
   }
 
+  def javaEnumSchema[T](beanClass: Class[T]): DataType = {
+    StructType(Seq(StructField("enum",
+      StructType(Seq(StructField(beanClass.getSimpleName, StringType, nullable 
= false))),
+      nullable = false)))
+  }
+
   // TODO: improve error message for java bean encoder.
   def javaBean[T](beanClass: Class[T]): ExpressionEncoder[T] = {
-    val schema = JavaTypeInference.inferDataType(beanClass)._1
+    val schema = if (beanClass.isEnum) {
+      javaEnumSchema(beanClass)
+    } else {
+      JavaTypeInference.inferDataType(beanClass)._1
+    }
     assert(schema.isInstanceOf[StructType])
 
     val serializer = JavaTypeInference.serializerFor(beanClass)

http://git-wip-us.apache.org/repos/asf/spark/blob/7d16776d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index 9b28a18..7c466fe 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -154,13 +154,13 @@ case class StaticInvoke(
     val evaluate = if (returnNullable) {
       if (ctx.defaultValue(dataType) == "null") {
         s"""
-          ${ev.value} = $callFunc;
+          ${ev.value} = (($javaType) ($callFunc));
           ${ev.isNull} = ${ev.value} == null;
         """
       } else {
         val boxedResult = ctx.freshName("boxedResult")
         s"""
-          ${ctx.boxedType(dataType)} $boxedResult = $callFunc;
+          ${ctx.boxedType(dataType)} $boxedResult = (($javaType) ($callFunc));
           ${ev.isNull} = $boxedResult == null;
           if (!${ev.isNull}) {
             ${ev.value} = $boxedResult;

http://git-wip-us.apache.org/repos/asf/spark/blob/7d16776d/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 4ca3b64..a344746 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1283,6 +1283,83 @@ public class JavaDatasetSuite implements Serializable {
     ds.collectAsList();
   }
 
+  public enum EnumBean {
+    A("www.elgoog.com"),
+    B("www.google.com");
+
+    private String url;
+
+    EnumBean(String url) {
+      this.url = url;
+    }
+
+    public String getUrl() {
+      return url;
+    }
+
+    public void setUrl(String url) {
+      this.url = url;
+    }
+  }
+
+  @Test
+  public void testEnum() {
+    List<EnumBean> data = Arrays.asList(EnumBean.B);
+    Encoder<EnumBean> encoder = Encoders.bean(EnumBean.class);
+    Dataset<EnumBean> ds = spark.createDataset(data, encoder);
+    Assert.assertEquals(ds.collectAsList(), data);
+  }
+
+  public static class BeanWithEnum {
+    EnumBean enumField;
+    String regularField;
+
+    public String getRegularField() {
+      return regularField;
+    }
+
+    public void setRegularField(String regularField) {
+      this.regularField = regularField;
+    }
+
+    public EnumBean getEnumField() {
+      return enumField;
+    }
+
+    public void setEnumField(EnumBean field) {
+      this.enumField = field;
+    }
+
+    public BeanWithEnum(EnumBean enumField, String regularField) {
+      this.enumField = enumField;
+      this.regularField = regularField;
+    }
+
+    public BeanWithEnum() {
+    }
+
+    public String toString() {
+      return "BeanWithEnum(" + enumField  + ", " + regularField + ")";
+    }
+
+    public boolean equals(Object other) {
+      if (other instanceof BeanWithEnum) {
+        BeanWithEnum beanWithEnum = (BeanWithEnum) other;
+        return beanWithEnum.regularField.equals(regularField) && 
beanWithEnum.enumField.equals(enumField);
+      }
+      return false;
+    }
+  }
+
+  @Test
+  public void testBeanWithEnum() {
+    List<BeanWithEnum> data = Arrays.asList(new BeanWithEnum(EnumBean.A, "mira 
avenue"),
+            new BeanWithEnum(EnumBean.B, "flower boulevard"));
+    Encoder<BeanWithEnum> encoder = Encoders.bean(BeanWithEnum.class);
+    Dataset<BeanWithEnum> ds = spark.createDataset(data, encoder);
+    Assert.assertEquals(ds.collectAsList(), data);
+  }
+
   public static class EmptyBean implements Serializable {}
 
   @Test


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to