Repository: spark Updated Branches: refs/heads/master a88329d45 -> c4a7eef0c
[SPARK-18481][ML] ML 2.1 QA: Remove deprecated methods for ML ## What changes were proposed in this pull request? Remove deprecated methods for ML. ## How was this patch tested? Existing tests. Author: Yanbo Liang <yblia...@gmail.com> Closes #15913 from yanboliang/spark-18481. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c4a7eef0 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c4a7eef0 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c4a7eef0 Branch: refs/heads/master Commit: c4a7eef0ce2d305c5c90a0a9a73b5a32eccfba95 Parents: a88329d Author: Yanbo Liang <yblia...@gmail.com> Authored: Sat Nov 26 05:28:41 2016 -0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Sat Nov 26 05:28:41 2016 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/ml/Pipeline.scala | 4 + .../spark/ml/classification/GBTClassifier.scala | 6 ++ .../ml/classification/LogisticRegression.scala | 8 +- .../classification/RandomForestClassifier.scala | 11 +-- .../apache/spark/ml/feature/ChiSqSelector.scala | 7 -- .../org/apache/spark/ml/param/params.scala | 15 ---- .../spark/ml/regression/GBTRegressor.scala | 6 ++ .../spark/ml/regression/LinearRegression.scala | 3 - .../ml/regression/RandomForestRegressor.scala | 10 +-- .../org/apache/spark/ml/tree/treeModels.scala | 5 -- .../org/apache/spark/ml/tree/treeParams.scala | 90 +++++++++----------- .../org/apache/spark/ml/util/ReadWrite.scala | 2 +- .../ml/classification/GBTClassifierSuite.scala | 8 ++ .../LogisticRegressionSuite.scala | 6 ++ project/MimaExcludes.scala | 30 +++++++ python/pyspark/ml/util.py | 40 ++++++++- 16 files changed, 144 insertions(+), 107 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala index f406f8c..38176b9 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala @@ -46,6 +46,10 @@ abstract class PipelineStage extends Params with Logging { * * Check transform validity and derive the output schema from the input schema. * + * We check validity for interactions between parameters during `transformSchema` and + * raise an exception if any parameter value is invalid. Parameter value checks which + * do not depend on other parameters are handled by `Param.validate()`. + * * Typical implementation should first conduct verification on schema change and parameter * validity, including complex parameter interaction checks. */ http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index 52f93f5..ca52231 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -203,6 +203,12 @@ class GBTClassificationModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + @Since("2.0.0") + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index fe29926..41b84f4 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -40,7 +40,7 @@ import org.apache.spark.mllib.util.MLUtils import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.{col, lit} -import org.apache.spark.sql.types.DoubleType +import org.apache.spark.sql.types.{DataType, DoubleType, StructType} import org.apache.spark.storage.StorageLevel import org.apache.spark.util.VersionUtils @@ -176,8 +176,12 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas } } - override def validateParams(): Unit = { + override protected def validateAndTransformSchema( + schema: StructType, + fitting: Boolean, + featuresDataType: DataType): StructType = { checkThresholdConsistency() + super.validateAndTransformSchema(schema, fitting, featuresDataType) } } http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index 907c73e..d151213 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -158,7 +158,7 @@ class RandomForestClassificationModel private[ml] ( @Since("1.6.0") override val numFeatures: Int, @Since("1.5.0") override val numClasses: Int) extends ProbabilisticClassificationModel[Vector, RandomForestClassificationModel] - with RandomForestClassificationModelParams with TreeEnsembleModel[DecisionTreeClassificationModel] + with RandomForestClassifierParams with TreeEnsembleModel[DecisionTreeClassificationModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestClassificationModel requires at least 1 tree.") @@ -221,15 +221,6 @@ class RandomForestClassificationModel private[ml] ( } } - /** - * Number of trees in ensemble - * - * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 - */ - // TODO: Once this is removed, then this class can inherit from RandomForestClassifierParams - @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): RandomForestClassificationModel = { copyValues(new RandomForestClassificationModel(uid, _trees, numFeatures, numClasses), extra) http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala index 653fa41..7cd0f15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ChiSqSelector.scala @@ -216,13 +216,6 @@ final class ChiSqSelectorModel private[ml] ( @Since("1.6.0") def setOutputCol(value: String): this.type = set(outputCol, value) - /** - * @group setParam - */ - @Since("1.6.0") - @deprecated("labelCol is not used by ChiSqSelectorModel.", "2.0.0") - def setLabelCol(value: String): this.type = set(labelCol, value) - @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val transformedSchema = transformSchema(dataset.schema, logging = true) http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/param/params.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala index 96206e0..5bd8ebe 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala @@ -547,21 +547,6 @@ trait Params extends Identifiable with Serializable { } /** - * Validates parameter values stored internally. - * Raise an exception if any parameter value is invalid. - * - * This only needs to check for interactions between parameters. - * Parameter value checks which do not depend on other parameters are handled by - * `Param.validate()`. This method does not handle input/output column parameters; - * those are checked during schema validation. - * @deprecated Will be removed in 2.1.0. All the checks should be merged into transformSchema - */ - @deprecated("Will be removed in 2.1.0. Checks should be merged into transformSchema.", "2.0.0") - def validateParams(): Unit = { - // Do nothing by default. Override to handle Param interactions. - } - - /** * Explains a param. * @param param input param, must belong to this instance. * @return a string that contains the input param name, doc, and optionally its default value and http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index ed2d055..6d8159a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -183,6 +183,12 @@ class GBTRegressionModel private[ml]( @Since("1.4.0") override def trees: Array[DecisionTreeRegressionModel] = _trees + /** + * Number of trees in ensemble + */ + @Since("2.0.0") + val getNumTrees: Int = trees.length + @Since("1.4.0") override def treeWeights: Array[Double] = _treeWeights http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala index eb4e38c..19ddf36 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala @@ -611,9 +611,6 @@ class LinearRegressionSummary private[regression] ( private val privateModel: LinearRegressionModel, private val diagInvAtWA: Array[Double]) extends Serializable { - @deprecated("The model field is deprecated and will be removed in 2.1.0.", "2.0.0") - val model: LinearRegressionModel = privateModel - @transient private val metrics = new RegressionMetrics( predictions .select(col(predictionCol), col(labelCol).cast(DoubleType)) http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index d60f05e..90d89c5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -145,7 +145,7 @@ class RandomForestRegressionModel private[ml] ( private val _trees: Array[DecisionTreeRegressionModel], override val numFeatures: Int) extends PredictionModel[Vector, RandomForestRegressionModel] - with RandomForestRegressionModelParams with TreeEnsembleModel[DecisionTreeRegressionModel] + with RandomForestRegressorParams with TreeEnsembleModel[DecisionTreeRegressionModel] with MLWritable with Serializable { require(_trees.nonEmpty, "RandomForestRegressionModel requires at least 1 tree.") @@ -182,14 +182,6 @@ class RandomForestRegressionModel private[ml] ( _trees.map(_.rootNode.predictImpl(features).prediction).sum / getNumTrees } - /** - * Number of trees in ensemble - * @deprecated Use [[getNumTrees]] instead. This method will be removed in 2.1.0 - */ - // TODO: Once this is removed, then this class can inherit from RandomForestRegressorParams - @deprecated("Use getNumTrees instead. This method will be removed in 2.1.0.", "2.0.0") - val numTrees: Int = trees.length - @Since("1.4.0") override def copy(extra: ParamMap): RandomForestRegressionModel = { copyValues(new RandomForestRegressionModel(uid, _trees, numFeatures), extra).setParent(parent) http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala index d3cbc36..0d6e903 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeModels.scala @@ -95,11 +95,6 @@ private[ml] trait TreeEnsembleModel[M <: DecisionTreeModel] { /** Trees in this ensemble. Warning: These have null parent Estimators. */ def trees: Array[M] - /** - * Number of trees in ensemble - */ - val getNumTrees: Int = trees.length - /** Weights for each tree, zippable with [[trees]] */ def treeWeights: Array[Double] http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 40510ad8..83ab4b5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -319,8 +319,32 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams { } } -/** Used for [[RandomForestParams]] */ -private[ml] trait HasFeatureSubsetStrategy extends Params { +/** + * Parameters for Random Forest algorithms. + */ +private[ml] trait RandomForestParams extends TreeEnsembleParams { + + /** + * Number of trees to train (>= 1). + * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. + * TODO: Change to always do bootstrapping (simpler). SPARK-7130 + * (default = 20) + * + * Note: The reason that we cannot add this to both GBT and RF (i.e. in TreeEnsembleParams) + * is the param `maxIter` controls how many trees a GBT has. The semantics in the algorithms + * are a bit different. + * @group param + */ + final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", + ParamValidators.gtEq(1)) + + setDefault(numTrees -> 20) + + /** @group setParam */ + def setNumTrees(value: Int): this.type = set(numTrees, value) + + /** @group getParam */ + final def getNumTrees: Int = $(numTrees) /** * The number of features to consider for splits at each tree node. @@ -366,38 +390,6 @@ private[ml] trait HasFeatureSubsetStrategy extends Params { final def getFeatureSubsetStrategy: String = $(featureSubsetStrategy).toLowerCase } -/** - * Used for [[RandomForestParams]]. - * This is separated out from [[RandomForestParams]] because of an issue with the - * `numTrees` method conflicting with this Param in the Estimator. - */ -private[ml] trait HasNumTrees extends Params { - - /** - * Number of trees to train (>= 1). - * If 1, then no bootstrapping is used. If > 1, then bootstrapping is done. - * TODO: Change to always do bootstrapping (simpler). SPARK-7130 - * (default = 20) - * @group param - */ - final val numTrees: IntParam = new IntParam(this, "numTrees", "Number of trees to train (>= 1)", - ParamValidators.gtEq(1)) - - setDefault(numTrees -> 20) - - /** @group setParam */ - def setNumTrees(value: Int): this.type = set(numTrees, value) - - /** @group getParam */ - final def getNumTrees: Int = $(numTrees) -} - -/** - * Parameters for Random Forest algorithms. - */ -private[ml] trait RandomForestParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with HasNumTrees - private[spark] object RandomForestParams { // These options should be lowercase. final val supportedFeatureSubsetStrategies: Array[String] = @@ -407,21 +399,15 @@ private[spark] object RandomForestParams { private[ml] trait RandomForestClassifierParams extends RandomForestParams with TreeClassifierParams -private[ml] trait RandomForestClassificationModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeClassifierParams - private[ml] trait RandomForestRegressorParams extends RandomForestParams with TreeRegressorParams -private[ml] trait RandomForestRegressionModelParams extends TreeEnsembleParams - with HasFeatureSubsetStrategy with TreeRegressorParams - /** * Parameters for Gradient-Boosted Tree algorithms. * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasStepSize { +private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter { /* TODO: Add this doc when we add this param. SPARK-7132 * Threshold for stopping early when runWithValidation is used. @@ -434,24 +420,26 @@ private[ml] trait GBTParams extends TreeEnsembleParams with HasMaxIter with HasS // final val validationTol: DoubleParam = new DoubleParam(this, "validationTol", "") // validationTol -> 1e-5 - setDefault(maxIter -> 20, stepSize -> 0.1) - /** @group setParam */ def setMaxIter(value: Int): this.type = set(maxIter, value) /** - * Step size (a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each - * estimator. + * Param for Step size (a.k.a. learning rate) in interval (0, 1] for shrinking + * the contribution of each estimator. * (default = 0.1) - * @group setParam + * @group param */ + final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size " + + "(a.k.a. learning rate) in interval (0, 1] for shrinking the contribution of each estimator.", + ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)) + + /** @group getParam */ + final def getStepSize: Double = $(stepSize) + + /** @group setParam */ def setStepSize(value: Double): this.type = set(stepSize, value) - override def validateParams(): Unit = { - require(ParamValidators.inRange(0, 1, lowerInclusive = false, upperInclusive = true)( - getStepSize), "GBT parameter stepSize should be in interval (0, 1], " + - s"but it given invalid value $getStepSize.") - } + setDefault(maxIter -> 20, stepSize -> 0.1) /** (private[ml]) Create a BoostingStrategy instance to use with the old API. */ private[ml] def getOldBoostingStrategy( http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 5b7e5ec..bbb9886 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -46,7 +46,7 @@ private[util] sealed trait BaseReadWrite { * Sets the Spark SQLContext to use for saving/loading. */ @Since("1.6.0") - @deprecated("Use session instead", "2.0.0") + @deprecated("Use session instead, This method will be removed in 2.2.0.", "2.0.0") def context(sqlContext: SQLContext): this.type = { optionSparkSession = Option(sqlContext.sparkSession) this http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index 3492709..7c36745 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -70,6 +70,14 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext ParamsSuite.checkParams(model) } + test("GBT parameter stepSize should be in interval (0, 1]") { + withClue("GBT parameter stepSize should be in interval (0, 1]") { + intercept[IllegalArgumentException] { + new GBTClassifier().setStepSize(10) + } + } + } + test("Binary classification with continuous features: Log Loss") { val categoricalFeatures = Map.empty[Int, Int] testCombinations.foreach { http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index e360542..9c4c59a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -194,6 +194,12 @@ class LogisticRegressionSuite // thresholds and threshold must be consistent: values withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { intercept[IllegalArgumentException] { + lr2.fit(smallBinaryDataset, + lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) + } + } + withClue("fit with ParamMap should throw error if threshold, thresholds do not match.") { + intercept[IllegalArgumentException] { val lr2model = lr2.fit(smallBinaryDataset, lr2.thresholds -> Array(0.3, 0.7), lr2.threshold -> (expectedThreshold / 2.0)) lr2model.getThreshold http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/project/MimaExcludes.scala ---------------------------------------------------------------------- diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 12f7ed2..8401401 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -867,6 +867,36 @@ object MimaExcludes { // [SPARK-12221] Add CPU time to metrics ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetrics.this"), ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this") + ) ++ Seq( + // [SPARK-18481] ML 2.1 QA: Remove deprecated methods for ML + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.PipelineStage.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.JavaParams.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.param.Params.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassificationModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegression.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.GBTClassifier.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.LogisticRegressionModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.feature.ChiSqSelectorModel.setLabelCol"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.evaluation.Evaluator.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressor.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.GBTRegressionModel.validateParams"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.LinearRegressionSummary.model"), + ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassifier"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassifier"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.classification.GBTClassificationModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressor"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressor"), + ProblemFilters.exclude[MissingTypesProblem]("org.apache.spark.ml.regression.GBTRegressionModel"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.getNumTrees"), + ProblemFilters.exclude[FinalMethodProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.getNumTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"), + ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy") ) } http://git-wip-us.apache.org/repos/asf/spark/blob/c4a7eef0/python/pyspark/ml/util.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 7d39c30..bec4b28 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -78,7 +78,14 @@ class MLWriter(object): raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) def context(self, sqlContext): - """Sets the SQL context to use for saving.""" + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def session(self, sparkSession): + """Sets the Spark Session to use for saving.""" raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) @@ -105,10 +112,19 @@ class JavaMLWriter(MLWriter): return self def context(self, sqlContext): - """Sets the SQL context to use for saving.""" + """ + Sets the SQL context to use for saving. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") self._jwrite.context(sqlContext._ssql_ctx) return self + def session(self, sparkSession): + """Sets the Spark Session to use for saving.""" + self._jwrite.session(sparkSession._jsparkSession) + return self + @inherit_doc class MLWritable(object): @@ -155,7 +171,14 @@ class MLReader(object): raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) def context(self, sqlContext): - """Sets the SQL context to use for loading.""" + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def session(self, sparkSession): + """Sets the Spark Session to use for loading.""" raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) @@ -180,10 +203,19 @@ class JavaMLReader(MLReader): return self._clazz._from_java(java_obj) def context(self, sqlContext): - """Sets the SQL context to use for loading.""" + """ + Sets the SQL context to use for loading. + .. note:: Deprecated in 2.1 and will be removed in 2.2, use session instead. + """ + warnings.warn("Deprecated in 2.1 and will be removed in 2.2, use session instead.") self._jread.context(sqlContext._ssql_ctx) return self + def session(self, sparkSession): + """Sets the Spark Session to use for loading.""" + self._jread.session(sparkSession._jsparkSession) + return self + @classmethod def _java_loader_class(cls, clazz): """ --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org