Repository: spark Updated Branches: refs/heads/master c605fee01 -> c19680be1
[SPARK-19852][PYSPARK][ML] Python StringIndexer supports 'keep' to handle invalid data ## What changes were proposed in this pull request? This PR is to maintain API parity with changes made in SPARK-17498 to support a new option 'keep' in StringIndexer to handle unseen labels or NULL values with PySpark. Note: This is updated version of #17237 , the primary author of this PR is VinceShieh . ## How was this patch tested? Unit tests. Author: VinceShieh <vincent....@intel.com> Author: Yanbo Liang <yblia...@gmail.com> Closes #18453 from yanboliang/spark-19852. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/c19680be Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/c19680be Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/c19680be Branch: refs/heads/master Commit: c19680be1c532dded1e70edce7a981ba28af09ad Parents: c605fee Author: Yanbo Liang <yblia...@gmail.com> Authored: Sun Jul 2 16:17:03 2017 +0800 Committer: Yanbo Liang <yblia...@gmail.com> Committed: Sun Jul 2 16:17:03 2017 +0800 ---------------------------------------------------------------------- python/pyspark/ml/feature.py | 6 ++++++ python/pyspark/ml/tests.py | 21 +++++++++++++++++++++ 2 files changed, 27 insertions(+) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/c19680be/python/pyspark/ml/feature.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 77de1cc..25ad06f 100755 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -2132,6 +2132,12 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, "frequencyDesc, frequencyAsc, alphabetDesc, alphabetAsc.", typeConverter=TypeConverters.toString) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid data (unseen " + + "labels or NULL values). Options are 'skip' (filter out rows with " + + "invalid data), error (throw an error), or 'keep' (put invalid data " + + "in a special additional bucket, at index numLabels).", + typeConverter=TypeConverters.toString) + @keyword_only def __init__(self, inputCol=None, outputCol=None, handleInvalid="error", stringOrderType="frequencyDesc"): http://git-wip-us.apache.org/repos/asf/spark/blob/c19680be/python/pyspark/ml/tests.py ---------------------------------------------------------------------- diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 17a3947..ffb8b0a 100755 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -551,6 +551,27 @@ class FeatureTests(SparkSessionTestCase): for i in range(0, len(expected)): self.assertTrue(all(observed[i]["features"].toArray() == expected[i])) + def test_string_indexer_handle_invalid(self): + df = self.spark.createDataFrame([ + (0, "a"), + (1, "d"), + (2, None)], ["id", "label"]) + + si1 = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid="keep", + stringOrderType="alphabetAsc") + model1 = si1.fit(df) + td1 = model1.transform(df) + actual1 = td1.select("id", "indexed").collect() + expected1 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0), Row(id=2, indexed=2.0)] + self.assertEqual(actual1, expected1) + + si2 = si1.setHandleInvalid("skip") + model2 = si2.fit(df) + td2 = model2.transform(df) + actual2 = td2.select("id", "indexed").collect() + expected2 = [Row(id=0, indexed=0.0), Row(id=1, indexed=1.0)] + self.assertEqual(actual2, expected2) + class HasInducedError(Params): --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org