This is an automated email from the ASF dual-hosted git repository.
hongze pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-gluten.git
The following commit(s) were added to refs/heads/main by this push:
new 88ec4469d8 [VL] Fix `RetryOnOomMemoryTarget` only spills one single
consumer on retrying (#8262)
88ec4469d8 is described below
commit 88ec4469d835884c431d62c3394a0e6fdf1d68a9
Author: Hongze Zhang <[email protected]>
AuthorDate: Wed Dec 18 11:38:05 2024 +0800
[VL] Fix `RetryOnOomMemoryTarget` only spills one single consumer on
retrying (#8262)
---
.../memory/memtarget/MemoryTargetVisitor.java | 2 +-
.../gluten/memory/memtarget/MemoryTargets.java | 9 +-
.../gluten/memory/memtarget/TreeMemoryTargets.java | 159 +--------------------
.../memory/memtarget/spark/TreeMemoryConsumer.java | 145 ++++++++++++++++++-
.../memtarget/spark/TreeMemoryConsumers.java | 22 +--
.../org/apache/spark/memory/SparkMemoryUtil.scala | 2 +-
.../memtarget/spark/TreeMemoryConsumerTest.java | 38 ++++-
7 files changed, 193 insertions(+), 184 deletions(-)
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java
index a42a51e0ce..f6ef49a789 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargetVisitor.java
@@ -28,7 +28,7 @@ public interface MemoryTargetVisitor<T> {
T visit(TreeMemoryConsumer treeMemoryConsumer);
- T visit(TreeMemoryTargets.Node node);
+ T visit(TreeMemoryConsumer.Node node);
T visit(LoggingMemoryTarget loggingMemoryTarget);
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
index 1997ce61d2..c6f5b59de8 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/MemoryTargets.java
@@ -64,9 +64,11 @@ public final class MemoryTargets {
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.factory(tmm);
if (GlutenConfig.getConf().memoryIsolation()) {
- return factory.newIsolatedConsumer(name, spiller, virtualChildren);
+ return TreeMemoryTargets.newChild(factory.isolatedRoot(), name, spiller,
virtualChildren);
}
- final TreeMemoryTarget consumer = factory.newLegacyConsumer(name, spiller,
virtualChildren);
+ final TreeMemoryTarget root = factory.legacyRoot();
+ final TreeMemoryTarget consumer =
+ TreeMemoryTargets.newChild(root, name, spiller, virtualChildren);
if (SparkEnv.get() == null) {
// We are likely in test code. Return the consumer directly.
LOGGER.info("SparkEnv not found. We are likely in test code.");
@@ -86,7 +88,8 @@ public final class MemoryTargets {
consumer,
() -> {
LOGGER.info("Request for spilling on consumer {}...",
consumer.name());
- long spilled = TreeMemoryTargets.spillTree(consumer, Long.MAX_VALUE);
+ // Note: Spill from root node so other consumers also get spilled.
+ long spilled = TreeMemoryTargets.spillTree(root, Long.MAX_VALUE);
LOGGER.info("Consumer {} spilled {} bytes.", consumer.name(),
spilled);
});
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
index 598317a3c4..6d94e72069 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/TreeMemoryTargets.java
@@ -17,18 +17,10 @@
package org.apache.gluten.memory.memtarget;
import org.apache.gluten.memory.MemoryUsageStatsBuilder;
-import org.apache.gluten.memory.SimpleMemoryUsageRecorder;
-import org.apache.gluten.proto.MemoryUsageStats;
-import com.google.common.base.Preconditions;
-import org.apache.spark.util.Utils;
-
-import java.util.Collections;
-import java.util.HashMap;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Queue;
-import java.util.stream.Collectors;
public class TreeMemoryTargets {
@@ -36,13 +28,16 @@ public class TreeMemoryTargets {
// enclose factory ctor
}
- public static TreeMemoryTarget newChild(
+ /**
+ * A short-cut method to create a child target of `parent`. The child will
follow the parent's
+ * maximum capacity.
+ */
+ static TreeMemoryTarget newChild(
TreeMemoryTarget parent,
String name,
- long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- return new Node(parent, name, capacity, spiller, virtualChildren);
+ return parent.newChild(name, TreeMemoryTarget.CAPACITY_UNLIMITED, spiller,
virtualChildren);
}
public static long spillTree(TreeMemoryTarget node, final long bytes) {
@@ -83,146 +78,4 @@ public class TreeMemoryTargets {
return bytes - remainingBytes;
}
-
- // non-root nodes are not Spark memory consumer
- public static class Node implements TreeMemoryTarget, KnownNameAndStats {
- private final Map<String, Node> children = new HashMap<>();
- private final TreeMemoryTarget parent;
- private final String name;
- private final long capacity;
- private final Spiller spiller;
- private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
- private final SimpleMemoryUsageRecorder selfRecorder = new
SimpleMemoryUsageRecorder();
-
- private Node(
- TreeMemoryTarget parent,
- String name,
- long capacity,
- Spiller spiller,
- Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- this.parent = parent;
- this.capacity = capacity;
- final String uniqueName = MemoryTargetUtil.toUniqueName(name);
- if (capacity == CAPACITY_UNLIMITED) {
- this.name = uniqueName;
- } else {
- this.name = String.format("%s, %s", uniqueName,
Utils.bytesToString(capacity));
- }
- this.spiller = spiller;
- this.virtualChildren = virtualChildren;
- }
-
- @Override
- public long borrow(long size) {
- if (size == 0) {
- return 0;
- }
- ensureFreeCapacity(size);
- return borrow0(Math.min(freeBytes(), size));
- }
-
- private long freeBytes() {
- return capacity - usedBytes();
- }
-
- private long borrow0(long size) {
- long granted = parent.borrow(size);
- selfRecorder.inc(granted);
- return granted;
- }
-
- @Override
- public Spiller getNodeSpiller() {
- return spiller;
- }
-
- private boolean ensureFreeCapacity(long bytesNeeded) {
- while (true) { // FIXME should we add retry limit?
- long freeBytes = freeBytes();
- Preconditions.checkState(freeBytes >= 0);
- if (freeBytes >= bytesNeeded) {
- // free bytes fit requirement
- return true;
- }
- // spill
- long bytesToSpill = bytesNeeded - freeBytes;
- long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill);
- Preconditions.checkState(spilledBytes >= 0);
- if (spilledBytes == 0) {
- // OOM
- return false;
- }
- }
- }
-
- @Override
- public long repay(long size) {
- if (size == 0) {
- return 0;
- }
- long toFree = Math.min(usedBytes(), size);
- long freed = parent.repay(toFree);
- selfRecorder.inc(-freed);
- return freed;
- }
-
- @Override
- public long usedBytes() {
- return selfRecorder.current();
- }
-
- @Override
- public <T> T accept(MemoryTargetVisitor<T> visitor) {
- return visitor.visit(this);
- }
-
- @Override
- public String name() {
- return name;
- }
-
- @Override
- public MemoryUsageStats stats() {
- final Map<String, MemoryUsageStats> childrenStats =
- new HashMap<>(
- children.entrySet().stream()
- .collect(Collectors.toMap(e -> e.getValue().name(), e ->
e.getValue().stats())));
-
- Preconditions.checkState(childrenStats.size() == children.size());
-
- // add virtual children
- for (Map.Entry<String, MemoryUsageStatsBuilder> entry :
virtualChildren.entrySet()) {
- if (childrenStats.containsKey(entry.getKey())) {
- throw new IllegalArgumentException("Child stats already exists: " +
entry.getKey());
- }
- childrenStats.put(entry.getKey(), entry.getValue().toStats());
- }
- return selfRecorder.toStats(childrenStats);
- }
-
- @Override
- public TreeMemoryTarget newChild(
- String name,
- long capacity,
- Spiller spiller,
- Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- final Node child =
- new Node(this, name, Math.min(this.capacity, capacity), spiller,
virtualChildren);
- if (children.containsKey(child.name())) {
- throw new IllegalArgumentException("Child already registered: " +
child.name());
- }
- children.put(child.name(), child);
- return child;
- }
-
- @Override
- public Map<String, TreeMemoryTarget> children() {
- return Collections.unmodifiableMap(children);
- }
-
- @Override
- public TreeMemoryTarget parent() {
- return parent;
- }
- }
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
index 1289a01c34..38ac7d9733 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumer.java
@@ -25,6 +25,7 @@ import com.google.common.base.Preconditions;
import org.apache.spark.memory.MemoryConsumer;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
import java.io.IOException;
import java.util.Collections;
@@ -121,8 +122,7 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
long capacity,
Spiller spiller,
Map<String, MemoryUsageStatsBuilder> virtualChildren) {
- final TreeMemoryTarget child =
- TreeMemoryTargets.newChild(this, name, capacity, spiller,
virtualChildren);
+ final TreeMemoryTarget child = new Node(this, name, capacity, spiller,
virtualChildren);
if (children.containsKey(child.name())) {
throw new IllegalArgumentException("Child already registered: " +
child.name());
}
@@ -151,4 +151,145 @@ public class TreeMemoryConsumer extends MemoryConsumer
implements TreeMemoryTarg
public TaskMemoryManager getTaskMemoryManager() {
return taskMemoryManager;
}
+
+ public static class Node implements TreeMemoryTarget, KnownNameAndStats {
+ private final Map<String, Node> children = new HashMap<>();
+ private final TreeMemoryTarget parent;
+ private final String name;
+ private final long capacity;
+ private final Spiller spiller;
+ private final Map<String, MemoryUsageStatsBuilder> virtualChildren;
+ private final SimpleMemoryUsageRecorder selfRecorder = new
SimpleMemoryUsageRecorder();
+
+ private Node(
+ TreeMemoryTarget parent,
+ String name,
+ long capacity,
+ Spiller spiller,
+ Map<String, MemoryUsageStatsBuilder> virtualChildren) {
+ this.parent = parent;
+ this.capacity = capacity;
+ final String uniqueName = MemoryTargetUtil.toUniqueName(name);
+ if (capacity == TreeMemoryTarget.CAPACITY_UNLIMITED) {
+ this.name = uniqueName;
+ } else {
+ this.name = String.format("%s, %s", uniqueName,
Utils.bytesToString(capacity));
+ }
+ this.spiller = spiller;
+ this.virtualChildren = virtualChildren;
+ }
+
+ @Override
+ public long borrow(long size) {
+ if (size == 0) {
+ return 0;
+ }
+ ensureFreeCapacity(size);
+ return borrow0(Math.min(freeBytes(), size));
+ }
+
+ private long freeBytes() {
+ return capacity - usedBytes();
+ }
+
+ private long borrow0(long size) {
+ long granted = parent.borrow(size);
+ selfRecorder.inc(granted);
+ return granted;
+ }
+
+ @Override
+ public Spiller getNodeSpiller() {
+ return spiller;
+ }
+
+ private boolean ensureFreeCapacity(long bytesNeeded) {
+ while (true) { // FIXME should we add retry limit?
+ long freeBytes = freeBytes();
+ Preconditions.checkState(freeBytes >= 0);
+ if (freeBytes >= bytesNeeded) {
+ // free bytes fit requirement
+ return true;
+ }
+ // spill
+ long bytesToSpill = bytesNeeded - freeBytes;
+ long spilledBytes = TreeMemoryTargets.spillTree(this, bytesToSpill);
+ Preconditions.checkState(spilledBytes >= 0);
+ if (spilledBytes == 0) {
+ // OOM
+ return false;
+ }
+ }
+ }
+
+ @Override
+ public long repay(long size) {
+ if (size == 0) {
+ return 0;
+ }
+ long toFree = Math.min(usedBytes(), size);
+ long freed = parent.repay(toFree);
+ selfRecorder.inc(-freed);
+ return freed;
+ }
+
+ @Override
+ public long usedBytes() {
+ return selfRecorder.current();
+ }
+
+ @Override
+ public <T> T accept(MemoryTargetVisitor<T> visitor) {
+ return visitor.visit(this);
+ }
+
+ @Override
+ public String name() {
+ return name;
+ }
+
+ @Override
+ public MemoryUsageStats stats() {
+ final Map<String, MemoryUsageStats> childrenStats =
+ new HashMap<>(
+ children.entrySet().stream()
+ .collect(Collectors.toMap(e -> e.getValue().name(), e ->
e.getValue().stats())));
+
+ Preconditions.checkState(childrenStats.size() == children.size());
+
+ // add virtual children
+ for (Map.Entry<String, MemoryUsageStatsBuilder> entry :
virtualChildren.entrySet()) {
+ if (childrenStats.containsKey(entry.getKey())) {
+ throw new IllegalArgumentException("Child stats already exists: " +
entry.getKey());
+ }
+ childrenStats.put(entry.getKey(), entry.getValue().toStats());
+ }
+ return selfRecorder.toStats(childrenStats);
+ }
+
+ @Override
+ public TreeMemoryTarget newChild(
+ String name,
+ long capacity,
+ Spiller spiller,
+ Map<String, MemoryUsageStatsBuilder> virtualChildren) {
+ final Node child =
+ new Node(this, name, Math.min(this.capacity, capacity), spiller,
virtualChildren);
+ if (children.containsKey(child.name())) {
+ throw new IllegalArgumentException("Child already registered: " +
child.name());
+ }
+ children.put(child.name(), child);
+ return child;
+ }
+
+ @Override
+ public Map<String, TreeMemoryTarget> children() {
+ return Collections.unmodifiableMap(children);
+ }
+
+ @Override
+ public TreeMemoryTarget parent() {
+ return parent;
+ }
+ }
}
diff --git
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
index e8bfb5cf75..a11a4a3e4a 100644
---
a/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
+++
b/gluten-core/src/main/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumers.java
@@ -17,8 +17,6 @@
package org.apache.gluten.memory.memtarget.spark;
import org.apache.gluten.GlutenConfig;
-import org.apache.gluten.memory.MemoryUsageStatsBuilder;
-import org.apache.gluten.memory.memtarget.Spiller;
import org.apache.gluten.memory.memtarget.Spillers;
import org.apache.gluten.memory.memtarget.TreeMemoryTarget;
@@ -61,22 +59,12 @@ public final class TreeMemoryConsumers {
Collections.emptyMap()));
}
- private TreeMemoryTarget legacyRoot() {
- return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
- }
-
- private TreeMemoryTarget isolatedRoot() {
- return
ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
- }
-
/**
* This works as a legacy Spark memory consumer which grants as much as
possible of memory
* capacity to each task.
*/
- public TreeMemoryTarget newLegacyConsumer(
- String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder>
virtualChildren) {
- final TreeMemoryTarget parent = legacyRoot();
- return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED,
spiller, virtualChildren);
+ public TreeMemoryTarget legacyRoot() {
+ return ofCapacity(TreeMemoryTarget.CAPACITY_UNLIMITED);
}
/**
@@ -88,10 +76,8 @@ public final class TreeMemoryConsumers {
*
* <p>See <a
href="https://github.com/oap-project/gluten/issues/3030">GLUTEN-3030</a>
*/
- public TreeMemoryTarget newIsolatedConsumer(
- String name, Spiller spiller, Map<String, MemoryUsageStatsBuilder>
virtualChildren) {
- final TreeMemoryTarget parent = isolatedRoot();
- return parent.newChild(name, TreeMemoryConsumer.CAPACITY_UNLIMITED,
spiller, virtualChildren);
+ public TreeMemoryTarget isolatedRoot() {
+ return
ofCapacity(GlutenConfig.getConf().conservativeTaskOffHeapMemorySize());
}
}
}
diff --git
a/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala
b/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala
index 637ef8b22f..338854cf08 100644
--- a/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala
+++ b/gluten-core/src/main/scala/org/apache/spark/memory/SparkMemoryUtil.scala
@@ -111,7 +111,7 @@ object SparkMemoryUtil {
collectFromTaskMemoryManager(treeMemoryConsumer.getTaskMemoryManager)
}
- override def visit(node: TreeMemoryTargets.Node): String = {
+ override def visit(node: TreeMemoryConsumer.Node): String = {
node.parent().accept(this) // walk up to find the one bound with task
memory manager
}
diff --git
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
index 934300a1ac..6cb38fe8d5 100644
---
a/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
+++
b/gluten-core/src/test/java/org/apache/gluten/memory/memtarget/spark/TreeMemoryConsumerTest.java
@@ -52,7 +52,13 @@ public class TreeMemoryConsumerTest {
final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
final TreeMemoryTarget consumer =
- factory.newIsolatedConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
+ factory
+ .isolatedRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
Assert.assertEquals(20, consumer.borrow(20));
Assert.assertEquals(70, consumer.borrow(70));
Assert.assertEquals(10, consumer.borrow(20));
@@ -67,7 +73,13 @@ public class TreeMemoryConsumerTest {
final TreeMemoryConsumers.Factory factory =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager());
final TreeMemoryTarget consumer =
- factory.newLegacyConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
+ factory
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
Assert.assertEquals(20, consumer.borrow(20));
Assert.assertEquals(70, consumer.borrow(70));
Assert.assertEquals(20, consumer.borrow(20));
@@ -81,11 +93,21 @@ public class TreeMemoryConsumerTest {
() -> {
final TreeMemoryTarget legacy =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
- .newLegacyConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
+ .legacyRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
Assert.assertEquals(110, legacy.borrow(110));
final TreeMemoryTarget isolated =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
- .newIsolatedConsumer("FOO", Spillers.NOOP,
Collections.emptyMap());
+ .isolatedRoot()
+ .newChild(
+ "FOO",
+ TreeMemoryTarget.CAPACITY_UNLIMITED,
+ Spillers.NOOP,
+ Collections.emptyMap());
Assert.assertEquals(100, isolated.borrow(110));
});
}
@@ -97,7 +119,9 @@ public class TreeMemoryConsumerTest {
final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
final TreeMemoryTarget legacy =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
- .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
+ .legacyRoot()
+ .newChild(
+ "FOO", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers,
Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(
@@ -131,7 +155,9 @@ public class TreeMemoryConsumerTest {
final Spillers.AppendableSpillerList spillers =
Spillers.appendable();
final TreeMemoryTarget legacy =
TreeMemoryConsumers.factory(TaskContext.get().taskMemoryManager())
- .newLegacyConsumer("FOO", spillers, Collections.emptyMap());
+ .legacyRoot()
+ .newChild(
+ "FOO", TreeMemoryTarget.CAPACITY_UNLIMITED, spillers,
Collections.emptyMap());
final AtomicInteger numSpills = new AtomicInteger(0);
final AtomicLong numSpilledBytes = new AtomicLong(0L);
spillers.append(
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]