This is an automated email from the ASF dual-hosted git repository.

jackietien pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/iotdb.git


The following commit(s) were added to refs/heads/master by this push:
     new b938a67c7d Revert "[IOTDB-5586] Reduce the scope of lock in MemoryPool"
b938a67c7d is described below

commit b938a67c7dc100b18adcd97cf862b9f656445016
Author: Liao Lanyu <[email protected]>
AuthorDate: Wed Mar 22 20:19:02 2023 +0800

    Revert "[IOTDB-5586] Reduce the scope of lock in MemoryPool"
---
 .../iotdb/db/mpp/execution/memory/MemoryPool.java  | 237 ++++++++++-----------
 .../iotdb/db/mpp/execution/exchange/Utils.java     |   4 +-
 .../db/mpp/execution/memory/MemoryPoolTest.java    |  26 ++-
 3 files changed, 132 insertions(+), 135 deletions(-)

diff --git 
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java 
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
index 7d4cda4b34..3c9dbb10e3 100644
--- 
a/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
+++ 
b/server/src/main/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPool.java
@@ -33,13 +33,12 @@ import javax.annotation.Nullable;
 
 import java.util.ArrayList;
 import java.util.Collections;
+import java.util.HashMap;
 import java.util.Iterator;
+import java.util.LinkedList;
 import java.util.List;
 import java.util.Map;
 import java.util.Queue;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentLinkedQueue;
-import java.util.concurrent.atomic.AtomicLong;
 
 /** A thread-safe memory pool. */
 public class MemoryPool {
@@ -58,8 +57,6 @@ public class MemoryPool {
      */
     private final long maxBytesCanReserve;
 
-    private boolean isMarked = false;
-
     private MemoryReservationFuture(
         String queryId,
         String fragmentInstanceId,
@@ -80,14 +77,6 @@ public class MemoryPool {
       return queryId;
     }
 
-    public boolean isMarked() {
-      return isMarked;
-    }
-
-    public void setMarked(boolean marked) {
-      isMarked = marked;
-    }
-
     public String getFragmentInstanceId() {
       return fragmentInstanceId;
     }
@@ -124,13 +113,12 @@ public class MemoryPool {
   private final long maxBytes;
   private final long maxBytesPerFragmentInstance;
 
-  private final AtomicLong remainingBytes;
+  private long reservedBytes = 0L;
   /** queryId -> fragmentInstanceId -> planNodeId -> bytesReserved */
   private final Map<String, Map<String, Map<String, Long>>> 
queryMemoryReservations =
-      new ConcurrentHashMap<>();
+      new HashMap<>();
 
-  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures =
-      new ConcurrentLinkedQueue<>();
+  private final Queue<MemoryReservationFuture<Void>> memoryReservationFutures 
= new LinkedList<>();
 
   public MemoryPool(String id, long maxBytes, long 
maxBytesPerFragmentInstance) {
     this.id = Validate.notNull(id);
@@ -142,13 +130,16 @@ public class MemoryPool {
         maxBytesPerFragmentInstance,
         maxBytes);
     this.maxBytesPerFragmentInstance = maxBytesPerFragmentInstance;
-    this.remainingBytes = new AtomicLong(maxBytes);
   }
 
   public String getId() {
     return id;
   }
 
+  public long getMaxBytes() {
+    return maxBytes;
+  }
+
   /**
    * Reserve memory with bytesToReserve.
    *
@@ -178,23 +169,37 @@ public class MemoryPool {
     }
 
     ListenableFuture<Void> result;
-    if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, 
maxBytesCanReserve)) {
-      result = Futures.immediateFuture(null);
-      return new Pair<>(result, Boolean.TRUE);
-    } else {
-      LOGGER.debug(
-          "Blocked reserve request: {} bytes memory for planNodeId{}", 
bytesToReserve, planNodeId);
-      rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
-      result =
-          MemoryReservationFuture.create(
-              queryId, fragmentInstanceId, planNodeId, bytesToReserve, 
maxBytesCanReserve);
-      memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
-      return new Pair<>(result, Boolean.FALSE);
+    synchronized (this) {
+      if (maxBytes - reservedBytes < bytesToReserve
+          || maxBytesCanReserve
+                  - queryMemoryReservations
+                      .getOrDefault(queryId, Collections.emptyMap())
+                      .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+                      .getOrDefault(planNodeId, 0L)
+              < bytesToReserve) {
+        LOGGER.debug(
+            "Blocked reserve request: {} bytes memory for planNodeId{}",
+            bytesToReserve,
+            planNodeId);
+        result =
+            MemoryReservationFuture.create(
+                queryId, fragmentInstanceId, planNodeId, bytesToReserve, 
maxBytesCanReserve);
+        memoryReservationFutures.add((MemoryReservationFuture<Void>) result);
+        return new Pair<>(result, Boolean.FALSE);
+      } else {
+        reservedBytes += bytesToReserve;
+        queryMemoryReservations
+            .computeIfAbsent(queryId, x -> new HashMap<>())
+            .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+            .merge(planNodeId, bytesToReserve, Long::sum);
+        result = Futures.immediateFuture(null);
+        return new Pair<>(result, Boolean.TRUE);
+      }
     }
   }
 
   @TestOnly
-  public boolean tryReserveForTest(
+  public boolean tryReserve(
       String queryId,
       String fragmentInstanceId,
       String planNodeId,
@@ -208,12 +213,32 @@ public class MemoryPool {
         "bytes should be greater than zero while less than or equal to max 
bytes per fragment instance: %d",
         bytesToReserve);
 
-    if (tryReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve, 
maxBytesCanReserve)) {
-      return true;
-    } else {
-      rollbackReserve(queryId, fragmentInstanceId, planNodeId, bytesToReserve);
+    if (maxBytes - reservedBytes < bytesToReserve
+        || maxBytesCanReserve
+                - queryMemoryReservations
+                    .getOrDefault(queryId, Collections.emptyMap())
+                    .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+                    .getOrDefault(planNodeId, 0L)
+            < bytesToReserve) {
       return false;
     }
+    synchronized (this) {
+      if (maxBytes - reservedBytes < bytesToReserve
+          || maxBytesCanReserve
+                  - queryMemoryReservations
+                      .getOrDefault(queryId, Collections.emptyMap())
+                      .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+                      .getOrDefault(planNodeId, 0L)
+              < bytesToReserve) {
+        return false;
+      }
+      reservedBytes += bytesToReserve;
+      queryMemoryReservations
+          .computeIfAbsent(queryId, x -> new HashMap<>())
+          .computeIfAbsent(fragmentInstanceId, x -> new HashMap<>())
+          .merge(planNodeId, bytesToReserve, Long::sum);
+    }
+    return true;
   }
 
   /**
@@ -257,46 +282,58 @@ public class MemoryPool {
   }
 
   public void free(String queryId, String fragmentInstanceId, String 
planNodeId, long bytes) {
-    Validate.notNull(queryId);
-    Validate.isTrue(bytes > 0L);
-
-    Long queryReservedBytes =
-        queryMemoryReservations
-            .getOrDefault(queryId, Collections.emptyMap())
-            .getOrDefault(fragmentInstanceId, Collections.emptyMap())
-            .computeIfPresent(
-                planNodeId,
-                (k, reservedMemory) -> {
-                  if (reservedMemory < bytes) {
-                    throw new IllegalArgumentException("Free more memory than 
has been reserved.");
-                  }
-                  return reservedMemory - bytes;
-                });
-    remainingBytes.addAndGet(bytes);
-
     List<MemoryReservationFuture<Void>> futureList = new ArrayList<>();
-    if (memoryReservationFutures.isEmpty()) {
-      return;
-    }
-    Iterator<MemoryReservationFuture<Void>> iterator = 
memoryReservationFutures.iterator();
-    while (iterator.hasNext()) {
-      MemoryReservationFuture<Void> future = iterator.next();
-      synchronized (future) {
-        if (future.isCancelled() || future.isDone() || future.isMarked()) {
+    synchronized (this) {
+      Validate.notNull(queryId);
+      Validate.isTrue(bytes > 0L);
+
+      Long queryReservedBytes =
+          queryMemoryReservations
+              .getOrDefault(queryId, Collections.emptyMap())
+              .getOrDefault(fragmentInstanceId, Collections.emptyMap())
+              .get(planNodeId);
+      Validate.notNull(queryReservedBytes);
+      Validate.isTrue(bytes <= queryReservedBytes);
+
+      queryReservedBytes -= bytes;
+      queryMemoryReservations
+          .get(queryId)
+          .get(fragmentInstanceId)
+          .put(planNodeId, queryReservedBytes);
+
+      reservedBytes -= bytes;
+
+      if (memoryReservationFutures.isEmpty()) {
+        return;
+      }
+      Iterator<MemoryReservationFuture<Void>> iterator = 
memoryReservationFutures.iterator();
+      while (iterator.hasNext()) {
+        MemoryReservationFuture<Void> future = iterator.next();
+        if (future.isCancelled() || future.isDone()) {
           continue;
         }
         long bytesToReserve = future.getBytesToReserve();
         String curQueryId = future.getQueryId();
         String curFragmentInstanceId = future.getFragmentInstanceId();
         String curPlanNodeId = future.getPlanNodeId();
-        long maxBytesCanReserve = future.getMaxBytesCanReserve();
-        if (tryReserve(
-            curQueryId, curFragmentInstanceId, curPlanNodeId, bytesToReserve, 
maxBytesCanReserve)) {
+        // check total reserved bytes in memory pool
+        if (maxBytes - reservedBytes < bytesToReserve) {
+          continue;
+        }
+        // check total reserved bytes of one Sink/Source handle
+        if (future.getMaxBytesCanReserve()
+                - queryMemoryReservations
+                    .getOrDefault(curQueryId, Collections.emptyMap())
+                    .getOrDefault(curFragmentInstanceId, 
Collections.emptyMap())
+                    .getOrDefault(curPlanNodeId, 0L)
+            >= bytesToReserve) {
+          reservedBytes += bytesToReserve;
+          queryMemoryReservations
+              .computeIfAbsent(curQueryId, x -> new HashMap<>())
+              .computeIfAbsent(curFragmentInstanceId, x -> new HashMap<>())
+              .merge(curPlanNodeId, bytesToReserve, Long::sum);
           futureList.add(future);
-          future.setMarked(true);
           iterator.remove();
-        } else {
-          rollbackReserve(curQueryId, curFragmentInstanceId, curPlanNodeId, 
bytesToReserve);
         }
       }
     }
@@ -305,10 +342,10 @@ public class MemoryPool {
     // If we put this block inside the MemoryPool's lock, we will get deadlock 
case like the
     // following:
     // Assuming that thread-A: LocalSourceHandle.receive() -> 
A-SharedTsBlockQueue.remove() ->
-    // MemoryPool.free() (hold future's lock) -> future.set(null) -> try to get
+    // MemoryPool.free() (hold MemoryPool's lock) -> future.set(null) -> try 
to get
     // B-SharedTsBlockQueue's lock
     // thread-B: LocalSourceHandle.receive() -> B-SharedTsBlockQueue.remove() 
(hold
-    // B-SharedTsBlockQueue's lock) -> try to get future's lock
+    // B-SharedTsBlockQueue's lock) -> try to get MemoryPool's lock
     for (MemoryReservationFuture<Void> future : futureList) {
       try {
         future.set(null);
@@ -331,64 +368,26 @@ public class MemoryPool {
   }
 
   public long getReservedBytes() {
-    return maxBytes - remainingBytes.get();
+    return reservedBytes;
   }
 
-  public void clearMemoryReservationMap(
+  public synchronized void clearMemoryReservationMap(
       String queryId, String fragmentInstanceId, String planNodeId) {
-    Map<String, Map<String, Long>> instanceBytesReserved = 
queryMemoryReservations.get(queryId);
-    Map<String, Long> planNodeIdToBytesReserved =
-        queryMemoryReservations
-            .getOrDefault(queryId, Collections.emptyMap())
-            .get(fragmentInstanceId);
-    if (instanceBytesReserved == null || planNodeIdToBytesReserved == null) {
+    if (queryMemoryReservations.get(queryId) == null
+        || queryMemoryReservations.get(queryId).get(fragmentInstanceId) == 
null) {
       return;
     }
-
+    Map<String, Long> planNodeIdToBytesReserved =
+        queryMemoryReservations.get(queryId).get(fragmentInstanceId);
     if (planNodeIdToBytesReserved.get(planNodeId) == null
-        || planNodeIdToBytesReserved.get(planNodeId) == 0) {
+        || planNodeIdToBytesReserved.get(planNodeId) <= 0) {
       planNodeIdToBytesReserved.remove(planNodeId);
-      instanceBytesReserved.computeIfPresent(
-          fragmentInstanceId,
-          (k, kPlanNodeBytesReserved) -> {
-            if (kPlanNodeBytesReserved.isEmpty()) {
-              return null;
-            }
-            return kPlanNodeBytesReserved;
-          });
-      queryMemoryReservations.computeIfPresent(
-          queryId,
-          (k, kInstanceBytesReserved) -> {
-            if (kInstanceBytesReserved.isEmpty()) {
-              return null;
-            }
-            return kInstanceBytesReserved;
-          });
+      if (planNodeIdToBytesReserved.isEmpty()) {
+        queryMemoryReservations.get(queryId).remove(fragmentInstanceId);
+      }
+      if (queryMemoryReservations.get(queryId).isEmpty()) {
+        queryMemoryReservations.remove(queryId);
+      }
     }
   }
-
-  private boolean tryReserve(
-      String queryId,
-      String fragmentInstanceId,
-      String planNodeId,
-      long bytesToReserve,
-      long maxBytesCanReserve) {
-    long tryRemainingBytes = remainingBytes.addAndGet(-bytesToReserve);
-    long queryRemainingBytes =
-        maxBytesCanReserve
-            - queryMemoryReservations
-                .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
-                .computeIfAbsent(fragmentInstanceId, x -> new 
ConcurrentHashMap<>())
-                .merge(planNodeId, bytesToReserve, Long::sum);
-    return tryRemainingBytes >= 0 && queryRemainingBytes >= 0;
-  }
-
-  private void rollbackReserve(
-      String queryId, String fragmentInstanceId, String planNodeId, long 
bytesToReserve) {
-    queryMemoryReservations
-        .computeIfAbsent(queryId, x -> new ConcurrentHashMap<>())
-        .computeIfAbsent(fragmentInstanceId, x -> new ConcurrentHashMap<>())
-        .merge(planNodeId, -bytesToReserve, Long::sum);
-    remainingBytes.addAndGet(bytesToReserve);
-  }
 }
diff --git 
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java 
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
index c843a77621..0ea87d58fd 100644
--- a/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
+++ b/server/src/test/java/org/apache/iotdb/db/mpp/execution/exchange/Utils.java
@@ -100,7 +100,7 @@ public class Utils {
             Mockito.eq(planNodeId),
             Mockito.anyLong());
     Mockito.when(
-            mockMemoryPool.tryReserveForTest(
+            mockMemoryPool.tryReserve(
                 Mockito.eq(queryId),
                 Mockito.eq(fragmentInstanceId),
                 Mockito.eq(planNodeId),
@@ -130,7 +130,7 @@ public class Utils {
                 Mockito.anyLong()))
         .thenReturn(new Pair<>(immediateFuture(null), true));
     Mockito.when(
-            mockMemoryPool.tryReserveForTest(
+            mockMemoryPool.tryReserve(
                 Mockito.anyString(),
                 Mockito.anyString(),
                 Mockito.anyString(),
diff --git 
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
 
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
index d82f2760d8..94a7c81999 100644
--- 
a/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
+++ 
b/server/src/test/java/org/apache/iotdb/db/mpp/execution/memory/MemoryPoolTest.java
@@ -42,7 +42,7 @@ public class MemoryPoolTest {
   public void testTryReserve() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
256L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 
Long.MAX_VALUE));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
   }
@@ -51,7 +51,7 @@ public class MemoryPoolTest {
   public void testTryReserveZero() {
 
     try {
-      pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, 
Long.MAX_VALUE);
+      pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 0L, 
Long.MAX_VALUE);
       Assert.fail("Expect IllegalArgumentException");
     } catch (IllegalArgumentException ignore) {
     }
@@ -61,7 +61,7 @@ public class MemoryPoolTest {
   public void testTryReserveNegative() {
 
     try {
-      pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
-1L, Long.MAX_VALUE);
+      pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, -1L, 
Long.MAX_VALUE);
       Assert.fail("Expect IllegalArgumentException");
     } catch (IllegalArgumentException ignore) {
     }
@@ -71,7 +71,7 @@ public class MemoryPoolTest {
   public void testTryReserveAll() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
   }
@@ -79,12 +79,10 @@ public class MemoryPoolTest {
   @Test
   public void testOverTryReserve() {
 
-    Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
256L, 512L));
+    Assert.assertTrue(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, 
PLAN_NODE_ID, 256L, 512L));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
-    Assert.assertFalse(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, 511L));
+    Assert.assertFalse(pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, 
PLAN_NODE_ID, 512L, 511L));
     Assert.assertEquals(256L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(256L, pool.getReservedBytes());
   }
@@ -208,7 +206,7 @@ public class MemoryPoolTest {
   public void testFree() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -221,7 +219,7 @@ public class MemoryPoolTest {
   public void testFreeAll() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -234,7 +232,7 @@ public class MemoryPoolTest {
   public void testFreeZero() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -249,7 +247,7 @@ public class MemoryPoolTest {
   public void testFreeNegative() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -264,7 +262,7 @@ public class MemoryPoolTest {
   public void testOverFree() {
 
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
     Assert.assertEquals(512L, pool.getQueryMemoryReservedBytes(QUERY_ID));
     Assert.assertEquals(512L, pool.getReservedBytes());
 
@@ -280,7 +278,7 @@ public class MemoryPoolTest {
 
     // Run out of memory.
     Assert.assertTrue(
-        pool.tryReserveForTest(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 
512L, Long.MAX_VALUE));
+        pool.tryReserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 512L, 
Long.MAX_VALUE));
 
     ListenableFuture<Void> f =
         pool.reserve(QUERY_ID, FRAGMENT_INSTANCE_ID, PLAN_NODE_ID, 256L, 
512L).left;

Reply via email to