Github user WeichenXu123 commented on a diff in the pull request: https://github.com/apache/spark/pull/20829#discussion_r174993897 --- Diff: mllib/src/main/scala/org/apache/spark/ml/feature/VectorAssembler.scala --- @@ -85,18 +120,34 @@ class VectorAssembler @Since("1.4.0") (@Since("1.4.0") override val uid: String) } else { // Otherwise, treat all attributes as numeric. If we cannot get the number of attributes // from metadata, check the first row. - val numAttrs = group.numAttributes.getOrElse(first.getAs[Vector](index).size) - Array.tabulate(numAttrs)(i => NumericAttribute.defaultAttr.withName(c + "_" + i)) + (0 until length).map { i => NumericAttribute.defaultAttr.withName(c + "_" + i) } + } + case DoubleType => + val attribute = Attribute.fromStructField(field) + attribute match { + case UnresolvedAttribute => + Seq(NumericAttribute.defaultAttr.withName(c)) + case _ => + Seq(attribute.withName(c)) } + case _ : NumericType | BooleanType => + // If the input column type is a compatible scalar type, assume numeric. + Seq(NumericAttribute.defaultAttr.withName(c)) case otherType => throw new SparkException(s"VectorAssembler does not support the $otherType type") } } - val metadata = new AttributeGroup($(outputCol), attrs).toMetadata() - + val featureAttributes = featureAttributesMap.flatten[Attribute] + val lengths = featureAttributesMap.map(a => a.length) + val metadata = new AttributeGroup($(outputCol), featureAttributes.toArray).toMetadata() + val (filteredDataset, keepInvalid) = $(handleInvalid) match { + case StringIndexer.SKIP_INVALID => (dataset.na.drop("any", $(inputCols)), false) --- End diff -- you can directly use `dataset.na.drop($(inputCols))`
--- --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org