Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/19621#discussion_r151344393 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -217,69 +289,94 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("2.3.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) + + /** @group setParam */ + @Since("2.3.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } transformSchema(dataset.schema, logging = true) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels - } + var (inputColNames, outputColNames) = getInOutCols + + val outputColumns = new Array[Column](outputColNames.length) - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(filteredLabels).toMetadata() + var filteredDataset = dataset // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = getHandleInvalid match { - case StringIndexer.SKIP_INVALID => + if (getHandleInvalid == StringIndexer.SKIP_INVALID) { + filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelToIndexArray(i) val filterer = udf { label: String => labelToIndex.contains(label) } - (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) + } } - val indexer = udf { label: String => - if (label == null) { - if (keepInvalid) { - labels.length - } else { - throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + - "NULLS, try setting StringIndexer.handleInvalid.") - } + for (i <- 0 until outputColNames.length) { + val inputColName = inputColNames(i) + val outputColName = outputColNames(i) + val labelToIndex = labelToIndexArray(i) + val labels = labelsArray(i) + + if (!dataset.schema.fieldNames.contains(inputColName)) { + logInfo(s"Input column ${inputColName} does not exist during transformation. " + + "Skip this column StringIndexerModel transform.") + outputColNames(i) = null } else { - if (labelToIndex.contains(label)) { - labelToIndex(label) - } else if (keepInvalid) { - labels.length - } else { - throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + - s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + val filteredLabels = getHandleInvalid match { + case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown" + case _ => labelsArray(i) } - } - }.asNondeterministic() - filteredDataset.select(col("*"), - indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), metadata)) + val metadata = NominalAttribute.defaultAttr + .withName(outputColName).withValues(filteredLabels).toMetadata() + + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) + + val indexer = udf { label: String => + if (label == null) { + if (keepInvalid) { + labels.length + } else { + throw new SparkException("StringIndexer encountered NULL value. To handle or skip " + + "NULLS, try setting StringIndexer.handleInvalid.") + } + } else { + if (labelToIndex.contains(label)) { + labelToIndex(label) + } else if (keepInvalid) { + labels.length + } else { + throw new SparkException(s"Unseen label: $label. To handle unseen labels, " + + s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.") + } + } + }.asNondeterministic() + + outputColumns(i) = indexer(dataset(inputColName).cast(StringType)) + .as(outputColName, metadata) + } + } + filteredDataset.withColumns(outputColNames.filter(_ != null), + outputColumns.filter(_ != null)) --- End diff -- In case `outputColNames` and `outputColNames` are empty, `withColumns` might return an empty dataset, not original dataset.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org