Github user dbtsai commented on a diff in the pull request:

    https://github.com/apache/spark/pull/20146#discussion_r183253932
  
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/StringIndexer.scala ---
    @@ -79,26 +80,56 @@ private[feature] trait StringIndexerBase extends Params 
with HasHandleInvalid wi
       @Since("2.3.0")
       def getStringOrderType: String = $(stringOrderType)
     
    -  /** Validates and transforms the input schema. */
    -  protected def validateAndTransformSchema(schema: StructType): StructType 
= {
    -    val inputColName = $(inputCol)
    +  /** Returns the input and output column names corresponding in pair. */
    +  private[feature] def getInOutCols(): (Array[String], Array[String]) = {
    +    ParamValidators.checkSingleVsMultiColumnParams(this, Seq(outputCol), 
Seq(outputCols))
    +
    +    if (isSet(inputCol)) {
    +      (Array($(inputCol)), Array($(outputCol)))
    +    } else {
    +      require($(inputCols).length == $(outputCols).length,
    +        "The number of input columns does not match output columns")
    +      ($(inputCols), $(outputCols))
    +    }
    +  }
    +
    +  private def validateAndTransformField(
    +      schema: StructType,
    +      inputColName: String,
    +      outputColName: String): StructField = {
         val inputDataType = schema(inputColName).dataType
         require(inputDataType == StringType || 
inputDataType.isInstanceOf[NumericType],
           s"The input column $inputColName must be either string type or 
numeric type, " +
             s"but got $inputDataType.")
    -    val inputFields = schema.fields
    -    val outputColName = $(outputCol)
    -    require(inputFields.forall(_.name != outputColName),
    +    require(schema.fields.forall(_.name != outputColName),
           s"Output column $outputColName already exists.")
    -    val attr = NominalAttribute.defaultAttr.withName($(outputCol))
    -    val outputFields = inputFields :+ attr.toStructField()
    -    StructType(outputFields)
    +    NominalAttribute.defaultAttr.withName($(outputCol)).toStructField()
    +  }
    +
    +  /** Validates and transforms the input schema. */
    +  protected def validateAndTransformSchema(
    +      schema: StructType,
    +      skipNonExistsCol: Boolean = false): StructType = {
    +    val (inputColNames, outputColNames) = getInOutCols()
    +
    +    val outputFields = for (i <- 0 until inputColNames.length) yield {
    +      if (schema.fieldNames.contains(inputColNames(i))) {
    +        validateAndTransformField(schema, inputColNames(i), 
outputColNames(i))
    +      } else {
    +        if (skipNonExistsCol) {
    +          null
    +        } else {
    +          throw new SparkException(s"Input column ${inputColNames(i)} does 
not exist.")
    +        }
    +      }
    +    }
    +    StructType(schema.fields ++ outputFields.filter(_ != null))
    --- End diff --
    
    Then you don't need to filter with the above code.


---

---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to