sohami commented on a change in pull request #1504: DRILL-6792: Find the right 
probe side fragment wrapper & fix DrillBuf…
URL: https://github.com/apache/drill/pull/1504#discussion_r233261620
 
 

 ##########
 File path: 
exec/java-exec/src/main/java/org/apache/drill/exec/work/filter/RuntimeFilterSink.java
 ##########
 @@ -17,206 +17,217 @@
  */
 package org.apache.drill.exec.work.filter;
 
-import org.apache.drill.exec.memory.BufferAllocator;
+import io.netty.buffer.DrillBuf;
+import org.apache.drill.exec.ops.AccountingDataTunnel;
+import org.apache.drill.exec.ops.Consumer;
+import org.apache.drill.exec.ops.SendingAccountor;
+import org.apache.drill.exec.ops.StatusHandler;
+import org.apache.drill.exec.proto.BitData;
+import org.apache.drill.exec.proto.CoordinationProtos;
+import org.apache.drill.exec.proto.GeneralRPCProtos;
+import org.apache.drill.exec.proto.UserBitShared;
+import org.apache.drill.exec.rpc.RpcException;
+import org.apache.drill.exec.rpc.RpcOutcomeListener;
+import org.apache.drill.exec.rpc.data.DataTunnel;
+import org.apache.drill.exec.server.DrillbitContext;
+import org.apache.drill.shaded.guava.com.google.common.base.Stopwatch;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
 import java.util.concurrent.BlockingQueue;
-import java.util.concurrent.ExecutorService;
-import java.util.concurrent.Future;
 import java.util.concurrent.LinkedBlockingQueue;
-import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
+import java.util.concurrent.TimeUnit;
 import java.util.concurrent.locks.Condition;
 import java.util.concurrent.locks.ReentrantLock;
 
 /**
  * This sink receives the RuntimeFilters from the netty thread,
- * aggregates them in an async thread, supplies the aggregated
- * one to the fragment running thread.
+ * aggregates them in an async thread, broadcast the final aggregated
+ * one to the RuntimeFilterRecordBatch.
  */
-public class RuntimeFilterSink implements AutoCloseable {
-
-  private AtomicInteger currentBookId = new AtomicInteger(0);
-
-  private int staleBookId = 0;
-
-  /**
-   * RuntimeFilterWritable holding the aggregated version of all the received 
filter
-   */
-  private RuntimeFilterWritable aggregated = null;
+public class RuntimeFilterSink
+{
 
   private BlockingQueue<RuntimeFilterWritable> rfQueue = new 
LinkedBlockingQueue<>();
 
-  /**
-   * Flag used by Minor Fragment thread to indicate it has encountered error
-   */
-  private AtomicBoolean running = new AtomicBoolean(true);
-
-  /**
-   * Lock used to synchronize between producer (Netty Thread) and consumer 
(AsyncAggregateThread) of elements of this
-   * queue. This is needed because in error condition running flag can be 
consumed by producer and consumer thread at
-   * different times. Whoever sees it first will take this lock and clear all 
elements and set the queue to null to
-   * indicate producer not to put any new elements in it.
-   */
   private ReentrantLock queueLock = new ReentrantLock();
 
   private Condition notEmpty = queueLock.newCondition();
 
-  private ReentrantLock aggregatedRFLock = new ReentrantLock();
+  private Map<Integer, Integer> joinMjId2rfNumber;
+
+  //HashJoin node's major fragment id to its corresponding probe side nodes's 
endpoints
+  private Map<Integer, List<CoordinationProtos.DrillbitEndpoint>> 
joinMjId2probeScanEps = new HashMap<>();
 
-  private BufferAllocator bufferAllocator;
+  //HashJoin node's major fragment id to its corresponding probe side scan 
node's belonging major fragment id
+  private Map<Integer, Integer> joinMjId2ScanMjId = new HashMap<>();
 
-  private Future future;
+  //HashJoin node's major fragment id to its aggregated RuntimeFilterWritable
+  private Map<Integer, RuntimeFilterWritable> joinMjId2AggregatedRF = new 
HashMap<>();
+  //for debug usage
+  private Map<Integer, Stopwatch> joinMjId2Stopwatch = new HashMap<>();
+
+  private DrillbitContext drillbitContext;
+
+  private SendingAccountor sendingAccountor;
 
   private static final Logger logger = 
LoggerFactory.getLogger(RuntimeFilterSink.class);
 
 
-  public RuntimeFilterSink(BufferAllocator bufferAllocator, ExecutorService 
executorService) {
-    this.bufferAllocator = bufferAllocator;
+  public RuntimeFilterSink(DrillbitContext drillbitContext, SendingAccountor 
sendingAccountor)
+  {
+    this.drillbitContext = drillbitContext;
+    this.sendingAccountor = sendingAccountor;
     AsyncAggregateWorker asyncAggregateWorker = new AsyncAggregateWorker();
-    future = executorService.submit(asyncAggregateWorker);
+    drillbitContext.getExecutor().submit(asyncAggregateWorker);
   }
 
-  public void aggregate(RuntimeFilterWritable runtimeFilterWritable) {
-    if (running.get()) {
-      try {
-        aggregatedRFLock.lock();
-        if (containOne()) {
-          boolean same = aggregated.equals(runtimeFilterWritable);
-          if (!same) {
-            // This is to solve the only one fragment case that two 
RuntimeFilterRecordBatchs
-            // share the same FragmentContext.
-            aggregated.close();
-            currentBookId.set(0);
-            staleBookId = 0;
-            clearQueued(false);
-          }
-        }
-      } finally {
-        aggregatedRFLock.unlock();
-      }
-
-      try {
-        queueLock.lock();
-        if (rfQueue != null) {
-          rfQueue.add(runtimeFilterWritable);
-          notEmpty.signal();
-        } else {
-          runtimeFilterWritable.close();
-        }
-      } finally {
-        queueLock.unlock();
-      }
-    } else {
-      runtimeFilterWritable.close();
+  public void add(RuntimeFilterWritable runtimeFilterWritable)
+  {
+    runtimeFilterWritable.retainBuffers(1);
+    int joinMjId = 
runtimeFilterWritable.getRuntimeFilterBDef().getMajorFragmentId();
+    if (joinMjId2Stopwatch.get(joinMjId) == null) {
+      Stopwatch stopwatch = Stopwatch.createStarted();
+      joinMjId2Stopwatch.put(joinMjId, stopwatch);
     }
-  }
-
-  public RuntimeFilterWritable fetchLatestDuplicatedAggregatedOne() {
+    queueLock.lock();
     try {
-      aggregatedRFLock.lock();
-      return aggregated.duplicate(bufferAllocator);
-    } finally {
-      aggregatedRFLock.unlock();
+      rfQueue.add(runtimeFilterWritable);
+      notEmpty.signal();
     }
-  }
-
-  /**
-   * whether there's a fresh aggregated RuntimeFilter
-   *
-   * @return
-   */
-  public boolean hasFreshOne() {
-    if (currentBookId.get() > staleBookId) {
-      staleBookId = currentBookId.get();
-      return true;
+    finally {
+      queueLock.unlock();
     }
-    return false;
   }
 
-  /**
-   * whether there's a usable RuntimeFilter.
-   *
-   * @return
-   */
-  public boolean containOne() {
-    return aggregated != null;
+  private void aggregate(RuntimeFilterWritable srcRuntimeFilterWritable)
+  {
+    BitData.RuntimeFilterBDef runtimeFilterB = 
srcRuntimeFilterWritable.getRuntimeFilterBDef();
+    int joinMajorId = runtimeFilterB.getMajorFragmentId();
+    int buildSideRfNumber;
+    RuntimeFilterWritable toAggregated = null;
+    buildSideRfNumber = joinMjId2rfNumber.get(joinMajorId);
+    buildSideRfNumber--;
+    joinMjId2rfNumber.put(joinMajorId, buildSideRfNumber);
+    toAggregated = joinMjId2AggregatedRF.get(joinMajorId);
+    if (toAggregated == null) {
+      toAggregated = srcRuntimeFilterWritable;
+      toAggregated.retainBuffers(1);
+    } else {
+      toAggregated.aggregate(srcRuntimeFilterWritable);
+    }
+    joinMjId2AggregatedRF.put(joinMajorId, toAggregated);
+    if (buildSideRfNumber == 0) {
+      joinMjId2AggregatedRF.remove(joinMajorId);
+      route(toAggregated);
+      joinMjId2rfNumber.remove(joinMajorId);
+      Stopwatch stopwatch = joinMjId2Stopwatch.get(joinMajorId);
+      logger.info(
+          "received all the RFWs belonging to the majorId {}'s HashJoin nodes 
and flushed aggregated RFW out elapsed {} ms",
+          joinMajorId,
+          stopwatch.elapsed(TimeUnit.MILLISECONDS)
+      );
+    }
   }
 
-  private void doCleanup() {
-    running.compareAndSet(true, false);
-    try {
-      aggregatedRFLock.lock();
-      if (containOne()) {
-        aggregated.close();
-        aggregated = null;
+  private void route(RuntimeFilterWritable srcRuntimeFilterWritable)
+  {
+    BitData.RuntimeFilterBDef runtimeFilterB = 
srcRuntimeFilterWritable.getRuntimeFilterBDef();
+    int joinMajorId = runtimeFilterB.getMajorFragmentId();
+    UserBitShared.QueryId queryId = runtimeFilterB.getQueryId();
+    List<String> probeFields = runtimeFilterB.getProbeFieldsList();
+    List<Integer> sizeInBytes = runtimeFilterB.getBloomFilterSizeInBytesList();
+    long rfIdentifier = runtimeFilterB.getRfIdentifier();
+    DrillBuf[] data = srcRuntimeFilterWritable.getData();
+    List<CoordinationProtos.DrillbitEndpoint> scanNodeEps = 
joinMjId2probeScanEps.get(joinMajorId);
+    int scanNodeSize = scanNodeEps.size();
+    srcRuntimeFilterWritable.retainBuffers(scanNodeSize - 1);
+    int scanNodeMjId = joinMjId2ScanMjId.get(joinMajorId);
+    for (int minorId = 0; minorId < scanNodeEps.size(); minorId++) {
+      BitData.RuntimeFilterBDef.Builder builder = 
BitData.RuntimeFilterBDef.newBuilder();
+      for (String probeField : probeFields) {
+        builder.addProbeFields(probeField);
       }
-    } finally {
-      aggregatedRFLock.unlock();
+      BitData.RuntimeFilterBDef runtimeFilterBDef = builder.setQueryId(queryId)
+                                                           
.setMajorFragmentId(scanNodeMjId)
+                                                           
.setMinorFragmentId(minorId)
+                                                           .setToForeman(false)
+                                                           
.setRfIdentifier(rfIdentifier)
+                                                           
.addAllBloomFilterSizeInBytes(sizeInBytes)
+                                                           .build();
+      RuntimeFilterWritable runtimeFilterWritable = new 
RuntimeFilterWritable(runtimeFilterBDef, data);
+      CoordinationProtos.DrillbitEndpoint drillbitEndpoint = 
scanNodeEps.get(minorId);
+
+      DataTunnel dataTunnel = 
drillbitContext.getDataConnectionsPool().getTunnel(drillbitEndpoint);
+      Consumer<RpcException> exceptionConsumer = new Consumer<RpcException>()
+      {
+        @Override
+        public void accept(final RpcException e)
+        {
+          logger.warn("fail to broadcast a runtime filter to the probe side 
scan node", e);
+        }
+
+        @Override
+        public void interrupt(final InterruptedException e)
+        {
+          logger.warn("fail to broadcast a runtime filter to the probe side 
scan node", e);
+        }
+      };
+      RpcOutcomeListener<GeneralRPCProtos.Ack> statusHandler = new 
StatusHandler(exceptionConsumer, sendingAccountor);
+      AccountingDataTunnel accountingDataTunnel = new 
AccountingDataTunnel(dataTunnel, sendingAccountor, statusHandler);
+      accountingDataTunnel.sendRuntimeFilter(runtimeFilterWritable);
     }
   }
 
-  @Override
-  public void close() throws Exception {
-    future.cancel(true);
-    doCleanup();
+  public void setJoinMjId2rfNumber(Map<Integer, Integer> joinMjId2rfNumber)
+  {
+    this.joinMjId2rfNumber = joinMjId2rfNumber;
   }
 
-  private void clearQueued(boolean setToNull) {
-    RuntimeFilterWritable toClear;
-    try {
-      queueLock.lock();
-      while (rfQueue != null && (toClear = rfQueue.poll()) != null) {
-        toClear.close();
-      }
-      rfQueue = (setToNull) ? null : rfQueue;
-    } finally {
-      queueLock.unlock();
-    }
+  public void setJoinMjId2probeScanEps(Map<Integer, 
List<CoordinationProtos.DrillbitEndpoint>> joinMjId2probeScanEps)
+  {
+    this.joinMjId2probeScanEps = joinMjId2probeScanEps;
   }
 
-  private class AsyncAggregateWorker implements Runnable {
+  public void setJoinMjId2ScanMjId(Map<Integer, Integer> joinMjId2ScanMjId)
+  {
+    this.joinMjId2ScanMjId = joinMjId2ScanMjId;
+  }
 
+  private class AsyncAggregateWorker implements Runnable
+  {
     @Override
-    public void run() {
-      try {
-        RuntimeFilterWritable toAggregate = null;
-        while (running.get()) {
+    public void run()
+    {
+      while (joinMjId2rfNumber == null || !joinMjId2rfNumber.isEmpty()) {
+        RuntimeFilterWritable toAggregate = rfQueue.poll();
 
 Review comment:
   Also how about simplifying this logic as below:
   ```
   @Override
   public void run() {
   while (joinMjId2rfNumber == null || !joinMjId2rfNumber.isEmpty()) {
   
           RuntimeFilterWritable toAggregate = null;
           synchronized (rfQueue) {
             try {
               toAggregate = rfQueue.poll();
               while (toAggregate == null) {
                 rfQueue.wait(5, TimeUnit.SECONDS);
                 toAggregate = rfQueue.poll();
               }
               aggregate(toAggregate);
             } catch (InterruptedException ex) {
               logger.error("RFW_Aggregator thread being interrupted", ex);
             } finally {
               if (toAggregate != null) {
                 toAggregate.close();
             }
           }
         }
   }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to