This is an automated email from the ASF dual-hosted git repository.
jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git
The following commit(s) were added to refs/heads/master by this push:
new b938a67c7d Revert "[IOTDB-5586] Reduce the scope of lock in MemoryPool"
b938a67c7d is described below
commit b938a67c7dc100b18adcd97cf862b9f656445016
Author: Liao Lanyu <[email protected]>
AuthorDate: Wed Mar 22 20:19:02 2023 +0800
Revert "[IOTDB-5586] Reduce the scope of lock in MemoryPool"
---
.../iotdb/db/mpp/execution/memory/MemoryPool.java | 237 ++++++++++-----------
.../iotdb/db/mpp/execution/exchange/Utils.java | 4 +-
.../db/mpp/execution/memory/MemoryPoolTest.java | 26 ++-
3 files changed, 132 insertions(+), 135 deletions(-)
diff --git
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 7d4cda4b34..3c9dbb10e3 100644
---
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -33,13 +33,12 @@ import javax.annotation.Nullable;
import java.util.ArrayList;
import java.util.Collections;
+import java.util.HashMap;
import java.util.Iterator;
+import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.atomic.AtomicLong;
/** A thread-safe memory pool. */
public class MemoryPool {
@@ -58,8 +57,6 @@ public class MemoryPool {
*/
private final long maxBytesCanReserve;
- private boolean isMarked = false;
-
private MemoryReservationFuture(
String queryId,
String fragmentInstanceId,
@@ -80,14 +77,6 @@ public class MemoryPool {
return queryId;
}
- public boolean isMarked() {
- return isMarked;
- }
-
- public void setMarked(boolean marked) {
- isMarked = marked;
- }
-
public String getFragmentInstanceId() {
return fragmentInstanceId;
}
@@ -124,13 +113,12 @@ public class MemoryPool {
private final long maxBytes;
private final long maxBytesPerFragmentInstance;
- private final AtomicLong remainingBytes;
+ private long reservedBytes = 0L;
/** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
private final Map<String, Map<String, Map<String, Long>>>
queryMemoryReservations =
- new ConcurrentHashMap<>();
+ new HashMap<>();
- private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures =
- new ConcurrentLinkedQueue<>();
+ private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures
= new LinkedList<>();
public MemoryPool(String id, long maxBytes, long
maxBytesPerFragmentInstance) {
this.id = Validate.notNull(id);
@@ -142,13 +130,16 @@ public class MemoryPool {
maxBytesPerFragmentInstance,
maxBytes);
this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
- this.remainingBytes = new AtomicLong(maxBytes);
}
public String getId() {
return id;
}
+ public long getMaxBytes() {
+ return maxBytes;
+ }
+
/**
* Reserve memory with bytesToReserve.
*
@@ -178,23 +169,37 @@ public class MemoryPool {
}
ListenableFuture<Void> result;
- if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve,
maxBytesCanReserve)) {
- result = Futures.immediateFuture(null);
- return new Pair<>(result, Boolean.TRUE);
- } else {
- LOGGER.debug(
- "Blocked reserve request: {} bytes memory for planNodeId{}",
bytesToReserve, planNodeId);
- rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
- result =
- MemoryReservationFuture.create(
- queryId, fragmentInstanceId, planNodeId, bytesToReserve,
maxBytesCanReserve);
- memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
- return new Pair<>(result, Boolean.FALSE);
+ synchronized (this) {
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
+ LOGGER.debug(
+ "Blocked reserve request: {} bytes memory for planNodeId{}",
+ bytesToReserve,
+ planNodeId);
+ result =
+ MemoryReservationFuture.create(
+ queryId, fragmentInstanceId, planNodeId, bytesToReserve,
maxBytesCanReserve);
+ memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+ return new Pair<>(result, Boolean.FALSE);
+ } else {
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations
+ .computeIfAbsent(queryId, x -> new HashMap<>())
+ .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+ .merge(planNodeId, bytesToReserve, Long::sum);
+ result = Futures.immediateFuture(null);
+ return new Pair<>(result, Boolean.TRUE);
+ }
}
}
@TestOnly
- public boolean tryReserveForTest(
+ public boolean tryReserve(
String queryId,
String fragmentInstanceId,
String planNodeId,
@@ -208,12 +213,32 @@ public class MemoryPool {
"bytes should be greater than zero while less than or equal to max
bytes per fragment instance: %d",
bytesToReserve);
- if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve,
maxBytesCanReserve)) {
- return true;
- } else {
- rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
return false;
}
+ synchronized (this) {
+ if (maxBytes - reservedBytes < bytesToReserve
+ || maxBytesCanReserve
+ - queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .getOrDefault(planNodeId, 0L)
+ < bytesToReserve) {
+ return false;
+ }
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations
+ .computeIfAbsent(queryId, x -> new HashMap<>())
+ .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+ .merge(planNodeId, bytesToReserve, Long::sum);
+ }
+ return true;
}
/**
@@ -257,46 +282,58 @@ public class MemoryPool {
}
public void free(String queryId, String fragmentInstanceId, String
planNodeId, long bytes) {
- Validate.notNull(queryId);
- Validate.isTrue(bytes > 0L);
-
- Long queryReservedBytes =
- queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .getOrDefault(fragmentInstanceId, Collections.emptyMap())
- .computeIfPresent(
- planNodeId,
- (k, reservedMemory) -> {
- if (reservedMemory < bytes) {
- throw new IllegalArgumentException("Free more memory than
has been reserved.");
- }
- return reservedMemory - bytes;
- });
- remainingBytes.addAndGet(bytes);
-
List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
- if (memoryReservationFutures.isEmpty()) {
- return;
- }
- Iterator<MemoryReservationFuture<Void>> iterator =
memoryReservationFutures.iterator();
- while (iterator.hasNext()) {
- MemoryReservationFuture<Void> future = iterator.next();
- synchronized (future) {
- if (future.isCancelled() || future.isDone() || future.isMarked()) {
+ synchronized (this) {
+ Validate.notNull(queryId);
+ Validate.isTrue(bytes > 0L);
+
+ Long queryReservedBytes =
+ queryMemoryReservations
+ .getOrDefault(queryId, Collections.emptyMap())
+ .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+ .get(planNodeId);
+ Validate.notNull(queryReservedBytes);
+ Validate.isTrue(bytes <= queryReservedBytes);
+
+ queryReservedBytes -= bytes;
+ queryMemoryReservations
+ .get(queryId)
+ .get(fragmentInstanceId)
+ .put(planNodeId, queryReservedBytes);
+
+ reservedBytes -= bytes;
+
+ if (memoryReservationFutures.isEmpty()) {
+ return;
+ }
+ Iterator<MemoryReservationFuture<Void>> iterator =
memoryReservationFutures.iterator();
+ while (iterator.hasNext()) {
+ MemoryReservationFuture<Void> future = iterator.next();
+ if (future.isCancelled() || future.isDone()) {
continue;
}
long bytesToReserve = future.getBytesToReserve();
String curQueryId = future.getQueryId();
String curFragmentInstanceId = future.getFragmentInstanceId();
String curPlanNodeId = future.getPlanNodeId();
- long maxBytesCanReserve = future.getMaxBytesCanReserve();
- if (tryReserve(
- curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve,
maxBytesCanReserve)) {
+ // check total reserved bytes in memory pool
+ if (maxBytes - reservedBytes < bytesToReserve) {
+ continue;
+ }
+ // check total reserved bytes of one Sink/Source handle
+ if (future.getMaxBytesCanReserve()
+ - queryMemoryReservations
+ .getOrDefault(curQueryId, Collections.emptyMap())
+ .getOrDefault(curFragmentInstanceId,
Collections.emptyMap())
+ .getOrDefault(curPlanNodeId, 0L)
+ >= bytesToReserve) {
+ reservedBytes += bytesToReserve;
+ queryMemoryReservations
+ .computeIfAbsent(curQueryId, x -> new HashMap<>())
+ .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
+ .merge(curPlanNodeId, bytesToReserve, Long::sum);
futureList.add(future);
- future.setMarked(true);
iterator.remove();
- } else {
- rollbackReserve(curQueryId, curFragmentInstanceId, curPlanNodeId,
bytesToReserve);
}
}
}
@@ -305,10 +342,10 @@ public class MemoryPool {
// If we put this block inside the MemoryPool's lock, we will get deadlock
case like the
// following:
// Assuming that thread-A: LocalSourceHandle.receive() ->
A-SharedTsBlockQueue.remove() ->
- // MemoryPool.free() (hold future's lock) -> future.set(null) -> try to get
+ // MemoryPool.free() (hold MemoryPool's lock) -> future.set(null) -> try
to get
// B-SharedTsBlockQueue's lock
// thread-B: LocalSourceHandle.receive() -> B-SharedTsBlockQueue.remove()
(hold
- // B-SharedTsBlockQueue's lock) -> try to get future's lock
+ // B-SharedTsBlockQueue's lock) -> try to get MemoryPool's lock
for (MemoryReservationFuture<Void> future : futureList) {
try {
future.set(null);
@@ -331,64 +368,26 @@ public class MemoryPool {
}
public long getReservedBytes() {
- return maxBytes - remainingBytes.get();
+ return reservedBytes;
}
- public void clearMemoryReservationMap(
+ public synchronized void clearMemoryReservationMap(
String queryId, String fragmentInstanceId, String planNodeId) {
- Map<String, Map<String, Long>> instanceBytesReserved =
queryMemoryReservations.get(queryId);
- Map<String, Long> planNodeIdToBytesReserved =
- queryMemoryReservations
- .getOrDefault(queryId, Collections.emptyMap())
- .get(fragmentInstanceId);
- if (instanceBytesReserved == null || planNodeIdToBytesReserved == null) {
+ if (queryMemoryReservations.get(queryId) == null
+ || queryMemoryReservations.get(queryId).get(fragmentInstanceId) ==
null) {
return;
}
-
+ Map<String, Long> planNodeIdToBytesReserved =
+ queryMemoryReservations.get(queryId).get(fragmentInstanceId);
if (planNodeIdToBytesReserved.get(planNodeId) == null
- || planNodeIdToBytesReserved.get(planNodeId) == 0) {
+ || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
planNodeIdToBytesReserved.remove(planNodeId);
- instanceBytesReserved.computeIfPresent(
- fragmentInstanceId,
- (k, kPlanNodeBytesReserved) -> {
- if (kPlanNodeBytesReserved.isEmpty()) {
- return null;
- }
- return kPlanNodeBytesReserved;
- });
- queryMemoryReservations.computeIfPresent(
- queryId,
- (k, kInstanceBytesReserved) -> {
- if (kInstanceBytesReserved.isEmpty()) {
- return null;
- }
- return kInstanceBytesReserved;
- });
+ if (planNodeIdToBytesReserved.isEmpty()) {
+ queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
+ }
+ if (queryMemoryReservations.get(queryId).isEmpty()) {
+ queryMemoryReservations.remove(queryId);
+ }
}
}
-
- private boolean tryReserve(
- String queryId,
- String fragmentInstanceId,
- String planNodeId,
- long bytesToReserve,
- long maxBytesCanReserve) {
- long tryRemainingBytes = remainingBytes.addAndGet(-bytesToReserve);
- long queryRemainingBytes =
- maxBytesCanReserve
- - queryMemoryReservations
- .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
- .computeIfAbsent(fragmentInstanceId, x -> new
ConcurrentHashMap<>())
- .merge(planNodeId, bytesToReserve, Long::sum);
- return tryRemainingBytes >= 0 && queryRemainingBytes >= 0;
- }
-
- private void rollbackReserve(
- String queryId, String fragmentInstanceId, String planNodeId, long
bytesToReserve) {
- queryMemoryReservations
- .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
- .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
- .merge(planNodeId, -bytesToReserve, Long::sum);
- remainingBytes.addAndGet(bytesToReserve);
- }
}
diff --git
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
index c843a77621..0ea87d58fd 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
@@ -100,7 +100,7 @@ public class Utils {
Mockito.eq(planNodeId),
Mockito.anyLong());
Mockito.when(
- mockMemoryPool.tryReserveForTest(
+ mockMemoryPool.tryReserve(
Mockito.eq(queryId),
Mockito.eq(fragmentInstanceId),
Mockito.eq(planNodeId),
@@ -130,7 +130,7 @@ public class Utils {
Mockito.anyLong()))
.thenReturn(new Pair<>(immediateFuture(null), true));
Mockito.when(
- mockMemoryPool.tryReserveForTest(
+ mockMemoryPool.tryReserve(
Mockito.anyString(),
Mockito.anyString(),
Mockito.anyString(),
diff --git
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index d82f2760d8..94a7c81999 100644
---
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -42,7 +42,7 @@ public class MemoryPoolTest {
public void testTryReserve() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
256L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L,
Long.MAX_VALUE));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@@ -51,7 +51,7 @@ public class MemoryPoolTest {
public void testTryReserveZero() {
try {
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L,
Long.MAX_VALUE);
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L,
Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -61,7 +61,7 @@ public class MemoryPoolTest {
public void testTryReserveNegative() {
try {
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
-1L, Long.MAX_VALUE);
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L,
Long.MAX_VALUE);
Assert.fail("Expect IllegalArgumentException");
} catch (IllegalArgumentException ignore) {
}
@@ -71,7 +71,7 @@ public class MemoryPoolTest {
public void testTryReserveAll() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
}
@@ -79,12 +79,10 @@ public class MemoryPoolTest {
@Test
public void testOverTryReserve() {
- Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
256L, 512L));
+ Assert.assertTrue(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID,
PLAN_NODE_ID, 256L, 512L));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
- Assert.assertFalse(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, 511L));
+ Assert.assertFalse(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID,
PLAN_NODE_ID, 512L, 511L));
Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(256L, pool.getReservedBytes());
}
@@ -208,7 +206,7 @@ public class MemoryPoolTest {
public void testFree() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -221,7 +219,7 @@ public class MemoryPoolTest {
public void testFreeAll() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -234,7 +232,7 @@ public class MemoryPoolTest {
public void testFreeZero() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -249,7 +247,7 @@ public class MemoryPoolTest {
public void testFreeNegative() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -264,7 +262,7 @@ public class MemoryPoolTest {
public void testOverFree() {
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
Assert.assertEquals(512L, pool.getReservedBytes());
@@ -280,7 +278,7 @@ public class MemoryPoolTest {
// Run out of memory.
Assert.assertTrue(
- pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID,
512L, Long.MAX_VALUE));
+ pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L,
Long.MAX_VALUE));
ListenableFuture<Void> f =
pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L,
512L).left;