Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/20146#discussion_r183257799 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -217,33 +295,32 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } - transformSchema(dataset.schema, logging = true) + /** @group setParam */ + @Since("2.4.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels + /** @group setParam */ + @Since("2.4.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { + var filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelsToIndexArray(i) + val filterer = udf { label: String => + labelToIndex.contains(label) + } + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) --- End diff -- is it possible to not use `var filteredDataset`?
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org