This is an automated email from the ASF dual-hosted git repository.

gurwls223 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new dfb8809  [SPARK-26818][ML] Make MLEvents JSON ser/de safe
dfb8809 is described below

commit dfb880951a8de55c587c1bf8b696df50eae6e68a
Author: Hyukjin Kwon <gurwls...@apache.org>
AuthorDate: Sun Feb 3 21:19:35 2019 +0800

    [SPARK-26818][ML] Make MLEvents JSON ser/de safe
    
    ## What changes were proposed in this pull request?
    
    Currently, it looks it's not going to cause any virtually effective problem 
apparently (if I didn't misread the codes).
    
    I see one place that JSON formatted events are being used.
    
    
https://github.com/apache/spark/blob/ec506bd30c2ca324c12c9ec811764081c2eb8c42/core/src/main/scala/org/apache/spark/scheduler/EventLoggingListener.scala#L148
    
    It's okay because it just logs when the exception is ignorable
    
    
https://github.com/apache/spark/blob/9690eba16efe6d25261934d8b73a221972b684f3/core/src/main/scala/org/apache/spark/util/ListenerBus.scala#L111
    
    I guess it should be best to stay safe - I don't want this unstable 
experimental feature breaks anything in any case. It also disables `logEvent` 
in `SparkListenerEvent` for the same reason.
    
    This is also to match SQL execution events side:
    
    
https://github.com/apache/spark/blob/ca545f79410a464ef24e3986fac225f53bb2ef02/sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SQLListener.scala#L41-L57
    
    to make ML events JSON ser/de safe.
    
    ## How was this patch tested?
    
    Manually tested, and unit tests were added.
    
    Closes #23728 from HyukjinKwon/SPARK-26818.
    
    Authored-by: Hyukjin Kwon <gurwls...@apache.org>
    Signed-off-by: Hyukjin Kwon <gurwls...@apache.org>
---
 .../main/scala/org/apache/spark/ml/events.scala    |  81 +++++++++++----
 .../scala/org/apache/spark/ml/MLEventsSuite.scala  | 112 +++++++++++++++++----
 2 files changed, 155 insertions(+), 38 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/events.scala 
b/mllib/src/main/scala/org/apache/spark/ml/events.scala
index c51600f..dc4be4d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/events.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/events.scala
@@ -17,6 +17,8 @@
 
 package org.apache.spark.ml
 
+import com.fasterxml.jackson.annotation.JsonIgnore
+
 import org.apache.spark.SparkContext
 import org.apache.spark.annotation.Unstable
 import org.apache.spark.internal.Logging
@@ -29,53 +31,84 @@ import org.apache.spark.sql.{DataFrame, Dataset}
  * after each operation (the event should document this).
  *
  * @note This is supported via [[Pipeline]] and [[PipelineModel]].
+ * @note This is experimental and unstable. Do not use this unless you fully
+ *   understand what `Unstable` means.
  */
 @Unstable
-sealed trait MLEvent extends SparkListenerEvent
+sealed trait MLEvent extends SparkListenerEvent {
+  // Do not log ML events in event log. It should be revisited to see
+  // how it works with history server.
+  protected[spark] override def logEvent: Boolean = false
+}
 
 /**
  * Event fired before `Transformer.transform`.
  */
 @Unstable
-case class TransformStart(transformer: Transformer, input: Dataset[_]) extends 
MLEvent
+case class TransformStart() extends MLEvent {
+  @JsonIgnore var transformer: Transformer = _
+  @JsonIgnore var input: Dataset[_] = _
+}
+
 /**
  * Event fired after `Transformer.transform`.
  */
 @Unstable
-case class TransformEnd(transformer: Transformer, output: Dataset[_]) extends 
MLEvent
+case class TransformEnd() extends MLEvent {
+  @JsonIgnore var transformer: Transformer = _
+  @JsonIgnore var output: Dataset[_] = _
+}
 
 /**
  * Event fired before `Estimator.fit`.
  */
 @Unstable
-case class FitStart[M <: Model[M]](estimator: Estimator[M], dataset: 
Dataset[_]) extends MLEvent
+case class FitStart[M <: Model[M]]() extends MLEvent {
+  @JsonIgnore var estimator: Estimator[M] = _
+  @JsonIgnore var dataset: Dataset[_] = _
+}
+
 /**
  * Event fired after `Estimator.fit`.
  */
 @Unstable
-case class FitEnd[M <: Model[M]](estimator: Estimator[M], model: M) extends 
MLEvent
+case class FitEnd[M <: Model[M]]() extends MLEvent {
+  @JsonIgnore var estimator: Estimator[M] = _
+  @JsonIgnore var model: M = _
+}
 
 /**
  * Event fired before `MLReader.load`.
  */
 @Unstable
-case class LoadInstanceStart[T](reader: MLReader[T], path: String) extends 
MLEvent
+case class LoadInstanceStart[T](path: String) extends MLEvent {
+  @JsonIgnore var reader: MLReader[T] = _
+}
+
 /**
  * Event fired after `MLReader.load`.
  */
 @Unstable
-case class LoadInstanceEnd[T](reader: MLReader[T], instance: T) extends MLEvent
+case class LoadInstanceEnd[T]() extends MLEvent {
+  @JsonIgnore var reader: MLReader[T] = _
+  @JsonIgnore var instance: T = _
+}
 
 /**
  * Event fired before `MLWriter.save`.
  */
 @Unstable
-case class SaveInstanceStart(writer: MLWriter, path: String) extends MLEvent
+case class SaveInstanceStart(path: String) extends MLEvent {
+  @JsonIgnore var writer: MLWriter = _
+}
+
 /**
  * Event fired after `MLWriter.save`.
  */
 @Unstable
-case class SaveInstanceEnd(writer: MLWriter, path: String) extends MLEvent
+case class SaveInstanceEnd(path: String) extends MLEvent {
+  @JsonIgnore var writer: MLWriter = _
+}
 
 /**
  * A small trait that defines some methods to send 
[[org.apache.spark.ml.MLEvent]].
@@ -91,11 +124,15 @@ private[ml] trait MLEvents extends Logging {
 
   def withFitEvent[M <: Model[M]](
       estimator: Estimator[M], dataset: Dataset[_])(func: => M): M = {
-    val startEvent = FitStart(estimator, dataset)
+    val startEvent = FitStart[M]()
+    startEvent.estimator = estimator
+    startEvent.dataset = dataset
     logEvent(startEvent)
     listenerBus.post(startEvent)
     val model: M = func
-    val endEvent = FitEnd(estimator, model)
+    val endEvent = FitEnd[M]()
+    endEvent.estimator = estimator
+    endEvent.model = model
     logEvent(endEvent)
     listenerBus.post(endEvent)
     model
@@ -103,34 +140,42 @@ private[ml] trait MLEvents extends Logging {
 
   def withTransformEvent(
       transformer: Transformer, input: Dataset[_])(func: => DataFrame): 
DataFrame = {
-    val startEvent = TransformStart(transformer, input)
+    val startEvent = TransformStart()
+    startEvent.transformer = transformer
+    startEvent.input = input
     logEvent(startEvent)
     listenerBus.post(startEvent)
     val output: DataFrame = func
-    val endEvent = TransformEnd(transformer, output)
+    val endEvent = TransformEnd()
+    endEvent.transformer = transformer
+    endEvent.output = output
     logEvent(endEvent)
     listenerBus.post(endEvent)
     output
   }
 
   def withLoadInstanceEvent[T](reader: MLReader[T], path: String)(func: => T): 
T = {
-    val startEvent = LoadInstanceStart(reader, path)
+    val startEvent = LoadInstanceStart[T](path)
+    startEvent.reader = reader
     logEvent(startEvent)
     listenerBus.post(startEvent)
     val instance: T = func
-    val endEvent = LoadInstanceEnd(reader, instance)
+    val endEvent = LoadInstanceEnd[T]()
+    endEvent.reader = reader
+    endEvent.instance = instance
     logEvent(endEvent)
     listenerBus.post(endEvent)
     instance
   }
 
   def withSaveInstanceEvent(writer: MLWriter, path: String)(func: => Unit): 
Unit = {
-    listenerBus.post(SaveInstanceEnd(writer, path))
-    val startEvent = SaveInstanceStart(writer, path)
+    val startEvent = SaveInstanceStart(path)
+    startEvent.writer = writer
     logEvent(startEvent)
     listenerBus.post(startEvent)
     func
-    val endEvent = SaveInstanceEnd(writer, path)
+    val endEvent = SaveInstanceEnd(path)
+    endEvent.writer = writer
     logEvent(endEvent)
     listenerBus.post(endEvent)
   }
diff --git a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
index 0a87328..80ae0c7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/MLEventsSuite.scala
@@ -34,6 +34,7 @@ import org.apache.spark.ml.util.{DefaultParamsReader, 
DefaultParamsWriter, MLWri
 import org.apache.spark.mllib.util.MLlibTestSparkContext
 import org.apache.spark.scheduler.{SparkListener, SparkListenerEvent}
 import org.apache.spark.sql._
+import org.apache.spark.util.JsonProtocol
 
 
 class MLEventsSuite
@@ -107,20 +108,48 @@ class MLEventsSuite
       .setStages(Array(estimator1, transformer1, estimator2))
     assert(events.isEmpty)
     val pipelineModel = pipeline.fit(dataset1)
-    val expected =
-      FitStart(pipeline, dataset1) ::
-      FitStart(estimator1, dataset1) ::
-      FitEnd(estimator1, model1) ::
-      TransformStart(model1, dataset1) ::
-      TransformEnd(model1, dataset2) ::
-      TransformStart(transformer1, dataset2) ::
-      TransformEnd(transformer1, dataset3) ::
-      FitStart(estimator2, dataset3) ::
-      FitEnd(estimator2, model2) ::
-      FitEnd(pipeline, pipelineModel) :: Nil
+
+    val event0 = FitStart[PipelineModel]()
+    event0.estimator = pipeline
+    event0.dataset = dataset1
+    val event1 = FitStart[MyModel]()
+    event1.estimator = estimator1
+    event1.dataset = dataset1
+    val event2 = FitEnd[MyModel]()
+    event2.estimator = estimator1
+    event2.model = model1
+    val event3 = TransformStart()
+    event3.transformer = model1
+    event3.input = dataset1
+    val event4 = TransformEnd()
+    event4.transformer = model1
+    event4.output = dataset2
+    val event5 = TransformStart()
+    event5.transformer = transformer1
+    event5.input = dataset2
+    val event6 = TransformEnd()
+    event6.transformer = transformer1
+    event6.output = dataset3
+    val event7 = FitStart[MyModel]()
+    event7.estimator = estimator2
+    event7.dataset = dataset3
+    val event8 = FitEnd[MyModel]()
+    event8.estimator = estimator2
+    event8.model = model2
+    val event9 = FitEnd[PipelineModel]()
+    event9.estimator = pipeline
+    event9.model = pipelineModel
+
+    val expected = Seq(
+      event0, event1, event2, event3, event4, event5, event6, event7, event8, 
event9)
     eventually(timeout(10 seconds), interval(1 second)) {
       assert(events === expected)
     }
+    // Test if they can be ser/de via JSON protocol.
+    assert(events.nonEmpty)
+    events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+      assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+    }
   }
 
   test("pipeline model transform events") {
@@ -144,18 +173,41 @@ class MLEventsSuite
       "pipeline0", Array(transformer1, model, transformer2))
     assert(events.isEmpty)
     val output = newPipelineModel.transform(dataset1)
-    val expected =
-      TransformStart(newPipelineModel, dataset1) ::
-      TransformStart(transformer1, dataset1) ::
-      TransformEnd(transformer1, dataset2) ::
-      TransformStart(model, dataset2) ::
-      TransformEnd(model, dataset3) ::
-      TransformStart(transformer2, dataset3) ::
-      TransformEnd(transformer2, dataset4) ::
-      TransformEnd(newPipelineModel, output) :: Nil
+
+    val event0 = TransformStart()
+    event0.transformer = newPipelineModel
+    event0.input = dataset1
+    val event1 = TransformStart()
+    event1.transformer = transformer1
+    event1.input = dataset1
+    val event2 = TransformEnd()
+    event2.transformer = transformer1
+    event2.output = dataset2
+    val event3 = TransformStart()
+    event3.transformer = model
+    event3.input = dataset2
+    val event4 = TransformEnd()
+    event4.transformer = model
+    event4.output = dataset3
+    val event5 = TransformStart()
+    event5.transformer = transformer2
+    event5.input = dataset3
+    val event6 = TransformEnd()
+    event6.transformer = transformer2
+    event6.output = dataset4
+    val event7 = TransformEnd()
+    event7.transformer = newPipelineModel
+    event7.output = output
+
+    val expected = Seq(event0, event1, event2, event3, event4, event5, event6, 
event7)
     eventually(timeout(10 seconds), interval(1 second)) {
       assert(events === expected)
     }
+    // Test if they can be ser/de via JSON protocol.
+    assert(events.nonEmpty)
+    events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+      assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+    }
   }
 
   test("pipeline read/write events") {
@@ -182,6 +234,11 @@ class MLEventsSuite
           case e => fail(s"Unexpected event thrown: $e")
         }
       }
+      // Test if they can be ser/de via JSON protocol.
+      assert(events.nonEmpty)
+      events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+        assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+      }
 
       events.clear()
       val pipelineReader = Pipeline.read
@@ -202,6 +259,11 @@ class MLEventsSuite
           case e => fail(s"Unexpected event thrown: $e")
         }
       }
+      // Test if they can be ser/de via JSON protocol.
+      assert(events.nonEmpty)
+      events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+        assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+      }
     }
   }
 
@@ -230,6 +292,11 @@ class MLEventsSuite
           case e => fail(s"Unexpected event thrown: $e")
         }
       }
+      // Test if they can be ser/de via JSON protocol.
+      assert(events.nonEmpty)
+      events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+        assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+      }
 
       events.clear()
       val pipelineModelReader = PipelineModel.read
@@ -250,6 +317,11 @@ class MLEventsSuite
           case e => fail(s"Unexpected event thrown: $e")
         }
       }
+      // Test if they can be ser/de via JSON protocol.
+      assert(events.nonEmpty)
+      events.map(JsonProtocol.sparkEventToJson).foreach { event =>
+        assert(JsonProtocol.sparkEventFromJson(event).isInstanceOf[MLEvent])
+      }
     }
   }
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to