Repository: spark
Updated Branches:
  refs/heads/master a7317ccdc -> 34d610be8


[SPARK-9929] [SQL] support metadata in withColumn

in MLlib sometimes we need to set metadata for the new column, thus we will 
alias the new column with metadata before call `withColumn` and in `withColumn` 
we alias this clolumn again. Here I overloaded `withColumn` to allow user set 
metadata, just like what we did  for `Column.as`.

Author: Wenchen Fan <cloud0...@outlook.com>

Closes #8159 from cloud-fan/withColumn.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/34d610be
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/34d610be
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/34d610be

Branch: refs/heads/master
Commit: 34d610be854d2a975d9c1e232d87433b85add6fd
Parents: a7317cc
Author: Wenchen Fan <cloud0...@outlook.com>
Authored: Fri Aug 14 12:00:01 2015 -0700
Committer: Reynold Xin <r...@databricks.com>
Committed: Fri Aug 14 12:00:01 2015 -0700

----------------------------------------------------------------------
 .../apache/spark/ml/classification/OneVsRest.scala |  6 +++---
 .../org/apache/spark/ml/feature/Bucketizer.scala   |  2 +-
 .../apache/spark/ml/feature/VectorIndexer.scala    |  2 +-
 .../org/apache/spark/ml/feature/VectorSlicer.scala |  3 +--
 .../scala/org/apache/spark/sql/DataFrame.scala     | 17 +++++++++++++++++
 5 files changed, 23 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala 
b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
index 1132d80..c62e132 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/OneVsRest.scala
@@ -131,7 +131,7 @@ final class OneVsRestModel private[ml] (
 
     // output label and label metadata as prediction
     aggregatedDataset
-      .withColumn($(predictionCol), 
labelUDF(col(accColName)).as($(predictionCol), labelMetadata))
+      .withColumn($(predictionCol), labelUDF(col(accColName)), labelMetadata)
       .drop(accColName)
   }
 
@@ -203,8 +203,8 @@ final class OneVsRest(override val uid: String)
       // TODO: use when ... otherwise after SPARK-7321 is merged
       val newLabelMeta = 
BinaryAttribute.defaultAttr.withName("label").toMetadata()
       val labelColName = "mc2b$" + index
-      val labelUDFWithNewMeta = labelUDF(col($(labelCol))).as(labelColName, 
newLabelMeta)
-      val trainingDataset = multiclassLabeled.withColumn(labelColName, 
labelUDFWithNewMeta)
+      val trainingDataset =
+        multiclassLabeled.withColumn(labelColName, labelUDF(col($(labelCol))), 
newLabelMeta)
       val classifier = getClassifier
       val paramMap = new ParamMap()
       paramMap.put(classifier.labelCol -> labelColName)

http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
index cfca494..6fdf25b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Bucketizer.scala
@@ -75,7 +75,7 @@ final class Bucketizer(override val uid: String)
     }
     val newCol = bucketizer(dataset($(inputCol)))
     val newField = prepOutputField(dataset.schema)
-    dataset.withColumn($(outputCol), newCol.as($(outputCol), 
newField.metadata))
+    dataset.withColumn($(outputCol), newCol, newField.metadata)
   }
 
   private def prepOutputField(schema: StructType): StructField = {

http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
index 6875aef..61b925c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorIndexer.scala
@@ -341,7 +341,7 @@ class VectorIndexerModel private[ml] (
     val newField = prepOutputField(dataset.schema)
     val transformUDF = udf { (vector: Vector) => transformFunc(vector) }
     val newCol = transformUDF(dataset($(inputCol)))
-    dataset.withColumn($(outputCol), newCol.as($(outputCol), 
newField.metadata))
+    dataset.withColumn($(outputCol), newCol, newField.metadata)
   }
 
   override def transformSchema(schema: StructType): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
index 772bebe..c5c2272 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/VectorSlicer.scala
@@ -119,8 +119,7 @@ final class VectorSlicer(override val uid: String)
         case features: SparseVector => features.slice(inds)
       }
     }
-    dataset.withColumn($(outputCol),
-      slicer(dataset($(inputCol))).as($(outputCol), outputAttr.toMetadata()))
+    dataset.withColumn($(outputCol), slicer(dataset($(inputCol))), 
outputAttr.toMetadata())
   }
 
   /** Get the feature indices in order: indices, names */

http://git-wip-us.apache.org/repos/asf/spark/blob/34d610be/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index c466d9e..cf75e64 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1150,6 +1150,23 @@ class DataFrame private[sql](
   }
 
   /**
+   * Returns a new [[DataFrame]] by adding a column with metadata.
+   */
+  private[spark] def withColumn(colName: String, col: Column, metadata: 
Metadata): DataFrame = {
+    val resolver = sqlContext.analyzer.resolver
+    val replaced = schema.exists(f => resolver(f.name, colName))
+    if (replaced) {
+      val colNames = schema.map { field =>
+        val name = field.name
+        if (resolver(name, colName)) col.as(colName, metadata) else 
Column(name)
+      }
+      select(colNames : _*)
+    } else {
+      select(Column("*"), col.as(colName, metadata))
+    }
+  }
+
+  /**
    * Returns a new [[DataFrame]] with a column renamed.
    * This is a no-op if schema doesn't contain existingName.
    * @group dfops


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to