This is an automated email from the ASF dual-hosted git repository. gurwls223 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new d2ff10c [SPARK-23674][ML] Adds Spark ML Events to Instrumentation d2ff10c is described below commit d2ff10cbe1c22f919a7b1999fe54db13f4178979 Author: Hyukjin Kwon <gurwls...@apache.org> AuthorDate: Fri Jan 25 10:11:49 2019 +0800 [SPARK-23674][ML] Adds Spark ML Events to Instrumentation ## What changes were proposed in this pull request? This PR proposes to add ML events to Instrumentation, and use it in Pipeline so that other developers can track and add some actions for them. ## Introduction ML events (like SQL events) can be quite useful when people want to track and make some actions for corresponding ML operations. For instance, I have been working on integrating Apache Spark with [Apache Atlas](https://atlas.apache.org/QuickStart.html). With some custom changes with this PR, I can visualise ML pipeline as below: ![spark_ml_streaming_lineage](https://user-images.githubusercontent.com/6477701/49682779-394bca80-faf5-11e8-85b8-5fae28b784b3.png) Another good thing that might have to be considered is, that we can interact this with other SQL/Streaming events. For instance, where the input `Dataset` is originated. For instance, with current Apache Spark, I can visualise SQL operations as below: ![screen shot 2018-12-10 at 9 41 36 am](https://user-images.githubusercontent.com/6477701/49706269-d9bdfe00-fc5f-11e8-943a-3309d1856ba5.png) I think we can combine those existing lineages together to easily understand where the data comes and goes. Currently, ML side is a hole so the lineages can't be connected for the current Apache Spark .. To add up, I think it's not to mention how useful it is to track the SQL/Streaming operations. Likewise, I would like to propose ML events as well (as lowest stability `Unstable` APIs for now - no guarantee about stability). ## Implementation Details ### Sends event (but not expose ML specific listener) **`mllib/src/main/scala/org/apache/spark/ml/events.scala`** ```scala Unstable case class ...StartEvent(caller, input) Unstable case class ...EndEvent(caller, output) trait MLEvents { // Wrappers to send events: // def with...Event(body) = { // body() // SparkContext.getOrCreate().listenerBus.post(event) // } } ``` This trait is used by `Instrumentation`. ```scala class Instrumentation ... with MLEvents { ``` and used as below: ```scala instrumented { instr => instr.with...Event(...) { ... } } ``` This way mimics both: **1. Catalog events (see `org/apache/spark/sql/catalyst/catalog/events.scala`)** - This allows a Catalog specific listener to be added `ExternalCatalogEventListener` - It's implemented in a way of wrapping whole `ExternalCatalog` named `ExternalCatalogWithListener` which delegates the operations to `ExternalCatalog` This is not quite possible in this case because most of instances (like `Pipeline`) will be directly created in most of cases. We might be able to do that via extending `ListenerBus` for all possible instances but IMHO it's too invasive. Also, exposing another ML specific listener sounds a bit too much at this stage. Therefore, I simply borrowed file name and structures here **2. SQL execution events (see `org/apache/spark/sql/execution/SQLExecution.scala`)** - Add an object that wraps a body to send events Current apporach is rather close to this. It has a `with...` wrapper to send events. I borrowed this approach to be consistent. ## Usage It needs a custom implementation for a query listener. For instance, with the custom listener below: ```scala class CustomMLListener extends SparkListener def onOtherEvents(e) = e match { case e: MLEvent => // do something case _ => // pass } } ``` There are two (existing) ways to use this. ```scala spark.sparkContext.addSparkListener(new CustomMLListener) ``` ```bash spark-submit ...\ --conf spark.extraListeners=CustomMLListener\ ... ``` It's also similar with other existing implementation in SQL side. ## Target users 1. I think someone in general would likely utilise this feature like other event listeners. At least, I can see some interests going on outside. - SQL Listener - https://stackoverflow.com/questions/46409339/spark-listener-to-an-sql-query - http://apache-spark-user-list.1001560.n3.nabble.com/spark-sql-Custom-Query-Execution-listener-via-conf-properties-td30979.html - Streaming Query Listener - https://jhui.github.io/2017/01/15/Apache-Spark-Streaming/ - http://apache-spark-developers-list.1001551.n3.nabble.com/Structured-Streaming-with-Watermark-td25413.html#a25416 2. Someone would likely run this via Atlas. The plugin mirror intentionally is exposed at [spark-atlas-connector](https://github.com/hortonworks-spark/spark-atlas-connector) so that anyone could do something about lineage and governance in Atlas. I'm trying to show integrated lineages in Apache Spark but this is a missing hole. ## How was this patch tested? Manually tested and unit tests were added. Closes #23263 from HyukjinKwon/SPARK-23674-1. Authored-by: Hyukjin Kwon <gurwls...@apache.org> Signed-off-by: Hyukjin Kwon <gurwls...@apache.org> --- .../main/scala/org/apache/spark/ml/Pipeline.scala | 48 ++-- .../main/scala/org/apache/spark/ml/events.scala | 137 +++++++++++ .../org/apache/spark/ml/util/Instrumentation.scala | 9 +- .../scala/org/apache/spark/ml/util/ReadWrite.scala | 11 +- .../scala/org/apache/spark/ml/MLEventsSuite.scala | 255 +++++++++++++++++++++ 5 files changed, 436 insertions(+), 24 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index 103082b..69a4dbe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -31,6 +31,7 @@ import org.apache.spark.annotation.{DeveloperApi, Since} import org.apache.spark.internal.Logging import org.apache.spark.ml.param.{Param, ParamMap, Params} import org.apache.spark.ml.util._ +import org.apache.spark.ml.util.Instrumentation.instrumented import org.apache.spark.sql.{DataFrame, Dataset} import org.apache.spark.sql.types.StructType @@ -132,7 +133,8 @@ class Pipeline @Since("1.4.0") ( * @return fitted pipeline */ @Since("2.0.0") - override def fit(dataset: Dataset[_]): PipelineModel = { + override def fit(dataset: Dataset[_]): PipelineModel = instrumented( + instr => instr.withFitEvent(this, dataset) { transformSchema(dataset.schema, logging = true) val theStages = $(stages) // Search for the last estimator. @@ -150,7 +152,7 @@ class Pipeline @Since("1.4.0") ( if (index <= indexOfLastEstimator) { val transformer = stage match { case estimator: Estimator[_] => - estimator.fit(curDataset) + instr.withFitEvent(estimator, curDataset)(estimator.fit(curDataset)) case t: Transformer => t case _ => @@ -158,7 +160,8 @@ class Pipeline @Since("1.4.0") ( s"Does not support stage $stage of type ${stage.getClass}") } if (index < indexOfLastEstimator) { - curDataset = transformer.transform(curDataset) + curDataset = instr.withTransformEvent( + transformer, curDataset)(transformer.transform(curDataset)) } transformers += transformer } else { @@ -167,7 +170,7 @@ class Pipeline @Since("1.4.0") ( } new PipelineModel(uid, transformers.toArray).setParent(this) - } + }) @Since("1.4.0") override def copy(extra: ParamMap): Pipeline = { @@ -197,10 +200,12 @@ object Pipeline extends MLReadable[Pipeline] { @Since("1.6.0") override def load(path: String): Pipeline = super.load(path) - private[Pipeline] class PipelineWriter(instance: Pipeline) extends MLWriter { + private[Pipeline] class PipelineWriter(val instance: Pipeline) extends MLWriter { SharedReadWrite.validateStages(instance.getStages) + override def save(path: String): Unit = + instrumented(_.withSaveInstanceEvent(this, path)(super.save(path))) override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, instance.getStages, sc, path) } @@ -210,10 +215,10 @@ object Pipeline extends MLReadable[Pipeline] { /** Checked against metadata when loading model */ private val className = classOf[Pipeline].getName - override def load(path: String): Pipeline = { + override def load(path: String): Pipeline = instrumented(_.withLoadInstanceEvent(this, path) { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) new Pipeline(uid).setStages(stages) - } + }) } /** @@ -243,7 +248,7 @@ object Pipeline extends MLReadable[Pipeline] { instance: Params, stages: Array[PipelineStage], sc: SparkContext, - path: String): Unit = { + path: String): Unit = instrumented { instr => val stageUids = stages.map(_.uid) val jsonParams = List("stageUids" -> parse(compact(render(stageUids.toSeq)))) DefaultParamsWriter.saveMetadata(instance, path, sc, paramMap = Some(jsonParams)) @@ -251,8 +256,9 @@ object Pipeline extends MLReadable[Pipeline] { // Save stages val stagesDir = new Path(path, "stages").toString stages.zipWithIndex.foreach { case (stage, idx) => - stage.asInstanceOf[MLWritable].write.save( - getStagePath(stage.uid, idx, stages.length, stagesDir)) + val writer = stage.asInstanceOf[MLWritable].write + val stagePath = getStagePath(stage.uid, idx, stages.length, stagesDir) + instr.withSaveInstanceEvent(writer, stagePath)(writer.save(stagePath)) } } @@ -263,7 +269,7 @@ object Pipeline extends MLReadable[Pipeline] { def load( expectedClassName: String, sc: SparkContext, - path: String): (String, Array[PipelineStage]) = { + path: String): (String, Array[PipelineStage]) = instrumented { instr => val metadata = DefaultParamsReader.loadMetadata(path, sc, expectedClassName) implicit val format = DefaultFormats @@ -271,7 +277,8 @@ object Pipeline extends MLReadable[Pipeline] { val stageUids: Array[String] = (metadata.params \ "stageUids").extract[Seq[String]].toArray val stages: Array[PipelineStage] = stageUids.zipWithIndex.map { case (stageUid, idx) => val stagePath = SharedReadWrite.getStagePath(stageUid, idx, stageUids.length, stagesDir) - DefaultParamsReader.loadParamsInstance[PipelineStage](stagePath, sc) + val reader = DefaultParamsReader.loadParamsInstanceReader[PipelineStage](stagePath, sc) + instr.withLoadInstanceEvent(reader, stagePath)(reader.load(stagePath)) } (metadata.uid, stages) } @@ -301,10 +308,12 @@ class PipelineModel private[ml] ( } @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { + override def transform(dataset: Dataset[_]): DataFrame = instrumented(instr => + instr.withTransformEvent(this, dataset) { transformSchema(dataset.schema, logging = true) - stages.foldLeft(dataset.toDF)((cur, transformer) => transformer.transform(cur)) - } + stages.foldLeft(dataset.toDF)((cur, transformer) => + instr.withTransformEvent(transformer, cur)(transformer.transform(cur))) + }) @Since("1.2.0") override def transformSchema(schema: StructType): StructType = { @@ -331,10 +340,12 @@ object PipelineModel extends MLReadable[PipelineModel] { @Since("1.6.0") override def load(path: String): PipelineModel = super.load(path) - private[PipelineModel] class PipelineModelWriter(instance: PipelineModel) extends MLWriter { + private[PipelineModel] class PipelineModelWriter(val instance: PipelineModel) extends MLWriter { SharedReadWrite.validateStages(instance.stages.asInstanceOf[Array[PipelineStage]]) + override def save(path: String): Unit = + instrumented(_.withSaveInstanceEvent(this, path)(super.save(path))) override protected def saveImpl(path: String): Unit = SharedReadWrite.saveImpl(instance, instance.stages.asInstanceOf[Array[PipelineStage]], sc, path) } @@ -344,7 +355,8 @@ object PipelineModel extends MLReadable[PipelineModel] { /** Checked against metadata when loading model */ private val className = classOf[PipelineModel].getName - override def load(path: String): PipelineModel = { + override def load(path: String): PipelineModel = instrumented(_.withLoadInstanceEvent( + this, path) { val (uid: String, stages: Array[PipelineStage]) = SharedReadWrite.load(className, sc, path) val transformers = stages map { case stage: Transformer => stage @@ -352,6 +364,6 @@ object PipelineModel extends MLReadable[PipelineModel] { s" was not a Transformer. Bad stage ${other.uid} of type ${other.getClass}") } new PipelineModel(uid, transformers) - } + }) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala b/mllib/src/main/scala/org/apache/spark/ml/events.scala new file mode 100644 index 0000000..c51600f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala @@ -0,0 +1,137 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import org.apache.spark.SparkContext +import org.apache.spark.annotation.Unstable +import org.apache.spark.internal.Logging +import org.apache.spark.ml.util.{MLReader, MLWriter} +import org.apache.spark.scheduler.SparkListenerEvent +import org.apache.spark.sql.{DataFrame, Dataset} + +/** + * Event emitted by ML operations. Events are either fired before and/or + * after each operation (the event should document this). + * + * @note This is supported via [[Pipeline]] and [[PipelineModel]]. + */ +@Unstable +sealed trait MLEvent extends SparkListenerEvent + +/** + * Event fired before `Transformer.transform`. + */ +@Unstable +case class TransformStart(transformer: Transformer, input: Dataset[_]) extends MLEvent +/** + * Event fired after `Transformer.transform`. + */ +@Unstable +case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends MLEvent + +/** + * Event fired before `Estimator.fit`. + */ +@Unstable +case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: Dataset[_]) extends MLEvent +/** + * Event fired after `Estimator.fit`. + */ +@Unstable +case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends MLEvent + +/** + * Event fired before `MLReader.load`. + */ +@Unstable +case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends MLEvent +/** + * Event fired after `MLReader.load`. + */ +@Unstable +case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent + +/** + * Event fired before `MLWriter.save`. + */ +@Unstable +case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent +/** + * Event fired after `MLWriter.save`. + */ +@Unstable +case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent + +/** + * A small trait that defines some methods to send [[org.apache.spark.ml.MLEvent]]. + */ +private[ml] trait MLEvents extends Logging { + + private def listenerBus = SparkContext.getOrCreate().listenerBus + + /** + * Log [[MLEvent]] to send. By default, it emits a debug-level log. + */ + def logEvent(event: MLEvent): Unit = logDebug(s"Sending an MLEvent: $event") + + def withFitEvent[M <: Model[M]]( + estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = { + val startEvent = FitStart(estimator, dataset) + logEvent(startEvent) + listenerBus.post(startEvent) + val model: M = func + val endEvent = FitEnd(estimator, model) + logEvent(endEvent) + listenerBus.post(endEvent) + model + } + + def withTransformEvent( + transformer: Transformer, input: Dataset[_])(func: => DataFrame): DataFrame = { + val startEvent = TransformStart(transformer, input) + logEvent(startEvent) + listenerBus.post(startEvent) + val output: DataFrame = func + val endEvent = TransformEnd(transformer, output) + logEvent(endEvent) + listenerBus.post(endEvent) + output + } + + def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): T = { + val startEvent = LoadInstanceStart(reader, path) + logEvent(startEvent) + listenerBus.post(startEvent) + val instance: T = func + val endEvent = LoadInstanceEnd(reader, instance) + logEvent(endEvent) + listenerBus.post(endEvent) + instance + } + + def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): Unit = { + listenerBus.post(SaveInstanceEnd(writer, path)) + val startEvent = SaveInstanceStart(writer, path) + logEvent(startEvent) + listenerBus.post(startEvent) + func + val endEvent = SaveInstanceEnd(writer, path) + logEvent(endEvent) + listenerBus.post(endEvent) + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala index 4965491..780650d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/Instrumentation.scala @@ -27,17 +27,18 @@ import org.json4s.JsonDSL._ import org.json4s.jackson.JsonMethods._ import org.apache.spark.internal.Logging -import org.apache.spark.ml.PipelineStage +import org.apache.spark.ml.{MLEvents, PipelineStage} import org.apache.spark.ml.param.{Param, Params} import org.apache.spark.rdd.RDD import org.apache.spark.sql.Dataset import org.apache.spark.util.Utils /** - * A small wrapper that defines a training session for an estimator, and some methods to log - * useful information during this session. + * A small wrapper that defines a training session for an estimator, some methods to log + * useful information during this session, and some methods to send + * [[org.apache.spark.ml.MLEvent]]. */ -private[spark] class Instrumentation private () extends Logging { +private[spark] class Instrumentation private () extends Logging with MLEvents { private val id = UUID.randomUUID() private val shortId = id.toString.take(8) diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index fbc7be2..ce8f346 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -624,10 +624,17 @@ private[ml] object DefaultParamsReader { * Load a `Params` instance from the given path, and return it. * This assumes the instance implements [[MLReadable]]. */ - def loadParamsInstance[T](path: String, sc: SparkContext): T = { + def loadParamsInstance[T](path: String, sc: SparkContext): T = + loadParamsInstanceReader(path, sc).load(path) + + /** + * Load a `Params` instance reader from the given path, and return it. + * This assumes the instance implements [[MLReadable]]. + */ + def loadParamsInstanceReader[T](path: String, sc: SparkContext): MLReader[T] = { val metadata = DefaultParamsReader.loadMetadata(path, sc) val cls = Utils.classForName(metadata.className) - cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]].load(path) + cls.getMethod("read").invoke(null).asInstanceOf[MLReader[T]] } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala new file mode 100644 index 0000000..0a87328 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala @@ -0,0 +1,255 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.apache.hadoop.fs.Path +import org.mockito.ArgumentMatchers.{any, eq => meq} +import org.mockito.Mockito.when +import org.scalatest.BeforeAndAfterEach +import org.scalatest.concurrent.Eventually +import org.scalatest.mockito.MockitoSugar.mock + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.param.ParamMap +import org.apache.spark.ml.util.{DefaultParamsReader, DefaultParamsWriter, MLWriter} +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent} +import org.apache.spark.sql._ + + +class MLEventsSuite + extends SparkFunSuite with BeforeAndAfterEach with MLlibTestSparkContext with Eventually { + + private val events = mutable.ArrayBuffer.empty[MLEvent] + private val listener: SparkListener = new SparkListener { + override def onOtherEvent(event: SparkListenerEvent): Unit = event match { + case e: MLEvent => events.append(e) + case _ => + } + } + + override def beforeAll(): Unit = { + super.beforeAll() + spark.sparkContext.addSparkListener(listener) + } + + override def afterEach(): Unit = { + try { + events.clear() + } finally { + super.afterEach() + } + } + + override def afterAll(): Unit = { + try { + if (spark != null) { + spark.sparkContext.removeSparkListener(listener) + } + } finally { + super.afterAll() + } + } + + abstract class MyModel extends Model[MyModel] + + test("pipeline fit events") { + val estimator1 = mock[Estimator[MyModel]] + val model1 = mock[MyModel] + val transformer1 = mock[Transformer] + val estimator2 = mock[Estimator[MyModel]] + val model2 = mock[MyModel] + + when(estimator1.copy(any[ParamMap])).thenReturn(estimator1) + when(model1.copy(any[ParamMap])).thenReturn(model1) + when(transformer1.copy(any[ParamMap])).thenReturn(transformer1) + when(estimator2.copy(any[ParamMap])).thenReturn(estimator2) + when(model2.copy(any[ParamMap])).thenReturn(model2) + + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] + val dataset5 = mock[DataFrame] + + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + when(dataset5.toDF).thenReturn(dataset5) + + when(estimator1.fit(meq(dataset1))).thenReturn(model1) + when(model1.transform(meq(dataset1))).thenReturn(dataset2) + when(model1.parent).thenReturn(estimator1) + when(transformer1.transform(meq(dataset2))).thenReturn(dataset3) + when(estimator2.fit(meq(dataset3))).thenReturn(model2) + + val pipeline = new Pipeline() + .setStages(Array(estimator1, transformer1, estimator2)) + assert(events.isEmpty) + val pipelineModel = pipeline.fit(dataset1) + val expected = + FitStart(pipeline, dataset1) :: + FitStart(estimator1, dataset1) :: + FitEnd(estimator1, model1) :: + TransformStart(model1, dataset1) :: + TransformEnd(model1, dataset2) :: + TransformStart(transformer1, dataset2) :: + TransformEnd(transformer1, dataset3) :: + FitStart(estimator2, dataset3) :: + FitEnd(estimator2, model2) :: + FitEnd(pipeline, pipelineModel) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(events === expected) + } + } + + test("pipeline model transform events") { + val dataset1 = mock[DataFrame] + val dataset2 = mock[DataFrame] + val dataset3 = mock[DataFrame] + val dataset4 = mock[DataFrame] + when(dataset1.toDF).thenReturn(dataset1) + when(dataset2.toDF).thenReturn(dataset2) + when(dataset3.toDF).thenReturn(dataset3) + when(dataset4.toDF).thenReturn(dataset4) + + val transformer1 = mock[Transformer] + val model = mock[MyModel] + val transformer2 = mock[Transformer] + when(transformer1.transform(meq(dataset1))).thenReturn(dataset2) + when(model.transform(meq(dataset2))).thenReturn(dataset3) + when(transformer2.transform(meq(dataset3))).thenReturn(dataset4) + + val newPipelineModel = new PipelineModel( + "pipeline0", Array(transformer1, model, transformer2)) + assert(events.isEmpty) + val output = newPipelineModel.transform(dataset1) + val expected = + TransformStart(newPipelineModel, dataset1) :: + TransformStart(transformer1, dataset1) :: + TransformEnd(transformer1, dataset2) :: + TransformStart(model, dataset2) :: + TransformEnd(model, dataset3) :: + TransformStart(transformer2, dataset3) :: + TransformEnd(transformer2, dataset4) :: + TransformEnd(newPipelineModel, output) :: Nil + eventually(timeout(10 seconds), interval(1 second)) { + assert(events === expected) + } + } + + test("pipeline read/write events") { + def getInstance(w: MLWriter): AnyRef = + w.getClass.getDeclaredMethod("instance").invoke(w) + + withTempDir { dir => + val path = new Path(dir.getCanonicalPath, "pipeline").toUri.toString + val writableStage = new WritableStage("writableStage") + val newPipeline = new Pipeline().setStages(Array(writableStage)) + val pipelineWriter = newPipeline.write + assert(events.isEmpty) + pipelineWriter.save(path) + eventually(timeout(10 seconds), interval(1 second)) { + events.foreach { + case e: SaveInstanceStart if e.writer.isInstanceOf[DefaultParamsWriter] => + assert(e.path.endsWith("writableStage")) + case e: SaveInstanceEnd if e.writer.isInstanceOf[DefaultParamsWriter] => + assert(e.path.endsWith("writableStage")) + case e: SaveInstanceStart if getInstance(e.writer).isInstanceOf[Pipeline] => + assert(getInstance(e.writer).asInstanceOf[Pipeline].uid === newPipeline.uid) + case e: SaveInstanceEnd if getInstance(e.writer).isInstanceOf[Pipeline] => + assert(getInstance(e.writer).asInstanceOf[Pipeline].uid === newPipeline.uid) + case e => fail(s"Unexpected event thrown: $e") + } + } + + events.clear() + val pipelineReader = Pipeline.read + assert(events.isEmpty) + pipelineReader.load(path) + eventually(timeout(10 seconds), interval(1 second)) { + events.foreach { + case e: LoadInstanceStart[PipelineStage] + if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] => + assert(e.path.endsWith("writableStage")) + case e: LoadInstanceEnd[PipelineStage] + if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] => + assert(e.instance.isInstanceOf[PipelineStage]) + case e: LoadInstanceStart[Pipeline] => + assert(e.reader === pipelineReader) + case e: LoadInstanceEnd[Pipeline] => + assert(e.instance.uid === newPipeline.uid) + case e => fail(s"Unexpected event thrown: $e") + } + } + } + } + + test("pipeline model read/write events") { + def getInstance(w: MLWriter): AnyRef = + w.getClass.getDeclaredMethod("instance").invoke(w) + + withTempDir { dir => + val path = new Path(dir.getCanonicalPath, "pipeline").toUri.toString + val writableStage = new WritableStage("writableStage") + val pipelineModel = + new PipelineModel("pipeline_89329329", Array(writableStage.asInstanceOf[Transformer])) + val pipelineWriter = pipelineModel.write + assert(events.isEmpty) + pipelineWriter.save(path) + eventually(timeout(10 seconds), interval(1 second)) { + events.foreach { + case e: SaveInstanceStart if e.writer.isInstanceOf[DefaultParamsWriter] => + assert(e.path.endsWith("writableStage")) + case e: SaveInstanceEnd if e.writer.isInstanceOf[DefaultParamsWriter] => + assert(e.path.endsWith("writableStage")) + case e: SaveInstanceStart if getInstance(e.writer).isInstanceOf[PipelineModel] => + assert(getInstance(e.writer).asInstanceOf[PipelineModel].uid === pipelineModel.uid) + case e: SaveInstanceEnd if getInstance(e.writer).isInstanceOf[PipelineModel] => + assert(getInstance(e.writer).asInstanceOf[PipelineModel].uid === pipelineModel.uid) + case e => fail(s"Unexpected event thrown: $e") + } + } + + events.clear() + val pipelineModelReader = PipelineModel.read + assert(events.isEmpty) + pipelineModelReader.load(path) + eventually(timeout(10 seconds), interval(1 second)) { + events.foreach { + case e: LoadInstanceStart[PipelineStage] + if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] => + assert(e.path.endsWith("writableStage")) + case e: LoadInstanceEnd[PipelineStage] + if e.reader.isInstanceOf[DefaultParamsReader[PipelineStage]] => + assert(e.instance.isInstanceOf[PipelineStage]) + case e: LoadInstanceStart[PipelineModel] => + assert(e.reader === pipelineModelReader) + case e: LoadInstanceEnd[PipelineModel] => + assert(e.instance.uid === pipelineModel.uid) + case e => fail(s"Unexpected event thrown: $e") + } + } + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org