Repository: spark
Updated Branches:
  refs/heads/master c7d014861 -> 0e36ba621


[SPARK-22644][ML][TEST] Make ML testsuite support StructuredStreaming test

## What changes were proposed in this pull request?

We need to add some helper code to make testing ML transformers & models easier 
with streaming data. These tests might help us catch any remaining issues and 
we could encourage future PRs to use these tests to prevent new Models & 
Transformers from having issues.

I add a `MLTest` trait which extends `StreamTest` trait, and override 
`createSparkSession`. So ML testsuite can only extend `MLTest`, to use both ML 
& Stream test util functions.

I only modify one testcase in `LinearRegressionSuite`, for first pass review.

Link to #19746

## How was this patch tested?

`MLTestSuite` added.

Author: WeichenXu <weichen...@databricks.com>

Closes #19843 from WeichenXu123/ml_stream_test_helper.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/0e36ba62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/0e36ba62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/0e36ba62

Branch: refs/heads/master
Commit: 0e36ba6212bc24b3185e385914fbf2d62cbfb6da
Parents: c7d0148
Author: WeichenXu <weichen...@databricks.com>
Authored: Tue Dec 12 21:28:24 2017 -0800
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Tue Dec 12 21:28:24 2017 -0800

----------------------------------------------------------------------
 mllib/pom.xml                                   | 14 +++
 .../ml/regression/LinearRegressionSuite.scala   |  8 +-
 .../scala/org/apache/spark/ml/util/MLTest.scala | 91 ++++++++++++++++++++
 .../org/apache/spark/ml/util/MLTestSuite.scala  | 47 ++++++++++
 .../apache/spark/sql/streaming/StreamTest.scala | 67 +++++++++-----
 .../apache/spark/sql/test/TestSQLContext.scala  |  2 +-
 6 files changed, 203 insertions(+), 26 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/mllib/pom.xml
----------------------------------------------------------------------
diff --git a/mllib/pom.xml b/mllib/pom.xml
index 925b542..a906c9e 100644
--- a/mllib/pom.xml
+++ b/mllib/pom.xml
@@ -62,6 +62,20 @@
     </dependency>
     <dependency>
       <groupId>org.apache.spark</groupId>
+      <artifactId>spark-catalyst_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
+      <artifactId>spark-sql_${scala.binary.version}</artifactId>
+      <version>${project.version}</version>
+      <type>test-jar</type>
+      <scope>test</scope>
+    </dependency>
+    <dependency>
+      <groupId>org.apache.spark</groupId>
       <artifactId>spark-graphx_${scala.binary.version}</artifactId>
       <version>${project.version}</version>
     </dependency>

http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 0e0be58..aec5ac0 100644
--- 
a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ 
b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -24,13 +24,12 @@ import org.apache.spark.ml.feature.Instance
 import org.apache.spark.ml.feature.LabeledPoint
 import org.apache.spark.ml.linalg.{DenseVector, Vector, Vectors}
 import org.apache.spark.ml.param.{ParamMap, ParamsSuite}
-import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTest, MLTestingUtils}
 import org.apache.spark.ml.util.TestingUtils._
 import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
 import org.apache.spark.sql.{DataFrame, Row}
 
-class LinearRegressionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+class LinearRegressionSuite extends MLTest with DefaultReadWriteTest {
 
   import testImplicits._
 
@@ -233,7 +232,8 @@ class LinearRegressionSuite
       assert(model2.intercept ~== interceptR relTol 1E-3)
       assert(model2.coefficients ~= coefficientsR relTol 1E-3)
 
-      model1.transform(datasetWithDenseFeature).select("features", 
"prediction").collect().foreach {
+      testTransformer[(Double, Vector)](datasetWithDenseFeature, model1,
+        "features", "prediction") {
         case Row(features: DenseVector, prediction1: Double) =>
           val prediction2 =
             features(0) * model1.coefficients(0) + features(1) * 
model1.coefficients(1) +

http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
new file mode 100644
index 0000000..7a5426e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTest.scala
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import java.io.File
+
+import org.scalatest.Suite
+
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.sql.{DataFrame, Encoder, Row}
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+import org.apache.spark.sql.test.TestSparkSession
+import org.apache.spark.util.Utils
+
+trait MLTest extends StreamTest with TempDirectory { self: Suite =>
+
+  @transient var sc: SparkContext = _
+  @transient var checkpointDir: String = _
+
+  protected override def createSparkSession: TestSparkSession = {
+    new TestSparkSession(new SparkContext("local[2]", "MLlibUnitTest", 
sparkConf))
+  }
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+    sc = spark.sparkContext
+    checkpointDir = Utils.createDirectory(tempDir.getCanonicalPath, 
"checkpoints").toString
+    sc.setCheckpointDir(checkpointDir)
+  }
+
+  override def afterAll() {
+    try {
+      Utils.deleteRecursively(new File(checkpointDir))
+    } finally {
+      super.afterAll()
+    }
+  }
+
+  def testTransformerOnStreamData[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+
+    val columnNames = dataframe.schema.fieldNames
+    val stream = MemoryStream[A]
+    val streamDF = stream.toDS().toDF(columnNames: _*)
+
+    val data = dataframe.as[A].collect()
+
+    val streamOutput = transformer.transform(streamDF)
+      .select(firstResultCol, otherResultCols: _*)
+    testStream(streamOutput) (
+      AddData(stream, data: _*),
+      CheckAnswer(checkFunction)
+    )
+  }
+
+  def testTransformer[A : Encoder](
+      dataframe: DataFrame,
+      transformer: Transformer,
+      firstResultCol: String,
+      otherResultCols: String*)
+      (checkFunction: Row => Unit): Unit = {
+    testTransformerOnStreamData(dataframe, transformer, firstResultCol,
+      otherResultCols: _*)(checkFunction)
+
+    val dfOutput = transformer.transform(dataframe)
+    dfOutput.select(firstResultCol, otherResultCols: _*).collect().foreach { 
row =>
+      checkFunction(row)
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
new file mode 100644
index 0000000..56217ec
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/util/MLTestSuite.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.util
+
+import org.apache.spark.ml.{PipelineModel, Transformer}
+import org.apache.spark.ml.feature.StringIndexer
+import org.apache.spark.sql.Row
+
+class MLTestSuite extends MLTest {
+
+  import testImplicits._
+
+  test("test transformer on stream data") {
+
+    val data = Seq((0, "a"), (1, "b"), (2, "c"), (3, "d"), (4, "e"), (5, "f"))
+      .toDF("id", "label")
+    val indexer = new StringIndexer().setStringOrderType("alphabetAsc")
+      .setInputCol("label").setOutputCol("indexed")
+    val indexerModel = indexer.fit(data)
+    testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", 
"indexed") {
+      case Row(id: Int, indexed: Double) =>
+        assert(id === indexed.toInt)
+    }
+
+    intercept[Exception] {
+      testTransformerOnStreamData[(Int, String)](data, indexerModel, "id", 
"indexed") {
+        case Row(id: Int, indexed: Double) =>
+          assert(id != indexed.toInt)
+      }
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
index e68fca0..dc5b998 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamTest.scala
@@ -133,6 +133,9 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, false, 
false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, false)
   }
 
   /**
@@ -154,6 +157,9 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     }
 
     def apply(rows: Row*): CheckAnswerRows = CheckAnswerRows(rows, true, false)
+
+    def apply(checkFunction: Row => Unit): CheckAnswerRowsByFunc =
+      CheckAnswerRowsByFunc(checkFunction, true)
   }
 
   case class CheckAnswerRows(expectedAnswer: Seq[Row], lastOnly: Boolean, 
isSorted: Boolean)
@@ -162,6 +168,12 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
     private def operatorName = if (lastOnly) "CheckLastBatch" else 
"CheckAnswer"
   }
 
+  case class CheckAnswerRowsByFunc(checkFunction: Row => Unit, lastOnly: 
Boolean)
+      extends StreamAction with StreamMustBeRunning {
+    override def toString: String = s"$operatorName: 
${checkFunction.toString()}"
+    private def operatorName = if (lastOnly) "CheckLastBatchByFunc" else 
"CheckAnswerByFunc"
+  }
+
   /** Stops the stream. It must currently be running. */
   case object StopStream extends StreamAction with StreamMustBeRunning
 
@@ -352,6 +364,29 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
          """.stripMargin)
     }
 
+    def fetchStreamAnswer(currentStream: StreamExecution, lastOnly: Boolean) = 
{
+      verify(currentStream != null, "stream not running")
+      // Get the map of source index to the current source objects
+      val indexToSource = currentStream
+        .logicalPlan
+        .collect { case StreamingExecutionRelation(s, _) => s }
+        .zipWithIndex
+        .map(_.swap)
+        .toMap
+
+      // Block until all data added has been processed for all the source
+      awaiting.foreach { case (sourceIndex, offset) =>
+        failAfter(streamingTimeout) {
+          currentStream.awaitOffset(indexToSource(sourceIndex), offset)
+        }
+      }
+
+      try if (lastOnly) sink.latestBatchData else sink.allData catch {
+        case e: Exception =>
+          failTest("Exception while getting data from sink", e)
+      }
+    }
+
     var manualClockExpectedTime = -1L
     val defaultCheckpointLocation =
       Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
@@ -552,30 +587,20 @@ trait StreamTest extends QueryTest with SharedSQLContext 
with TimeLimits with Be
             e.runAction()
 
           case CheckAnswerRows(expectedAnswer, lastOnly, isSorted) =>
-            verify(currentStream != null, "stream not running")
-            // Get the map of source index to the current source objects
-            val indexToSource = currentStream
-              .logicalPlan
-              .collect { case StreamingExecutionRelation(s, _) => s }
-              .zipWithIndex
-              .map(_.swap)
-              .toMap
-
-            // Block until all data added has been processed for all the source
-            awaiting.foreach { case (sourceIndex, offset) =>
-              failAfter(streamingTimeout) {
-                currentStream.awaitOffset(indexToSource(sourceIndex), offset)
-              }
-            }
-
-            val sparkAnswer = try if (lastOnly) sink.latestBatchData else 
sink.allData catch {
-              case e: Exception =>
-                failTest("Exception while getting data from sink", e)
-            }
-
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
             QueryTest.sameRows(expectedAnswer, sparkAnswer, isSorted).foreach {
               error => failTest(error)
             }
+
+          case CheckAnswerRowsByFunc(checkFunction, lastOnly) =>
+            val sparkAnswer = fetchStreamAnswer(currentStream, lastOnly)
+            sparkAnswer.foreach { row =>
+              try {
+                checkFunction(row)
+              } catch {
+                case e: Throwable => failTest(e.toString)
+              }
+            }
         }
         pos += 1
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/0e36ba62/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 959edf9..4286e8a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.internal.{SessionState, 
SessionStateBuilder, SQLConf
 /**
  * A special `SparkSession` prepared for testing.
  */
-private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) 
{ self =>
+private[spark] class TestSparkSession(sc: SparkContext) extends 
SparkSession(sc) { self =>
   def this(sparkConf: SparkConf) {
     this(new SparkContext("local[2]", "test-sql-context",
       sparkConf.set("spark.sql.testkey", "true")))


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to