Github user viirya commented on a diff in the pull request: https://github.com/apache/spark/pull/20146#discussion_r183344264 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala --- @@ -217,33 +295,32 @@ class StringIndexerModel ( @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) - @Since("2.0.0") - override def transform(dataset: Dataset[_]): DataFrame = { - if (!dataset.schema.fieldNames.contains($(inputCol))) { - logInfo(s"Input column ${$(inputCol)} does not exist during transformation. " + - "Skip StringIndexerModel.") - return dataset.toDF - } - transformSchema(dataset.schema, logging = true) + /** @group setParam */ + @Since("2.4.0") + def setInputCols(value: Array[String]): this.type = set(inputCols, value) - val filteredLabels = getHandleInvalid match { - case StringIndexer.KEEP_INVALID => labels :+ "__unknown" - case _ => labels + /** @group setParam */ + @Since("2.4.0") + def setOutputCols(value: Array[String]): this.type = set(outputCols, value) + + private def filterInvalidData(dataset: Dataset[_], inputColNames: Seq[String]): Dataset[_] = { + var filteredDataset = dataset.na.drop(inputColNames.filter( + dataset.schema.fieldNames.contains(_))) + for (i <- 0 until inputColNames.length) { + val inputColName = inputColNames(i) + val labelToIndex = labelsToIndexArray(i) + val filterer = udf { label: String => + labelToIndex.contains(label) + } + filteredDataset = filteredDataset.where(filterer(dataset(inputColName))) } + filteredDataset + } - val metadata = NominalAttribute.defaultAttr - .withName($(outputCol)).withValues(filteredLabels).toMetadata() - // If we are skipping invalid records, filter them out. - val (filteredDataset, keepInvalid) = $(handleInvalid) match { - case StringIndexer.SKIP_INVALID => - val filterer = udf { label: String => - labelToIndex.contains(label) - } - (dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), false) - case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID) - } + private def getIndexer(labels: Seq[String], labelToIndex: OpenHashMap[String, Double]) = { + val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID) - val indexer = udf { label: String => + udf { label: String => --- End diff -- > This requires calling many udf for different input columns. Should we combine then in one udf? Then we must define the single UDF with different parameter number (looks like the big pattern matching in `ScalaUDF`). We also don't support UDFs with more than 22 parameters.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org