Repository: spark Updated Branches: refs/heads/master 7634fe951 -> 4883a5087
[SPARK-12374][SPARK-12150][SQL] Adding logical/physical operators for Range Based on the suggestions from marmbrus , added logical/physical operators for Range for improving the performance. Also added another API for resolving the JIRA Spark-12150. Could you take a look at my implementation, marmbrus ? If not good, I can rework it. : ) Thank you very much! Author: gatorsmile <gatorsm...@gmail.com> Closes #10335 from gatorsmile/rangeOperators. Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/4883a508 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/4883a508 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/4883a508 Branch: refs/heads/master Commit: 4883a5087d481d4de5d3beabbd709853de01399a Parents: 7634fe9 Author: gatorsmile <gatorsm...@gmail.com> Authored: Mon Dec 21 13:46:58 2015 -0800 Committer: Michael Armbrust <mich...@databricks.com> Committed: Mon Dec 21 13:46:58 2015 -0800 ---------------------------------------------------------------------- .../scala/org/apache/spark/SparkContext.scala | 2 +- .../catalyst/plans/logical/basicOperators.scala | 32 ++++++++++ .../scala/org/apache/spark/sql/SQLContext.scala | 23 +++++--- .../spark/sql/execution/SparkStrategies.scala | 2 + .../spark/sql/execution/basicOperators.scala | 62 ++++++++++++++++++++ .../org/apache/spark/sql/DataFrameSuite.scala | 5 ++ .../execution/ExchangeCoordinatorSuite.scala | 1 + 7 files changed, 119 insertions(+), 8 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/core/src/main/scala/org/apache/spark/SparkContext.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 194ecc0..81a4d0a 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -759,7 +759,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli val numElements: BigInt = { val safeStart = BigInt(start) val safeEnd = BigInt(end) - if ((safeEnd - safeStart) % step == 0 || safeEnd > safeStart ^ step > 0) { + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { (safeEnd - safeStart) / step } else { // the remainder has the same sign with range, could add 1 more http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala index ec42b76..64ef4d7 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala @@ -210,6 +210,38 @@ case class Sort( override def output: Seq[Attribute] = child.output } +/** Factory for constructing new `Range` nodes. */ +object Range { + def apply(start: Long, end: Long, step: Long, numSlices: Int): Range = { + val output = StructType(StructField("id", LongType, nullable = false) :: Nil).toAttributes + new Range(start, end, step, numSlices, output) + } +} + +case class Range( + start: Long, + end: Long, + step: Long, + numSlices: Int, + output: Seq[Attribute]) extends LeafNode { + require(step != 0, "step cannot be 0") + val numElements: BigInt = { + val safeStart = BigInt(start) + val safeEnd = BigInt(end) + if ((safeEnd - safeStart) % step == 0 || (safeEnd > safeStart) != (step > 0)) { + (safeEnd - safeStart) / step + } else { + // the remainder has the same sign with range, could add 1 more + (safeEnd - safeStart) / step + 1 + } + } + + override def statistics: Statistics = { + val sizeInBytes = LongType.defaultSize * numElements + Statistics( sizeInBytes = sizeInBytes ) + } +} + case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index db286ea..eadf5cb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -36,7 +36,7 @@ import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.optimizer.{DefaultOptimizer, Optimizer} -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} import org.apache.spark.sql.catalyst.rules.RuleExecutor import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _} import org.apache.spark.sql.execution._ @@ -785,9 +785,20 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long): DataFrame = { - createDataFrame( - sparkContext.range(start, end).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataframe + */ + @Experimental + def range(start: Long, end: Long, step: Long): DataFrame = { + range(start, end, step, numPartitions = sparkContext.defaultParallelism) } /** @@ -801,9 +812,7 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): DataFrame = { - createDataFrame( - sparkContext.range(start, end, step, numPartitions).map(Row(_)), - StructType(StructField("id", LongType, nullable = false) :: Nil)) + DataFrame(this, Range(start, end, step, numPartitions)) } /** http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index 688555c..183d9b6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -358,6 +358,8 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd, "OneRowRelation") :: Nil + case r @ logical.Range(start, end, step, numSlices, output) => + execution.Range(start, step, numSlices, r.numElements, output) :: Nil case logical.RepartitionByExpression(expressions, child, nPartitions) => execution.Exchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index b3e4688..21325be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -25,6 +25,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeRowJoiner import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.LongType import org.apache.spark.util.MutablePair import org.apache.spark.util.random.PoissonSampler import org.apache.spark.{HashPartitioner, SparkEnv} @@ -126,6 +127,67 @@ case class Sample( } } +case class Range( + start: Long, + step: Long, + numSlices: Int, + numElements: BigInt, + output: Seq[Attribute]) + extends LeafNode { + + override def outputsUnsafeRows: Boolean = true + + protected override def doExecute(): RDD[InternalRow] = { + sqlContext + .sparkContext + .parallelize(0 until numSlices, numSlices) + .mapPartitionsWithIndex((i, _) => { + val partitionStart = (i * numElements) / numSlices * step + start + val partitionEnd = (((i + 1) * numElements) / numSlices) * step + start + def getSafeMargin(bi: BigInt): Long = + if (bi.isValidLong) { + bi.toLong + } else if (bi > 0) { + Long.MaxValue + } else { + Long.MinValue + } + val safePartitionStart = getSafeMargin(partitionStart) + val safePartitionEnd = getSafeMargin(partitionEnd) + val rowSize = UnsafeRow.calculateBitSetWidthInBytes(1) + LongType.defaultSize + val unsafeRow = UnsafeRow.createFromByteArray(rowSize, 1) + + new Iterator[InternalRow] { + private[this] var number: Long = safePartitionStart + private[this] var overflow: Boolean = false + + override def hasNext = + if (!overflow) { + if (step > 0) { + number < safePartitionEnd + } else { + number > safePartitionEnd + } + } else false + + override def next() = { + val ret = number + number += step + if (number < ret ^ step < 0) { + // we have Long.MaxValue + Long.MaxValue < Long.MaxValue + // and Long.MinValue + Long.MinValue > Long.MinValue, so iff the step causes a step + // back, we are pretty sure that we have an overflow. + overflow = true + } + + unsafeRow.setLong(0, ret) + unsafeRow + } + } + }) + } +} + /** * Union two plans, without a distinct. This is UNION ALL in SQL. */ http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 1a0f1b6..ad478b0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -769,6 +769,11 @@ class DataFrameSuite extends QueryTest with SharedSQLContext { val res11 = sqlContext.range(-1).select("id") assert(res11.count == 0) + + // using the default slice number + val res12 = sqlContext.range(3, 15, 3).select("id") + assert(res12.count == 4) + assert(res12.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) } test("SPARK-8621: support empty string column name") { http://git-wip-us.apache.org/repos/asf/spark/blob/4883a508/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala ---------------------------------------------------------------------- diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala index 180050b..101cf50 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeCoordinatorSuite.scala @@ -260,6 +260,7 @@ class ExchangeCoordinatorSuite extends SparkFunSuite with BeforeAndAfterAll { .set("spark.driver.allowMultipleContexts", "true") .set(SQLConf.SHUFFLE_PARTITIONS.key, "5") .set(SQLConf.ADAPTIVE_EXECUTION_ENABLED.key, "true") + .set(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key, "-1") .set( SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key, targetNumPostShufflePartitions.toString) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org