[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-20 Thread holdenk
Github user holdenk commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r158126871
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
+val loader = Utils.getContextOrSparkClassLoader
+val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], 
loader)
+val stageName = stage.getClass.getName
+val targetName = s"${source}+${stageName}"
+val formats = serviceLoader.asScala.toList
+val shortNames = formats.map(_.shortName())
+val writerCls = 
formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+  // requested name did not match any given registered alias
+  case Nil =>
+Try(loader.loadClass(source)) match {
+  case Success(writer) =>
+// Found the ML writer using the fully qualified path
+writer
+  case Failure(error) =>
+throw new SparkException(
+  s"Could not load requested format $source for $stageName 
($targetName) had $formats" +
+  s"supporting $shortNames", error)
+}
+  case head :: Nil =>
+head.getClass
+  case _ =>
+// Multiple sources
+throw new SparkException(
+  s"Multiple writers found for $source+$stageName, try using the 
class name of the writer")
+}
+if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+  val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+  writer.write(path, sparkSession, optionMap, stage)
+} else {
+  throw new SparkException("ML source $source is not a valid 
MLWriterFormat")
--- End diff --

Good catch, I've added a test for this error message.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-20 Thread holdenk
Github user holdenk commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r158114844
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
--- End diff --

So I don't think that belongs in the base GeneralMLWriter, but we could 
make a trait for writers which support PMML to mix in?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-20 Thread holdenk
Github user holdenk commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r158113236
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -994,6 +998,38 @@ class LinearRegressionSuite
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
--- End diff --

Sure, I'll put a dummy writer in test so it doesn't clog up our class space.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156389238
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
--- End diff --

Perhaps for another PR, but maybe we could add a method here:

```scala
  def pmml(path: String): Unit = {
this.source = "pmml"
save(path)
  }
```


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r155157785
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with 
Logging {
 this
   }
 
+  // override for Java compatibility
+  override def session(sparkSession: SparkSession): this.type = 
super.session(sparkSession)
+
+  // override for Java compatibility
+  override def context(sqlContext: SQLContext): this.type = 
super.session(sqlContext.sparkSession)
+}
+
+/**
+ * A ML Writer which delegates based on the requested format.
+ */
+class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging {
+  private var source: String = "internal"
+
   /**
-   * Overwrites if the output path already exists.
+   * Specifies the format of ML export (e.g. PMML, internal, or
+   * the fully qualified class name for export).
*/
-  @Since("1.6.0")
-  def overwrite(): this.type = {
-shouldOverwrite = true
+  @Since("2.3.0")
+  def format(source: String): this.type = {
+this.source = source
 this
   }
 
+  /**
+   * Dispatches the save to the correct MLFormat.
+   */
+  @Since("2.3.0")
+  @throws[IOException]("If the input path already exists but overwrite is 
not enabled.")
+  @throws[SparkException]("If multiple sources for a given short name 
format are found.")
+  override protected def saveImpl(path: String) = {
+val loader = Utils.getContextOrSparkClassLoader
+val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], 
loader)
+val stageName = stage.getClass.getName
+val targetName = s"${source}+${stageName}"
+val formats = serviceLoader.asScala.toList
+val shortNames = formats.map(_.shortName())
+val writerCls = 
formats.filter(_.shortName().equalsIgnoreCase(targetName)) match {
+  // requested name did not match any given registered alias
+  case Nil =>
+Try(loader.loadClass(source)) match {
+  case Success(writer) =>
+// Found the ML writer using the fully qualified path
+writer
+  case Failure(error) =>
+throw new SparkException(
+  s"Could not load requested format $source for $stageName 
($targetName) had $formats" +
+  s"supporting $shortNames", error)
+}
+  case head :: Nil =>
+head.getClass
+  case _ =>
+// Multiple sources
+throw new SparkException(
+  s"Multiple writers found for $source+$stageName, try using the 
class name of the writer")
+}
+if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) {
+  val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat]
+  writer.write(path, sparkSession, optionMap, stage)
+} else {
+  throw new SparkException("ML source $source is not a valid 
MLWriterFormat")
--- End diff --

nit: need string interpolation here


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156370588
  
--- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---
@@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite {
   protected final def sc: SparkContext = sparkSession.sparkContext
 }
 
+/**
+ * ML export formats for should implement this trait so that users can 
specify a shortname rather
+ * than the fully qualified class name of the exporter.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLFormatRegister {
+  /**
+   * The string that represents the format that this data source provider 
uses. This is
+   * overridden by children to provide a nice alias for the data source. 
For example:
+   *
+   * {{{
+   *   override def shortName(): String =
+   *   "pmml+org.apache.spark.ml.regression.LinearRegressionModel"
+   * }}}
+   * Indicates that this format is capable of saving Spark's own 
LinearRegressionModel in pmml.
+   *
+   * Format discovery is done using a ServiceLoader so make sure to list 
your format in
+   * META-INF/services.
+   * @since 2.3.0
+   */
+  def shortName(): String
+}
+
+/**
+ * Implemented by objects that provide ML exportability.
+ *
+ * A new instance of this class will be instantiated each time a DDL call 
is made.
+ *
+ * @since 2.3.0
+ */
+@InterfaceStability.Evolving
+trait MLWriterFormat {
+  /**
+   * Function write the provided pipeline stage out.
+   */
+  def write(path: String, session: SparkSession, optionMap: 
mutable.Map[String, String],
+stage: PipelineStage)
--- End diff --

return type?


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156388871
  
--- Diff: 
mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala 
---
@@ -994,6 +998,38 @@ class LinearRegressionSuite
   LinearRegressionSuite.allParamSettings, checkModelData)
   }
 
+  test("pmml export") {
+val lr = new LinearRegression()
+val model = lr.fit(datasetWithWeight)
+def checkModel(pmml: PMML): Unit = {
+  val dd = pmml.getDataDictionary
+  assert(dd.getNumberOfFields === 3)
+  val fields = dd.getDataFields.asScala
+  assert(fields(0).getName().toString === "field_0")
+  assert(fields(0).getOpType() == OpType.CONTINUOUS)
+  val pmmlRegressionModel = 
pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel]
+  val pmmlPredictors = 
pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors
+  val pmmlWeights = 
pmmlPredictors.asScala.map(_.getCoefficient()).toList
+  assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3)
+  assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3)
+}
+testPMMLWrite(sc, model, checkModel)
+  }
+
+  test("unsupported export format") {
--- End diff --

Would be great to have a test that verifies that this works with third 
party implementations. Specifically, that something like 
`model.write.format("org.apache.spark.ml.MyDummyWriter").save(path)` works.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org



[GitHub] spark pull request #19876: [WIP][ML][SPARK-11171][SPARK-11239] Add PMML expo...

2017-12-12 Thread sethah
Github user sethah commented on a diff in the pull request:

https://github.com/apache/spark/pull/19876#discussion_r156381361
  
--- Diff: 
mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---
@@ -554,7 +555,49 @@ class LinearRegressionModel private[ml] (
* This also does not save the [[parent]] currently.
*/
   @Since("1.6.0")
-  override def write: MLWriter = new 
LinearRegressionModel.LinearRegressionModelWriter(this)
+  override def write: GeneralMLWriter = new GeneralMLWriter(this)
--- End diff --

The doc above this is wrong.


---

-
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org