Github user yogeshg commented on a diff in the pull request: https://github.com/apache/spark/pull/20829#discussion_r175153265 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala --- @@ -49,32 +51,65 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) @Since("1.4.0") def setOutputCol(value: String): this.type = set(outputCol, value) + /** @group setParam */ + @Since("1.6.0") + def setHandleInvalid(value: String): this.type = set(handleInvalid, value) + + /** + * Param for how to handle invalid data (NULL values). Options are 'skip' (filter out rows with + * invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN in the + * output). + * Default: "error" + * @group param + */ + @Since("1.6.0") + override val handleInvalid: Param[String] = new Param[String](this, "handleInvalid", + "Hhow to handle invalid data (NULL values). Options are 'skip' (filter out rows with " + + "invalid data), 'error' (throw an error), or 'keep' (return relevant number of NaN " + + "in the * output).", ParamValidators.inArray(VectorAssembler.supportedHandleInvalids)) + + setDefault(handleInvalid, VectorAssembler.ERROR_INVALID) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) // Schema transformation. val schema = dataset.schema - lazy val first = dataset.toDF.first() - val attrs = $(inputCols).flatMap { c => + + val featureAttributesMap: Seq[Seq[Attribute]] = $(inputCols).toSeq.map { c => val field = schema(c) - val index = schema.fieldIndex(c) field.dataType match { - case DoubleType => - val attr = Attribute.fromStructField(field) - // If the input column doesn't have ML attribute, assume numeric. - if (attr == UnresolvedAttribute) { - Some(NumericAttribute.defaultAttr.withName(c)) - } else { - Some(attr.withName(c)) - } - case _: NumericType | BooleanType => - // If the input column type is a compatible scalar type, assume numeric. - Some(NumericAttribute.defaultAttr.withName(c)) case _: VectorUDT => - val group = AttributeGroup.fromStructField(field) - if (group.attributes.isDefined) { - // If attributes are defined, copy them with updated names. - group.attributes.get.zipWithIndex.map { case (attr, i) => + val attributeGroup = AttributeGroup.fromStructField(field) + var length = attributeGroup.size + val isMissingNumAttrs = -1 == length + if (isMissingNumAttrs && dataset.isStreaming) { + // this condition is checked for every column, but should be cheap + throw new RuntimeException( + s""" + |VectorAssembler cannot dynamically determine the size of vectors for streaming + |data. Consider applying VectorSizeHint to ${c} so that this transformer can be + |used to transform streaming inputs. + """.stripMargin.replaceAll("\n", " ")) + } + if (isMissingNumAttrs) { + val column = dataset.select(c).na.drop() --- End diff -- Good catch! That name was bothering me too :P @MrBago and I are thinking of another way to do this more efficiently.
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org