Github user jkbradley commented on a diff in the pull request:
    --- Diff: 
mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala ---
    @@ -49,32 +55,64 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
       def setOutputCol(value: String): this.type = set(outputCol, value)
    +  /** @group setParam */
    +  @Since("2.4.0")
    +  def setHandleInvalid(value: String): this.type = set(handleInvalid, 
    +  /**
    +   * Param for how to handle invalid data (NULL values). Options are 
'skip' (filter out rows with
    +   * invalid data), 'error' (throw an error), or 'keep' (return relevant 
number of NaN in the
    +   * output). Column lengths are taken from the size of ML Attribute 
Group, which can be set using
    +   * `VectorSizeHint` in a pipeline before `VectorAssembler`. Column 
lengths can also be inferred
    +   * from first rows of the data since it is safe to do so but only in 
case of 'error' or 'skip'.
    +   * Default: "error"
    +   * @group param
    +   */
    +  @Since("2.4.0")
    +  override val handleInvalid: Param[String] = new Param[String](this, 
    +    """
    +    | Param for how to handle invalid data (NULL values). Options are 
'skip' (filter out rows with
    +    | invalid data), 'error' (throw an error), or 'keep' (return relevant 
number of NaN in the
    +    | output). Column lengths are taken from the size of ML Attribute 
Group, which can be set using
    +    | `VectorSizeHint` in a pipeline before `VectorAssembler`. Column 
lengths can also be inferred
    +    | from first rows of the data since it is safe to do so but only in 
case of 'error' or 'skip'.
    +    | """.stripMargin.replaceAll("\n", " "),
    +    ParamValidators.inArray(VectorAssembler.supportedHandleInvalids))
    +  setDefault(handleInvalid, VectorAssembler.ERROR_INVALID)
       override def transform(dataset: Dataset[_]): DataFrame = {
         transformSchema(dataset.schema, logging = true)
         // Schema transformation.
         val schema = dataset.schema
    -    lazy val first = dataset.toDF.first()
    -    val attrs = $(inputCols).flatMap { c =>
    +    val vectorCols = $(inputCols).toSeq.filter { c =>
    +      schema(c).dataType match {
    +        case _: VectorUDT => true
    +        case _ => false
    +      }
    +    }
    +    val vectorColsLengths = VectorAssembler.getLengths(dataset, 
vectorCols, $(handleInvalid))
    +    val featureAttributesMap = $(inputCols) { c =>
           val field = schema(c)
    -      val index = schema.fieldIndex(c)
           field.dataType match {
             case DoubleType =>
    -          val attr = Attribute.fromStructField(field)
    -          // If the input column doesn't have ML attribute, assume numeric.
    -          if (attr == UnresolvedAttribute) {
    -            Some(NumericAttribute.defaultAttr.withName(c))
    -          } else {
    -            Some(attr.withName(c))
    +          val attribute = Attribute.fromStructField(field)
    +          attribute match {
    +            case UnresolvedAttribute =>
    +              Seq(NumericAttribute.defaultAttr.withName(c))
    +            case _ =>
    +              Seq(attribute.withName(c))
             case _: NumericType | BooleanType =>
               // If the input column type is a compatible scalar type, assume 
    -          Some(NumericAttribute.defaultAttr.withName(c))
    +          Seq(NumericAttribute.defaultAttr.withName(c))
             case _: VectorUDT =>
    -          val group = AttributeGroup.fromStructField(field)
    -          if (group.attributes.isDefined) {
    -            // If attributes are defined, copy them with updated names.
    -   { case (attr, i) =>
    +          val attributeGroup = AttributeGroup.fromStructField(field)
    --- End diff --
    for the future, I'd avoid renaming things like this unless it's really 
unclear or needed (to make diffs shorter)


To unsubscribe, e-mail:
For additional commands, e-mail:

Reply via email to