Github user viirya commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r183344264
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -217,33 +295,32 @@ class StringIndexerModel (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    -  @Since("2.0.0")
    -  override def transform(dataset: Dataset[_]): DataFrame = {
    -    if (!dataset.schema.fieldNames.contains($(inputCol))) {
    -      logInfo(s"Input column ${$(inputCol)} does not exist during 
transformation. " +
    -        "Skip StringIndexerModel.")
    -      return dataset.toDF
    -    }
    -    transformSchema(dataset.schema, logging = true)
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
     
    -    val filteredLabels = getHandleInvalid match {
    -      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
    -      case _ => labels
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
    +  private def filterInvalidData(dataset: Dataset[_], inputColNames: 
Seq[String]): Dataset[_] = {
    +    var filteredDataset = dataset.na.drop(inputColNames.filter(
    +      dataset.schema.fieldNames.contains(_)))
    +    for (i <- 0 until inputColNames.length) {
    +      val inputColName = inputColNames(i)
    +      val labelToIndex = labelsToIndexArray(i)
    +      val filterer = udf { label: String =>
    +        labelToIndex.contains(label)
    +      }
    +      filteredDataset = 
filteredDataset.where(filterer(dataset(inputColName)))
         }
    +    filteredDataset
    +  }
     
    -    val metadata = NominalAttribute.defaultAttr
    -      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
    -    // If we are skipping invalid records, filter them out.
    -    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
    -      case StringIndexer.SKIP_INVALID =>
    -        val filterer = udf { label: String =>
    -          labelToIndex.contains(label)
    -        }
    -        
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), 
false)
    -      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
    -    }
    +  private def getIndexer(labels: Seq[String], labelToIndex: 
OpenHashMap[String, Double]) = {
    +    val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID)
     
    -    val indexer = udf { label: String =>
    +    udf { label: String =>
    --- End diff --
    
    > This requires calling many udf for different input columns. Should we 
combine then in one udf? 
    
    Then we must define the single UDF with different parameter number (looks 
like the big pattern matching in `ScalaUDF`). We also don't support UDFs with 
more than 22 parameters.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to