Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/19621#discussion_r151344393
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -217,69 +289,94 @@ class StringIndexerModel (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
    +
    +  /** @group setParam */
    +  @Since("2.3.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
       @Since("2.0.0")
       override def transform(dataset: Dataset[_]): DataFrame = {
    -    if (!dataset.schema.fieldNames.contains($(inputCol))) {
    -      logInfo(s"Input column ${$(inputCol)} does not exist during 
transformation. " +
    -        "Skip StringIndexerModel.")
    -      return dataset.toDF
    -    }
         transformSchema(dataset.schema, logging = true)
     
    -    val filteredLabels = getHandleInvalid match {
    -      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
    -      case _ => labels
    -    }
    +    var (inputColNames, outputColNames) = getInOutCols
    +
    +    val outputColumns = new Array[Column](outputColNames.length)
     
    -    val metadata = NominalAttribute.defaultAttr
    -      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
    +    var filteredDataset = dataset
         // If we are skipping invalid records, filter them out.
    -    val (filteredDataset, keepInvalid) = getHandleInvalid match {
    -      case StringIndexer.SKIP_INVALID =>
    +    if (getHandleInvalid == StringIndexer.SKIP_INVALID) {
    +      filteredDataset = dataset.na.drop(inputColNames.filter(
    +        dataset.schema.fieldNames.contains(_)))
    +      for (i <- 0 until inputColNames.length) {
    +        val inputColName = inputColNames(i)
    +        val labelToIndex = labelToIndexArray(i)
             val filterer = udf { label: String =>
               labelToIndex.contains(label)
             }
    -        
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), 
false)
    -      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
    +        filteredDataset = 
filteredDataset.where(filterer(dataset(inputColName)))
    +      }
         }
     
    -    val indexer = udf { label: String =>
    -      if (label == null) {
    -        if (keepInvalid) {
    -          labels.length
    -        } else {
    -          throw new SparkException("StringIndexer encountered NULL value. 
To handle or skip " +
    -            "NULLS, try setting StringIndexer.handleInvalid.")
    -        }
    +    for (i <- 0 until outputColNames.length) {
    +      val inputColName = inputColNames(i)
    +      val outputColName = outputColNames(i)
    +      val labelToIndex = labelToIndexArray(i)
    +      val labels = labelsArray(i)
    +
    +      if (!dataset.schema.fieldNames.contains(inputColName)) {
    +        logInfo(s"Input column ${inputColName} does not exist during 
transformation. " +
    +          "Skip this column StringIndexerModel transform.")
    +        outputColNames(i) = null
           } else {
    -        if (labelToIndex.contains(label)) {
    -          labelToIndex(label)
    -        } else if (keepInvalid) {
    -          labels.length
    -        } else {
    -          throw new SparkException(s"Unseen label: $label.  To handle 
unseen labels, " +
    -            s"set Param handleInvalid to ${StringIndexer.KEEP_INVALID}.")
    +        val filteredLabels = getHandleInvalid match {
    +          case StringIndexer.KEEP_INVALID => labelsArray(i) :+ "__unknown"
    +          case _ => labelsArray(i)
             }
    -      }
    -    }.asNondeterministic()
     
    -    filteredDataset.select(col("*"),
    -      indexer(dataset($(inputCol)).cast(StringType)).as($(outputCol), 
metadata))
    +        val metadata = NominalAttribute.defaultAttr
    +          .withName(outputColName).withValues(filteredLabels).toMetadata()
    +
    +        val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID)
    +
    +        val indexer = udf { label: String =>
    +          if (label == null) {
    +            if (keepInvalid) {
    +              labels.length
    +            } else {
    +              throw new SparkException("StringIndexer encountered NULL 
value. To handle or skip " +
    +                "NULLS, try setting StringIndexer.handleInvalid.")
    +            }
    +          } else {
    +            if (labelToIndex.contains(label)) {
    +              labelToIndex(label)
    +            } else if (keepInvalid) {
    +              labels.length
    +            } else {
    +              throw new SparkException(s"Unseen label: $label.  To handle 
unseen labels, " +
    +                s"set Param handleInvalid to 
${StringIndexer.KEEP_INVALID}.")
    +            }
    +          }
    +        }.asNondeterministic()
    +
    +        outputColumns(i) = indexer(dataset(inputColName).cast(StringType))
    +          .as(outputColName, metadata)
    +      }
    +    }
    +    filteredDataset.withColumns(outputColNames.filter(_ != null),
    +      outputColumns.filter(_ != null))
    --- End diff --
    
    In case `outputColNames` and `outputColNames` are empty, `withColumns` 
might return an empty dataset, not original dataset.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to