Repository: spark
Updated Branches:
  refs/heads/branch-2.3 566ef93a6 -> 7241556d8


[SPARK-22389][SQL] data source v2 partitioning reporting interface

## What changes were proposed in this pull request?

a new interface which allows data source to report partitioning and avoid 
shuffle at Spark side.

The design is pretty like the internal distribution/partitioing framework. 
Spark defines a `Distribution` interfaces and several concrete implementations, 
and ask the data source to report a `Partitioning`, the `Partitioning` should 
tell Spark if it can satisfy a `Distribution` or not.

## How was this patch tested?

new test

Author: Wenchen Fan <wenc...@databricks.com>

Closes #20201 from cloud-fan/partition-reporting.

(cherry picked from commit 51eb750263dd710434ddb60311571fa3dcec66eb)
Signed-off-by: gatorsmile <gatorsm...@gmail.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7241556d
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7241556d
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7241556d

Branch: refs/heads/branch-2.3
Commit: 7241556d8b550e22eed2341287812ea373dc1cb2
Parents: 566ef93
Author: Wenchen Fan <wenc...@databricks.com>
Authored: Mon Jan 22 15:21:09 2018 -0800
Committer: gatorsmile <gatorsm...@gmail.com>
Committed: Mon Jan 22 15:21:19 2018 -0800

----------------------------------------------------------------------
 .../catalyst/plans/physical/partitioning.scala  |   2 +-
 .../v2/reader/ClusteredDistribution.java        |  38 +++++++
 .../sql/sources/v2/reader/Distribution.java     |  39 +++++++
 .../sql/sources/v2/reader/Partitioning.java     |  46 ++++++++
 .../v2/reader/SupportsReportPartitioning.java   |  33 ++++++
 .../datasources/v2/DataSourcePartitioning.scala |  56 ++++++++++
 .../datasources/v2/DataSourceV2ScanExec.scala   |   9 ++
 .../v2/JavaPartitionAwareDataSource.java        | 110 +++++++++++++++++++
 .../sql/sources/v2/DataSourceV2Suite.scala      |  79 +++++++++++++
 9 files changed, 411 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 0189bd7..4d9a992 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -153,7 +153,7 @@ case class BroadcastDistribution(mode: BroadcastMode) 
extends Distribution {
  *   1. number of partitions.
  *   2. if it can satisfy a given distribution.
  */
-sealed trait Partitioning {
+trait Partitioning {
   /** Returns the number of partitions that the data is split across */
   val numPartitions: Int
 

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
new file mode 100644
index 0000000..7346500
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/ClusteredDistribution.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A concrete implementation of {@link Distribution}. Represents a 
distribution where records that
+ * share the same values for the {@link #clusteredColumns} will be produced by 
the same
+ * {@link ReadTask}.
+ */
+@InterfaceStability.Evolving
+public class ClusteredDistribution implements Distribution {
+
+  /**
+   * The names of the clustered columns. Note that they are order insensitive.
+   */
+  public final String[] clusteredColumns;
+
+  public ClusteredDistribution(String[] clusteredColumns) {
+    this.clusteredColumns = clusteredColumns;
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
new file mode 100644
index 0000000..a6201a2
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Distribution.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent data distribution requirement, which specifies 
how the records should
+ * be distributed among the {@link ReadTask}s that are returned by
+ * {@link DataSourceV2Reader#createReadTasks()}. Note that this interface has 
nothing to do with
+ * the data ordering inside one partition(the output records of a single 
{@link ReadTask}).
+ *
+ * The instance of this interface is created and provided by Spark, then 
consumed by
+ * {@link Partitioning#satisfy(Distribution)}. This means data source 
developers don't need to
+ * implement this interface, but need to catch as more concrete 
implementations of this interface
+ * as possible in {@link Partitioning#satisfy(Distribution)}.
+ *
+ * Concrete implementations until now:
+ * <ul>
+ *   <li>{@link ClusteredDistribution}</li>
+ * </ul>
+ */
+@InterfaceStability.Evolving
+public interface Distribution {}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
new file mode 100644
index 0000000..199e45d
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/Partitioning.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * An interface to represent the output data partitioning for a data source, 
which is returned by
+ * {@link SupportsReportPartitioning#outputPartitioning()}. Note that this 
should work like a
+ * snapshot. Once created, it should be deterministic and always report the 
same number of
+ * partitions and the same "satisfy" result for a certain distribution.
+ */
+@InterfaceStability.Evolving
+public interface Partitioning {
+
+  /**
+   * Returns the number of partitions(i.e., {@link ReadTask}s) the data source 
outputs.
+   */
+  int numPartitions();
+
+  /**
+   * Returns true if this partitioning can satisfy the given distribution, 
which means Spark does
+   * not need to shuffle the output data of this data source for some certain 
operations.
+   *
+   * Note that, Spark may add new concrete implementations of {@link 
Distribution} in new releases.
+   * This method should be aware of it and always return false for 
unrecognized distributions. It's
+   * recommended to check every Spark new release and support new 
distributions if possible, to
+   * avoid shuffle at Spark side for more cases.
+   */
+  boolean satisfy(Distribution distribution);
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
new file mode 100644
index 0000000..f786472
--- /dev/null
+++ 
b/sql/core/src/main/java/org/apache/spark/sql/sources/v2/reader/SupportsReportPartitioning.java
@@ -0,0 +1,33 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.sources.v2.reader;
+
+import org.apache.spark.annotation.InterfaceStability;
+
+/**
+ * A mix in interface for {@link DataSourceV2Reader}. Data source readers can 
implement this
+ * interface to report data partitioning and try to avoid shuffle at Spark 
side.
+ */
+@InterfaceStability.Evolving
+public interface SupportsReportPartitioning {
+
+  /**
+   * Returns the output data partitioning that this reader guarantees.
+   */
+  Partitioning outputPartitioning();
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
new file mode 100644
index 0000000..943d010
--- /dev/null
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourcePartitioning.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.v2
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeMap, 
Expression}
+import org.apache.spark.sql.catalyst.plans.physical
+import org.apache.spark.sql.sources.v2.reader.{ClusteredDistribution, 
Partitioning}
+
+/**
+ * An adapter from public data source partitioning to catalyst internal 
`Partitioning`.
+ */
+class DataSourcePartitioning(
+    partitioning: Partitioning,
+    colNames: AttributeMap[String]) extends physical.Partitioning {
+
+  override val numPartitions: Int = partitioning.numPartitions()
+
+  override def satisfies(required: physical.Distribution): Boolean = {
+    super.satisfies(required) || {
+      required match {
+        case d: physical.ClusteredDistribution if isCandidate(d.clustering) =>
+          val attrs = d.clustering.map(_.asInstanceOf[Attribute])
+          partitioning.satisfy(
+            new ClusteredDistribution(attrs.map { a =>
+              val name = colNames.get(a)
+              assert(name.isDefined, s"Attribute ${a.name} is not found in the 
data source output")
+              name.get
+            }.toArray))
+
+        case _ => false
+      }
+    }
+  }
+
+  private def isCandidate(clustering: Seq[Expression]): Boolean = {
+    clustering.forall {
+      case a: Attribute => colNames.contains(a)
+      case _ => false
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
index beb6673..69d871d 100644
--- 
a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
+++ 
b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2ScanExec.scala
@@ -24,6 +24,7 @@ import org.apache.spark.sql.Row
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, RowEncoder}
 import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical
 import org.apache.spark.sql.execution.{ColumnarBatchScan, LeafExecNode, 
WholeStageCodegenExec}
 import org.apache.spark.sql.execution.streaming.continuous._
 import org.apache.spark.sql.sources.v2.reader._
@@ -42,6 +43,14 @@ case class DataSourceV2ScanExec(
 
   override def producedAttributes: AttributeSet = AttributeSet(fullOutput)
 
+  override def outputPartitioning: physical.Partitioning = reader match {
+    case s: SupportsReportPartitioning =>
+      new DataSourcePartitioning(
+        s.outputPartitioning(), AttributeMap(output.map(a => a -> a.name)))
+
+    case _ => super.outputPartitioning
+  }
+
   private lazy val readTasks: java.util.List[ReadTask[UnsafeRow]] = reader 
match {
     case r: SupportsScanUnsafeRow => r.createUnsafeRowReadTasks()
     case _ =>

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
 
b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
new file mode 100644
index 0000000..806d0bc
--- /dev/null
+++ 
b/sql/core/src/test/java/test/org/apache/spark/sql/sources/v2/JavaPartitionAwareDataSource.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package test.org.apache.spark.sql.sources.v2;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.List;
+
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.catalyst.expressions.GenericRow;
+import org.apache.spark.sql.sources.v2.DataSourceV2;
+import org.apache.spark.sql.sources.v2.DataSourceV2Options;
+import org.apache.spark.sql.sources.v2.ReadSupport;
+import org.apache.spark.sql.sources.v2.reader.*;
+import org.apache.spark.sql.types.StructType;
+
+public class JavaPartitionAwareDataSource implements DataSourceV2, ReadSupport 
{
+
+  class Reader implements DataSourceV2Reader, SupportsReportPartitioning {
+    private final StructType schema = new StructType().add("a", 
"int").add("b", "int");
+
+    @Override
+    public StructType readSchema() {
+      return schema;
+    }
+
+    @Override
+    public List<ReadTask<Row>> createReadTasks() {
+      return java.util.Arrays.asList(
+        new SpecificReadTask(new int[]{1, 1, 3}, new int[]{4, 4, 6}),
+        new SpecificReadTask(new int[]{2, 4, 4}, new int[]{6, 2, 2}));
+    }
+
+    @Override
+    public Partitioning outputPartitioning() {
+      return new MyPartitioning();
+    }
+  }
+
+  static class MyPartitioning implements Partitioning {
+
+    @Override
+    public int numPartitions() {
+      return 2;
+    }
+
+    @Override
+    public boolean satisfy(Distribution distribution) {
+      if (distribution instanceof ClusteredDistribution) {
+        String[] clusteredCols = ((ClusteredDistribution) 
distribution).clusteredColumns;
+        return Arrays.asList(clusteredCols).contains("a");
+      }
+
+      return false;
+    }
+  }
+
+  static class SpecificReadTask implements ReadTask<Row>, DataReader<Row> {
+    private int[] i;
+    private int[] j;
+    private int current = -1;
+
+    SpecificReadTask(int[] i, int[] j) {
+      assert i.length == j.length;
+      this.i = i;
+      this.j = j;
+    }
+
+    @Override
+    public boolean next() throws IOException {
+      current += 1;
+      return current < i.length;
+    }
+
+    @Override
+    public Row get() {
+      return new GenericRow(new Object[] {i[current], j[current]});
+    }
+
+    @Override
+    public void close() throws IOException {
+
+    }
+
+    @Override
+    public DataReader<Row> createDataReader() {
+      return this;
+    }
+  }
+
+  @Override
+  public DataSourceV2Reader createReader(DataSourceV2Options options) {
+    return new Reader();
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/7241556d/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
index 0ca2952..0620693 100644
--- 
a/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/sources/v2/DataSourceV2Suite.scala
@@ -24,6 +24,7 @@ import test.org.apache.spark.sql.sources.v2._
 import org.apache.spark.SparkException
 import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
 import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
 import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector
 import org.apache.spark.sql.sources.{Filter, GreaterThan}
 import org.apache.spark.sql.sources.v2.reader._
@@ -95,6 +96,40 @@ class DataSourceV2Suite extends QueryTest with 
SharedSQLContext {
     }
   }
 
+  test("partitioning reporting") {
+    import org.apache.spark.sql.functions.{count, sum}
+    Seq(classOf[PartitionAwareDataSource], 
classOf[JavaPartitionAwareDataSource]).foreach { cls =>
+      withClue(cls.getName) {
+        val df = spark.read.format(cls.getName).load()
+        checkAnswer(df, Seq(Row(1, 4), Row(1, 4), Row(3, 6), Row(2, 6), Row(4, 
2), Row(4, 2)))
+
+        val groupByColA = df.groupBy('a).agg(sum('b))
+        checkAnswer(groupByColA, Seq(Row(1, 8), Row(2, 6), Row(3, 6), Row(4, 
4)))
+        assert(groupByColA.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isEmpty)
+
+        val groupByColAB = df.groupBy('a, 'b).agg(count("*"))
+        checkAnswer(groupByColAB, Seq(Row(1, 4, 2), Row(2, 6, 1), Row(3, 6, 
1), Row(4, 2, 2)))
+        assert(groupByColAB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isEmpty)
+
+        val groupByColB = df.groupBy('b).agg(sum('a))
+        checkAnswer(groupByColB, Seq(Row(2, 8), Row(4, 2), Row(6, 5)))
+        assert(groupByColB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isDefined)
+
+        val groupByAPlusB = df.groupBy('a + 'b).agg(count("*"))
+        checkAnswer(groupByAPlusB, Seq(Row(5, 2), Row(6, 2), Row(8, 1), Row(9, 
1)))
+        assert(groupByAPlusB.queryExecution.executedPlan.collectFirst {
+          case e: ShuffleExchangeExec => e
+        }.isDefined)
+      }
+    }
+  }
+
   test("simple writable data source") {
     // TODO: java implementation.
     Seq(classOf[SimpleWritableDataSource]).foreach { cls =>
@@ -365,3 +400,47 @@ class BatchReadTask(start: Int, end: Int)
 
   override def close(): Unit = batch.close()
 }
+
+class PartitionAwareDataSource extends DataSourceV2 with ReadSupport {
+
+  class Reader extends DataSourceV2Reader with SupportsReportPartitioning {
+    override def readSchema(): StructType = new StructType().add("a", 
"int").add("b", "int")
+
+    override def createReadTasks(): JList[ReadTask[Row]] = {
+      // Note that we don't have same value of column `a` across partitions.
+      java.util.Arrays.asList(
+        new SpecificReadTask(Array(1, 1, 3), Array(4, 4, 6)),
+        new SpecificReadTask(Array(2, 4, 4), Array(6, 2, 2)))
+    }
+
+    override def outputPartitioning(): Partitioning = new MyPartitioning
+  }
+
+  class MyPartitioning extends Partitioning {
+    override def numPartitions(): Int = 2
+
+    override def satisfy(distribution: Distribution): Boolean = distribution 
match {
+      case c: ClusteredDistribution => c.clusteredColumns.contains("a")
+      case _ => false
+    }
+  }
+
+  override def createReader(options: DataSourceV2Options): DataSourceV2Reader 
= new Reader
+}
+
+class SpecificReadTask(i: Array[Int], j: Array[Int]) extends ReadTask[Row] 
with DataReader[Row] {
+  assert(i.length == j.length)
+
+  private var current = -1
+
+  override def createDataReader(): DataReader[Row] = this
+
+  override def next(): Boolean = {
+    current += 1
+    current < i.length
+  }
+
+  override def get(): Row = Row(i(current), j(current))
+
+  override def close(): Unit = {}
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to