Repository: spark
Updated Branches:
  refs/heads/master 060a28c63 -> 43ebf7a9c


[SPARK-13456][SQL] fix creating encoders for case classes defined in Spark shell

## What changes were proposed in this pull request?

case classes defined in REPL are wrapped by line classes, and we have a trick 
for scala 2.10 REPL to automatically register the wrapper classes to 
`OuterScope` so that we can use when create encoders.
However, this trick doesn't work right after we upgrade to scala 2.11, and 
unfortunately the tests are only in scala 2.10, which makes this bug hidden 
until now.

This PR moves the encoder tests to scala 2.11  `ReplSuite`, and fixes this bug 
by another approach(the previous trick can't port to scala 2.11 REPL): make 
`OuterScope` smarter that can detect classes defined in REPL and load the 
singleton of line wrapper classes automatically.

## How was this patch tested?

the migrated encoder tests in `ReplSuite`

Author: Wenchen Fan <wenc...@databricks.com>

Closes #11410 from cloud-fan/repl.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/43ebf7a9
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/43ebf7a9
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/43ebf7a9

Branch: refs/heads/master
Commit: 43ebf7a9cbd70d6af75e140a6fc91bf0ffc2b877
Parents: 060a28c
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Mar 21 10:37:24 2016 -0700
Committer: Yin Huai <yh...@databricks.com>
Committed: Mon Mar 21 10:37:24 2016 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/repl/ReplSuite.scala |  4 +-
 .../scala/org/apache/spark/repl/ReplSuite.scala | 68 +++++++++++++++++++-
 .../spark/sql/catalyst/analysis/Analyzer.scala  |  2 +-
 .../sql/catalyst/encoders/OuterScopes.scala     | 47 +++++++++++++-
 4 files changed, 115 insertions(+), 6 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/43ebf7a9/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
----------------------------------------------------------------------
diff --git 
a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala 
b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index cbcccb1..6b9aa50 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
         |import org.apache.spark.sql.Encoder
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
-        |val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
+        |val simpleSum = new Aggregator[Int, Int, Int] {
         |  def zero: Int = 0                     // The initial value.
         |  def reduce(b: Int, a: Int) = b + a    // Add an element to the 
running total
         |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
@@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
         |import org.apache.spark.sql.expressions.Aggregator
         |import org.apache.spark.sql.TypedColumn
         |/** An `Aggregator` that adds up any numeric type returned by the 
given function. */
-        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] 
with Serializable {
+        |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
         |  val numeric = implicitly[Numeric[N]]
         |  override def zero: N = numeric.zero
         |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))

http://git-wip-us.apache.org/repos/asf/spark/blob/43ebf7a9/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
----------------------------------------------------------------------
diff --git 
a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala 
b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 6bee880..f148a6d 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
     // We need to use local-cluster to test this case.
     val output = runInterpreter("local-cluster[1,1,1024]",
       """
-        |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
-        |import sqlContext.implicits._
         |case class TestCaseClass(value: Int)
         |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
+        |
+        |// Test Dataset Serialization in the REPL
+        |Seq(TestCaseClass(1)).toDS().collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
+  test("Datasets and encoders") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |val simpleSum = new Aggregator[Int, Int, Int] {
+        |  def zero: Int = 0                     // The initial value.
+        |  def reduce(b: Int, a: Int) = b + a    // Add an element to the 
running total
+        |  def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
+        |  def finish(b: Int) = b                // Return the final result.
+        |}.toColumn
+        |
+        |val ds = Seq(1, 2, 3, 4).toDS()
+        |ds.select(simpleSum).collect
       """.stripMargin)
     assertDoesNotContain("error:", output)
     assertDoesNotContain("Exception", output)
@@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
     }
   }
 
+  test("Datasets agg type-inference") {
+    val output = runInterpreter("local",
+      """
+        |import org.apache.spark.sql.functions._
+        |import org.apache.spark.sql.Encoder
+        |import org.apache.spark.sql.expressions.Aggregator
+        |import org.apache.spark.sql.TypedColumn
+        |/** An `Aggregator` that adds up any numeric type returned by the 
given function. */
+        |class SumOf[I, N : Numeric](f: I => N) extends
+        |  org.apache.spark.sql.expressions.Aggregator[I, N, N] {
+        |  val numeric = implicitly[Numeric[N]]
+        |  override def zero: N = numeric.zero
+        |  override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+        |  override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+        |  override def finish(reduction: N): N = reduction
+        |}
+        |
+        |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new 
SumOf(f).toColumn
+        |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+        |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
+
   test("collecting objects of class defined in repl") {
     val output = runInterpreter("local[2]",
       """
@@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
     assertDoesNotContain("Exception", output)
     assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
   }
+
+  test("line wrapper only initialized once when used as encoder outer scope") {
+    val output = runInterpreter("local",
+      """
+        |val fileName = "repl-test-" + System.currentTimeMillis
+        |val tmpDir = System.getProperty("java.io.tmpdir")
+        |val file = new java.io.File(tmpDir, fileName)
+        |def createFile(): Unit = file.createNewFile()
+        |
+        |createFile();case class TestCaseClass(value: Int)
+        |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
+        |
+        |file.delete()
+      """.stripMargin)
+    assertDoesNotContain("error:", output)
+    assertDoesNotContain("Exception", output)
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/43ebf7a9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index ccc65b4..ebb3a93 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -571,7 +571,7 @@ class Analyzer(
           if n.outerPointer.isEmpty &&
              n.cls.isMemberClass &&
              !Modifier.isStatic(n.cls.getModifiers) =>
-          val outer = 
OuterScopes.outerScopes.get(n.cls.getDeclaringClass.getName)
+          val outer = OuterScopes.getOuterScope(n.cls)
           if (outer == null) {
             throw new AnalysisException(
               s"Unable to generate an encoder for inner class 
`${n.cls.getName}` without " +

http://git-wip-us.apache.org/repos/asf/spark/blob/43ebf7a9/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
index a753b18..c047e96 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/OuterScopes.scala
@@ -21,6 +21,8 @@ import java.util.concurrent.ConcurrentMap
 
 import com.google.common.collect.MapMaker
 
+import org.apache.spark.util.Utils
+
 object OuterScopes {
   @transient
   lazy val outerScopes: ConcurrentMap[String, AnyRef] =
@@ -28,7 +30,7 @@ object OuterScopes {
 
   /**
    * Adds a new outer scope to this context that can be used when 
instantiating an `inner class`
-   * during deserialialization. Inner classes are created when a case class is 
defined in the
+   * during deserialization. Inner classes are created when a case class is 
defined in the
    * Spark REPL and registering the outer scope that this class was defined in 
allows us to create
    * new instances on the spark executors.  In normal use, users should not 
need to call this
    * function.
@@ -39,4 +41,47 @@ object OuterScopes {
   def addOuterScope(outer: AnyRef): Unit = {
     outerScopes.putIfAbsent(outer.getClass.getName, outer)
   }
+
+  def getOuterScope(innerCls: Class[_]): AnyRef = {
+    assert(innerCls.isMemberClass)
+    val outerClassName = innerCls.getDeclaringClass.getName
+    val outer = outerScopes.get(outerClassName)
+    if (outer == null) {
+      outerClassName match {
+        // If the outer class is generated by REPL, users don't need to 
register it as it has
+        // only one instance and there is a way to retrieve it: get the 
`$read` object, call the
+        // `INSTANCE()` method to get the single instance of class `$read`. 
Then call `$iw()`
+        // method multiply times to get the single instance of the inner most 
`$iw` class.
+        case REPLClass(baseClassName) =>
+          val objClass = Utils.classForName(baseClassName + "$")
+          val objInstance = objClass.getField("MODULE$").get(null)
+          val baseInstance = objClass.getMethod("INSTANCE").invoke(objInstance)
+          val baseClass = Utils.classForName(baseClassName)
+
+          var getter = iwGetter(baseClass)
+          var obj = baseInstance
+          while (getter != null) {
+            obj = getter.invoke(obj)
+            getter = iwGetter(getter.getReturnType)
+          }
+
+          outerScopes.putIfAbsent(outerClassName, obj)
+          obj
+        case _ => null
+      }
+    } else {
+      outer
+    }
+  }
+
+  private def iwGetter(cls: Class[_]) = {
+    try {
+      cls.getMethod("$iw")
+    } catch {
+      case _: NoSuchMethodException => null
+    }
+  }
+
+  // The format of REPL generated wrapper class's name, e.g. 
`$line12.$read$$iw$$iw`
+  private[this] val REPLClass = """^(\$line(?:\d+)\.\$read)(?:\$\$iw)+$""".r
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to