Repository: spark
Updated Branches:
  refs/heads/master 648553d48 -> ca5d8b590


[SQL] Pass SQLContext instead of SparkContext into physical operators.

This makes it easier to use config options in operators.

Author: Reynold Xin <r...@apache.org>

Closes #1164 from rxin/sqlcontext and squashes the following commits:

797b2fd [Reynold Xin] Pass SQLContext instead of SparkContext into physical 
operators.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca5d8b59
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca5d8b59
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca5d8b59

Branch: refs/heads/master
Commit: ca5d8b5904dc6dd5b691af506d3a842e508b3673
Parents: 648553d
Author: Reynold Xin <r...@apache.org>
Authored: Fri Jun 20 22:49:48 2014 -0700
Committer: Reynold Xin <r...@apache.org>
Committed: Fri Jun 20 22:49:48 2014 -0700

----------------------------------------------------------------------
 .../scala/org/apache/spark/sql/SQLContext.scala |  4 +++-
 .../apache/spark/sql/execution/Aggregate.scala  |  5 +++--
 .../spark/sql/execution/SparkStrategies.scala   | 22 ++++++++++----------
 .../spark/sql/execution/basicOperators.scala    | 20 ++++++++++--------
 .../org/apache/spark/sql/execution/joins.scala  | 21 ++++++++++---------
 .../sql/parquet/ParquetTableOperations.scala    | 21 ++++++++++---------
 .../spark/sql/parquet/ParquetQuerySuite.scala   |  2 +-
 7 files changed, 51 insertions(+), 44 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index ab376e5..c60af28 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext)
   }
 
   protected[sql] class SparkPlanner extends SparkStrategies {
-    val sparkContext = self.sparkContext
+    val sparkContext: SparkContext = self.sparkContext
+
+    val sqlContext: SQLContext = self
 
     def numPartitions = self.numShufflePartitions
 

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
index 34d88fe..d85d2d7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala
@@ -24,6 +24,7 @@ import org.apache.spark.SparkContext
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.sql.SQLContext
 
 /**
  * :: DeveloperApi ::
@@ -41,7 +42,7 @@ case class Aggregate(
     partial: Boolean,
     groupingExpressions: Seq[Expression],
     aggregateExpressions: Seq[NamedExpression],
-    child: SparkPlan)(@transient sc: SparkContext)
+    child: SparkPlan)(@transient sqlContext: SQLContext)
   extends UnaryNode with NoBind {
 
   override def requiredChildDistribution =
@@ -55,7 +56,7 @@ case class Aggregate(
       }
     }
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   // HACK: Generators don't correctly preserve their output through 
serializations so we grab
   // out child's output attributes statically here.

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 4694f25..bd8ae4c 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       // no predicate can be evaluated by matching hash keys
       case logical.Join(left, right, LeftSemi, condition) =>
         execution.LeftSemiJoinBNL(
-          planLater(left), planLater(right), condition)(sparkContext) :: Nil
+          planLater(left), planLater(right), condition)(sqlContext) :: Nil
       case _ => Nil
     }
   }
@@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
               partial = true,
               groupingExpressions,
               partialComputation,
-              planLater(child))(sparkContext))(sparkContext) :: Nil
+              planLater(child))(sqlContext))(sqlContext) :: Nil
         } else {
           Nil
         }
@@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Join(left, right, joinType, condition) =>
         execution.BroadcastNestedLoopJoin(
-          planLater(left), planLater(right), joinType, 
condition)(sparkContext) :: Nil
+          planLater(left), planLater(right), joinType, condition)(sqlContext) 
:: Nil
       case _ => Nil
     }
   }
@@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
   object TakeOrdered extends Strategy {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) =>
-        execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: 
Nil
+        execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: 
Nil
       case _ => Nil
     }
   }
@@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         val relation =
           ParquetRelation.create(path, child, sparkContext.hadoopConfiguration)
         // Note: overwrite=false because otherwise the metadata we just 
created will be deleted
-        InsertIntoParquetTable(relation, planLater(child), 
overwrite=false)(sparkContext) :: Nil
+        InsertIntoParquetTable(relation, planLater(child), 
overwrite=false)(sqlContext) :: Nil
       case logical.InsertIntoTable(table: ParquetRelation, partition, child, 
overwrite) =>
-        InsertIntoParquetTable(table, planLater(child), 
overwrite)(sparkContext) :: Nil
+        InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) 
:: Nil
       case PhysicalOperation(projectList, filters: Seq[Expression], relation: 
ParquetRelation) =>
         val prunePushedDownFilters =
           if 
(sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, 
true)) {
@@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
           projectList,
           filters,
           prunePushedDownFilters,
-          ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil
+          ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil
 
       case _ => Nil
     }
@@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
     def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
       case logical.Distinct(child) =>
         execution.Aggregate(
-          partial = false, child.output, child.output, 
planLater(child))(sparkContext) :: Nil
+          partial = false, child.output, child.output, 
planLater(child))(sqlContext) :: Nil
       case logical.Sort(sortExprs, child) =>
         // This sort is a global sort. Its requiredDistribution will be an 
OrderedDistribution.
         execution.Sort(sortExprs, global = true, planLater(child)):: Nil
@@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
       case logical.Filter(condition, child) =>
         execution.Filter(condition, planLater(child)) :: Nil
       case logical.Aggregate(group, agg, child) =>
-        execution.Aggregate(partial = false, group, agg, 
planLater(child))(sparkContext) :: Nil
+        execution.Aggregate(partial = false, group, agg, 
planLater(child))(sqlContext) :: Nil
       case logical.Sample(fraction, withReplacement, seed, child) =>
         execution.Sample(fraction, withReplacement, seed, planLater(child)) :: 
Nil
       case logical.LocalRelation(output, data) =>
@@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
             new GenericRow(r.productIterator.map(convertToCatalyst).toArray): 
Row))
         execution.ExistingRdd(output, dataAsRdd) :: Nil
       case logical.Limit(IntegerLiteral(limit), child) =>
-        execution.Limit(limit, planLater(child))(sparkContext) :: Nil
+        execution.Limit(limit, planLater(child))(sqlContext) :: Nil
       case Unions(unionChildren) =>
-        execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil
+        execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil
       case logical.Generate(generator, join, outer, _, child) =>
         execution.Generate(generator, join = join, outer = outer, 
planLater(child)) :: Nil
       case logical.NoRelation =>

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 8969794..18f4a58 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,8 +20,9 @@ package org.apache.spark.sql.execution
 import scala.reflect.runtime.universe.TypeTag
 
 import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.{HashPartitioner, SparkConf, SparkContext}
+import org.apache.spark.{HashPartitioner, SparkConf}
 import org.apache.spark.rdd.{RDD, ShuffledRDD}
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.ScalaReflection
 import org.apache.spark.sql.catalyst.errors._
 import org.apache.spark.sql.catalyst.expressions._
@@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: 
Boolean, seed: Long, child:
  * :: DeveloperApi ::
  */
 @DeveloperApi
-case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) 
extends SparkPlan {
+case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) 
extends SparkPlan {
   // TODO: attributes output by union should be distinct for nullability 
purposes
   override def output = children.head.output
-  override def execute() = sc.union(children.map(_.execute()))
+  override def execute() = 
sqlContext.sparkContext.union(children.map(_.execute()))
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 }
 
 /**
@@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: 
SparkContext) extends
  * data to a single partition to compute the global limit.
  */
 @DeveloperApi
-case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) 
extends UnaryNode {
+case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: 
SQLContext)
+  extends UnaryNode {
   // TODO: Implement a partition local limit, and use a strategy to generate 
the proper limit plan:
   // partition local limit -> exchange into one partition -> partition local 
limit again
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   override def output = child.output
 
@@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient 
sc: SparkContext) exte
  */
 @DeveloperApi
 case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
-                      (@transient sc: SparkContext) extends UnaryNode {
-  override def otherCopyArgs = sc :: Nil
+                      (@transient sqlContext: SQLContext) extends UnaryNode {
+  override def otherCopyArgs = sqlContext :: Nil
 
   override def output = child.output
 
@@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: 
Seq[SortOrder], child: SparkPlan)
 
   // TODO: Terminal split should be implemented differently from non-terminal 
split.
   // TODO: Pick num splits based on |limit|.
-  override def execute() = sc.makeRDD(executeCollect(), 1)
+  override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1)
 }
 
 /**

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 8d7a5ba..84bdde3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -19,9 +19,8 @@ package org.apache.spark.sql.execution
 
 import scala.collection.mutable.{ArrayBuffer, BitSet}
 
-import org.apache.spark.SparkContext
-
 import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, 
Partitioning}
@@ -200,13 +199,13 @@ case class LeftSemiJoinHash(
 @DeveloperApi
 case class LeftSemiJoinBNL(
     streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression])
-    (@transient sc: SparkContext)
+    (@transient sqlContext: SQLContext)
   extends BinaryNode {
   // TODO: Override requiredChildDistribution.
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   def output = left.output
 
@@ -223,7 +222,8 @@ case class LeftSemiJoinBNL(
 
 
   def execute() = {
-    val broadcastedRelation = 
sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+    val broadcastedRelation =
+      
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
     streamed.execute().mapPartitions { streamedIter =>
       val joinedRow = new JoinedRow
@@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: 
SparkPlan) extends BinaryNod
 @DeveloperApi
 case class BroadcastNestedLoopJoin(
     streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: 
Option[Expression])
-    (@transient sc: SparkContext)
+    (@transient sqlContext: SQLContext)
   extends BinaryNode {
   // TODO: Override requiredChildDistribution.
 
   override def outputPartitioning: Partitioning = streamed.outputPartitioning
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   def output = left.output ++ right.output
 
@@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin(
 
 
   def execute() = {
-    val broadcastedRelation = 
sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
+    val broadcastedRelation =
+      
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
 
     val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter 
=>
       val matchedRows = new ArrayBuffer[Row]
@@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin(
       }
 
     // TODO: Breaks lineage.
-    sc.union(
-      streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches))
+    sqlContext.sparkContext.union(
+      streamedPlusMatches.flatMap(_._1), 
sqlContext.sparkContext.makeRDD(rightOuterMatches))
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
index 624f2e2..ade823b 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala
@@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil
 import parquet.io.InvalidRecordException
 import parquet.schema.MessageType
 
-import org.apache.spark.{Logging, SerializableWritable, SparkContext, 
TaskContext}
+import org.apache.spark.{Logging, SerializableWritable, TaskContext}
 import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.SQLContext
 import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row}
-import org.apache.spark.sql.catalyst.types.StructType
 import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode}
 
 /**
@@ -49,10 +49,11 @@ case class ParquetTableScan(
     output: Seq[Attribute],
     relation: ParquetRelation,
     columnPruningPred: Seq[Expression])(
-    @transient val sc: SparkContext)
+    @transient val sqlContext: SQLContext)
   extends LeafNode {
 
   override def execute(): RDD[Row] = {
+    val sc = sqlContext.sparkContext
     val job = new Job(sc.hadoopConfiguration)
     ParquetInputFormat.setReadSupportClass(
       job,
@@ -93,7 +94,7 @@ case class ParquetTableScan(
       .filter(_ != null) // Parquet's record filters may produce null values
   }
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   /**
    * Applies a (candidate) projection.
@@ -104,7 +105,7 @@ case class ParquetTableScan(
   def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = {
     val success = validateProjection(prunedAttributes)
     if (success) {
-      ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc)
+      ParquetTableScan(prunedAttributes, relation, 
columnPruningPred)(sqlContext)
     } else {
       sys.error("Warning: Could not validate Parquet schema projection in 
pruneColumns")
       this
@@ -152,7 +153,7 @@ case class InsertIntoParquetTable(
     relation: ParquetRelation,
     child: SparkPlan,
     overwrite: Boolean = false)(
-    @transient val sc: SparkContext)
+    @transient val sqlContext: SQLContext)
   extends UnaryNode with SparkHadoopMapReduceUtil {
 
   /**
@@ -168,7 +169,7 @@ case class InsertIntoParquetTable(
     val childRdd = child.execute()
     assert(childRdd != null)
 
-    val job = new Job(sc.hadoopConfiguration)
+    val job = new Job(sqlContext.sparkContext.hadoopConfiguration)
 
     val writeSupport =
       if (child.output.map(_.dataType).forall(_.isPrimitive)) {
@@ -204,7 +205,7 @@ case class InsertIntoParquetTable(
 
   override def output = child.output
 
-  override def otherCopyArgs = sc :: Nil
+  override def otherCopyArgs = sqlContext :: Nil
 
   /**
    * Stores the given Row RDD as a Hadoop file.
@@ -231,7 +232,7 @@ case class InsertIntoParquetTable(
     val wrappedConf = new SerializableWritable(job.getConfiguration)
     val formatter = new SimpleDateFormat("yyyyMMddHHmm")
     val jobtrackerID = formatter.format(new Date())
-    val stageId = sc.newRddId()
+    val stageId = sqlContext.sparkContext.newRddId()
 
     val taskIdOffset =
       if (overwrite) {
@@ -270,7 +271,7 @@ case class InsertIntoParquetTable(
     val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
     val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
     jobCommitter.setupJob(jobTaskContext)
-    sc.runJob(rdd, writeShard _)
+    sqlContext.sparkContext.runJob(rdd, writeShard _)
     jobCommitter.commitJob(jobTaskContext)
   }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
index 7714eb1..2ca0c1c 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala
@@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike 
with BeforeAndAfterA
     val scanner = new ParquetTableScan(
       ParquetTestData.testData.output,
       ParquetTestData.testData,
-      Seq())(TestSQLContext.sparkContext)
+      Seq())(TestSQLContext)
     val projected = scanner.pruneColumns(ParquetTypesConverter
       .convertToAttributes(MessageTypeParser
       .parseMessageType(ParquetTestData.subTestSchema)))

Reply via email to