Repository: spark
Updated Branches:
  refs/heads/master 3bdbbc6c9 -> 5a5f65905


http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 0849624..aebb390 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -34,6 +34,13 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       data: _*)
   }
 
+  test("as tuple") {
+    val data = Seq(("a", 1), ("b", 2)).toDF("a", "b")
+    checkAnswer(
+      data.as[(String, Int)],
+      ("a", 1), ("b", 2))
+  }
+
   test("as case class / collect") {
     val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDF("a", "b").as[ClassData]
     checkAnswer(
@@ -61,14 +68,40 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
       2, 3, 4)
   }
 
-  test("select 3") {
+  test("select 2") {
     val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
     checkAnswer(
       ds.select(
         expr("_1").as[String],
-        expr("_2").as[Int],
-        expr("_2 + 1").as[Int]),
-      ("a", 1, 2), ("b", 2, 3), ("c", 3, 4))
+        expr("_2").as[Int]) : Dataset[(String, Int)],
+      ("a", 1), ("b", 2), ("c", 3))
+  }
+
+  test("select 2, primitive and tuple") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkAnswer(
+      ds.select(
+        expr("_1").as[String],
+        expr("struct(_2, _2)").as[(Int, Int)]),
+      ("a", (1, 1)), ("b", (2, 2)), ("c", (3, 3)))
+  }
+
+  test("select 2, primitive and class") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkAnswer(
+      ds.select(
+        expr("_1").as[String],
+        expr("named_struct('a', _1, 'b', _2)").as[ClassData]),
+      ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 
3)))
+  }
+
+  test("select 2, primitive and class, fields reordered") {
+    val ds = Seq(("a", 1) , ("b", 2), ("c", 3)).toDS()
+    checkDecoding(
+      ds.select(
+        expr("_1").as[String],
+        expr("named_struct('b', _2, 'a', _1)").as[ClassData]),
+      ("a", ClassData("a", 1)), ("b", ClassData("b", 2)), ("c", ClassData("c", 
3)))
   }
 
   test("filter") {
@@ -102,6 +135,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext 
{
     assert(ds.fold(("", 0))((a, b) => ("sum", a._2 + b._2)) == ("sum", 6))
   }
 
+  test("joinWith, flat schema") {
+    val ds1 = Seq(1, 2, 3).toDS().as("a")
+    val ds2 = Seq(1, 2).toDS().as("b")
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"a.value" === $"b.value"),
+      (1, 1), (2, 2))
+  }
+
+  test("joinWith, expression condition") {
+    val ds1 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"_1" === $"a"),
+      (ClassData("a", 1), ("a", 1)), (ClassData("b", 2), ("b", 2)))
+  }
+
+  test("joinWith tuple with primitive, expression") {
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"value" === $"_2"),
+      (1, ("a", 1)), (1, ("a", 1)), (2, ("b", 2)))
+  }
+
+  test("joinWith class with primitive, toDF") {
+    val ds1 = Seq(1, 1, 2).toDS()
+    val ds2 = Seq(ClassData("a", 1), ClassData("b", 2)).toDS()
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"value" === $"b").toDF().select($"_1", $"_2.a", 
$"_2.b"),
+      Row(1, "a", 1) :: Row(1, "a", 1) :: Row(2, "b", 2) :: Nil)
+  }
+
+  test("multi-level joinWith") {
+    val ds1 = Seq(("a", 1), ("b", 2)).toDS().as("a")
+    val ds2 = Seq(("a", 1), ("b", 2)).toDS().as("b")
+    val ds3 = Seq(("a", 1), ("b", 2)).toDS().as("c")
+
+    checkAnswer(
+      ds1.joinWith(ds2, $"a._2" === $"b._2").as("ab").joinWith(ds3, 
$"ab._1._2" === $"c._2"),
+      ((("a", 1), ("a", 1)), ("a", 1)),
+      ((("b", 2), ("b", 2)), ("b", 2)))
+
+  }
+
   test("groupBy function, keys") {
     val ds = Seq(("a", 1), ("b", 1)).toDS()
     val grouped = ds.groupBy(v => (1, v._2))

http://git-wip-us.apache.org/repos/asf/spark/blob/5a5f6590/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index aba5675..73e02eb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql
 import java.util.{Locale, TimeZone}
 
 import scala.collection.JavaConverters._
-import scala.reflect.runtime.universe._
 
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.util._
 import org.apache.spark.sql.columnar.InMemoryRelation
-import org.apache.spark.sql.catalyst.encoders.{ProductEncoder, Encoder}
+import org.apache.spark.sql.catalyst.encoders.Encoder
 
 abstract class QueryTest extends PlanTest {
 
@@ -55,10 +54,49 @@ abstract class QueryTest extends PlanTest {
     }
   }
 
-  protected def checkAnswer[T : Encoder](ds: => Dataset[T], expectedAnswer: 
T*): Unit = {
+  /**
+   * Evaluates a dataset to make sure that the result of calling collect 
matches the given
+   * expected answer.
+   *  - Special handling is done based on whether the query plan should be 
expected to return
+   *    the results in sorted order.
+   *  - This function also checks to make sure that the schema for serializing 
the expected answer
+   *    matches that produced by the dataset (i.e. does manual construction of 
object match
+   *    the constructed encoder for cases like joins, etc).  Note that this 
means that it will fail
+   *    for cases where reordering is done on fields.  For such tests, user 
`checkDecoding` instead
+   *    which performs a subset of the checks done by this function.
+   */
+  protected def checkAnswer[T : Encoder](
+      ds: => Dataset[T],
+      expectedAnswer: T*): Unit = {
     checkAnswer(
       ds.toDF(),
       sqlContext.createDataset(expectedAnswer).toDF().collect().toSeq)
+
+    checkDecoding(ds, expectedAnswer: _*)
+  }
+
+  protected def checkDecoding[T](
+      ds: => Dataset[T],
+      expectedAnswer: T*): Unit = {
+    val decoded = try ds.collect().toSet catch {
+      case e: Exception =>
+        fail(
+          s"""
+             |Exception collecting dataset as objects
+             |${ds.encoder}
+             |${ds.encoder.constructExpression.treeString}
+             |${ds.queryExecution}
+           """.stripMargin, e)
+    }
+
+    if (decoded != expectedAnswer.toSet) {
+      fail(
+        s"""Decoded objects do not match expected objects:
+           |Expected: ${expectedAnswer.toSet.toSeq.map((a: Any) => 
a.toString).sorted}
+            |Actual ${decoded.toSet.toSeq.map((a: Any) => a.toString).sorted}
+            |${ds.encoder.constructExpression.treeString}
+         """.stripMargin)
+    }
   }
 
   /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to