Repository: spark
Updated Branches:
  refs/heads/branch-1.6 38fe092ff -> fbe65c592


[SPARK-11754][SQL] consolidate `ExpressionEncoder.tuple` and `Encoders.tuple`

These 2 are very similar, we can consolidate them into one.

Also add tests for it and fix a bug.

Author: Wenchen Fan <wenc...@databricks.com>

Closes #9729 from cloud-fan/tuple.

(cherry picked from commit b1a9662623951079e80bd7498e064c4cae4977e9)
Signed-off-by: Michael Armbrust <mich...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/fbe65c59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/fbe65c59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/fbe65c59

Branch: refs/heads/branch-1.6
Commit: fbe65c5924d2f5f4789bf54a1da0a7b6bbf1eb42
Parents: 38fe092
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Nov 16 12:45:34 2015 -0800
Committer: Michael Armbrust <mich...@databricks.com>
Committed: Mon Nov 16 12:46:26 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/Encoder.scala    |  95 ++++-------------
 .../catalyst/encoders/ExpressionEncoder.scala   | 104 ++++++++++---------
 .../catalyst/encoders/ProductEncoderSuite.scala |  29 ++++++
 3 files changed, 108 insertions(+), 120 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/fbe65c59/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
----------------------------------------------------------------------
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
index 5f619d6..c8b017e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/Encoder.scala
@@ -19,10 +19,8 @@ package org.apache.spark.sql
 
 import scala.reflect.ClassTag
 
-import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.{ObjectType, StructField, StructType}
-import org.apache.spark.util.Utils
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor}
+import org.apache.spark.sql.types.StructType
 
 /**
  * Used to convert a JVM object of type `T` to and from the internal Spark SQL 
representation.
@@ -49,83 +47,34 @@ object Encoders {
   def DOUBLE: Encoder[java.lang.Double] = ExpressionEncoder(flat = true)
   def STRING: Encoder[java.lang.String] = ExpressionEncoder(flat = true)
 
-  def tuple[T1, T2](enc1: Encoder[T1], enc2: Encoder[T2]): Encoder[(T1, T2)] = 
{
-    tuple(Seq(enc1, enc2).map(_.asInstanceOf[ExpressionEncoder[_]]))
-      .asInstanceOf[ExpressionEncoder[(T1, T2)]]
+  def tuple[T1, T2](
+      e1: Encoder[T1],
+      e2: Encoder[T2]): Encoder[(T1, T2)] = {
+    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2))
   }
 
   def tuple[T1, T2, T3](
-      enc1: Encoder[T1],
-      enc2: Encoder[T2],
-      enc3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
-    tuple(Seq(enc1, enc2, enc3).map(_.asInstanceOf[ExpressionEncoder[_]]))
-      .asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3]): Encoder[(T1, T2, T3)] = {
+    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3))
   }
 
   def tuple[T1, T2, T3, T4](
-      enc1: Encoder[T1],
-      enc2: Encoder[T2],
-      enc3: Encoder[T3],
-      enc4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
-    tuple(Seq(enc1, enc2, enc3, 
enc4).map(_.asInstanceOf[ExpressionEncoder[_]]))
-      .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4)]]
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3],
+      e4: Encoder[T4]): Encoder[(T1, T2, T3, T4)] = {
+    ExpressionEncoder.tuple(encoderFor(e1), encoderFor(e2), encoderFor(e3), 
encoderFor(e4))
   }
 
   def tuple[T1, T2, T3, T4, T5](
-      enc1: Encoder[T1],
-      enc2: Encoder[T2],
-      enc3: Encoder[T3],
-      enc4: Encoder[T4],
-      enc5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
-    tuple(Seq(enc1, enc2, enc3, enc4, 
enc5).map(_.asInstanceOf[ExpressionEncoder[_]]))
-      .asInstanceOf[ExpressionEncoder[(T1, T2, T3, T4, T5)]]
-  }
-
-  private def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] 
= {
-    assert(encoders.length > 1)
-    // make sure all encoders are resolved, i.e. `Attribute` has been resolved 
to `BoundReference`.
-    
assert(encoders.forall(_.fromRowExpression.find(_.isInstanceOf[Attribute]).isEmpty))
-
-    val schema = StructType(encoders.zipWithIndex.map {
-      case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
-    })
-
-    val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
-
-    val extractExpressions = encoders.map {
-      case e if e.flat => e.toRowExpressions.head
-      case other => CreateStruct(other.toRowExpressions)
-    }.zipWithIndex.map { case (expr, index) =>
-      expr.transformUp {
-        case BoundReference(0, t: ObjectType, _) =>
-          Invoke(
-            BoundReference(0, ObjectType(cls), nullable = true),
-            s"_${index + 1}",
-            t)
-      }
-    }
-
-    val constructExpressions = encoders.zipWithIndex.map { case (enc, index) =>
-      if (enc.flat) {
-        enc.fromRowExpression.transform {
-          case b: BoundReference => b.copy(ordinal = index)
-        }
-      } else {
-        enc.fromRowExpression.transformUp {
-          case BoundReference(ordinal, dt, _) =>
-            GetInternalRowField(BoundReference(index, enc.schema, nullable = 
true), ordinal, dt)
-        }
-      }
-    }
-
-    val constructExpression =
-      NewInstance(cls, constructExpressions, propagateNull = false, 
ObjectType(cls))
-
-    new ExpressionEncoder[Any](
-      schema,
-      flat = false,
-      extractExpressions,
-      constructExpression,
-      ClassTag(cls))
+      e1: Encoder[T1],
+      e2: Encoder[T2],
+      e3: Encoder[T3],
+      e4: Encoder[T4],
+      e5: Encoder[T5]): Encoder[(T1, T2, T3, T4, T5)] = {
+    ExpressionEncoder.tuple(
+      encoderFor(e1), encoderFor(e2), encoderFor(e3), encoderFor(e4), 
encoderFor(e5))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/fbe65c59/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
index 0d3e4aa..9a1a8f5 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala
@@ -67,47 +67,77 @@ object ExpressionEncoder {
   def tuple(encoders: Seq[ExpressionEncoder[_]]): ExpressionEncoder[_] = {
     encoders.foreach(_.assertUnresolved())
 
-    val schema =
-      StructType(
-        encoders.zipWithIndex.map {
-          case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
-        })
+    val schema = StructType(encoders.zipWithIndex.map {
+      case (e, i) => StructField(s"_${i + 1}", if (e.flat) 
e.schema.head.dataType else e.schema)
+    })
+
     val cls = 
Utils.getContextOrSparkClassLoader.loadClass(s"scala.Tuple${encoders.size}")
 
-    // Rebind the encoders to the nested schema.
-    val newConstructExpressions = encoders.zipWithIndex.map {
-      case (e, i) if !e.flat => e.nested(i).fromRowExpression
-      case (e, i) => e.shift(i).fromRowExpression
+    val toRowExpressions = encoders.map {
+      case e if e.flat => e.toRowExpressions.head
+      case other => CreateStruct(other.toRowExpressions)
+    }.zipWithIndex.map { case (expr, index) =>
+      expr.transformUp {
+        case BoundReference(0, t, _) =>
+          Invoke(
+            BoundReference(0, ObjectType(cls), nullable = true),
+            s"_${index + 1}",
+            t)
+      }
     }
 
-    val constructExpression =
-      NewInstance(cls, newConstructExpressions, false, ObjectType(cls))
-
-    val input = BoundReference(0, ObjectType(cls), false)
-    val extractExpressions = encoders.zipWithIndex.map {
-      case (e, i) if !e.flat => CreateStruct(e.toRowExpressions.map(_ 
transformUp {
-        case b: BoundReference =>
-          Invoke(input, s"_${i + 1}", b.dataType, Nil)
-      }))
-      case (e, i) => e.toRowExpressions.head transformUp {
-        case b: BoundReference =>
-          Invoke(input, s"_${i + 1}", b.dataType, Nil)
+    val fromRowExpressions = encoders.zipWithIndex.map { case (enc, index) =>
+      if (enc.flat) {
+        enc.fromRowExpression.transform {
+          case b: BoundReference => b.copy(ordinal = index)
+        }
+      } else {
+        val input = BoundReference(index, enc.schema, nullable = true)
+        enc.fromRowExpression.transformUp {
+          case UnresolvedAttribute(nameParts) =>
+            assert(nameParts.length == 1)
+            UnresolvedExtractValue(input, Literal(nameParts.head))
+          case BoundReference(ordinal, dt, _) => GetInternalRowField(input, 
ordinal, dt)
+        }
       }
     }
 
+    val fromRowExpression =
+      NewInstance(cls, fromRowExpressions, propagateNull = false, 
ObjectType(cls))
+
     new ExpressionEncoder[Any](
       schema,
-      false,
-      extractExpressions,
-      constructExpression,
-      ClassTag.apply(cls))
+      flat = false,
+      toRowExpressions,
+      fromRowExpression,
+      ClassTag(cls))
   }
 
-  /** A helper for producing encoders of Tuple2 from other encoders. */
   def tuple[T1, T2](
       e1: ExpressionEncoder[T1],
       e2: ExpressionEncoder[T2]): ExpressionEncoder[(T1, T2)] =
-    tuple(e1 :: e2 :: Nil).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+    tuple(Seq(e1, e2)).asInstanceOf[ExpressionEncoder[(T1, T2)]]
+
+  def tuple[T1, T2, T3](
+      e1: ExpressionEncoder[T1],
+      e2: ExpressionEncoder[T2],
+      e3: ExpressionEncoder[T3]): ExpressionEncoder[(T1, T2, T3)] =
+    tuple(Seq(e1, e2, e3)).asInstanceOf[ExpressionEncoder[(T1, T2, T3)]]
+
+  def tuple[T1, T2, T3, T4](
+      e1: ExpressionEncoder[T1],
+      e2: ExpressionEncoder[T2],
+      e3: ExpressionEncoder[T3],
+      e4: ExpressionEncoder[T4]): ExpressionEncoder[(T1, T2, T3, T4)] =
+    tuple(Seq(e1, e2, e3, e4)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, 
T4)]]
+
+  def tuple[T1, T2, T3, T4, T5](
+      e1: ExpressionEncoder[T1],
+      e2: ExpressionEncoder[T2],
+      e3: ExpressionEncoder[T3],
+      e4: ExpressionEncoder[T4],
+      e5: ExpressionEncoder[T5]): ExpressionEncoder[(T1, T2, T3, T4, T5)] =
+    tuple(Seq(e1, e2, e3, e4, e5)).asInstanceOf[ExpressionEncoder[(T1, T2, T3, 
T4, T5)]]
 }
 
 /**
@@ -208,26 +238,6 @@ case class ExpressionEncoder[T](
     })
   }
 
-  /**
-   * Returns a copy of this encoder where the expressions used to create an 
object given an
-   * input row have been modified to pull the object out from a nested struct, 
instead of the
-   * top level fields.
-   */
-  private def nested(i: Int): ExpressionEncoder[T] = {
-    // We don't always know our input type at this point since it might be 
unresolved.
-    // We fill in null and it will get unbound to the actual attribute at this 
position.
-    val input = BoundReference(i, NullType, nullable = true)
-    copy(fromRowExpression = fromRowExpression transformUp {
-      case u: Attribute =>
-        UnresolvedExtractValue(input, Literal(u.name))
-      case b: BoundReference =>
-        GetStructField(
-          input,
-          StructField(s"i[${b.ordinal}]", b.dataType),
-          b.ordinal)
-    })
-  }
-
   protected val attrs = toRowExpressions.flatMap(_.collect {
     case _: UnresolvedAttribute => ""
     case a: Attribute => s"#${a.exprId}"

http://git-wip-us.apache.org/repos/asf/spark/blob/fbe65c59/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
index fda978e..bc539d6 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/encoders/ProductEncoderSuite.scala
@@ -117,6 +117,35 @@ class ProductEncoderSuite extends ExpressionEncoderSuite {
   productTest(("Seq[Seq[(Int, Int)]]",
     Seq(Seq((1, 2)))))
 
+  encodeDecodeTest(
+    1 -> 10L,
+    ExpressionEncoder.tuple(FlatEncoder[Int], FlatEncoder[Long]),
+    "tuple with 2 flat encoders")
+
+  encodeDecodeTest(
+    (PrimitiveData(1, 1, 1, 1, 1, 1, true), (3, 30L)),
+    ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], 
ProductEncoder[(Int, Long)]),
+    "tuple with 2 product encoders")
+
+  encodeDecodeTest(
+    (PrimitiveData(1, 1, 1, 1, 1, 1, true), 3),
+    ExpressionEncoder.tuple(ProductEncoder[PrimitiveData], FlatEncoder[Int]),
+    "tuple with flat encoder and product encoder")
+
+  encodeDecodeTest(
+    (3, PrimitiveData(1, 1, 1, 1, 1, 1, true)),
+    ExpressionEncoder.tuple(FlatEncoder[Int], ProductEncoder[PrimitiveData]),
+    "tuple with product encoder and flat encoder")
+
+  encodeDecodeTest(
+    (1, (10, 100L)),
+    {
+      val intEnc = FlatEncoder[Int]
+      val longEnc = FlatEncoder[Long]
+      ExpressionEncoder.tuple(intEnc, ExpressionEncoder.tuple(intEnc, longEnc))
+    },
+    "nested tuple encoder")
+
   private def productTest[T <: Product : TypeTag](input: T): Unit = {
     encodeDecodeTest(input, ProductEncoder[T], input.getClass.getSimpleName)
   }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to