This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new ba32d04998 [SYSTEMDS-3363] Netty Encoding Buffer Allocation
ba32d04998 is described below

commit ba32d0499857f244406b44889a9ffff791e9cf92
Author: ywcb00 <[email protected]>
AuthorDate: Sat Apr 23 18:07:01 2022 +0200

    [SYSTEMDS-3363] Netty Encoding Buffer Allocation
    
    This commit replaces the ObjectEncoders for encoding federated
    communication messages (i.e., FederatedRequest and FederatedResponse)
    with a respective FederatedRequestEncoder and FederatedResponseEncoder.
    These two new classes extent from ObjectEncoder and override the
    allocateBuffer method in order to allocate the buffer with an initial
    capacity corresponding to the FederatedRequest/FederatedResponse instead
    of the default value of 256.
    I've tried estimating the entire serialized size of the requests and
    responses to allocate the exact size needed, but haven't found a good
    solution for it because of the many cases to consider. Hence, the initial
    capacity in this PR only results from taking into consideration a static
    offset (FederatedRequest: 512, FederatedResponse: 312), the size of the
    CacheBlock parameters, and the length of the lineage trace (only for
    FederatedRequests).
    In my experiments with a matrix of dimensions 600.000x200, it reduces
    the encoding time from 50 seconds to 1.9 seconds
---
 .../controlprogram/federated/FederatedData.java    | 31 ++++++++++++++++++++--
 .../controlprogram/federated/FederatedRequest.java | 13 +++++++++
 .../federated/FederatedResponse.java               | 12 +++++++++
 .../controlprogram/federated/FederatedWorker.java  | 25 ++++++++++++++++-
 4 files changed, 78 insertions(+), 3 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 35f9543843..95901b4a2e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.federated;
 
+import java.io.Serializable;
 import java.net.InetSocketAddress;
 import java.util.ArrayList;
 import java.util.HashSet;
@@ -33,11 +34,12 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.meta.MetaData;
 
 import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelFuture;
 import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -184,7 +186,7 @@ public class FederatedData {
                                        cp.addLast("ObjectDecoder", new 
ObjectDecoder(Integer.MAX_VALUE,
                                                
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
                                        cp.addLast("FederatedOperationHandler", 
handler);
-                                       cp.addLast("ObjectEncoder", new 
ObjectEncoder());
+                                       cp.addLast("FederatedRequestEncoder", 
new FederatedRequestEncoder());
                                }
                        });
 
@@ -284,4 +286,29 @@ public class FederatedData {
                sb.append(":" + _filepath);
                return sb.toString();
        }
+
+       public static class FederatedRequestEncoder extends ObjectEncoder {
+               @Override
+               protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, 
Serializable msg,
+               boolean preferDirect) throws Exception {
+                       int initCapacity = 256; // default initial capacity
+                       if(msg instanceof FederatedRequest[]) {
+                               initCapacity = 0;
+                               try {
+                                       for(FederatedRequest fr : 
(FederatedRequest[])msg) {
+                                               int frSize = 
Math.toIntExact(fr.estimateSerializationBufferSize());
+                                               if(Integer.MAX_VALUE - 
initCapacity < frSize) // summed sizes exceed integer limits
+                                                       throw new 
ArithmeticException("Overflow.");
+                                               initCapacity += frSize;
+                                       }
+                               } catch(ArithmeticException ae) { // size of 
federated request exceeds integer limits
+                                       initCapacity = Integer.MAX_VALUE;
+                               }
+                       }
+                       if(preferDirect)
+                               return ctx.alloc().ioBuffer(initCapacity);
+                       else
+                               return ctx.alloc().heapBuffer(initCapacity);
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 610300c81c..0ae8ac033c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -195,6 +195,19 @@ public class FederatedRequest implements Serializable {
                return _lineageTrace;
        }
 
+       public long estimateSerializationBufferSize() {
+               long minBufferSize = 512; // general offset for the 
FederatedRequest object
+               if(_data != null) {
+                       for(Object obj : _data) {
+                               if(obj instanceof CacheBlock)
+                                       minBufferSize += 
((CacheBlock)obj).getExactSerializedSize();
+                       }
+               }
+               if(_lineageTrace != null)
+                       minBufferSize += _lineageTrace.length();
+               return minBufferSize;
+       }
+
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index 5b0e4eb87b..b8cb55851c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -25,6 +25,7 @@ import java.util.Map;
 import java.util.concurrent.atomic.LongAdder;
 
 import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
@@ -79,6 +80,17 @@ public class FederatedResponse implements Serializable {
                return _data;
        }
 
+       public long estimateSerializationBufferSize() {
+               long minBufferSize = 312; // general offset for the 
FederatedResponse object
+               if(_data != null) {
+                       for(Object obj : _data) {
+                               if(obj instanceof CacheBlock)
+                                       minBufferSize += 
((CacheBlock)obj).getExactSerializedSize();
+                       }
+               }
+               return minBufferSize;
+       }
+
        /**
         * Checks the data object array for exceptions that occurred in the 
federated worker
         * during handling of request. 
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index b99fd41be1..9a8fe38f6e 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.controlprogram.federated;
 
+import java.io.Serializable;
 import java.security.cert.CertificateException;
 import java.util.concurrent.SynchronousQueue;
 import java.util.concurrent.ThreadPoolExecutor;
@@ -27,7 +28,9 @@ import java.util.concurrent.TimeUnit;
 import javax.net.ssl.SSLException;
 
 import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
 import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
 import io.netty.channel.ChannelInitializer;
 import io.netty.channel.ChannelOption;
 import io.netty.channel.ChannelPipeline;
@@ -99,7 +102,7 @@ public class FederatedWorker {
                                                cp.addLast("ObjectDecoder",
                                                        new 
ObjectDecoder(Integer.MAX_VALUE,
                                                                
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
-                                               cp.addLast("ObjectEncoder", new 
ObjectEncoder());
+                                               
cp.addLast("FederatedResponseEncoder", new FederatedResponseEncoder());
                                                
cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt, _frc, 
_fan));
                                        }
                                }).option(ChannelOption.SO_BACKLOG, 
128).childOption(ChannelOption.SO_KEEPALIVE, true);
@@ -121,4 +124,24 @@ public class FederatedWorker {
                        bossGroup.shutdownGracefully();
                }
        }
+
+       public static class FederatedResponseEncoder extends ObjectEncoder {
+               @Override
+               protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, 
Serializable msg,
+                       boolean preferDirect) throws Exception {
+                       int initCapacity = 256; // default initial capacity
+                       if(msg instanceof FederatedResponse) {
+                               FederatedResponse response = 
(FederatedResponse)msg;
+                               try {
+                                       initCapacity = 
Math.toIntExact(response.estimateSerializationBufferSize());
+                               } catch(ArithmeticException ae) { // size of 
cache block exceeds integer limits
+                                       initCapacity = Integer.MAX_VALUE;
+                               }
+                       }
+                       if(preferDirect)
+                               return ctx.alloc().ioBuffer(initCapacity);
+                       else
+                               return ctx.alloc().heapBuffer(initCapacity);
+               }
+       }
 }

Reply via email to