Repository: spark Updated Branches: refs/heads/master 90d575421 -> f89808b0f
[SPARK-17499][SPARKR][ML][MLLIB] make the default params in sparkR spark.mlp consistent with MultilayerPerceptronClassifier ## What changes were proposed in this pull request? update `MultilayerPerceptronClassifierWrapper.fit` paramter type: `layers: Array[Int]` `seed: String` update several default params in sparkR `spark.mlp`: `tol` --> 1e-6 `stepSize` --> 0.03 `seed` --> NULL ( when seed == NULL, the scala-side wrapper regard it as a `null` value and the seed will use the default one ) r-side `seed` only support 32bit integer. remove `layers` default value, and move it in front of those parameters with default value. add `layers` parameter validation check. ## How was this patch tested? tests added. Author: WeichenXu <weichenxu...@outlook.com> Closes #15051 from WeichenXu123/update_py_mlp_default. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f89808b0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f89808b0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f89808b0 Branch: refs/heads/master Commit: f89808b0fdbc04e1bdff1489a6ec4c84ddb2adc4 Parents: 90d5754 Author: WeichenXu <weichenxu...@outlook.com> Authored: Fri Sep 23 11:14:22 2016 -0700 Committer: Felix Cheung <felixche...@apache.org> Committed: Fri Sep 23 11:14:22 2016 -0700 ---------------------------------------------------------------------- R/pkg/R/mllib.R | 13 ++++++++++--- R/pkg/inst/tests/testthat/test_mllib.R | 19 +++++++++++++++++++ .../MultilayerPerceptronClassifierWrapper.scala | 8 ++++---- 3 files changed, 33 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/f89808b0/R/pkg/R/mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R index 98db367..971c166 100644 --- a/R/pkg/R/mllib.R +++ b/R/pkg/R/mllib.R @@ -694,12 +694,19 @@ setMethod("predict", signature(object = "KMeansModel"), #' } #' @note spark.mlp since 2.1.0 setMethod("spark.mlp", signature(data = "SparkDataFrame"), - function(data, blockSize = 128, layers = c(3, 5, 2), solver = "l-bfgs", maxIter = 100, - tol = 0.5, stepSize = 1, seed = 1) { + function(data, layers, blockSize = 128, solver = "l-bfgs", maxIter = 100, + tol = 1E-6, stepSize = 0.03, seed = NULL) { + layers <- as.integer(na.omit(layers)) + if (length(layers) <= 1) { + stop ("layers must be a integer vector with length > 1.") + } + if (!is.null(seed)) { + seed <- as.character(as.integer(seed)) + } jobj <- callJStatic("org.apache.spark.ml.r.MultilayerPerceptronClassifierWrapper", "fit", data@sdf, as.integer(blockSize), as.array(layers), as.character(solver), as.integer(maxIter), as.numeric(tol), - as.numeric(stepSize), as.integer(seed)) + as.numeric(stepSize), seed) new("MultilayerPerceptronClassificationModel", jobj = jobj) }) http://git-wip-us.apache.org/repos/asf/spark/blob/f89808b0/R/pkg/inst/tests/testthat/test_mllib.R ---------------------------------------------------------------------- diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R index 24c40a8..a1eaaf2 100644 --- a/R/pkg/inst/tests/testthat/test_mllib.R +++ b/R/pkg/inst/tests/testthat/test_mllib.R @@ -391,6 +391,25 @@ test_that("spark.mlp", { unlink(modelPath) + # Test default parameter + model <- spark.mlp(df, layers = c(4, 5, 4, 3)) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 10), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 0)) + + # Test illegal parameter + expect_error(spark.mlp(df, layers = NULL), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, layers = c()), "layers must be a integer vector with length > 1.") + expect_error(spark.mlp(df, layers = c(3)), "layers must be a integer vector with length > 1.") + + # Test random seed + # default seed + model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 0, 1, 2, 2, 1, 2, 0, 1)) + # seed equals 10 + model <- spark.mlp(df, layers = c(4, 5, 4, 3), maxIter = 10, seed = 10) + mlpPredictions <- collect(select(predict(model, mlpTestDF), "prediction")) + expect_equal(head(mlpPredictions$prediction, 12), c(1, 1, 1, 1, 2, 1, 2, 2, 1, 0, 0, 1)) }) test_that("spark.naiveBayes", { http://git-wip-us.apache.org/repos/asf/spark/blob/f89808b0/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala index be51e74..1067300 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/r/MultilayerPerceptronClassifierWrapper.scala @@ -53,26 +53,26 @@ private[r] object MultilayerPerceptronClassifierWrapper def fit( data: DataFrame, blockSize: Int, - layers: Array[Double], + layers: Array[Int], solver: String, maxIter: Int, tol: Double, stepSize: Double, - seed: Int + seed: String ): MultilayerPerceptronClassifierWrapper = { // get labels and feature names from output schema val schema = data.schema // assemble and fit the pipeline val mlp = new MultilayerPerceptronClassifier() - .setLayers(layers.map(_.toInt)) + .setLayers(layers) .setBlockSize(blockSize) .setSolver(solver) .setMaxIter(maxIter) .setTol(tol) .setStepSize(stepSize) - .setSeed(seed) .setPredictionCol(PREDICTED_LABEL_COL) + if (seed != null && seed.length > 0) mlp.setSeed(seed.toInt) val pipeline = new Pipeline() .setStages(Array(mlp)) .fit(data) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org