Github user dbtsai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r183258353
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -217,33 +295,32 @@ class StringIndexerModel (
       @Since("1.4.0")
       def setOutputCol(value: String): this.type = set(outputCol, value)
     
    -  @Since("2.0.0")
    -  override def transform(dataset: Dataset[_]): DataFrame = {
    -    if (!dataset.schema.fieldNames.contains($(inputCol))) {
    -      logInfo(s"Input column ${$(inputCol)} does not exist during 
transformation. " +
    -        "Skip StringIndexerModel.")
    -      return dataset.toDF
    -    }
    -    transformSchema(dataset.schema, logging = true)
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setInputCols(value: Array[String]): this.type = set(inputCols, value)
     
    -    val filteredLabels = getHandleInvalid match {
    -      case StringIndexer.KEEP_INVALID => labels :+ "__unknown"
    -      case _ => labels
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setOutputCols(value: Array[String]): this.type = set(outputCols, 
value)
    +
    +  private def filterInvalidData(dataset: Dataset[_], inputColNames: 
Seq[String]): Dataset[_] = {
    +    var filteredDataset = dataset.na.drop(inputColNames.filter(
    +      dataset.schema.fieldNames.contains(_)))
    +    for (i <- 0 until inputColNames.length) {
    +      val inputColName = inputColNames(i)
    +      val labelToIndex = labelsToIndexArray(i)
    +      val filterer = udf { label: String =>
    +        labelToIndex.contains(label)
    +      }
    +      filteredDataset = 
filteredDataset.where(filterer(dataset(inputColName)))
         }
    +    filteredDataset
    +  }
     
    -    val metadata = NominalAttribute.defaultAttr
    -      .withName($(outputCol)).withValues(filteredLabels).toMetadata()
    -    // If we are skipping invalid records, filter them out.
    -    val (filteredDataset, keepInvalid) = $(handleInvalid) match {
    -      case StringIndexer.SKIP_INVALID =>
    -        val filterer = udf { label: String =>
    -          labelToIndex.contains(label)
    -        }
    -        
(dataset.na.drop(Array($(inputCol))).where(filterer(dataset($(inputCol)))), 
false)
    -      case _ => (dataset, getHandleInvalid == StringIndexer.KEEP_INVALID)
    -    }
    +  private def getIndexer(labels: Seq[String], labelToIndex: 
OpenHashMap[String, Double]) = {
    +    val keepInvalid = (getHandleInvalid == StringIndexer.KEEP_INVALID)
     
    -    val indexer = udf { label: String =>
    +    udf { label: String =>
    --- End diff --
    
    This requires calling many udf for different input columns. Should we 
combine then in one udf? The `filteredDataset` logic can be in as well to avoid 
multiple lookups. 


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to