Repository: spark
Updated Branches:
  refs/heads/master 7b06a8948 -> 6959061f0


[SPARK-16706][SQL] support java map in encoder

## What changes were proposed in this pull request?

finish the TODO, create a new expression `ExternalMapToCatalyst` to iterate the 
map directly.

## How was this patch tested?

new test in `JavaDatasetSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #14344 from cloud-fan/java-map.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/6959061f
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/6959061f
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/6959061f

Branch: refs/heads/master
Commit: 6959061f02b02afd4cef683b5eea0b7097eedee7
Parents: 7b06a89
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Tue Jul 26 15:33:05 2016 +0800
Committer: Cheng Lian <l...@databricks.com>
Committed: Tue Jul 26 15:33:05 2016 +0800

----------------------------------------------------------------------
 .../spark/sql/catalyst/JavaTypeInference.scala  |  12 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  34 ++--
 .../catalyst/expressions/objects/objects.scala  | 158 ++++++++++++++++++-
 .../encoders/ExpressionEncoderSuite.scala       |   6 +
 .../org/apache/spark/sql/JavaDatasetSuite.java  |  58 ++++++-
 5 files changed, 236 insertions(+), 32 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/6959061f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index b3a233a..e6f61b0 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -395,10 +395,14 @@ object JavaTypeInference {
           toCatalystArray(inputObject, elementType(typeToken))
 
         case _ if mapType.isAssignableFrom(typeToken) =>
-          // TODO: for java map, if we get the keys and values by `keySet` and 
`values`, we can
-          // not guarantee they have same iteration order(which is different 
from scala map).
-          // A possible solution is creating a new `MapObjects` that can 
iterate a map directly.
-          throw new UnsupportedOperationException("map type is not supported 
currently")
+          val (keyType, valueType) = mapKeyValueType(typeToken)
+          ExternalMapToCatalyst(
+            inputObject,
+            ObjectType(keyType.getRawType),
+            serializerFor(_, keyType),
+            ObjectType(valueType.getRawType),
+            serializerFor(_, valueType)
+          )
 
         case other =>
           val properties = getJavaBeanProperties(other)

http://git-wip-us.apache.org/repos/asf/spark/blob/6959061f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 8affb03..76f87f6 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -472,29 +472,17 @@ object ScalaReflection extends ScalaReflection {
 
       case t if t <:< localTypeOf[Map[_, _]] =>
         val TypeRef(_, _, Seq(keyType, valueType)) = t
-
-        val keys =
-          Invoke(
-            Invoke(inputObject, "keysIterator",
-              ObjectType(classOf[scala.collection.Iterator[_]])),
-            "toSeq",
-            ObjectType(classOf[scala.collection.Seq[_]]))
-        val convertedKeys = toCatalystArray(keys, keyType)
-
-        val values =
-          Invoke(
-            Invoke(inputObject, "valuesIterator",
-              ObjectType(classOf[scala.collection.Iterator[_]])),
-            "toSeq",
-            ObjectType(classOf[scala.collection.Seq[_]]))
-        val convertedValues = toCatalystArray(values, valueType)
-
-        val Schema(keyDataType, _) = schemaFor(keyType)
-        val Schema(valueDataType, valueNullable) = schemaFor(valueType)
-        NewInstance(
-          classOf[ArrayBasedMapData],
-          convertedKeys :: convertedValues :: Nil,
-          dataType = MapType(keyDataType, valueDataType, valueNullable))
+        val keyClsName = getClassNameFromType(keyType)
+        val valueClsName = getClassNameFromType(valueType)
+        val keyPath = s"""- map key class: "$keyClsName"""" +: walkedTypePath
+        val valuePath = s"""- map value class: "$valueClsName"""" +: 
walkedTypePath
+
+        ExternalMapToCatalyst(
+          inputObject,
+          dataTypeFor(keyType),
+          serializerFor(_, keyType, keyPath),
+          dataTypeFor(valueType),
+          serializerFor(_, valueType, valuePath))
 
       case t if t <:< localTypeOf[String] =>
         StaticInvoke(

http://git-wip-us.apache.org/repos/asf/spark/blob/6959061f/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
index d6863ed..0658941 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/objects/objects.scala
@@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.RowEncoder
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, 
ExprCode}
-import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, GenericArrayData}
 import org.apache.spark.sql.types._
 
 /**
@@ -501,6 +501,162 @@ case class MapObjects private(
   }
 }
 
+object ExternalMapToCatalyst {
+  private val curId = new java.util.concurrent.atomic.AtomicInteger()
+
+  def apply(
+      inputMap: Expression,
+      keyType: DataType,
+      keyConverter: Expression => Expression,
+      valueType: DataType,
+      valueConverter: Expression => Expression): ExternalMapToCatalyst = {
+    val id = curId.getAndIncrement()
+    val keyName = "ExternalMapToCatalyst_key" + id
+    val valueName = "ExternalMapToCatalyst_value" + id
+    val valueIsNull = "ExternalMapToCatalyst_value_isNull" + id
+
+    ExternalMapToCatalyst(
+      keyName,
+      keyType,
+      keyConverter(LambdaVariable(keyName, "false", keyType)),
+      valueName,
+      valueIsNull,
+      valueType,
+      valueConverter(LambdaVariable(valueName, valueIsNull, valueType)),
+      inputMap
+    )
+  }
+}
+
+/**
+ * Converts a Scala/Java map object into catalyst format, by applying the 
key/value converter when
+ * iterate the map.
+ *
+ * @param key the name of the map key variable that used when iterate the map, 
and used as input for
+ *            the `keyConverter`
+ * @param keyType the data type of the map key variable that used when iterate 
the map, and used as
+ *                input for the `keyConverter`
+ * @param keyConverter A function that take the `key` as input, and converts 
it to catalyst format.
+ * @param value the name of the map value variable that used when iterate the 
map, and used as input
+ *              for the `valueConverter`
+ * @param valueIsNull the nullability of the map value variable that used when 
iterate the map, and
+ *                    used as input for the `valueConverter`
+ * @param valueType the data type of the map value variable that used when 
iterate the map, and
+ *                  used as input for the `valueConverter`
+ * @param valueConverter A function that take the `value` as input, and 
converts it to catalyst
+ *                       format.
+ * @param child An expression that when evaluated returns the input map object.
+ */
+case class ExternalMapToCatalyst private(
+    key: String,
+    keyType: DataType,
+    keyConverter: Expression,
+    value: String,
+    valueIsNull: String,
+    valueType: DataType,
+    valueConverter: Expression,
+    child: Expression)
+  extends UnaryExpression with NonSQLExpression {
+
+  override def foldable: Boolean = false
+
+  override def dataType: MapType = MapType(keyConverter.dataType, 
valueConverter.dataType)
+
+  override def eval(input: InternalRow): Any =
+    throw new UnsupportedOperationException("Only code-generated evaluation is 
supported")
+
+  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): 
ExprCode = {
+    val inputMap = child.genCode(ctx)
+    val genKeyConverter = keyConverter.genCode(ctx)
+    val genValueConverter = valueConverter.genCode(ctx)
+    val length = ctx.freshName("length")
+    val index = ctx.freshName("index")
+    val convertedKeys = ctx.freshName("convertedKeys")
+    val convertedValues = ctx.freshName("convertedValues")
+    val entry = ctx.freshName("entry")
+    val entries = ctx.freshName("entries")
+
+    val (defineEntries, defineKeyValue) = child.dataType match {
+      case ObjectType(cls) if classOf[java.util.Map[_, 
_]].isAssignableFrom(cls) =>
+        val javaIteratorCls = classOf[java.util.Iterator[_]].getName
+        val javaMapEntryCls = classOf[java.util.Map.Entry[_, _]].getName
+
+        val defineEntries =
+          s"final $javaIteratorCls $entries = 
${inputMap.value}.entrySet().iterator();"
+
+        val defineKeyValue =
+          s"""
+            final $javaMapEntryCls $entry = ($javaMapEntryCls) $entries.next();
+            ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) 
$entry.getKey();
+            ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) 
$entry.getValue();
+          """
+
+        defineEntries -> defineKeyValue
+
+      case ObjectType(cls) if classOf[scala.collection.Map[_, 
_]].isAssignableFrom(cls) =>
+        val scalaIteratorCls = classOf[Iterator[_]].getName
+        val scalaMapEntryCls = classOf[Tuple2[_, _]].getName
+
+        val defineEntries = s"final $scalaIteratorCls $entries = 
${inputMap.value}.iterator();"
+
+        val defineKeyValue =
+          s"""
+            final $scalaMapEntryCls $entry = ($scalaMapEntryCls) 
$entries.next();
+            ${ctx.javaType(keyType)} $key = (${ctx.boxedType(keyType)}) 
$entry._1();
+            ${ctx.javaType(valueType)} $value = (${ctx.boxedType(valueType)}) 
$entry._2();
+          """
+
+        defineEntries -> defineKeyValue
+    }
+
+    val valueNullCheck = if (ctx.isPrimitiveType(valueType)) {
+      s"boolean $valueIsNull = false;"
+    } else {
+      s"boolean $valueIsNull = $value == null;"
+    }
+
+    val arrayCls = classOf[GenericArrayData].getName
+    val mapCls = classOf[ArrayBasedMapData].getName
+    val convertedKeyType = ctx.boxedType(keyConverter.dataType)
+    val convertedValueType = ctx.boxedType(valueConverter.dataType)
+    val code =
+      s"""
+        ${inputMap.code}
+        ${ctx.javaType(dataType)} ${ev.value} = ${ctx.defaultValue(dataType)};
+        if (!${inputMap.isNull}) {
+          final int $length = ${inputMap.value}.size();
+          final Object[] $convertedKeys = new Object[$length];
+          final Object[] $convertedValues = new Object[$length];
+          int $index = 0;
+          $defineEntries
+          while($entries.hasNext()) {
+            $defineKeyValue
+            $valueNullCheck
+
+            ${genKeyConverter.code}
+            if (${genKeyConverter.isNull}) {
+              throw new RuntimeException("Cannot use null as map key!");
+            } else {
+              $convertedKeys[$index] = ($convertedKeyType) 
${genKeyConverter.value};
+            }
+
+            ${genValueConverter.code}
+            if (${genValueConverter.isNull}) {
+              $convertedValues[$index] = null;
+            } else {
+              $convertedValues[$index] = ($convertedValueType) 
${genValueConverter.value};
+            }
+
+            $index++;
+          }
+
+          ${ev.value} = new $mapCls(new $arrayCls($convertedKeys), new 
$arrayCls($convertedValues));
+        }
+      """
+    ev.copy(code = code, isNull = inputMap.isNull)
+  }
+}
+
 /**
  * Constructs a new external row, using the result of evaluating the specified 
expressions
  * as content.

http://git-wip-us.apache.org/repos/asf/spark/blob/6959061f/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
index a1f9259..4df9062 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoderSuite.scala
@@ -328,6 +328,12 @@ class ExpressionEncoderSuite extends PlanTest with 
AnalysisTest {
     }
   }
 
+  test("null check for map key") {
+    val encoder = ExpressionEncoder[Map[String, Int]]()
+    val e = intercept[RuntimeException](encoder.toRow(Map(("a", 1), (null, 
2))))
+    assert(e.getMessage.contains("Cannot use null as map key"))
+  }
+
   private def encodeDecodeTest[T : ExpressionEncoder](
       input: T,
       testName: String): Unit = {

http://git-wip-us.apache.org/repos/asf/spark/blob/6959061f/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java 
b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index a711811..96e8fb06 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -497,6 +497,8 @@ public class JavaDatasetSuite implements Serializable {
     private String[] d;
     private List<String> e;
     private List<Long> f;
+    private Map<Integer, String> g;
+    private Map<List<Long>, Map<String, String>> h;
 
     public boolean isA() {
       return a;
@@ -546,6 +548,22 @@ public class JavaDatasetSuite implements Serializable {
       this.f = f;
     }
 
+    public Map<Integer, String> getG() {
+      return g;
+    }
+
+    public void setG(Map<Integer, String> g) {
+      this.g = g;
+    }
+
+    public Map<List<Long>, Map<String, String>> getH() {
+      return h;
+    }
+
+    public void setH(Map<List<Long>, Map<String, String>> h) {
+      this.h = h;
+    }
+
     @Override
     public boolean equals(Object o) {
       if (this == o) return true;
@@ -558,7 +576,10 @@ public class JavaDatasetSuite implements Serializable {
       if (!Arrays.equals(c, that.c)) return false;
       if (!Arrays.equals(d, that.d)) return false;
       if (!e.equals(that.e)) return false;
-      return f.equals(that.f);
+      if (!f.equals(that.f)) return false;
+      if (!g.equals(that.g)) return false;
+      return h.equals(that.h);
+
     }
 
     @Override
@@ -569,6 +590,8 @@ public class JavaDatasetSuite implements Serializable {
       result = 31 * result + Arrays.hashCode(d);
       result = 31 * result + e.hashCode();
       result = 31 * result + f.hashCode();
+      result = 31 * result + g.hashCode();
+      result = 31 * result + h.hashCode();
       return result;
     }
   }
@@ -648,6 +671,17 @@ public class JavaDatasetSuite implements Serializable {
     obj1.setD(new String[]{"hello", null});
     obj1.setE(Arrays.asList("a", "b"));
     obj1.setF(Arrays.asList(100L, null, 200L));
+    Map<Integer, String> map1 = new HashMap<Integer, String>();
+    map1.put(1, "a");
+    map1.put(2, "b");
+    obj1.setG(map1);
+    Map<String, String> nestedMap1 = new HashMap<String, String>();
+    nestedMap1.put("x", "1");
+    nestedMap1.put("y", "2");
+    Map<List<Long>, Map<String, String>> complexMap1 = new HashMap<>();
+    complexMap1.put(Arrays.asList(1L, 2L), nestedMap1);
+    obj1.setH(complexMap1);
+
     SimpleJavaBean obj2 = new SimpleJavaBean();
     obj2.setA(false);
     obj2.setB(30);
@@ -655,6 +689,16 @@ public class JavaDatasetSuite implements Serializable {
     obj2.setD(new String[]{null, "world"});
     obj2.setE(Arrays.asList("x", "y"));
     obj2.setF(Arrays.asList(300L, null, 400L));
+    Map<Integer, String> map2 = new HashMap<Integer, String>();
+    map2.put(3, "c");
+    map2.put(4, "d");
+    obj2.setG(map2);
+    Map<String, String> nestedMap2 = new HashMap<String, String>();
+    nestedMap2.put("q", "1");
+    nestedMap2.put("w", "2");
+    Map<List<Long>, Map<String, String>> complexMap2 = new HashMap<>();
+    complexMap2.put(Arrays.asList(3L, 4L), nestedMap2);
+    obj2.setH(complexMap2);
 
     List<SimpleJavaBean> data = Arrays.asList(obj1, obj2);
     Dataset<SimpleJavaBean> ds = spark.createDataset(data, 
Encoders.bean(SimpleJavaBean.class));
@@ -673,21 +717,27 @@ public class JavaDatasetSuite implements Serializable {
       new byte[]{1, 2},
       new String[]{"hello", null},
       Arrays.asList("a", "b"),
-      Arrays.asList(100L, null, 200L)});
+      Arrays.asList(100L, null, 200L),
+      map1,
+      complexMap1});
     Row row2 = new GenericRow(new Object[]{
       false,
       30,
       new byte[]{3, 4},
       new String[]{null, "world"},
       Arrays.asList("x", "y"),
-      Arrays.asList(300L, null, 400L)});
+      Arrays.asList(300L, null, 400L),
+      map2,
+      complexMap2});
     StructType schema = new StructType()
       .add("a", BooleanType, false)
       .add("b", IntegerType, false)
       .add("c", BinaryType)
       .add("d", createArrayType(StringType))
       .add("e", createArrayType(StringType))
-      .add("f", createArrayType(LongType));
+      .add("f", createArrayType(LongType))
+      .add("g", createMapType(IntegerType, StringType))
+      .add("h",createMapType(createArrayType(LongType), 
createMapType(StringType, StringType)));
     Dataset<SimpleJavaBean> ds3 = spark.createDataFrame(Arrays.asList(row1, 
row2), schema)
       .as(Encoders.bean(SimpleJavaBean.class));
     Assert.assertEquals(data, ds3.collectAsList());


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to