Repository: spark Updated Branches: refs/heads/master 9dc0ca060 -> 44cbb61b3
[SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula forceIndexLabel. ## What changes were proposed in this pull request? Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```. ## How was this patch tested? Unit test. Author: Yanbo Liang <yblia...@gmail.com> Closes #15430 from yanboliang/spark-15957-python. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/44cbb61b Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/44cbb61b Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/44cbb61b Branch: refs/heads/master Commit: 44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5 Parents: 9dc0ca0 Author: Yanbo Liang <yblia...@gmail.com> Authored: Thu Oct 13 19:44:24 2016 -0700 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Thu Oct 13 19:44:24 2016 -0700 ---------------------------------------------------------------------- python/pyspark/ml/feature.py | 31 +++++++++++++++++++++++++++---- python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 43 insertions(+), 4 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 64b21ca..a33c3e7 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM formula = Param(Params._dummy(), "formula", "R model formula", typeConverter=TypeConverters.toString) + forceIndexLabel = Param(Params._dummy(), "forceIndexLabel", + "Force to index label whether it is numeric or string", + typeConverter=TypeConverters.toBoolean) + @keyword_only - def __init__(self, formula=None, featuresCol="features", labelCol="label"): + def __init__(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - __init__(self, formula=None, featuresCol="features", labelCol="label") + __init__(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) """ super(RFormula, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid) + self._setDefault(forceIndexLabel=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.5.0") - def setParams(self, formula=None, featuresCol="features", labelCol="label"): + def setParams(self, formula=None, featuresCol="features", labelCol="label", + forceIndexLabel=False): """ - setParams(self, formula=None, featuresCol="features", labelCol="label") + setParams(self, formula=None, featuresCol="features", labelCol="label", \ + forceIndexLabel=False) Sets params for RFormula. """ kwargs = self.setParams._input_kwargs @@ -2528,6 +2537,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM """ return self.getOrDefault(self.formula) + @since("2.1.0") + def setForceIndexLabel(self, value): + """ + Sets the value of :py:attr:`forceIndexLabel`. + """ + return self._set(forceIndexLabel=value) + + @since("2.1.0") + def getForceIndexLabel(self): + """ + Gets the value of :py:attr:`forceIndexLabel`. + """ + return self.getOrDefault(self.forceIndexLabel) + def _create_model(self, java_model): return RFormulaModel(java_model) http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e233549..9d46cc3 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -477,6 +477,22 @@ class FeatureTests(SparkSessionTestCase): feature, expected = r self.assertEqual(feature, expected) + def test_rformula_force_index_label(self): + df = self.spark.createDataFrame([ + (1.0, 1.0, "a"), + (0.0, 2.0, "b"), + (1.0, 0.0, "a")], ["y", "x", "s"]) + # Does not index label by default since it's numeric type. + rf = RFormula(formula="y ~ x + s") + model = rf.fit(df) + transformedDF = model.transform(df) + self.assertEqual(transformedDF.head().label, 1.0) + # Force to index label. + rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True) + model2 = rf2.fit(df) + transformedDF2 = model2.transform(df) + self.assertEqual(transformedDF2.head().label, 0.0) + class HasInducedError(Params): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org