Repository: spark Updated Branches: refs/heads/master a7317ccdc -> 34d610be8
[SPARK-9929] [SQL] support metadata in withColumn in MLlib sometimes we need to set metadata for the new column, thus we will alias the new column with metadata before call `withColumn` and in `withColumn` we alias this clolumn again. Here I overloaded `withColumn` to allow user set metadata, just like what we did for `Column.as`. Author: Wenchen Fan <cloud0...@outlook.com> Closes #8159 from cloud-fan/withColumn. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/34d610be Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/34d610be Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/34d610be Branch: refs/heads/master Commit: 34d610be854d2a975d9c1e232d87433b85add6fd Parents: a7317cc Author: Wenchen Fan <cloud0...@outlook.com> Authored: Fri Aug 14 12:00:01 2015 -0700 Committer: Reynold Xin <r...@databricks.com> Committed: Fri Aug 14 12:00:01 2015 -0700 ---------------------------------------------------------------------- .../apache/spark/ml/classification/OneVsRest.scala | 6 +++--- .../org/apache/spark/ml/feature/Bucketizer.scala | 2 +- .../apache/spark/ml/feature/VectorIndexer.scala | 2 +- .../org/apache/spark/ml/feature/VectorSlicer.scala | 3 +-- .../scala/org/apache/spark/sql/DataFrame.scala | 17 +++++++++++++++++ 5 files changed, 23 insertions(+), 7 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala index 1132d80..c62e132 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala @@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] ( // output label and label metadata as prediction aggregatedDataset - .withColumn($(predictionCol), labelUDF(col(accColName)).as($(predictionCol), labelMetadata)) + .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata) .drop(accColName) } @@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String) // TODO: use when ... otherwise after SPARK-7321 is merged val newLabelMeta = BinaryAttribute.defaultAttr.withName("label").toMetadata() val labelColName = "mc2b$" + index - val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, newLabelMeta) - val trainingDataset = multiclassLabeled.withColumn(labelColName, labelUDFWithNewMeta) + val trainingDataset = + multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), newLabelMeta) val classifier = getClassifier val paramMap = new ParamMap() paramMap.put(classifier.labelCol -> labelColName) http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala index cfca494..6fdf25b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala @@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String) } val newCol = bucketizer(dataset($(inputCol))) val newField = prepOutputField(dataset.schema) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } private def prepOutputField(schema: StructType): StructField = { http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala index 6875aef..61b925c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala @@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] ( val newField = prepOutputField(dataset.schema) val transformUDF = udf { (vector: Vector) => transformFunc(vector) } val newCol = transformUDF(dataset($(inputCol))) - dataset.withColumn($(outputCol), newCol.as($(outputCol), newField.metadata)) + dataset.withColumn($(outputCol), newCol, newField.metadata) } override def transformSchema(schema: StructType): StructType = { http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala index 772bebe..c5c2272 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala @@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String) case features: SparseVector => features.slice(inds) } } - dataset.withColumn($(outputCol), - slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata())) + dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), outputAttr.toMetadata()) } /** Get the feature indices in order: indices, names */ http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index c466d9e..cf75e64 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -1150,6 +1150,23 @@ class DataFrame private[sql]( } /** + * Returns a new [[DataFrame]] by adding a column with metadata. + */ + private[spark] def withColumn(colName: String, col: Column, metadata: Metadata): DataFrame = { + val resolver = sqlContext.analyzer.resolver + val replaced = schema.exists(f => resolver(f.name, colName)) + if (replaced) { + val colNames = schema.map { field => + val name = field.name + if (resolver(name, colName)) col.as(colName, metadata) else Column(name) + } + select(colNames : _*) + } else { + select(Column("*"), col.as(colName, metadata)) + } + } + + /** * Returns a new [[DataFrame]] with a column renamed. * This is a no-op if schema doesn't contain existingName. * @group dfops --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org