Repository: spark Updated Branches: refs/heads/master 648553d48 -> ca5d8b590
[SQL] Pass SQLContext instead of SparkContext into physical operators. This makes it easier to use config options in operators. Author: Reynold Xin <r...@apache.org> Closes #1164 from rxin/sqlcontext and squashes the following commits: 797b2fd [Reynold Xin] Pass SQLContext instead of SparkContext into physical operators. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/ca5d8b59 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/ca5d8b59 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/ca5d8b59 Branch: refs/heads/master Commit: ca5d8b5904dc6dd5b691af506d3a842e508b3673 Parents: 648553d Author: Reynold Xin <r...@apache.org> Authored: Fri Jun 20 22:49:48 2014 -0700 Committer: Reynold Xin <r...@apache.org> Committed: Fri Jun 20 22:49:48 2014 -0700 ---------------------------------------------------------------------- .../scala/org/apache/spark/sql/SQLContext.scala | 4 +++- .../apache/spark/sql/execution/Aggregate.scala | 5 +++-- .../spark/sql/execution/SparkStrategies.scala | 22 ++++++++++---------- .../spark/sql/execution/basicOperators.scala | 20 ++++++++++-------- .../org/apache/spark/sql/execution/joins.scala | 21 ++++++++++--------- .../sql/parquet/ParquetTableOperations.scala | 21 ++++++++++--------- .../spark/sql/parquet/ParquetQuerySuite.scala | 2 +- 7 files changed, 51 insertions(+), 44 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index ab376e5..c60af28 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -221,7 +221,9 @@ class SQLContext(@transient val sparkContext: SparkContext) } protected[sql] class SparkPlanner extends SparkStrategies { - val sparkContext = self.sparkContext + val sparkContext: SparkContext = self.sparkContext + + val sqlContext: SQLContext = self def numPartitions = self.numShufflePartitions http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 34d88fe..d85d2d7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -24,6 +24,7 @@ import org.apache.spark.SparkContext import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.SQLContext /** * :: DeveloperApi :: @@ -41,7 +42,7 @@ case class Aggregate( partial: Boolean, groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], - child: SparkPlan)(@transient sc: SparkContext) + child: SparkPlan)(@transient sqlContext: SQLContext) extends UnaryNode with NoBind { override def requiredChildDistribution = @@ -55,7 +56,7 @@ case class Aggregate( } } - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil // HACK: Generators don't correctly preserve their output through serializations so we grab // out child's output attributes statically here. http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 4694f25..bd8ae4c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -40,7 +40,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { // no predicate can be evaluated by matching hash keys case logical.Join(left, right, LeftSemi, condition) => execution.LeftSemiJoinBNL( - planLater(left), planLater(right), condition)(sparkContext) :: Nil + planLater(left), planLater(right), condition)(sqlContext) :: Nil case _ => Nil } } @@ -103,7 +103,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { partial = true, groupingExpressions, partialComputation, - planLater(child))(sparkContext))(sparkContext) :: Nil + planLater(child))(sqlContext))(sqlContext) :: Nil } else { Nil } @@ -115,7 +115,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Join(left, right, joinType, condition) => execution.BroadcastNestedLoopJoin( - planLater(left), planLater(right), joinType, condition)(sparkContext) :: Nil + planLater(left), planLater(right), joinType, condition)(sqlContext) :: Nil case _ => Nil } } @@ -143,7 +143,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { object TakeOrdered extends Strategy { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Limit(IntegerLiteral(limit), logical.Sort(order, child)) => - execution.TakeOrdered(limit, order, planLater(child))(sparkContext) :: Nil + execution.TakeOrdered(limit, order, planLater(child))(sqlContext) :: Nil case _ => Nil } } @@ -155,9 +155,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { val relation = ParquetRelation.create(path, child, sparkContext.hadoopConfiguration) // Note: overwrite=false because otherwise the metadata we just created will be deleted - InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sparkContext) :: Nil + InsertIntoParquetTable(relation, planLater(child), overwrite=false)(sqlContext) :: Nil case logical.InsertIntoTable(table: ParquetRelation, partition, child, overwrite) => - InsertIntoParquetTable(table, planLater(child), overwrite)(sparkContext) :: Nil + InsertIntoParquetTable(table, planLater(child), overwrite)(sqlContext) :: Nil case PhysicalOperation(projectList, filters: Seq[Expression], relation: ParquetRelation) => val prunePushedDownFilters = if (sparkContext.conf.getBoolean(ParquetFilters.PARQUET_FILTER_PUSHDOWN_ENABLED, true)) { @@ -186,7 +186,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { projectList, filters, prunePushedDownFilters, - ParquetTableScan(_, relation, filters)(sparkContext)) :: Nil + ParquetTableScan(_, relation, filters)(sqlContext)) :: Nil case _ => Nil } @@ -211,7 +211,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { case logical.Distinct(child) => execution.Aggregate( - partial = false, child.output, child.output, planLater(child))(sparkContext) :: Nil + partial = false, child.output, child.output, planLater(child))(sqlContext) :: Nil case logical.Sort(sortExprs, child) => // This sort is a global sort. Its requiredDistribution will be an OrderedDistribution. execution.Sort(sortExprs, global = true, planLater(child)):: Nil @@ -224,7 +224,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.Filter(condition, child) => execution.Filter(condition, planLater(child)) :: Nil case logical.Aggregate(group, agg, child) => - execution.Aggregate(partial = false, group, agg, planLater(child))(sparkContext) :: Nil + execution.Aggregate(partial = false, group, agg, planLater(child))(sqlContext) :: Nil case logical.Sample(fraction, withReplacement, seed, child) => execution.Sample(fraction, withReplacement, seed, planLater(child)) :: Nil case logical.LocalRelation(output, data) => @@ -233,9 +233,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { new GenericRow(r.productIterator.map(convertToCatalyst).toArray): Row)) execution.ExistingRdd(output, dataAsRdd) :: Nil case logical.Limit(IntegerLiteral(limit), child) => - execution.Limit(limit, planLater(child))(sparkContext) :: Nil + execution.Limit(limit, planLater(child))(sqlContext) :: Nil case Unions(unionChildren) => - execution.Union(unionChildren.map(planLater))(sparkContext) :: Nil + execution.Union(unionChildren.map(planLater))(sqlContext) :: Nil case logical.Generate(generator, join, outer, _, child) => execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil case logical.NoRelation => http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 8969794..18f4a58 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,8 +20,9 @@ package org.apache.spark.sql.execution import scala.reflect.runtime.universe.TypeTag import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.{HashPartitioner, SparkConf, SparkContext} +import org.apache.spark.{HashPartitioner, SparkConf} import org.apache.spark.rdd.{RDD, ShuffledRDD} +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.ScalaReflection import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ @@ -70,12 +71,12 @@ case class Sample(fraction: Double, withReplacement: Boolean, seed: Long, child: * :: DeveloperApi :: */ @DeveloperApi -case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends SparkPlan { +case class Union(children: Seq[SparkPlan])(@transient sqlContext: SQLContext) extends SparkPlan { // TODO: attributes output by union should be distinct for nullability purposes override def output = children.head.output - override def execute() = sc.union(children.map(_.execute())) + override def execute() = sqlContext.sparkContext.union(children.map(_.execute())) - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil } /** @@ -87,11 +88,12 @@ case class Union(children: Seq[SparkPlan])(@transient sc: SparkContext) extends * data to a single partition to compute the global limit. */ @DeveloperApi -case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) extends UnaryNode { +case class Limit(limit: Int, child: SparkPlan)(@transient sqlContext: SQLContext) + extends UnaryNode { // TODO: Implement a partition local limit, and use a strategy to generate the proper limit plan: // partition local limit -> exchange into one partition -> partition local limit again - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil override def output = child.output @@ -117,8 +119,8 @@ case class Limit(limit: Int, child: SparkPlan)(@transient sc: SparkContext) exte */ @DeveloperApi case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) - (@transient sc: SparkContext) extends UnaryNode { - override def otherCopyArgs = sc :: Nil + (@transient sqlContext: SQLContext) extends UnaryNode { + override def otherCopyArgs = sqlContext :: Nil override def output = child.output @@ -129,7 +131,7 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) // TODO: Terminal split should be implemented differently from non-terminal split. // TODO: Pick num splits based on |limit|. - override def execute() = sc.makeRDD(executeCollect(), 1) + override def execute() = sqlContext.sparkContext.makeRDD(executeCollect(), 1) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala index 8d7a5ba..84bdde3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala @@ -19,9 +19,8 @@ package org.apache.spark.sql.execution import scala.collection.mutable.{ArrayBuffer, BitSet} -import org.apache.spark.SparkContext - import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans._ import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning} @@ -200,13 +199,13 @@ case class LeftSemiJoinHash( @DeveloperApi case class LeftSemiJoinBNL( streamed: SparkPlan, broadcast: SparkPlan, condition: Option[Expression]) - (@transient sc: SparkContext) + (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil def output = left.output @@ -223,7 +222,8 @@ case class LeftSemiJoinBNL( def execute() = { - val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + val broadcastedRelation = + sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) streamed.execute().mapPartitions { streamedIter => val joinedRow = new JoinedRow @@ -263,13 +263,13 @@ case class CartesianProduct(left: SparkPlan, right: SparkPlan) extends BinaryNod @DeveloperApi case class BroadcastNestedLoopJoin( streamed: SparkPlan, broadcast: SparkPlan, joinType: JoinType, condition: Option[Expression]) - (@transient sc: SparkContext) + (@transient sqlContext: SQLContext) extends BinaryNode { // TODO: Override requiredChildDistribution. override def outputPartitioning: Partitioning = streamed.outputPartitioning - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil def output = left.output ++ right.output @@ -286,7 +286,8 @@ case class BroadcastNestedLoopJoin( def execute() = { - val broadcastedRelation = sc.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) + val broadcastedRelation = + sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq) val streamedPlusMatches = streamed.execute().mapPartitions { streamedIter => val matchedRows = new ArrayBuffer[Row] @@ -337,7 +338,7 @@ case class BroadcastNestedLoopJoin( } // TODO: Breaks lineage. - sc.union( - streamedPlusMatches.flatMap(_._1), sc.makeRDD(rightOuterMatches)) + sqlContext.sparkContext.union( + streamedPlusMatches.flatMap(_._1), sqlContext.sparkContext.makeRDD(rightOuterMatches)) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala index 624f2e2..ade823b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetTableOperations.scala @@ -33,10 +33,10 @@ import parquet.hadoop.util.ContextUtil import parquet.io.InvalidRecordException import parquet.schema.MessageType -import org.apache.spark.{Logging, SerializableWritable, SparkContext, TaskContext} +import org.apache.spark.{Logging, SerializableWritable, TaskContext} import org.apache.spark.rdd.RDD +import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, Row} -import org.apache.spark.sql.catalyst.types.StructType import org.apache.spark.sql.execution.{LeafNode, SparkPlan, UnaryNode} /** @@ -49,10 +49,11 @@ case class ParquetTableScan( output: Seq[Attribute], relation: ParquetRelation, columnPruningPred: Seq[Expression])( - @transient val sc: SparkContext) + @transient val sqlContext: SQLContext) extends LeafNode { override def execute(): RDD[Row] = { + val sc = sqlContext.sparkContext val job = new Job(sc.hadoopConfiguration) ParquetInputFormat.setReadSupportClass( job, @@ -93,7 +94,7 @@ case class ParquetTableScan( .filter(_ != null) // Parquet's record filters may produce null values } - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil /** * Applies a (candidate) projection. @@ -104,7 +105,7 @@ case class ParquetTableScan( def pruneColumns(prunedAttributes: Seq[Attribute]): ParquetTableScan = { val success = validateProjection(prunedAttributes) if (success) { - ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sc) + ParquetTableScan(prunedAttributes, relation, columnPruningPred)(sqlContext) } else { sys.error("Warning: Could not validate Parquet schema projection in pruneColumns") this @@ -152,7 +153,7 @@ case class InsertIntoParquetTable( relation: ParquetRelation, child: SparkPlan, overwrite: Boolean = false)( - @transient val sc: SparkContext) + @transient val sqlContext: SQLContext) extends UnaryNode with SparkHadoopMapReduceUtil { /** @@ -168,7 +169,7 @@ case class InsertIntoParquetTable( val childRdd = child.execute() assert(childRdd != null) - val job = new Job(sc.hadoopConfiguration) + val job = new Job(sqlContext.sparkContext.hadoopConfiguration) val writeSupport = if (child.output.map(_.dataType).forall(_.isPrimitive)) { @@ -204,7 +205,7 @@ case class InsertIntoParquetTable( override def output = child.output - override def otherCopyArgs = sc :: Nil + override def otherCopyArgs = sqlContext :: Nil /** * Stores the given Row RDD as a Hadoop file. @@ -231,7 +232,7 @@ case class InsertIntoParquetTable( val wrappedConf = new SerializableWritable(job.getConfiguration) val formatter = new SimpleDateFormat("yyyyMMddHHmm") val jobtrackerID = formatter.format(new Date()) - val stageId = sc.newRddId() + val stageId = sqlContext.sparkContext.newRddId() val taskIdOffset = if (overwrite) { @@ -270,7 +271,7 @@ case class InsertIntoParquetTable( val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId) val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext) jobCommitter.setupJob(jobTaskContext) - sc.runJob(rdd, writeShard _) + sqlContext.sparkContext.runJob(rdd, writeShard _) jobCommitter.commitJob(jobTaskContext) } } http://git-wip-us.apache.org/repos/asf/spark/blob/ca5d8b59/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala index 7714eb1..2ca0c1c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/parquet/ParquetQuerySuite.scala @@ -166,7 +166,7 @@ class ParquetQuerySuite extends QueryTest with FunSuiteLike with BeforeAndAfterA val scanner = new ParquetTableScan( ParquetTestData.testData.output, ParquetTestData.testData, - Seq())(TestSQLContext.sparkContext) + Seq())(TestSQLContext) val projected = scanner.pruneColumns(ParquetTypesConverter .convertToAttributes(MessageTypeParser .parseMessageType(ParquetTestData.subTestSchema)))