[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r176826923 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -86,7 +88,80 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.3.0 + */ +@InterfaceStability.Unstable +@Since("2.3.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.3.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], +stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. --- End diff -- done --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r176821909 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -86,7 +88,80 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.3.0 + */ +@InterfaceStability.Unstable +@Since("2.3.0") +trait MLWriterFormat { + /** + * Function to write the provided pipeline stage out. + * + * @param path The path to write the result out to. + * @param session SparkSession associated with the write request. + * @param optionMap User provided options stored as strings. + * @param stage The pipeline stage to be saved. + */ + @Since("2.3.0") + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], +stage: PipelineStage): Unit +} + +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a save call is made. --- End diff -- Add a comment about zero arg constructor requirement --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r176821060 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -86,7 +88,80 @@ private[util] sealed trait BaseReadWrite { } /** - * Abstract class for utility classes that can save ML instances. + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a save call is made. + * + * Must have a valid zero argument constructor which will be called to instantiate. + * + * @since 2.3.0 --- End diff -- Need to update since annotations to 2.4.0 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r161944495 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { --- End diff -- The follow up issue to track this is https://issues.apache.org/jira/browse/SPARK-11241 --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r161391241 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala --- @@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) +} +testPMMLWrite(sc, model, checkModel) + } + + test("unsupported export format") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +intercept[SparkException] { + model.write.format("boop").save("boop") +} +intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") +} +withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { + intercept[SparkException] { +model.write.format("org.apache.spark.SparkContext").save("boop2") + } +} + } + + test("dummy export format is called") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +withClue("Dummy writer doesn't write") { + intercept[Exception] { --- End diff -- good point --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r161390600 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { -shouldOverwrite = true + @Since("2.3.0") + def format(source: String): this.type = { +this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String) = { +val loader = Utils.getContextOrSparkClassLoader +val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) +val stageName = stage.getClass.getName +val targetName = s"${source}+${stageName}" +val formats = serviceLoader.asScala.toList +val shortNames = formats.map(_.shortName()) +val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => +Try(loader.loadClass(source)) match { + case Success(writer) => +// Found the ML writer using the fully qualified path +writer + case Failure(error) => +throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) +} + case head :: Nil => +head.getClass + case _ => +// Multiple sources +throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") +} +if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] --- End diff -- True, we have the same issue with the DataFormat provider though. I don't think there is a way around this while keeping the DF like interface that lets us be pluggable in the way folks want (but if there is a way to require a 0 argument constructor in the concrete class with a trait I'm interested). I think given the folks who we general expect to be writing these formats that reasonable, but I'll add a comment about this in the doc? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160503466 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) --- End diff -- since tags here --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160484001 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala --- @@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) +} +testPMMLWrite(sc, model, checkModel) + } + + test("unsupported export format") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +intercept[SparkException] { --- End diff -- Doesn't this and the one below it test the same thing? I think we could remove the first one. --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160461644 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { -shouldOverwrite = true + @Since("2.3.0") + def format(source: String): this.type = { +this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String) = { +val loader = Utils.getContextOrSparkClassLoader +val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) +val stageName = stage.getClass.getName +val targetName = s"${source}+${stageName}" --- End diff -- don't need brackets --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160503640 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def shortName(): String = + * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.3.0 + */ + def shortName(): String +} + +/** + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLWriterFormat { --- End diff -- do we need the actual since annotations here, though? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160496808 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { -shouldOverwrite = true + @Since("2.3.0") + def format(source: String): this.type = { +this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String) = { --- End diff -- return type --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160462794 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def shortName(): String = + * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" --- End diff -- what about making a second abstract field `def stageName(): String`, instead of having it packed into one string? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160502536 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def shortName(): String = + * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.3.0 + */ + def shortName(): String +} + +/** + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLWriterFormat { + /** + * Function write the provided pipeline stage out. --- End diff -- Should add a full doc here with param annotations. Also should it be "Function to write ..."? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160501723 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or + * the fully qualified class name for export). */ - @Since("1.6.0") - def overwrite(): this.type = { -shouldOverwrite = true + @Since("2.3.0") + def format(source: String): this.type = { +this.source = source this } + /** + * Dispatches the save to the correct MLFormat. + */ + @Since("2.3.0") + @throws[IOException]("If the input path already exists but overwrite is not enabled.") + @throws[SparkException]("If multiple sources for a given short name format are found.") + override protected def saveImpl(path: String) = { +val loader = Utils.getContextOrSparkClassLoader +val serviceLoader = ServiceLoader.load(classOf[MLFormatRegister], loader) +val stageName = stage.getClass.getName +val targetName = s"${source}+${stageName}" +val formats = serviceLoader.asScala.toList +val shortNames = formats.map(_.shortName()) +val writerCls = formats.filter(_.shortName().equalsIgnoreCase(targetName)) match { + // requested name did not match any given registered alias + case Nil => +Try(loader.loadClass(source)) match { + case Success(writer) => +// Found the ML writer using the fully qualified path +writer + case Failure(error) => +throw new SparkException( + s"Could not load requested format $source for $stageName ($targetName) had $formats" + + s"supporting $shortNames", error) +} + case head :: Nil => +head.getClass + case _ => +// Multiple sources +throw new SparkException( + s"Multiple writers found for $source+$stageName, try using the class name of the writer") +} +if (classOf[MLWriterFormat].isAssignableFrom(writerCls)) { + val writer = writerCls.newInstance().asInstanceOf[MLWriterFormat] --- End diff -- This will fail, non-intuitively, if anyone ever extends `MLWriterFormat` with a constructor that has more than zero arguments. Meaning: ```scala class DummyLinearRegressionWriter(someParam: Int) extends MLWriterFormat ``` will raise `java.lang.NoSuchMethodException: org.apache.spark.ml.regression.DummyLinearRegressionWriter.()` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160503322 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { --- End diff -- need `@Since("2.3.0")` here? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160471845 --- Diff: mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala --- @@ -710,15 +711,57 @@ class LinearRegressionModel private[ml] ( } /** - * Returns a [[org.apache.spark.ml.util.MLWriter]] instance for this ML instance. + * Returns a [[org.apache.spark.ml.util.GeneralMLWriter]] instance for this ML instance. * * For [[LinearRegressionModel]], this does NOT currently save the training [[summary]]. * An option to save [[summary]] may be added in the future. * * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) +} + +/** A writer for LinearRegression that handles the "internal" (or default) format */ +private class InternalLinearRegressionModelWriter() + extends MLWriterFormat with MLFormatRegister { + + override def shortName(): String = +"internal+org.apache.spark.ml.regression.LinearRegressionModel" + + private case class Data(intercept: Double, coefficients: Vector, scale: Double) + + override def write(path: String, sparkSession: SparkSession, +optionMap: mutable.Map[String, String], stage: PipelineStage): Unit = { +val instance = stage.asInstanceOf[LinearRegressionModel] +val sc = sparkSession.sparkContext +// Save metadata and Params +DefaultParamsWriter.saveMetadata(instance, path, sc) +// Save model data: intercept, coefficients, scale +val data = Data(instance.intercept, instance.coefficients, instance.scale) +val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } +} + +/** A writer for LinearRegression that handles the "pmml" format */ +private class PMMLLinearRegressionModelWriter() --- End diff -- I could be wrong, but I think we prefer just omitting the `()`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160463657 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -126,15 +180,69 @@ abstract class MLWriter extends BaseReadWrite with Logging { this } + // override for Java compatibility + override def session(sparkSession: SparkSession): this.type = super.session(sparkSession) + + // override for Java compatibility + override def context(sqlContext: SQLContext): this.type = super.session(sqlContext.sparkSession) +} + +/** + * A ML Writer which delegates based on the requested format. + */ +class GeneralMLWriter(stage: PipelineStage) extends MLWriter with Logging { + private var source: String = "internal" + /** - * Overwrites if the output path already exists. + * Specifies the format of ML export (e.g. PMML, internal, or --- End diff -- change to `e.g. "pmml", "internal", or the fully qualified class name for export)."` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160483562 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala --- @@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) +} +testPMMLWrite(sc, model, checkModel) + } + + test("unsupported export format") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +intercept[SparkException] { + model.write.format("boop").save("boop") +} +intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") +} +withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { + intercept[SparkException] { +model.write.format("org.apache.spark.SparkContext").save("boop2") + } +} + } + + test("dummy export format is called") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +withClue("Dummy writer doesn't write") { + intercept[Exception] { --- End diff -- this just catches any exception. Can we do something like ```scala val thrown = intercept[Exception] { model.write.format("org.apache.spark.ml.regression.DummyLinearRegressionWriter").save("") } assert(thrown.getMessage.contains("Dummy writer doesn't write.")) ``` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160461560 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. --- End diff -- Was this supposed to be retained from the `DataSourceRegister`? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160506592 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala --- @@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) +} +testPMMLWrite(sc, model, checkModel) + } + + test("unsupported export format") { +val lr = new LinearRegression() +val model = lr.fit(datasetWithWeight) +intercept[SparkException] { + model.write.format("boop").save("boop") +} +intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") +} +withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { + intercept[SparkException] { +model.write.format("org.apache.spark.SparkContext").save("boop2") + } +} + } + + test("dummy export format is called") { --- End diff -- We can also add tests for the `MLFormatRegister` similar to `DDLSourceLoadSuite`. Just add a `META-INF/services/` directory to `src/test/resources/` --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160463225 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: --- End diff -- "data source" -> "model format"? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r159018510 --- Diff: mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala --- @@ -85,12 +87,55 @@ private[util] sealed trait BaseReadWrite { protected final def sc: SparkContext = sparkSession.sparkContext } +/** + * ML export formats for should implement this trait so that users can specify a shortname rather + * than the fully qualified class name of the exporter. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLFormatRegister { + /** + * The string that represents the format that this data source provider uses. This is + * overridden by children to provide a nice alias for the data source. For example: + * + * {{{ + * override def shortName(): String = + * "pmml+org.apache.spark.ml.regression.LinearRegressionModel" + * }}} + * Indicates that this format is capable of saving Spark's own LinearRegressionModel in pmml. + * + * Format discovery is done using a ServiceLoader so make sure to list your format in + * META-INF/services. + * @since 2.3.0 + */ + def shortName(): String +} + +/** + * Implemented by objects that provide ML exportability. + * + * A new instance of this class will be instantiated each time a DDL call is made. + * + * @since 2.3.0 + */ +@InterfaceStability.Evolving +trait MLWriterFormat { + /** + * Function write the provided pipeline stage out. + */ + def write(path: String, session: SparkSession, optionMap: mutable.Map[String, String], +stage: PipelineStage): Unit +} + /** * Abstract class for utility classes that can save ML instances. */ +@deprecated("Use GeneralMLWriter instead. Will be removed in Spark 3.0.0", "2.3.0") --- End diff -- I'm debating if this should be deprecated in 2.4 and just have this as a new option in 2.3. What do you think @sethah / @MLnick ? --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org
[GitHub] spark pull request #19876: [ML][SPARK-11171][SPARK-11239] Add PMML export to...
Github user holdenk commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r158434850 --- Diff: mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala --- @@ -554,7 +555,49 @@ class LinearRegressionModel private[ml] ( * This also does not save the [[parent]] currently. */ @Since("1.6.0") - override def write: MLWriter = new LinearRegressionModel.LinearRegressionModelWriter(this) + override def write: GeneralMLWriter = new GeneralMLWriter(this) --- End diff -- fixed --- - To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org