Repository: spark Updated Branches: refs/heads/master d4a637cd4 -> 85941ecf2
[SPARK-11569][ML] Fix StringIndexer to handle null value properly ## What changes were proposed in this pull request? This PR is to enhance StringIndexer with NULL values handling. Before the PR, StringIndexer will throw an exception when encounters NULL values. With this PR: - handleInvalid=error: Throw an exception as before - handleInvalid=skip: Skip null values as well as unseen labels - handleInvalid=keep: Give null values an additional index as well as unseen labels BTW, I noticed someone was trying to solve the same problem ( #9920 ) but seems getting no progress or response for a long time. Would you mind to give me a chance to solve it ? I'm eager to help. :-) ## How was this patch tested? new unit tests Author: Menglong TAN <tanmengl...@renrenche.com> Author: Menglong TAN <tanmengl...@gmail.com> Closes #17233 from crackcell/11569_StringIndexer_NULL. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/85941ecf Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/85941ecf Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/85941ecf Branch: refs/heads/master Commit: 85941ecf28362f35718ebcd3a22dbb17adb49154 Parents: d4a637c Author: Menglong TAN <tanmengl...@renrenche.com> Authored: Tue Mar 14 07:45:42 2017 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Tue Mar 14 07:45:42 2017 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/feature/StringIndexer.scala | 54 ++++++++++++-------- .../spark/ml/feature/StringIndexerSuite.scala | 45 ++++++++++++++++ 2 files changed, 77 insertions(+), 22 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/85941ecf/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala index 810b02f..99321bc 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala @@ -39,20 +39,21 @@ import org.apache.spark.util.collection.OpenHashMap private[feature] trait StringIndexerBase extends Params with HasInputCol with HasOutputCol { /** - * Param for how to handle unseen labels. Options are 'skip' (filter out rows with - * unseen labels), 'error' (throw an error), or 'keep' (put unseen labels in a special additional - * bucket, at index numLabels. + * Param for how to handle invalid data (unseen labels or NULL values). + * Options are 'skip' (filter out rows with invalid data), + * 'error' (throw an error), or 'keep' (put invalid data in a special additional + * bucket, at index numLabels). * Default: "error" * @group param */ @Since("1.6.0") val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", "how to handle " + - "unseen labels. Options are 'skip' (filter out rows with unseen labels), " + - "error (throw an error), or 'keep' (put unseen labels in a special additional bucket, " + - "at index numLabels).", + "invalid data (unseen labels or NULL values). " + + "Options are 'skip' (filter out rows with invalid data), error (throw an error), " + + "or 'keep' (put invalid data in a special additional bucket, at index numLabels).", ParamValidators.inArray(StringIndexer.supportedHandleInvalids)) - setDefault(handleInvalid, StringIndexer.ERROR_UNSEEN_LABEL) + setDefault(handleInvalid, StringIndexer.ERROR_INVALID) /** @group getParam */ @Since("1.6.0") @@ -106,7 +107,7 @@ class StringIndexer @Since("1.4.0") ( @Since("2.0.0") override def fit(dataset: Dataset[_]): StringIndexerModel = { transformSchema(dataset.schema, logging = true) - val counts = dataset.select(col($(inputCol)).cast(StringType)) + val counts = dataset.na.drop(Array($(inputCol))).select(col($(inputCol)).cast(StringType)) .rdd .map(_.getString(0)) .countByValue() @@ -125,11 +126,11 @@ class StringIndexer @Since("1.4.0") ( @Since("1.6.0") object StringIndexer extends DefaultParamsReadable[StringIndexer] { - private[feature] val SKIP_UNSEEN_LABEL: String = "skip" - private[feature] val ERROR_UNSEEN_LABEL: String = "error" - private[feature] val KEEP_UNSEEN_LABEL: String = "keep" + private[feature] val SKIP_INVALID: String = "skip" + private[feature] val ERROR_INVALID: String = "error" + private[feature] val KEEP_INVALID: String = "keep" private[feature] val supportedHandleInvalids: Array[String] = - Array(SKIP_UNSEEN_LABEL, ERROR_UNSEEN_LABEL, KEEP_UNSEEN_LABEL) + Array(SKIP_INVALID, ERROR_INVALID, KEEP_INVALID) @Since("1.6.0") override def load(path: String): StringIndexer = super.load(path) @@ -188,7 +189,7 @@ class StringIndexerModel ( transformSchema(dataset.schema, logging = true) val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_UNSEEN_LABEL => labels :+ "__unknown" + case StringIndexer.KEEP_INVALID => labels :+ "__unknown" case _ => labels } @@ -196,22 +197,31 @@ class StringIndexerModel ( .withName($(outputCol)).withValues(filteredLabels).toMetadata() // If we are skipping invalid records, filter them out. val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_UNSEEN_LABEL => + case StringIndexer.SKIP_INVALID => val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_UNSEEN_LABEL) + (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) + case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) } val indexer = udf { label: String => - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_UNSEEN_LABEL}.") + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } } } http://git-wip-us.apache.org/repos/asf/spark/blob/85941ecf/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala index 188dffb..8d9042b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/StringIndexerSuite.scala @@ -122,6 +122,51 @@ class StringIndexerSuite assert(output === expected) } + test("StringIndexer with NULLs") { + val data: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (2, "b"), (3, null)) + val data2: Seq[(Int, String)] = Seq((0, "a"), (1, "b"), (3, null)) + val df = data.toDF("id", "label") + val df2 = data2.toDF("id", "label") + + val indexer = new StringIndexer() + .setInputCol("label") + .setOutputCol("labelIndex") + + withClue("StringIndexer should throw error when setHandleInvalid=error " + + "when given NULL values") { + intercept[SparkException] { + indexer.setHandleInvalid("error") + indexer.fit(df).transform(df2).collect() + } + } + + indexer.setHandleInvalid("skip") + val transformedSkip = indexer.fit(df).transform(df2) + val attrSkip = Attribute + .fromStructField(transformedSkip.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrSkip.values.get === Array("b", "a")) + val outputSkip = transformedSkip.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0 + val expectedSkip = Set((0, 1.0), (1, 0.0)) + assert(outputSkip === expectedSkip) + + indexer.setHandleInvalid("keep") + val transformedKeep = indexer.fit(df).transform(df2) + val attrKeep = Attribute + .fromStructField(transformedKeep.schema("labelIndex")) + .asInstanceOf[NominalAttribute] + assert(attrKeep.values.get === Array("b", "a", "__unknown")) + val outputKeep = transformedKeep.select("id", "labelIndex").rdd.map { r => + (r.getInt(0), r.getDouble(1)) + }.collect().toSet + // a -> 1, b -> 0, null -> 2 + val expectedKeep = Set((0, 1.0), (1, 0.0), (3, 2.0)) + assert(outputKeep === expectedKeep) + } + test("StringIndexerModel should keep silent if the input column does not exist.") { val indexerModel = new StringIndexerModel("indexer", Array("a", "b", "c")) .setInputCol("label") --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org