Repository: spark Updated Branches: refs/heads/branch-2.1 8520d7c6d -> 258ca40cf
Revert "[SPARK-21306][ML] OneVsRest should support setWeightCol" This reverts commit 8520d7c6d5e880dea3c1a8a874148c07222b4b4b. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/258ca40c Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/258ca40c Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/258ca40c Branch: refs/heads/branch-2.1 Commit: 258ca40cf43eedae59b014a41fc6197df9bde299 Parents: 8520d7c Author: Yanbo Liang <yblia...@gmail.com> Authored: Fri Jul 28 20:24:54 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Fri Jul 28 20:24:54 2017 +0800 ---------------------------------------------------------------------- .../spark/ml/classification/OneVsRest.scala | 39 ++------------------ .../ml/classification/OneVsRestSuite.scala | 10 ----- python/pyspark/ml/classification.py | 27 +++----------- python/pyspark/ml/tests.py | 14 ------- 4 files changed, 9 insertions(+), 81 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/258ca40c/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index c4a8f1f..e58b30d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -34,7 +34,6 @@ import org.apache.spark.ml._ import org.apache.spark.ml.attribute._ import org.apache.spark.ml.linalg.Vector import org.apache.spark.ml.param.{Param, ParamMap, ParamPair, Params} -import org.apache.spark.ml.param.shared.HasWeightCol import org.apache.spark.ml.util._ import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions._ @@ -54,8 +53,7 @@ private[ml] trait ClassifierTypeTrait { /** * Params for [[OneVsRest]]. */ -private[ml] trait OneVsRestParams extends PredictorParams - with ClassifierTypeTrait with HasWeightCol { +private[ml] trait OneVsRestParams extends PredictorParams with ClassifierTypeTrait { /** * param for the base binary classifier that we reduce multiclass classification into. @@ -301,18 +299,6 @@ final class OneVsRest @Since("1.4.0") ( @Since("1.5.0") def setPredictionCol(value: String): this.type = set(predictionCol, value) - /** - * Sets the value of param [[weightCol]]. - * - * This is ignored if weight is not supported by [[classifier]]. - * If this is not set or empty, we treat all instance weights as 1.0. - * Default is not set, so all instances have weight one. - * - * @group setParam - */ - @Since("2.3.0") - def setWeightCol(value: String): this.type = set(weightCol, value) - @Since("1.4.0") override def transformSchema(schema: StructType): StructType = { validateAndTransformSchema(schema, fitting = true, getClassifier.featuresDataType) @@ -331,20 +317,7 @@ final class OneVsRest @Since("1.4.0") ( } val numClasses = MetadataUtils.getNumClasses(labelSchema).fold(computeNumClasses())(identity) - val weightColIsUsed = isDefined(weightCol) && $(weightCol).nonEmpty && { - getClassifier match { - case _: HasWeightCol => true - case c => - logWarning(s"weightCol is ignored, as it is not supported by $c now.") - false - } - } - - val multiclassLabeled = if (weightColIsUsed) { - dataset.select($(labelCol), $(featuresCol), $(weightCol)) - } else { - dataset.select($(labelCol), $(featuresCol)) - } + val multiclassLabeled = dataset.select($(labelCol), $(featuresCol)) // persist if underlying dataset is not persistent. val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE @@ -364,13 +337,7 @@ final class OneVsRest @Since("1.4.0") ( paramMap.put(classifier.labelCol -> labelColName) paramMap.put(classifier.featuresCol -> getFeaturesCol) paramMap.put(classifier.predictionCol -> getPredictionCol) - if (weightColIsUsed) { - val classifier_ = classifier.asInstanceOf[ClassifierType with HasWeightCol] - paramMap.put(classifier_.weightCol -> getWeightCol) - classifier_.fit(trainingDataset, paramMap) - } else { - classifier.fit(trainingDataset, paramMap) - } + classifier.fit(trainingDataset, paramMap) }.toArray[ClassificationModel[_, _]] if (handlePersistence) { http://git-wip-us.apache.org/repos/asf/spark/blob/258ca40c/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala index 491eca5..aacb792 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala @@ -157,16 +157,6 @@ class OneVsRestSuite extends SparkFunSuite with MLlibTestSparkContext with Defau assert(output.schema.fieldNames.toSet === Set("label", "features", "prediction")) } - test("SPARK-21306: OneVsRest should support setWeightCol") { - val dataset2 = dataset.withColumn("weight", lit(1)) - // classifier inherits hasWeightCol - val ova = new OneVsRest().setWeightCol("weight").setClassifier(new LogisticRegression()) - assert(ova.fit(dataset2) !== null) - // classifier doesn't inherit hasWeightCol - val ova2 = new OneVsRest().setWeightCol("weight").setClassifier(new DecisionTreeClassifier()) - assert(ova2.fit(dataset2) !== null) - } - test("OneVsRest.copy and OneVsRestModel.copy") { val lr = new LogisticRegression() .setMaxIter(1) http://git-wip-us.apache.org/repos/asf/spark/blob/258ca40c/python/pyspark/ml/classification.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index f88be70..2b47c40 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -1331,7 +1331,7 @@ class MultilayerPerceptronClassificationModel(JavaModel, JavaPredictionModel, Ja return self._call_java("weights") -class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasWeightCol, HasPredictionCol): +class OneVsRestParams(HasFeaturesCol, HasLabelCol, HasPredictionCol): """ Parameters for OneVsRest and OneVsRestModel. """ @@ -1394,10 +1394,10 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): @keyword_only def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", - classifier=None, weightCol=None): + classifier=None): """ __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \ - classifier=None, weightCol=None) + classifier=None) """ super(OneVsRest, self).__init__() kwargs = self._input_kwargs @@ -1405,11 +1405,9 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): @keyword_only @since("2.0.0") - def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, - classifier=None, weightCol=None): + def setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): """ - setParams(self, featuresCol=None, labelCol=None, predictionCol=None, \ - classifier=None, weightCol=None): + setParams(self, featuresCol=None, labelCol=None, predictionCol=None, classifier=None): Sets params for OneVsRest. """ kwargs = self._input_kwargs @@ -1425,18 +1423,7 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): numClasses = int(dataset.agg({labelCol: "max"}).head()["max("+labelCol+")"]) + 1 - weightCol = None - if (self.isDefined(self.weightCol) and self.getWeightCol()): - if isinstance(classifier, HasWeightCol): - weightCol = self.getWeightCol() - else: - warnings.warn("weightCol is ignored, " - "as it is not supported by {} now.".format(classifier)) - - if weightCol: - multiclassLabeled = dataset.select(labelCol, featuresCol, weightCol) - else: - multiclassLabeled = dataset.select(labelCol, featuresCol) + multiclassLabeled = dataset.select(labelCol, featuresCol) # persist if underlying dataset is not persistent. handlePersistence = \ @@ -1452,8 +1439,6 @@ class OneVsRest(Estimator, OneVsRestParams, MLReadable, MLWritable): paramMap = dict([(classifier.labelCol, binaryLabelCol), (classifier.featuresCol, featuresCol), (classifier.predictionCol, predictionCol)]) - if weightCol: - paramMap[classifier.weightCol] = weightCol return classifier.fit(trainingDataset, paramMap) # TODO: Parallel training for all classes. http://git-wip-us.apache.org/repos/asf/spark/blob/258ca40c/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9046e9f..7152036 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -1218,20 +1218,6 @@ class OneVsRestTests(SparkSessionTestCase): output = model.transform(df) self.assertEqual(output.columns, ["label", "features", "prediction"]) - def test_support_for_weightCol(self): - df = self.spark.createDataFrame([(0.0, Vectors.dense(1.0, 0.8), 1.0), - (1.0, Vectors.sparse(2, [], []), 1.0), - (2.0, Vectors.dense(0.5, 0.5), 1.0)], - ["label", "features", "weight"]) - # classifier inherits hasWeightCol - lr = LogisticRegression(maxIter=5, regParam=0.01) - ovr = OneVsRest(classifier=lr, weightCol="weight") - self.assertIsNotNone(ovr.fit(df)) - # classifier doesn't inherit hasWeightCol - dt = DecisionTreeClassifier() - ovr2 = OneVsRest(classifier=dt, weightCol="weight") - self.assertIsNotNone(ovr2.fit(df)) - class HashingTFTest(SparkSessionTestCase): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org