Repository: spark Updated Branches: refs/heads/master cdf9e9753 -> a6428292f
[SPARK-14931][ML][PYTHON] Mismatched default values between pipelines in Spark and PySpark - update ## What changes were proposed in this pull request? This PR is an update for [https://github.com/apache/spark/pull/12738] which: * Adds a generic unit test for JavaParams wrappers in pyspark.ml for checking default Param values vs. the defaults in the Scala side * Various fixes for bugs found * This includes changing classes taking weightCol to treat unset and empty String Param values the same way. Defaults changed: * Scala * LogisticRegression: weightCol defaults to not set (instead of empty string) * StringIndexer: labels default to not set (instead of empty array) * GeneralizedLinearRegression: * maxIter always defaults to 25 (simpler than defaulting to 25 for a particular solver) * weightCol defaults to not set (instead of empty string) * LinearRegression: weightCol defaults to not set (instead of empty string) * Python * MultilayerPerceptron: layers default to not set (instead of [1,1]) * ChiSqSelector: numTopFeatures defaults to 50 (instead of not set) ## How was this patch tested? Generic unit test. Manually tested that unit test by changing defaults and verifying that broke the test. Author: Joseph K. Bradley <jos...@databricks.com> Author: yinxusen <yinxu...@gmail.com> Closes #12816 from jkbradley/yinxusen-SPARK-14931. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/a6428292 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/a6428292 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/a6428292 Branch: refs/heads/master Commit: a6428292f78fd594f41a4a7bf254d40268f46305 Parents: cdf9e97 Author: Xusen Yin <yinxu...@gmail.com> Authored: Sun May 1 12:29:01 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Sun May 1 12:29:01 2016 -0700 ---------------------------------------------------------------------- .../ml/classification/LogisticRegression.scala | 7 ++- .../apache/spark/ml/feature/StringIndexer.scala | 5 +- .../GeneralizedLinearRegression.scala | 31 +++++++------ .../spark/ml/regression/LinearRegression.scala | 15 +++--- .../LogisticRegressionSuite.scala | 2 +- .../GeneralizedLinearRegressionSuite.scala | 2 +- python/pyspark/ml/classification.py | 13 ++---- python/pyspark/ml/feature.py | 1 + python/pyspark/ml/regression.py | 9 ++-- python/pyspark/ml/tests.py | 48 ++++++++++++++++++++ python/pyspark/ml/wrapper.py | 3 +- 11 files changed, 96 insertions(+), 40 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 717e93c..d2d4e24 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -235,13 +235,12 @@ class LogisticRegression @Since("1.2.0") ( /** * Whether to over-/under-sample training instances according to the given weights in weightCol. - * If empty, all instances are treated equally (weight 1.0). - * Default is empty, so all instances have weight one. + * If not set or empty String, all instances are treated equally (weight 1.0). + * Default is not set, so all instances have weight one. * @group setParam */ @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "") @Since("1.5.0") override def setThresholds(value: Array[Double]): this.type = super.setThresholds(value) @@ -264,7 +263,7 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train(dataset: Dataset[_], handlePersistence: Boolean): LogisticRegressionModel = { - val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 7e0d374..cc0571f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -263,13 +263,12 @@ class IndexToString private[ml] (override val uid: String) /** * Optional param for array of labels specifying index-string mapping. * - * Default: Empty array, in which case [[inputCol]] metadata is used for labels. + * Default: Not specified, in which case [[inputCol]] metadata is used for labels. * @group param */ final val labels: StringArrayParam = new StringArrayParam(this, "labels", "Optional array of labels specifying index-string mapping." + " If not provided or if empty, then metadata from inputCol is used instead.") - setDefault(labels, Array.empty[String]) /** @group getParam */ final def getLabels: Array[String] = $(labels) @@ -292,7 +291,7 @@ class IndexToString private[ml] (override val uid: String) override def transform(dataset: Dataset[_]): DataFrame = { val inputColSchema = dataset.schema($(inputCol)) // If the labels array is empty use column metadata - val values = if ($(labels).isEmpty) { + val values = if (!isDefined(labels) || $(labels).isEmpty) { Attribute.fromStructField(inputColSchema) .asInstanceOf[NominalAttribute].values.get } else { http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala index bf9d3ff..c294ef3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala @@ -31,7 +31,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.linalg.{BLAS, Vector} import org.apache.spark.rdd.RDD -import org.apache.spark.sql.{DataFrame, Dataset, Row} +import org.apache.spark.sql.{Column, DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types.{DataType, DoubleType, StructType} @@ -101,9 +101,6 @@ private[regression] trait GeneralizedLinearRegressionBase extends PredictorParam schema: StructType, fitting: Boolean, featuresDataType: DataType): StructType = { - if ($(solver) == "irls") { - setDefault(maxIter -> 25) - } if (isDefined(link)) { require(supportedFamilyAndLinkPairs.contains( Family.fromName($(family)) -> Link.fromName($(link))), "Generalized Linear Regression " + @@ -171,13 +168,14 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) /** - * Sets the maximum number of iterations. - * Default is 25 if the solver algorithm is "irls". + * Sets the maximum number of iterations (applicable for solver "irls"). + * Default is 25. * * @group setParam */ @Since("2.0.0") def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 25) /** * Sets the convergence tolerance of iterations. @@ -213,7 +211,6 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val */ @Since("2.0.0") def setWeightCol(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "") /** * Sets the solver algorithm used for optimization. @@ -252,7 +249,7 @@ class GeneralizedLinearRegression @Since("2.0.0") (@Since("2.0.0") override val throw new SparkException(msg) } - val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) val instances: RDD[Instance] = dataset.select(col($(labelCol)).cast(DoubleType), w, col($(featuresCol))).rdd.map { case Row(label: Double, weight: Double, features: Vector) => @@ -912,19 +909,27 @@ class GeneralizedLinearRegressionSummary private[regression] ( numInstances } + private def weightCol: Column = { + if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { + lit(1.0) + } else { + col(model.getWeightCol) + } + } + private[regression] lazy val devianceResiduals: DataFrame = { val drUDF = udf { (y: Double, mu: Double, weight: Double) => val r = math.sqrt(math.max(family.deviance(y, mu, weight), 0.0)) if (y > mu) r else -1.0 * r } - val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val w = weightCol predictions.select( drUDF(col(model.getLabelCol), col(predictionCol), w).as("devianceResiduals")) } private[regression] lazy val pearsonResiduals: DataFrame = { val prUDF = udf { mu: Double => family.variance(mu) } - val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val w = weightCol predictions.select(col(model.getLabelCol).minus(col(predictionCol)) .multiply(sqrt(w)).divide(sqrt(prUDF(col(predictionCol)))).as("pearsonResiduals")) } @@ -967,7 +972,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val nullDeviance: Double = { - val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val w = weightCol val wtdmu: Double = if (model.getFitIntercept) { val agg = predictions.agg(sum(w.multiply(col(model.getLabelCol))), sum(w)).first() agg.getDouble(0) / agg.getDouble(1) @@ -985,7 +990,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( */ @Since("2.0.0") lazy val deviance: Double = { - val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val w = weightCol predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => family.deviance(label, pred, weight) @@ -1010,7 +1015,7 @@ class GeneralizedLinearRegressionSummary private[regression] ( /** Akaike's "An Information Criterion"(AIC) for the fitted model. */ @Since("2.0.0") lazy val aic: Double = { - val w = if (model.getWeightCol.isEmpty) lit(1.0) else col(model.getWeightCol) + val w = weightCol val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0) val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map { case Row(label: Double, pred: Double, weight: Double) => http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index 5117ee1..d13b15f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -136,13 +136,12 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String /** * Whether to over-/under-sample training instances according to the given weights in weightCol. - * If empty, all instances are treated equally (weight 1.0). - * Default is empty, so all instances have weight one. + * If not set or empty, all instances are treated equally (weight 1.0). + * Default is not set, so all instances have weight one. * @group setParam */ @Since("1.6.0") def setWeightCol(value: String): this.type = set(weightCol, value) - setDefault(weightCol -> "") /** * Set the solver algorithm used for optimization. @@ -163,7 +162,7 @@ class LinearRegression @Since("1.3.0") (@Since("1.3.0") override val uid: String val numFeatures = dataset.select(col($(featuresCol))).limit(1).rdd.map { case Row(features: Vector) => features.size }.first() - val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) if (($(solver) == "auto" && $(elasticNetParam) == 0.0 && numFeatures <= WeightedLeastSquares.MAX_NUM_FEATURES) || $(solver) == "normal") { @@ -643,7 +642,11 @@ class LinearRegressionSummary private[regression] ( * the square root of the instance weights. */ lazy val devianceResiduals: Array[Double] = { - val weighted = if (model.getWeightCol.isEmpty) lit(1.0) else sqrt(col(model.getWeightCol)) + val weighted = if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { + lit(1.0) + } else { + sqrt(col(model.getWeightCol)) + } val dr = predictions.select(col(model.getLabelCol).minus(col(model.getPredictionCol)) .multiply(weighted).as("weightedResiduals")) .select(min(col("weightedResiduals")).as("min"), max(col("weightedResiduals")).as("max")) @@ -665,7 +668,7 @@ class LinearRegressionSummary private[regression] ( throw new UnsupportedOperationException( "No Std. Error of coefficients available for this LinearRegressionModel") } else { - val rss = if (model.getWeightCol.isEmpty) { + val rss = if (!model.isDefined(model.weightCol) || model.getWeightCol.isEmpty) { meanSquaredError * numInstances } else { val t = udf { (pred: Double, label: Double, weight: Double) => http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index 48db428..73e961d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -81,7 +81,7 @@ class LogisticRegressionSuite assert(lr.getPredictionCol === "prediction") assert(lr.getRawPredictionCol === "rawPrediction") assert(lr.getProbabilityCol === "probability") - assert(lr.getWeightCol === "") + assert(!lr.isDefined(lr.weightCol)) assert(lr.getFitIntercept) assert(lr.getStandardization) val model = lr.fit(dataset) http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala index e4c9a3b..b854be2 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala @@ -180,7 +180,7 @@ class GeneralizedLinearRegressionSuite assert(glr.getPredictionCol === "prediction") assert(glr.getFitIntercept) assert(glr.getTol === 1E-6) - assert(glr.getWeightCol === "") + assert(!glr.isDefined(glr.weightCol)) assert(glr.getRegParam === 0.0) assert(glr.getSolver == "irls") // TODO: Construct model directly instead of via fitting. http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f616c7f..4331f73 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1056,7 +1056,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " + "E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " + - "neurons and output layer of 10 neurons, default is [1, 1].", + "neurons and output layer of 10 neurons.", typeConverter=TypeConverters.toListInt) blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " + "matrices. Data is stacked within partitions. If block size is more than " + @@ -1069,12 +1069,12 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128) """ super(MultilayerPerceptronClassifier, self).__init__() self._java_obj = self._new_java_obj( "org.apache.spark.ml.classification.MultilayerPerceptronClassifier", self.uid) - self._setDefault(maxIter=100, tol=1E-4, layers=[1, 1], blockSize=128) + self._setDefault(maxIter=100, tol=1E-4, blockSize=128) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1084,14 +1084,11 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - maxIter=100, tol=1e-4, seed=None, layers=[1, 1], blockSize=128) + maxIter=100, tol=1e-4, seed=None, layers=None, blockSize=128) Sets params for MultilayerPerceptronClassifier. """ kwargs = self.setParams._input_kwargs - if layers is None: - return self._set(**kwargs).setLayers([1, 1]) - else: - return self._set(**kwargs) + return self._set(**kwargs) def _create_model(self, java_model): return MultilayerPerceptronClassificationModel(java_model) http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1b059a7..b95d288 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2617,6 +2617,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja """ super(ChiSqSelector, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid) + self._setDefault(numTopFeatures=50) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/python/pyspark/ml/regression.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index d490953..0f08f9b 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -1080,7 +1080,8 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=None, quantilesCol=None): + quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), + quantilesCol=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ @@ -1091,7 +1092,8 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi self._java_obj = self._new_java_obj( "org.apache.spark.ml.regression.AFTSurvivalRegression", self.uid) self._setDefault(censorCol="censor", - quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]) + quantileProbabilities=[0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99], + maxIter=100, tol=1E-6) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -1099,7 +1101,8 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi @since("1.6.0") def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", - quantileProbabilities=None, quantilesCol=None): + quantileProbabilities=list([0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]), + quantilesCol=None): """ setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ fitIntercept=True, maxIter=100, tol=1E-6, censorCol="censor", \ http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index d5dd6d4..78ec96a 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -41,6 +41,7 @@ else: from shutil import rmtree import tempfile import numpy as np +import inspect from pyspark import keyword_only from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer @@ -54,6 +55,7 @@ from pyspark.ml.recommendation import ALS from pyspark.ml.regression import LinearRegression, DecisionTreeRegressor from pyspark.ml.tuning import * from pyspark.ml.wrapper import JavaParams +from pyspark.mllib.common import _java2py from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -1026,6 +1028,52 @@ class ALSTest(PySparkTestCase): self.assertEqual(als._java_obj.getFinalStorageLevel(), "DISK_ONLY") +class DefaultValuesTests(PySparkTestCase): + """ + Test :py:class:`JavaParams` classes to see if their default Param values match + those in their Scala counterparts. + """ + + def check_params(self, py_stage): + if not hasattr(py_stage, "_to_java"): + return + java_stage = py_stage._to_java() + if java_stage is None: + return + for p in py_stage.params: + java_param = java_stage.getParam(p.name) + py_has_default = py_stage.hasDefault(p) + java_has_default = java_stage.hasDefault(java_param) + self.assertEqual(py_has_default, java_has_default, + "Default value mismatch of param %s for Params %s" + % (p.name, str(py_stage))) + if py_has_default: + if p.name == "seed": + return # Random seeds between Spark and PySpark are different + java_default =\ + _java2py(self.sc, java_stage.clear(java_param).getOrDefault(java_param)) + py_stage._clear(p) + py_default = py_stage.getOrDefault(p) + self.assertEqual(java_default, py_default, + "Java default %s != python default %s of param %s for Params %s" + % (str(java_default), str(py_default), p.name, str(py_stage))) + + def test_java_params(self): + import pyspark.ml.feature + import pyspark.ml.classification + import pyspark.ml.clustering + import pyspark.ml.pipeline + import pyspark.ml.recommendation + import pyspark.ml.regression + modules = [pyspark.ml.feature, pyspark.ml.classification, pyspark.ml.clustering, + pyspark.ml.pipeline, pyspark.ml.recommendation, pyspark.ml.regression] + for module in modules: + for name, cls in inspect.getmembers(module, inspect.isclass): + if not name.endswith('Model') and issubclass(cls, JavaParams)\ + and not inspect.isabstract(cls): + self.check_params(cls()) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: http://git-wip-us.apache.org/repos/asf/spark/blob/a6428292/python/pyspark/ml/wrapper.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index fef626c..fef0040 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -110,7 +110,8 @@ class JavaParams(JavaWrapper, Params): for param in self.params: if self._java_obj.hasParam(param.name): java_param = self._java_obj.getParam(param.name) - if self._java_obj.isDefined(java_param): + # SPARK-14931: Only check set params back to avoid default params mismatch. + if self._java_obj.isSet(java_param): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._set(**{param.name: value}) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org