Repository: spark
Updated Branches:
  refs/heads/master a63be1a18 -> 5ad78f620


[SQL] Various DataFrame DSL update.

1. Added foreach, foreachPartition, flatMap to DataFrame.
2. Added col() in dsl.
3. Support renaming columns in toDataFrame.
4. Support type inference on arrays (in addition to Seq).
5. Updated mllib to use the new DSL.

Author: Reynold Xin <[email protected]>

Closes #4260 from rxin/sql-dsl-update and squashes the following commits:

73466c1 [Reynold Xin] Fixed LogisticRegression. Also added better error message 
for resolve.
fab3ccc [Reynold Xin] Bug fix.
d31fcd2 [Reynold Xin] Style fix.
62608c4 [Reynold Xin] [SQL] Various DataFrame DSL update.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/5ad78f62
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/5ad78f62
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/5ad78f62

Branch: refs/heads/master
Commit: 5ad78f62056f2560cd371ee964111a646806d0ff
Parents: a63be1a
Author: Reynold Xin <[email protected]>
Authored: Thu Jan 29 00:01:10 2015 -0800
Committer: Reynold Xin <[email protected]>
Committed: Thu Jan 29 00:01:10 2015 -0800

----------------------------------------------------------------------
 .../scala/org/apache/spark/ml/Transformer.scala |  3 +-
 .../ml/classification/LogisticRegression.scala  | 12 +++--
 .../spark/ml/feature/StandardScaler.scala       |  3 +-
 .../apache/spark/ml/recommendation/ALS.scala    | 35 +++++----------
 .../org/apache/spark/mllib/linalg/Vectors.scala |  3 +-
 .../spark/sql/catalyst/ScalaReflection.scala    |  5 ++-
 .../sql/catalyst/ScalaReflectionSuite.scala     |  5 +++
 .../scala/org/apache/spark/sql/Column.scala     | 12 +++--
 .../scala/org/apache/spark/sql/DataFrame.scala  | 47 ++++++++++++++++++--
 .../main/scala/org/apache/spark/sql/api.scala   |  6 +++
 .../org/apache/spark/sql/api/java/dsl.java      |  7 +++
 .../spark/sql/api/scala/dsl/package.scala       | 21 +++++++++
 12 files changed, 114 insertions(+), 45 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 29cd981..6eb7ea6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -23,7 +23,6 @@ import org.apache.spark.Logging
 import org.apache.spark.annotation.AlphaComponent
 import org.apache.spark.ml.param._
 import org.apache.spark.sql.DataFrame
-import org.apache.spark.sql._
 import org.apache.spark.sql.api.scala.dsl._
 import org.apache.spark.sql.types._
 
@@ -99,6 +98,6 @@ private[ml] abstract class UnaryTransformer[IN, OUT, T <: 
UnaryTransformer[IN, O
     transformSchema(dataset.schema, paramMap, logging = true)
     val map = this.paramMap ++ paramMap
     dataset.select($"*", callUDF(
-      this.createTransformFunc(map), outputDataType, 
Column(map(inputCol))).as(map(outputCol)))
+      this.createTransformFunc(map), outputDataType, 
dataset(map(inputCol))).as(map(outputCol)))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 101f6c8..d82360d 100644
--- 
a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ 
b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
 import org.apache.spark.mllib.regression.LabeledPoint
 import org.apache.spark.sql._
 import org.apache.spark.sql.api.scala.dsl._
-import org.apache.spark.sql.catalyst.dsl._
 import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
 import org.apache.spark.storage.StorageLevel
 
@@ -133,15 +132,14 @@ class LogisticRegressionModel private[ml] (
   override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
     transformSchema(dataset.schema, paramMap, logging = true)
     val map = this.paramMap ++ paramMap
-    val score: Vector => Double = (v) => {
+    val scoreFunction: Vector => Double = (v) => {
       val margin = BLAS.dot(v, weights)
       1.0 / (1.0 + math.exp(-margin))
     }
     val t = map(threshold)
-    val predict: Double => Double = (score) => {
-      if (score > t) 1.0 else 0.0
-    }
-    dataset.select($"*", callUDF(score, 
Column(map(featuresCol))).as(map(scoreCol)))
-      .select($"*", callUDF(predict, 
Column(map(scoreCol))).as(map(predictionCol)))
+    val predictFunction: Double => Double = (score) => { if (score > t) 1.0 
else 0.0 }
+    dataset
+      .select($"*", callUDF(scoreFunction, 
col(map(featuresCol))).as(map(scoreCol)))
+      .select($"*", callUDF(predictFunction, 
col(map(scoreCol))).as(map(predictionCol)))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
----------------------------------------------------------------------
diff --git 
a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
index c456beb..78a4856 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/StandardScaler.scala
@@ -24,7 +24,6 @@ import org.apache.spark.mllib.feature
 import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
 import org.apache.spark.sql._
 import org.apache.spark.sql.api.scala.dsl._
-import org.apache.spark.sql.catalyst.dsl._
 import org.apache.spark.sql.types.{StructField, StructType}
 
 /**
@@ -85,7 +84,7 @@ class StandardScalerModel private[ml] (
     val scale: (Vector) => Vector = (v) => {
       scaler.transform(v)
     }
-    dataset.select($"*", callUDF(scale, 
Column(map(inputCol))).as(map(outputCol)))
+    dataset.select($"*", callUDF(scale, col(map(inputCol))).as(map(outputCol)))
   }
 
   private[ml] override def transformSchema(schema: StructType, paramMap: 
ParamMap): StructType = {

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala 
b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 738b184..474d473 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -111,20 +111,10 @@ class ALSModel private[ml] (
   def setPredictionCol(value: String): this.type = set(predictionCol, value)
 
   override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
-    import dataset.sqlContext._
-    import org.apache.spark.ml.recommendation.ALSModel.Factor
+    import dataset.sqlContext.createDataFrame
     val map = this.paramMap ++ paramMap
-    // TODO: Add DSL to simplify the code here.
-    val instanceTable = s"instance_$uid"
-    val userTable = s"user_$uid"
-    val itemTable = s"item_$uid"
-    val instances = dataset.as(instanceTable)
-    val users = userFactors.map { case (id, features) =>
-      Factor(id, features)
-    }.as(userTable)
-    val items = itemFactors.map { case (id, features) =>
-      Factor(id, features)
-    }.as(itemTable)
+    val users = userFactors.toDataFrame("id", "features")
+    val items = itemFactors.toDataFrame("id", "features")
     val predict: (Seq[Float], Seq[Float]) => Float = (userFeatures, 
itemFeatures) => {
       if (userFeatures != null && itemFeatures != null) {
         blas.sdot(k, userFeatures.toArray, 1, itemFeatures.toArray, 1)
@@ -133,13 +123,14 @@ class ALSModel private[ml] (
       }
     }
     val inputColumns = dataset.schema.fieldNames
-    val prediction = callUDF(predict, $"$userTable.features", 
$"$itemTable.features")
-        .as(map(predictionCol))
-    val outputColumns = inputColumns.map(f => $"$instanceTable.$f".as(f)) :+ 
prediction
-    instances
-      .join(users, Column(map(userCol)) === $"$userTable.id", "left")
-      .join(items, Column(map(itemCol)) === $"$itemTable.id", "left")
+    val prediction = callUDF(predict, users("features"), 
items("features")).as(map(predictionCol))
+    val outputColumns = inputColumns.map(f => dataset(f)) :+ prediction
+    dataset
+      .join(users, dataset(map(userCol)) === users("id"), "left")
+      .join(items, dataset(map(itemCol)) === items("id"), "left")
       .select(outputColumns: _*)
+      // TODO: Just use a dataset("*")
+      // .select(dataset("*"), prediction)
   }
 
   override private[ml] def transformSchema(schema: StructType, paramMap: 
ParamMap): StructType = {
@@ -147,10 +138,6 @@ class ALSModel private[ml] (
   }
 }
 
-private object ALSModel {
-  /** Case class to convert factors to [[DataFrame]]s */
-  private case class Factor(id: Int, features: Seq[Float])
-}
 
 /**
  * Alternating Least Squares (ALS) matrix factorization.
@@ -210,7 +197,7 @@ class ALS extends Estimator[ALSModel] with ALSParams {
   override def fit(dataset: DataFrame, paramMap: ParamMap): ALSModel = {
     val map = this.paramMap ++ paramMap
     val ratings = dataset
-      .select(Column(map(userCol)), Column(map(itemCol)), 
Column(map(ratingCol)).cast(FloatType))
+      .select(col(map(userCol)), col(map(itemCol)), 
col(map(ratingCol)).cast(FloatType))
       .map { row =>
         new Rating(row.getInt(0), row.getInt(1), row.getFloat(2))
       }

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala 
b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 31c33f1..567a8a6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -27,7 +27,8 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => 
BSV, Vector => BV}
 
 import org.apache.spark.SparkException
 import org.apache.spark.mllib.util.NumericParser
-import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, Row}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
 import org.apache.spark.sql.types._
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
index 191d16f..4def65b 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/ScalaReflection.scala
@@ -57,6 +57,7 @@ trait ScalaReflection {
     case (obj, udt: UserDefinedType[_]) => udt.serialize(obj)
     case (o: Option[_], _) => o.map(convertToCatalyst(_, dataType)).orNull
     case (s: Seq[_], arrayType: ArrayType) => s.map(convertToCatalyst(_, 
arrayType.elementType))
+    case (s: Array[_], arrayType: ArrayType) => s.toSeq
     case (m: Map[_, _], mapType: MapType) => m.map { case (k, v) =>
       convertToCatalyst(k, mapType.keyType) -> convertToCatalyst(v, 
mapType.valueType)
     }
@@ -140,7 +141,9 @@ trait ScalaReflection {
       // Need to decide if we actually need a special type here.
       case t if t <:< typeOf[Array[Byte]] => Schema(BinaryType, nullable = 
true)
       case t if t <:< typeOf[Array[_]] =>
-        sys.error(s"Only Array[Byte] supported now, use Seq instead of $t")
+        val TypeRef(_, _, Seq(elementType)) = t
+        val Schema(dataType, nullable) = schemaFor(elementType)
+        Schema(ArrayType(dataType, containsNull = nullable), nullable = true)
       case t if t <:< typeOf[Seq[_]] =>
         val TypeRef(_, _, Seq(elementType)) = t
         val Schema(dataType, nullable) = schemaFor(elementType)

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
index 5138942..4a66716 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/ScalaReflectionSuite.scala
@@ -60,6 +60,7 @@ case class OptionalData(
 
 case class ComplexData(
     arrayField: Seq[Int],
+    arrayField1: Array[Int],
     arrayFieldContainsNull: Seq[java.lang.Integer],
     mapField: Map[Int, Long],
     mapFieldValueContainsNull: Map[Int, java.lang.Long],
@@ -132,6 +133,10 @@ class ScalaReflectionSuite extends FunSuite {
           ArrayType(IntegerType, containsNull = false),
           nullable = true),
         StructField(
+          "arrayField1",
+          ArrayType(IntegerType, containsNull = false),
+          nullable = true),
+        StructField(
           "arrayFieldContainsNull",
           ArrayType(IntegerType, containsNull = true),
           nullable = true),

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index 7f9a91a..9be2a03 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -22,15 +22,19 @@ import scala.language.implicitConversions
 import org.apache.spark.sql.api.scala.dsl.lit
 import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, Star}
 import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.{Literal => LiteralExpr}
 import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
 import org.apache.spark.sql.types._
 
 
 object Column {
-  def unapply(col: Column): Option[Expression] = Some(col.expr)
-
+  /**
+   * Creates a [[Column]] based on the given column name.
+   * Same as [[api.scala.dsl.col]] and [[api.java.dsl.col]].
+   */
   def apply(colName: String): Column = new Column(colName)
+
+  /** For internal pattern matching. */
+  private[sql] def unapply(col: Column): Option[Expression] = Some(col.expr)
 }
 
 
@@ -438,7 +442,7 @@ class Column(
    * @param ordinal
    * @return
    */
-  override def getItem(ordinal: Int): Column = GetItem(expr, 
LiteralExpr(ordinal))
+  override def getItem(ordinal: Int): Column = GetItem(expr, Literal(ordinal))
 
   /**
    * An expression that gets a field by name in a [[StructField]].

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index ceb5f86..050366a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -118,8 +118,8 @@ class DataFrame protected[sql](
 
   /** Resolves a column name into a Catalyst [[NamedExpression]]. */
   protected[sql] def resolve(colName: String): NamedExpression = {
-    logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(
-      throw new RuntimeException(s"""Cannot resolve column name "$colName""""))
+    logicalPlan.resolve(colName, sqlContext.analyzer.resolver).getOrElse(throw 
new RuntimeException(
+      s"""Cannot resolve column name "$colName" among 
(${schema.fieldNames.mkString(", ")})"""))
   }
 
   /** Left here for compatibility reasons. */
@@ -131,6 +131,29 @@ class DataFrame protected[sql](
    */
   def toDataFrame: DataFrame = this
 
+  /**
+   * Returns a new [[DataFrame]] with columns renamed. This can be quite 
convenient in conversion
+   * from a RDD of tuples into a [[DataFrame]] with meaningful names. For 
example:
+   * {{{
+   *   val rdd: RDD[(Int, String)] = ...
+   *   rdd.toDataFrame  // this implicit conversion creates a DataFrame with 
column name _1 and _2
+   *   rdd.toDataFrame("id", "name")  // this creates a DataFrame with column 
name "id" and "name"
+   * }}}
+   */
+  @scala.annotation.varargs
+  def toDataFrame(colName: String, colNames: String*): DataFrame = {
+    val newNames = colName +: colNames
+    require(schema.size == newNames.size,
+      "The number of columns doesn't match.\n" +
+      "Old column names: " + schema.fields.map(_.name).mkString(", ") + "\n" +
+      "New column names: " + newNames.mkString(", "))
+
+    val newCols = schema.fieldNames.zip(newNames).map { case (oldName, 
newName) =>
+      apply(oldName).as(newName)
+    }
+    select(newCols :_*)
+  }
+
   /** Returns the schema of this [[DataFrame]]. */
   override def schema: StructType = queryExecution.analyzed.schema
 
@@ -227,7 +250,7 @@ class DataFrame protected[sql](
   }
 
   /**
-   * Selects a single column and return it as a [[Column]].
+   * Selects column based on the column name and return it as a [[Column]].
    */
   override def apply(colName: String): Column = colName match {
     case "*" =>
@@ -467,6 +490,12 @@ class DataFrame protected[sql](
   }
 
   /**
+   * Returns a new RDD by first applying a function to all rows of this 
[[DataFrame]],
+   * and then flattening the results.
+   */
+  override def flatMap[R: ClassTag](f: Row => TraversableOnce[R]): RDD[R] = 
rdd.flatMap(f)
+
+  /**
    * Returns a new RDD by applying a function to each partition of this 
DataFrame.
    */
   override def mapPartitions[R: ClassTag](f: Iterator[Row] => Iterator[R]): 
RDD[R] = {
@@ -474,6 +503,16 @@ class DataFrame protected[sql](
   }
 
   /**
+   * Applies a function `f` to all rows.
+   */
+  override def foreach(f: Row => Unit): Unit = rdd.foreach(f)
+
+  /**
+   * Applies a function f to each partition of this [[DataFrame]].
+   */
+  override def foreachPartition(f: Iterator[Row] => Unit): Unit = 
rdd.foreachPartition(f)
+
+  /**
    * Returns the first `n` rows in the [[DataFrame]].
    */
   override def take(n: Int): Array[Row] = head(n)
@@ -520,7 +559,7 @@ class DataFrame protected[sql](
   /////////////////////////////////////////////////////////////////////////////
 
   /**
-   * Return the content of the [[DataFrame]] as a [[RDD]] of [[Row]]s.
+   * Returns the content of the [[DataFrame]] as an [[RDD]] of [[Row]]s.
    */
   override def rdd: RDD[Row] = {
     val schema = this.schema

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/core/src/main/scala/org/apache/spark/sql/api.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
index 5eeaf17..5963408 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api.scala
@@ -44,8 +44,14 @@ private[sql] trait RDDApi[T] {
 
   def map[R: ClassTag](f: T => R): RDD[R]
 
+  def flatMap[R: ClassTag](f: T => TraversableOnce[R]): RDD[R]
+
   def mapPartitions[R: ClassTag](f: Iterator[T] => Iterator[R]): RDD[R]
 
+  def foreach(f: T => Unit): Unit
+
+  def foreachPartition(f: Iterator[T] => Unit): Unit
+
   def take(n: Int): Array[T]
 
   def collect(): Array[T]

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java 
b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
index 74d7649..16702af 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/dsl.java
@@ -33,6 +33,13 @@ public class dsl {
   private static package$ scalaDsl = package$.MODULE$;
 
   /**
+   * Returns a {@link Column} based on the given column name.
+   */
+  public static Column col(String colName) {
+    return new Column(colName);
+  }
+
+  /**
    * Creates a column of literal value.
    */
   public static Column lit(Object literalValue) {

http://git-wip-us.apache.org/repos/asf/spark/blob/5ad78f62/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
index 9f2d142..dc851fc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/scala/dsl/package.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.api.scala
 import scala.language.implicitConversions
 import scala.reflect.runtime.universe.{TypeTag, typeTag}
 
+import org.apache.spark.rdd.RDD
 import org.apache.spark.sql._
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.expressions._
@@ -37,6 +38,21 @@ package object dsl {
   /** An implicit conversion that turns a Scala `Symbol` into a [[Column]]. */
   implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
 
+//  /**
+//   * An implicit conversion that turns a RDD of product into a [[DataFrame]].
+//   *
+//   * This method requires an implicit SQLContext in scope. For example:
+//   * {{{
+//   *   implicit val sqlContext: SQLContext = ...
+//   *   val rdd: RDD[(Int, String)] = ...
+//   *   rdd.toDataFrame  // triggers the implicit here
+//   * }}}
+//   */
+//  implicit def rddToDataFrame[A <: Product: TypeTag](rdd: RDD[A])(implicit 
context: SQLContext)
+//    : DataFrame = {
+//    context.createDataFrame(rdd)
+//  }
+
   /** Converts $"col name" into an [[Column]]. */
   implicit class StringToColumn(val sc: StringContext) extends AnyVal {
     def $(args: Any*): ColumnName = {
@@ -47,6 +63,11 @@ package object dsl {
   private[this] implicit def toColumn(expr: Expression): Column = new 
Column(expr)
 
   /**
+   * Returns a [[Column]] based on the given column name.
+   */
+  def col(colName: String): Column = new Column(colName)
+
+  /**
    * Creates a [[Column]] of literal value.
    */
   def lit(literal: Any): Column = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to