Github user dbtsai commented on a diff in the pull request: https://github.com/apache/spark/pull/20146#discussion_r183258353 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -217,33 +295,32 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } - transformSchema(dataset.schema, logging = true) + /** @group setParam */ + @Since("2.4.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels + /** @group setParam */ + @Since("2.4.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { + var filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelsToIndexArray(i) + val filterer = udf { label: String => + labelToIndex.contains(label) + } + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) } + filteredDataset + } - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(filteredLabels).toMetadata() - // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = $(handleInvalid) match { - case StringIndexer.SKIP_INVALID => - val filterer = udf { label: String => - labelToIndex.contains(label) - } - (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) - } + private def getIndexer(labels: Seq[String], labelToIndex: OpenHashMap[String, Double]) = { + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) - val indexer = udf { label: String => + udf { label: String => --- End diff -- This requires calling many udf for different input columns. Should we combine then in one udf? The `filteredDataset` logic can be in as well to avoid multiple lookups.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org