Repository: spark
Updated Branches:
  refs/heads/master 9dc0ca060 -> 44cbb61b3


[SPARK-15957][FOLLOW-UP][ML][PYSPARK] Add Python API for RFormula 
forceIndexLabel.

## What changes were proposed in this pull request?
Follow-up work of #13675, add Python API for ```RFormula forceIndexLabel```.

## How was this patch tested?
Unit test.

Author: Yanbo Liang <yblia...@gmail.com>

Closes #15430 from yanboliang/spark-15957-python.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/44cbb61b
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/44cbb61b
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/44cbb61b

Branch: refs/heads/master
Commit: 44cbb61b34a98e3e0d8e2543a4eb6e950e0019a5
Parents: 9dc0ca0
Author: Yanbo Liang <yblia...@gmail.com>
Authored: Thu Oct 13 19:44:24 2016 -0700
Committer: Yanbo Liang <yblia...@gmail.com>
Committed: Thu Oct 13 19:44:24 2016 -0700

----------------------------------------------------------------------
 python/pyspark/ml/feature.py | 31 +++++++++++++++++++++++++++----
 python/pyspark/ml/tests.py   | 16 ++++++++++++++++
 2 files changed, 43 insertions(+), 4 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/feature.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 64b21ca..a33c3e7 100755
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2494,21 +2494,30 @@ class RFormula(JavaEstimator, HasFeaturesCol, 
HasLabelCol, JavaMLReadable, JavaM
     formula = Param(Params._dummy(), "formula", "R model formula",
                     typeConverter=TypeConverters.toString)
 
+    forceIndexLabel = Param(Params._dummy(), "forceIndexLabel",
+                            "Force to index label whether it is numeric or 
string",
+                            typeConverter=TypeConverters.toBoolean)
+
     @keyword_only
-    def __init__(self, formula=None, featuresCol="features", labelCol="label"):
+    def __init__(self, formula=None, featuresCol="features", labelCol="label",
+                 forceIndexLabel=False):
         """
-        __init__(self, formula=None, featuresCol="features", labelCol="label")
+        __init__(self, formula=None, featuresCol="features", labelCol="label", 
\
+                 forceIndexLabel=False)
         """
         super(RFormula, self).__init__()
         self._java_obj = 
self._new_java_obj("org.apache.spark.ml.feature.RFormula", self.uid)
+        self._setDefault(forceIndexLabel=False)
         kwargs = self.__init__._input_kwargs
         self.setParams(**kwargs)
 
     @keyword_only
     @since("1.5.0")
-    def setParams(self, formula=None, featuresCol="features", 
labelCol="label"):
+    def setParams(self, formula=None, featuresCol="features", labelCol="label",
+                  forceIndexLabel=False):
         """
-        setParams(self, formula=None, featuresCol="features", labelCol="label")
+        setParams(self, formula=None, featuresCol="features", 
labelCol="label", \
+                  forceIndexLabel=False)
         Sets params for RFormula.
         """
         kwargs = self.setParams._input_kwargs
@@ -2528,6 +2537,20 @@ class RFormula(JavaEstimator, HasFeaturesCol, 
HasLabelCol, JavaMLReadable, JavaM
         """
         return self.getOrDefault(self.formula)
 
+    @since("2.1.0")
+    def setForceIndexLabel(self, value):
+        """
+        Sets the value of :py:attr:`forceIndexLabel`.
+        """
+        return self._set(forceIndexLabel=value)
+
+    @since("2.1.0")
+    def getForceIndexLabel(self):
+        """
+        Gets the value of :py:attr:`forceIndexLabel`.
+        """
+        return self.getOrDefault(self.forceIndexLabel)
+
     def _create_model(self, java_model):
         return RFormulaModel(java_model)
 

http://git-wip-us.apache.org/repos/asf/spark/blob/44cbb61b/python/pyspark/ml/tests.py
----------------------------------------------------------------------
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e233549..9d46cc3 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -477,6 +477,22 @@ class FeatureTests(SparkSessionTestCase):
             feature, expected = r
             self.assertEqual(feature, expected)
 
+    def test_rformula_force_index_label(self):
+        df = self.spark.createDataFrame([
+            (1.0, 1.0, "a"),
+            (0.0, 2.0, "b"),
+            (1.0, 0.0, "a")], ["y", "x", "s"])
+        # Does not index label by default since it's numeric type.
+        rf = RFormula(formula="y ~ x + s")
+        model = rf.fit(df)
+        transformedDF = model.transform(df)
+        self.assertEqual(transformedDF.head().label, 1.0)
+        # Force to index label.
+        rf2 = RFormula(formula="y ~ x + s").setForceIndexLabel(True)
+        model2 = rf2.fit(df)
+        transformedDF2 = model2.transform(df)
+        self.assertEqual(transformedDF2.head().label, 0.0)
+
 
 class HasInducedError(Params):
 


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to