This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new ba32d04998 [SYSTEMDS-3363] Netty Encoding Buffer Allocation
ba32d04998 is described below
commit ba32d0499857f244406b44889a9ffff791e9cf92
Author: ywcb00 <[email protected]>
AuthorDate: Sat Apr 23 18:07:01 2022 +0200
[SYSTEMDS-3363] Netty Encoding Buffer Allocation
This commit replaces the ObjectEncoders for encoding federated
communication messages (i.e., FederatedRequest and FederatedResponse)
with a respective FederatedRequestEncoder and FederatedResponseEncoder.
These two new classes extent from ObjectEncoder and override the
allocateBuffer method in order to allocate the buffer with an initial
capacity corresponding to the FederatedRequest/FederatedResponse instead
of the default value of 256.
I've tried estimating the entire serialized size of the requests and
responses to allocate the exact size needed, but haven't found a good
solution for it because of the many cases to consider. Hence, the initial
capacity in this PR only results from taking into consideration a static
offset (FederatedRequest: 512, FederatedResponse: 312), the size of the
CacheBlock parameters, and the length of the lineage trace (only for
FederatedRequests).
In my experiments with a matrix of dimensions 600.000x200, it reduces
the encoding time from 50 seconds to 1.9 seconds
---
.../controlprogram/federated/FederatedData.java | 31 ++++++++++++++++++++--
.../controlprogram/federated/FederatedRequest.java | 13 +++++++++
.../federated/FederatedResponse.java | 12 +++++++++
.../controlprogram/federated/FederatedWorker.java | 25 ++++++++++++++++-
4 files changed, 78 insertions(+), 3 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
index 35f9543843..95901b4a2e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedData.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.federated;
+import java.io.Serializable;
import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.HashSet;
@@ -33,11 +34,12 @@ import org.apache.commons.logging.LogFactory;
import org.apache.sysds.common.Types;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.meta.MetaData;
import io.netty.bootstrap.Bootstrap;
+import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
@@ -184,7 +186,7 @@ public class FederatedData {
cp.addLast("ObjectDecoder", new
ObjectDecoder(Integer.MAX_VALUE,
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
cp.addLast("FederatedOperationHandler",
handler);
- cp.addLast("ObjectEncoder", new
ObjectEncoder());
+ cp.addLast("FederatedRequestEncoder",
new FederatedRequestEncoder());
}
});
@@ -284,4 +286,29 @@ public class FederatedData {
sb.append(":" + _filepath);
return sb.toString();
}
+
+ public static class FederatedRequestEncoder extends ObjectEncoder {
+ @Override
+ protected ByteBuf allocateBuffer(ChannelHandlerContext ctx,
Serializable msg,
+ boolean preferDirect) throws Exception {
+ int initCapacity = 256; // default initial capacity
+ if(msg instanceof FederatedRequest[]) {
+ initCapacity = 0;
+ try {
+ for(FederatedRequest fr :
(FederatedRequest[])msg) {
+ int frSize =
Math.toIntExact(fr.estimateSerializationBufferSize());
+ if(Integer.MAX_VALUE -
initCapacity < frSize) // summed sizes exceed integer limits
+ throw new
ArithmeticException("Overflow.");
+ initCapacity += frSize;
+ }
+ } catch(ArithmeticException ae) { // size of
federated request exceeds integer limits
+ initCapacity = Integer.MAX_VALUE;
+ }
+ }
+ if(preferDirect)
+ return ctx.alloc().ioBuffer(initCapacity);
+ else
+ return ctx.alloc().heapBuffer(initCapacity);
+ }
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
index 610300c81c..0ae8ac033c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedRequest.java
@@ -195,6 +195,19 @@ public class FederatedRequest implements Serializable {
return _lineageTrace;
}
+ public long estimateSerializationBufferSize() {
+ long minBufferSize = 512; // general offset for the
FederatedRequest object
+ if(_data != null) {
+ for(Object obj : _data) {
+ if(obj instanceof CacheBlock)
+ minBufferSize +=
((CacheBlock)obj).getExactSerializedSize();
+ }
+ }
+ if(_lineageTrace != null)
+ minBufferSize += _lineageTrace.length();
+ return minBufferSize;
+ }
+
@Override
public String toString() {
StringBuilder sb = new StringBuilder("FederatedRequest[");
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
index 5b0e4eb87b..b8cb55851c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedResponse.java
@@ -25,6 +25,7 @@ import java.util.Map;
import java.util.concurrent.atomic.LongAdder;
import org.apache.commons.lang.exception.ExceptionUtils;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.privacy.CheckedConstraintsLog;
import org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
@@ -79,6 +80,17 @@ public class FederatedResponse implements Serializable {
return _data;
}
+ public long estimateSerializationBufferSize() {
+ long minBufferSize = 312; // general offset for the
FederatedResponse object
+ if(_data != null) {
+ for(Object obj : _data) {
+ if(obj instanceof CacheBlock)
+ minBufferSize +=
((CacheBlock)obj).getExactSerializedSize();
+ }
+ }
+ return minBufferSize;
+ }
+
/**
* Checks the data object array for exceptions that occurred in the
federated worker
* during handling of request.
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
index b99fd41be1..9a8fe38f6e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorker.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.controlprogram.federated;
+import java.io.Serializable;
import java.security.cert.CertificateException;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
@@ -27,7 +28,9 @@ import java.util.concurrent.TimeUnit;
import javax.net.ssl.SSLException;
import io.netty.bootstrap.ServerBootstrap;
+import io.netty.buffer.ByteBuf;
import io.netty.channel.ChannelFuture;
+import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInitializer;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
@@ -99,7 +102,7 @@ public class FederatedWorker {
cp.addLast("ObjectDecoder",
new
ObjectDecoder(Integer.MAX_VALUE,
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader())));
- cp.addLast("ObjectEncoder", new
ObjectEncoder());
+
cp.addLast("FederatedResponseEncoder", new FederatedResponseEncoder());
cp.addLast("FederatedWorkerHandler", new FederatedWorkerHandler(_flt, _frc,
_fan));
}
}).option(ChannelOption.SO_BACKLOG,
128).childOption(ChannelOption.SO_KEEPALIVE, true);
@@ -121,4 +124,24 @@ public class FederatedWorker {
bossGroup.shutdownGracefully();
}
}
+
+ public static class FederatedResponseEncoder extends ObjectEncoder {
+ @Override
+ protected ByteBuf allocateBuffer(ChannelHandlerContext ctx,
Serializable msg,
+ boolean preferDirect) throws Exception {
+ int initCapacity = 256; // default initial capacity
+ if(msg instanceof FederatedResponse) {
+ FederatedResponse response =
(FederatedResponse)msg;
+ try {
+ initCapacity =
Math.toIntExact(response.estimateSerializationBufferSize());
+ } catch(ArithmeticException ae) { // size of
cache block exceeds integer limits
+ initCapacity = Integer.MAX_VALUE;
+ }
+ }
+ if(preferDirect)
+ return ctx.alloc().ioBuffer(initCapacity);
+ else
+ return ctx.alloc().heapBuffer(initCapacity);
+ }
+ }
}