This is an automated email from the ASF dual-hosted git repository.

wenchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 887f83122a0 [SPARK-41513][SQL] Implement an accumulator to collect per 
mapper row count metrics
887f83122a0 is described below

commit 887f83122a0b5684b6a5b0fab7ca9768a6c184f0
Author: Rui Wang <rui.w...@databricks.com>
AuthorDate: Thu Dec 22 19:59:43 2022 +0800

    [SPARK-41513][SQL] Implement an accumulator to collect per mapper row count 
metrics
    
    ### What changes were proposed in this pull request?
    
    In current Spark optimizer, a single partition shuffle might be created for 
a limit if this limit is not the last non-action operation (e.g. a filter 
following the limit and the data size exceeds a threshold). There is a 
possibility that the previous output partitions before go into this limit are 
sorted. The single partition shuffle approach has a correctness bug in this 
case: shuffle read partitions could be out of partition order and the limit 
exec just take the first limit rows wh [...]
    
    So we propose a row count based AQE algorithm that optimizes this problem 
by two folds:
    
    1. Avoid the extra sort on the shuffle read side (or with the limit exec) 
to achieve the correct result.
    2. Avoid reading all shuffle data from mappers for this single partition 
shuffle to reduce shuffle cost.
    
    Note that 1. is only applied for the sorted partition case where 2. is 
applied for general single partition shuffle + limit case
    
    The algorithm works as the following:
    
    1. Each mapper will record a row count when writing shuffle data.
    2. Since this is single shuffle partition case, there is only one partition 
but N mappers.
    3. A accumulatorV2 is implemented to collect a list of tuple which records 
the mapping between mapper id and the number of row written by the mapper (row 
count metrics)
    4. AQE framework detects a plan shape of shuffle plus a global limit.
    5. AQE framework reads necessary data from mappers based on the limit. For 
example, if mapper 1 writes 200 rows and mapper 2 writes 300 rows, and the 
limit is 500, AQE creates shuffle read node to write from both mapper 1 and 2, 
thus skip the left mappers.
    6. This is both correct for limit with the sorted or non-sorted partitions.
    
    This is the first step to implement the idea in 
https://issues.apache.org/jira/browse/SPARK-41512, which is to implement a row 
count accumulator that will be used to collect row count metrics.
    
    ### Why are the changes needed?
    
    Optimization algorithm for global limit with single partition shuffle
    
    ### Does this PR introduce _any_ user-facing change?
    
    NO
    ### How was this patch tested?
    
    UT
    
    Closes #39057 from amaliujia/add_row_counter.
    
    Authored-by: Rui Wang <rui.w...@databricks.com>
    Signed-off-by: Wenchen Fan <wenc...@databricks.com>
---
 .../apache/spark/sql/util/MapperRowCounter.scala   | 85 ++++++++++++++++++++++
 .../spark/sql/util/MapperRowCounterSuite.scala     | 54 ++++++++++++++
 2 files changed, 139 insertions(+)

diff --git 
a/sql/core/src/main/scala/org/apache/spark/sql/util/MapperRowCounter.scala 
b/sql/core/src/main/scala/org/apache/spark/sql/util/MapperRowCounter.scala
new file mode 100644
index 00000000000..7e1dfacca4a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/util/MapperRowCounter.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.util
+
+import java.{lang => jl}
+
+import org.apache.spark.util.AccumulatorV2
+
+/**
+ * An AccumulatorV2 counter for collecting a list of (mapper index, row count).
+ *
+ * @since 3.4.0
+ */
+class MapperRowCounter extends AccumulatorV2[jl.Long, 
java.util.List[(jl.Integer, jl.Long)]] {
+
+  private var _agg: java.util.List[(jl.Integer, jl.Long)] = _
+
+  private def getOrCreate = {
+    _agg = Option(_agg).getOrElse(new java.util.ArrayList[(jl.Integer, 
jl.Long)]())
+    _agg
+  }
+
+  /**
+   * Returns false if this accumulator has had any values added to it or the 
sum is non-zero.
+   */
+  override def isZero: Boolean = this.synchronized(getOrCreate.isEmpty)
+
+  override def copyAndReset(): MapperRowCounter = new MapperRowCounter
+
+  override def copy(): MapperRowCounter = {
+    val newAcc = new MapperRowCounter()
+    this.synchronized {
+      newAcc.getOrCreate.addAll(getOrCreate)
+    }
+    newAcc
+  }
+
+  override def reset(): Unit = {
+    this.synchronized {
+      _agg = null
+    }
+  }
+
+  override def add(v: jl.Long): Unit = {
+    this.synchronized {
+      assert(getOrCreate.size() == 1, "agg must have been initialized")
+      val p = getOrCreate.get(0)._1
+      val n = getOrCreate.get(0)._2 + 1
+      getOrCreate.set(0, (p, n))
+    }
+  }
+
+  def setPartitionId(id: jl.Integer): Unit = {
+    this.synchronized {
+      assert(isZero, "agg must not have been initialized")
+      getOrCreate.add((id, 0))
+    }
+  }
+
+  override def merge(
+      other: AccumulatorV2[jl.Long, java.util.List[(jl.Integer, jl.Long)]]): 
Unit
+  = other match {
+    case o: MapperRowCounter =>
+      this.synchronized(getOrCreate.addAll(o.value))
+    case _ =>
+      throw new UnsupportedOperationException(
+        s"Cannot merge ${this.getClass.getName} with 
${other.getClass.getName}")
+  }
+
+  override def value: java.util.List[(jl.Integer, jl.Long)] = 
this.synchronized(getOrCreate)
+}
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/util/MapperRowCounterSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/util/MapperRowCounterSuite.scala
new file mode 100644
index 00000000000..3f4d32ed9fa
--- /dev/null
+++ 
b/sql/core/src/test/scala/org/apache/spark/sql/util/MapperRowCounterSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.sql.util
+
+import org.apache.spark.SparkFunSuite
+
+class MapperRowCounterSuite extends SparkFunSuite {
+
+  test("Test MapperRowCounter") {
+    val counter = new MapperRowCounter()
+    assert(counter.isZero)
+
+    counter.setPartitionId(0)
+    counter.add(1L)
+    counter.add(1L)
+    assert(counter.value.get(0)._1 == 0L)
+    assert(counter.value.get(0)._2 == 2L)
+
+    counter.reset()
+    assert(counter.isZero)
+    counter.setPartitionId(100)
+    counter.add(1L)
+    assert(counter.value.get(0)._1 == 100L)
+    assert(counter.value.get(0)._2 == 1L)
+
+    val counter2 = new MapperRowCounter()
+    counter2.setPartitionId(40)
+    counter2.add(1L)
+    counter2.add(1L)
+    counter2.add(1L)
+
+    counter.merge(counter2)
+    assert(counter.value.size() == 2)
+    assert(counter.value.get(0)._1 == 100L)
+    assert(counter.value.get(0)._2 == 1L)
+    assert(counter.value.get(1)._1 == 40L)
+    assert(counter.value.get(1)._2 == 3L)
+  }
+}
+


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to