Repository: spark
Updated Branches:
  refs/heads/master bcceab649 -> f5f8e84d9


[SPARK-22614] Dataset API: repartitionByRange(...)

## What changes were proposed in this pull request?

This PR introduces a way to explicitly range-partition a Dataset. So far, only 
round-robin and hash partitioning were possible via `df.repartition(...)`, but 
sometimes range partitioning might be desirable: e.g. when writing to disk, for 
better compression without the cost of global sort.

The current implementation piggybacks on the existing `RepartitionByExpression` 
`LogicalPlan` and simply adds the following logic: If its expressions are of 
type `SortOrder`, then it will do `RangePartitioning`; otherwise 
`HashPartitioning`. This was by far the least intrusive solution I could come 
up with.

## How was this patch tested?
Unit test for `RepartitionByExpression` changes, a test to ensure we're not 
changing the behavior of existing `.repartition()` and a few end-to-end tests 
in `DataFrameSuite`.

Author: Adrian Ionescu <adr...@databricks.com>

Closes #19828 from adrian-ionescu/repartitionByRange.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/f5f8e84d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/f5f8e84d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/f5f8e84d

Branch: refs/heads/master
Commit: f5f8e84d9d35751dad51490b6ae22931aa88db7b
Parents: bcceab6
Author: Adrian Ionescu <adr...@databricks.com>
Authored: Thu Nov 30 15:41:34 2017 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Thu Nov 30 15:41:34 2017 -0800

----------------------------------------------------------------------
 .../plans/logical/basicLogicalOperators.scala   | 20 +++++++
 .../sql/catalyst/analysis/AnalysisSuite.scala   | 26 +++++++++
 .../scala/org/apache/spark/sql/Dataset.scala    | 57 ++++++++++++++++++--
 .../spark/sql/execution/SparkStrategies.scala   |  5 +-
 .../org/apache/spark/sql/DataFrameSuite.scala   | 57 ++++++++++++++++++++
 5 files changed, 157 insertions(+), 8 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c2750c3..93de7c1 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
 import org.apache.spark.sql.catalyst.plans._
 import org.apache.spark.sql.catalyst.plans.logical.statsEstimation._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, RangePartitioning}
 import org.apache.spark.sql.types._
 import org.apache.spark.util.Utils
 import org.apache.spark.util.random.RandomSampler
@@ -838,6 +839,25 @@ case class RepartitionByExpression(
 
   require(numPartitions > 0, s"Number of partitions ($numPartitions) must be 
positive.")
 
+  val partitioning: Partitioning = {
+    val (sortOrder, nonSortOrder) = 
partitionExpressions.partition(_.isInstanceOf[SortOrder])
+
+    require(sortOrder.isEmpty || nonSortOrder.isEmpty,
+      s"${getClass.getSimpleName} expects that either all its 
`partitionExpressions` are of type " +
+        "`SortOrder`, which means `RangePartitioning`, or none of them are 
`SortOrder`, which " +
+        "means `HashPartitioning`. In this case we have:" +
+      s"""
+         |SortOrder: ${sortOrder}
+         |NonSortOrder: ${nonSortOrder}
+       """.stripMargin)
+
+    if (sortOrder.nonEmpty) {
+      RangePartitioning(sortOrder.map(_.asInstanceOf[SortOrder]), 
numPartitions)
+    } else {
+      HashPartitioning(nonSortOrder, numPartitions)
+    }
+  }
+
   override def maxRows: Option[Long] = child.maxRows
   override def shuffle: Boolean = true
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e56a5d6..0e2e706 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.dsl.plans._
 import org.apache.spark.sql.catalyst.expressions._
 import org.apache.spark.sql.catalyst.plans.{Cross, Inner}
 import org.apache.spark.sql.catalyst.plans.logical._
+import org.apache.spark.sql.catalyst.plans.physical.{HashPartitioning, 
Partitioning, RangePartitioning}
 import org.apache.spark.sql.types._
 
 
@@ -514,4 +515,29 @@ class AnalysisSuite extends AnalysisTest with Matchers {
       Seq("Number of column aliases does not match number of columns. " +
         "Number of column aliases: 5; number of columns: 4."))
   }
+
+  test("SPARK-22614 RepartitionByExpression partitioning") {
+    def checkPartitioning[T <: Partitioning](numPartitions: Int, exprs: 
Expression*): Unit = {
+      val partitioning = RepartitionByExpression(exprs, testRelation2, 
numPartitions).partitioning
+      assert(partitioning.isInstanceOf[T])
+    }
+
+    checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 
Literal(20))
+    checkPartitioning[HashPartitioning](numPartitions = 10, exprs = 'a.attr, 
'b.attr)
+
+    checkPartitioning[RangePartitioning](numPartitions = 10,
+      exprs = SortOrder(Literal(10), Ascending))
+    checkPartitioning[RangePartitioning](numPartitions = 10,
+      exprs = SortOrder('a.attr, Ascending), SortOrder('b.attr, Descending))
+
+    intercept[IllegalArgumentException] {
+      checkPartitioning(numPartitions = 0, exprs = Literal(20))
+    }
+    intercept[IllegalArgumentException] {
+      checkPartitioning(numPartitions = -1, exprs = Literal(20))
+    }
+    intercept[IllegalArgumentException] {
+      checkPartitioning(numPartitions = 10, exprs = SortOrder('a.attr, 
Ascending), 'b.attr)
+    }
+  }
 }

http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 1620ab3..167c9d0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -2732,8 +2732,18 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @scala.annotation.varargs
-  def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = 
withTypedPlan {
-    RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, 
numPartitions)
+  def repartition(numPartitions: Int, partitionExprs: Column*): Dataset[T] = {
+    // The underlying `LogicalPlan` operator special-cases all-`SortOrder` 
arguments.
+    // However, we don't want to complicate the semantics of this API method.
+    // Instead, let's give users a friendly error message, pointing them to 
the new method.
+    val sortOrders = partitionExprs.filter(_.expr.isInstanceOf[SortOrder])
+    if (sortOrders.nonEmpty) throw new IllegalArgumentException(
+      s"""Invalid partitionExprs specified: $sortOrders
+         |For range partitioning use repartitionByRange(...) instead.
+       """.stripMargin)
+    withTypedPlan {
+      RepartitionByExpression(partitionExprs.map(_.expr), logicalPlan, 
numPartitions)
+    }
   }
 
   /**
@@ -2747,9 +2757,46 @@ class Dataset[T] private[sql](
    * @since 2.0.0
    */
   @scala.annotation.varargs
-  def repartition(partitionExprs: Column*): Dataset[T] = withTypedPlan {
-    RepartitionByExpression(
-      partitionExprs.map(_.expr), logicalPlan, 
sparkSession.sessionState.conf.numShufflePartitions)
+  def repartition(partitionExprs: Column*): Dataset[T] = {
+    repartition(sparkSession.sessionState.conf.numShufflePartitions, 
partitionExprs: _*)
+  }
+
+  /**
+   * Returns a new Dataset partitioned by the given partitioning expressions 
into
+   * `numPartitions`. The resulting Dataset is range partitioned.
+   *
+   * At least one partition-by expression must be specified.
+   * When no explicit sort order is specified, "ascending nulls first" is 
assumed.
+   *
+   * @group typedrel
+   * @since 2.3.0
+   */
+  @scala.annotation.varargs
+  def repartitionByRange(numPartitions: Int, partitionExprs: Column*): 
Dataset[T] = {
+    require(partitionExprs.nonEmpty, "At least one partition-by expression 
must be specified.")
+    val sortOrder: Seq[SortOrder] = partitionExprs.map(_.expr match {
+      case expr: SortOrder => expr
+      case expr: Expression => SortOrder(expr, Ascending)
+    })
+    withTypedPlan {
+      RepartitionByExpression(sortOrder, logicalPlan, numPartitions)
+    }
+  }
+
+  /**
+   * Returns a new Dataset partitioned by the given partitioning expressions, 
using
+   * `spark.sql.shuffle.partitions` as number of partitions.
+   * The resulting Dataset is range partitioned.
+   *
+   * At least one partition-by expression must be specified.
+   * When no explicit sort order is specified, "ascending nulls first" is 
assumed.
+   *
+   * @group typedrel
+   * @since 2.3.0
+   */
+  @scala.annotation.varargs
+  def repartitionByRange(partitionExprs: Column*): Dataset[T] = {
+    repartitionByRange(sparkSession.sessionState.conf.numShufflePartitions, 
partitionExprs: _*)
   }
 
   /**

http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1fe3cb1..9e713cd 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -482,9 +482,8 @@ abstract class SparkStrategies extends 
QueryPlanner[SparkPlan] {
         execution.RDDScanExec(Nil, singleRowRdd, "OneRowRelation") :: Nil
       case r: logical.Range =>
         execution.RangeExec(r) :: Nil
-      case logical.RepartitionByExpression(expressions, child, numPartitions) 
=>
-        exchange.ShuffleExchangeExec(HashPartitioning(
-          expressions, numPartitions), planLater(child)) :: Nil
+      case r: logical.RepartitionByExpression =>
+        exchange.ShuffleExchangeExec(r.partitioning, planLater(r.child)) :: Nil
       case ExternalRDD(outputObjAttr, rdd) => 
ExternalRDDScanExec(outputObjAttr, rdd) :: Nil
       case r: LogicalRDD =>
         RDDScanExec(r.output, r.rdd, "ExistingRDD", r.outputPartitioning, 
r.outputOrdering) :: Nil

http://git-wip-us.apache.org/repos/asf/spark/blob/f5f8e84d/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
----------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 72a5cc9..5e4c1a6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -358,6 +358,63 @@ class DataFrameSuite extends QueryTest with 
SharedSQLContext {
       testData.select('key).collect().toSeq)
   }
 
+  test("repartition with SortOrder") {
+    // passing SortOrder expressions to .repartition() should result in an 
informative error
+
+    def checkSortOrderErrorMsg[T](data: => Dataset[T]): Unit = {
+      val ex = intercept[IllegalArgumentException](data)
+      assert(ex.getMessage.contains("repartitionByRange"))
+    }
+
+    checkSortOrderErrorMsg {
+      Seq(0).toDF("a").repartition(2, $"a".asc)
+    }
+
+    checkSortOrderErrorMsg {
+      Seq((0, 0)).toDF("a", "b").repartition(2, $"a".asc, $"b")
+    }
+  }
+
+  test("repartitionByRange") {
+    val data1d = Random.shuffle(0.to(9))
+    val data2d = data1d.map(i => (i, data1d.size - i))
+
+    checkAnswer(
+      data1d.toDF("val").repartitionByRange(data1d.size, $"val".asc)
+        .select(spark_partition_id().as("id"), $"val"),
+      data1d.map(i => Row(i, i)))
+
+    checkAnswer(
+      data1d.toDF("val").repartitionByRange(data1d.size, $"val".desc)
+        .select(spark_partition_id().as("id"), $"val"),
+      data1d.map(i => Row(i, data1d.size - 1 - i)))
+
+    checkAnswer(
+      data1d.toDF("val").repartitionByRange(data1d.size, lit(42))
+        .select(spark_partition_id().as("id"), $"val"),
+      data1d.map(i => Row(0, i)))
+
+    checkAnswer(
+      data1d.toDF("val").repartitionByRange(data1d.size, lit(null), 
$"val".asc, rand())
+        .select(spark_partition_id().as("id"), $"val"),
+      data1d.map(i => Row(i, i)))
+
+    // .repartitionByRange() assumes .asc by default if no explicit sort order 
is specified
+    checkAnswer(
+      data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, $"b")
+        .select(spark_partition_id().as("id"), $"a", $"b"),
+      data2d.toDF("a", "b").repartitionByRange(data2d.size, $"a".desc, 
$"b".asc)
+        .select(spark_partition_id().as("id"), $"a", $"b"))
+
+    // at least one partition-by expression must be specified
+    intercept[IllegalArgumentException] {
+      data1d.toDF("val").repartitionByRange(data1d.size)
+    }
+    intercept[IllegalArgumentException] {
+      data1d.toDF("val").repartitionByRange(data1d.size, Seq.empty: _*)
+    }
+  }
+
   test("coalesce") {
     intercept[IllegalArgumentException] {
       testData.select('key).coalesce(0)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to