Github user sethah commented on a diff in the pull request: https://github.com/apache/spark/pull/19876#discussion_r160506592 --- Diff: mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala --- @@ -1044,6 +1056,50 @@ class LinearRegressionSuite extends MLTest with DefaultReadWriteTest { LinearRegressionSuite.allParamSettings, checkModelData) } + test("pmml export") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + def checkModel(pmml: PMML): Unit = { + val dd = pmml.getDataDictionary + assert(dd.getNumberOfFields === 3) + val fields = dd.getDataFields.asScala + assert(fields(0).getName().toString === "field_0") + assert(fields(0).getOpType() == OpType.CONTINUOUS) + val pmmlRegressionModel = pmml.getModels().get(0).asInstanceOf[PMMLRegressionModel] + val pmmlPredictors = pmmlRegressionModel.getRegressionTables.get(0).getNumericPredictors + val pmmlWeights = pmmlPredictors.asScala.map(_.getCoefficient()).toList + assert(pmmlWeights(0) ~== model.coefficients(0) relTol 1E-3) + assert(pmmlWeights(1) ~== model.coefficients(1) relTol 1E-3) + } + testPMMLWrite(sc, model, checkModel) + } + + test("unsupported export format") { + val lr = new LinearRegression() + val model = lr.fit(datasetWithWeight) + intercept[SparkException] { + model.write.format("boop").save("boop") + } + intercept[SparkException] { + model.write.format("com.holdenkarau.boop").save("boop") + } + withClue("ML source org.apache.spark.SparkContext is not a valid MLWriterFormat") { + intercept[SparkException] { + model.write.format("org.apache.spark.SparkContext").save("boop2") + } + } + } + + test("dummy export format is called") { --- End diff -- We can also add tests for the `MLFormatRegister` similar to `DDLSourceLoadSuite`. Just add a `META-INF/services/` directory to `src/test/resources/`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org