This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 03508eec80 [SYSTEMDS-3451] Fix missing NNZ propagation in federated
instructions
03508eec80 is described below
commit 03508eec80f3a144c9d311e8d4092f979abb5dee
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Oct 19 15:11:19 2022 -0400
[SYSTEMDS-3451] Fix missing NNZ propagation in federated instructions
For many federated operations that output federated data, no
information on the number of non-zeros was propagated back to the
coordinator and hence, suboptional plan choices where made during
dynamic recompilation. We now generally return the nnz of instruction
outputs from all EXEC_INST and EXEC_UDF types, and provide utils to
obtain this info from the federated responses. The exploitation of
this info was already introduced into right indexing, append, and
replace which are often executed on the main federated matrix.
On a scenario of training MLogReg on 3 workers and a Critero subset
with 1M rows and ~98M columns after one hot encoding, this patch
reduced the end-to-end execution time by more than 4x but is generally
applicable.
a) STATS BEFORE PATCH:
Total elapsed time: 402.882 sec.
Total compilation time: 2.046 sec.
Total execution time: 400.837 sec.
Cache hits (Mem/Li/WB/FS/HDFS): 4452/0/0/0/0.
Cache writes (Li/WB/FS/HDFS): 4/1576/0/0.
Cache times (ACQr/m, RLS, EXP): 0.085/0.051/0.389/0.000 sec.
HOP DAGs recompiled (PRED, SB): 0/248.
HOP DAGs recompile time: 1.810 sec.
Functions recompiled: 1.
Functions recompile time: 0.253 sec.
Federated I/O (Read, Put, Get): 3/409/134.
Federated Execute (Inst, UDF): 436/6.
Fed Put Count (Sc/Li/Ma/Fr/MC): 0/0/378/0/31.
Fed Put Bytes (Mat/Frame): 1507638888/0 Bytes.
Federated prefetch count: 0.
Total JIT compile time: 29.476 sec.
Total JVM GC count: 18.
Total JVM GC time: 0.442 sec.
Heavy hitter instructions:
1 m_multiLogReg 370.387 1
2 fed_r' 221.613 29
3 fed_mmchain 53.184 63
4 fed_ba+* 36.619 62
5 fed_transformencode 17.992 1
6 n+ 13.264 121
7 * 11.882 754
8 fed_fedinit 8.670 1
9 - 4.210 361
10 +* 3.892 127
b) STATS AFTER PATCH:
Total elapsed time: 88.907 sec.
Total compilation time: 2.065 sec.
Total execution time: 86.842 sec.
Cache hits (Mem/Li/WB/FS/HDFS): 4510/0/0/0/0.
Cache writes (Li/WB/FS/HDFS): 4/1634/0/0.
Cache times (ACQr/m, RLS, EXP): 0.049/0.034/0.311/0.000 sec.
HOP DAGs recompiled (PRED, SB): 0/248.
HOP DAGs recompile time: 0.506 sec.
Functions recompiled: 1.
Functions recompile time: 0.268 sec.
Federated I/O (Read, Put, Get): 3/380/134.
Federated Execute (Inst, UDF): 378/6.
Fed Put Count (Sc/Li/Ma/Fr/MC): 0/0/378/0/2.
Fed Put Bytes (Mat/Frame): 1507638888/0 Bytes.
Federated prefetch count: 0.
Total JIT compile time: 19.534 sec.
Total JVM GC count: 17.
Total JVM GC time: 0.436 sec.
Heavy hitter instructions:
1 m_multiLogReg 56.603 1
2 fed_mmchain 21.669 63
3 fed_transformencode 18.811 1
4 fed_ba+* 12.250 62
5 fed_fedinit 9.061 1
6 n+ 5.271 121
7 * 3.699 754
8 fed_rightIndex 1.557 2
9 fed_uack+ 1.475 2
10 +* 1.330 127
---
.../controlprogram/caching/CacheableData.java | 2 +-
.../controlprogram/caching/MatrixObject.java | 2 +-
.../federated/FederatedWorkerHandler.java | 29 ++++++++++++++++-----
.../controlprogram/federated/FederationUtils.java | 11 ++++++++
.../instructions/fed/AppendFEDInstruction.java | 9 +++++--
.../instructions/fed/IndexingFEDInstruction.java | 30 +++++++++++-----------
...tiReturnParameterizedBuiltinFEDInstruction.java | 21 +++++++++------
.../fed/ParameterizedBuiltinFEDInstruction.java | 6 +++--
8 files changed, 74 insertions(+), 36 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
index 6dd726db6f..c4d3939279 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/CacheableData.java
@@ -366,7 +366,7 @@ public abstract class CacheableData<T extends CacheBlock>
extends Data
return getDataCharacteristics().getCols();
}
- public long getBlocksize() {
+ public int getBlocksize() {
return getDataCharacteristics().getBlocksize();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
index 9e332f119e..c723cc56fa 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/caching/MatrixObject.java
@@ -194,7 +194,7 @@ public class MatrixObject extends
CacheableData<MatrixBlock> {
mc.setNonZeros(_data.getNonZeros());
}
- public long getBlocksize() {
+ public int getBlocksize() {
return getDataCharacteristics().getBlocksize();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 0e8df94ec6..96abf41eb4 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -60,6 +60,7 @@ import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.Instruction.IType;
import org.apache.sysds.runtime.instructions.InstructionParser;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ComputationCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.ListObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
@@ -194,6 +195,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
FederatedResponse response = null; // last response
boolean containsCLEAR = false;
long clearReqPid = -1;
+ int numGETrequests = 0;
var event = new EventModel();
final String coordinatorHostIdFormat = "%s-%d";
event.setCoordinatorHostId(String.format(coordinatorHostIdFormat, remoteHost,
requests[0].getPID()));
@@ -229,7 +231,7 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
}
else if(t == RequestType.GET_VAR) {
// If any of the requests was a GET_VAR then
set it as output.
- if(response != null) {
+ if(response != null && numGETrequests > 0) {
String message = "Multiple GET_VAR are
not supported in single batch of requests.";
LOG.error(message);
if (DMLScript.STATISTICS)
@@ -237,6 +239,12 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
throw new
FederatedWorkerHandlerException(message);
}
response = tmp;
+ numGETrequests ++;
+ }
+ else if(response == null
+ && (t == RequestType.EXEC_INST || t ==
RequestType.EXEC_UDF)) {
+ // If there was no GET, use the EXEC INST or
UDF to obtain the returned nnz
+ response = tmp;
}
else if(response == null && i == requests.length - 1) {
response = tmp; // return last
@@ -244,16 +252,13 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
if (DMLScript.STATISTICS) {
if(t == RequestType.PUT_VAR || t ==
RequestType.EXEC_UDF) {
- for (int paramIndex = 0; paramIndex <
request.getNumParams(); paramIndex++) {
+ for (int paramIndex = 0; paramIndex <
request.getNumParams(); paramIndex++)
FederatedStatistics.incFedTransfer(request.getParam(paramIndex),
_remoteAddress, request.getPID());
- }
}
-
if(t == RequestType.GET_VAR) {
var data = response.getData();
- for (int dataObjIndex = 0; dataObjIndex
< Arrays.stream(data).count(); dataObjIndex++) {
+ for (int dataObjIndex = 0; dataObjIndex
< Arrays.stream(data).count(); dataObjIndex++)
FederatedStatistics.incFedTransfer(data[dataObjIndex], _remoteAddress,
request.getPID());
- }
}
}
@@ -577,7 +582,8 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
setThreads(ins);
exec(ec, ins);
adaptToWorkload(ec, _fan, tid, ins);
- return new FederatedResponse(ResponseType.SUCCESS_EMPTY);
+ return new FederatedResponse(
+ ResponseType.SUCCESS_EMPTY, getOutputNnz(ec, ins));
}
private static ExecutionContext getContextForInstruction(long id,
Instruction ins, ExecutionContextMap ecm){
@@ -625,6 +631,15 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
});
}
}
+
+ private static long getOutputNnz(ExecutionContext ec, Instruction ins) {
+ if( ins instanceof ComputationCPInstruction ) {
+ Data dat =
ec.getVariable(((ComputationCPInstruction)ins).getOutput());
+ if( dat instanceof MatrixObject )
+ return ((MatrixObject)dat).getNnz();
+ }
+ return -1L;
+ }
private FederatedResponse execUDF(FederatedRequest request,
ExecutionContextMap ecm, EventStageModel eventStage) {
checkNumParams(request.getNumParams(), 1);
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 02aefb928c..73939117ce 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -577,4 +577,15 @@ public class FederationUtils {
return new ObjectDecoder(Integer.MAX_VALUE,
ClassResolvers.weakCachingResolver(ClassLoader.getSystemClassLoader()));
}
+
+ public static long sumNonZeros(Future<FederatedResponse>[] responses) {
+ long nnz = 0;
+ try {
+ for( Future<FederatedResponse> r : responses)
+ nnz += (Long)r.get().getData()[0];
+ return nnz;
+ }
+ catch(Exception ex) { }
+ return -1;
+ }
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
index 0b915c7436..d126cae1dc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AppendFEDInstruction.java
@@ -19,6 +19,8 @@
package org.apache.sysds.runtime.instructions.fed;
+import java.util.concurrent.Future;
+
import org.apache.sysds.hops.fedplanner.FTypes.AlignType;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -26,6 +28,7 @@ import
org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
@@ -166,17 +169,19 @@ public class AppendFEDInstruction extends
BinaryFEDInstruction {
new long[]{ fr1[0].getID(),
moFed.getFedMapping().getID()});
//execute federated operations and set output
+ Future<FederatedResponse>[] ret = null;
if(isSpark) {
FederatedRequest tmp = new
FederatedRequest(RequestType.PUT_VAR,
fr2.getID(), new
MatrixCharacteristics(-1, -1), mo1.getDataType());
- moFed.getFedMapping().execute(getTID(), true,
fr1, tmp, fr2);
+ ret = moFed.getFedMapping().execute(getTID(),
true, fr1, tmp, fr2);
} else {
- moFed.getFedMapping().execute(getTID(), true,
fr1, fr2);
+ ret = moFed.getFedMapping().execute(getTID(),
true, fr1, fr2);
}
int dim = (_cbind ? 1 : 0);
FederationMap newFedMap =
moFed.getFedMapping().copyWithNewID(fr2.getID())
.modifyFedRanges(moFed.getDim(dim) +
moLoc.getDim(dim), dim);
out.setFedMapping(newFedMap);
+
out.getDataCharacteristics().setNonZeros(FederationUtils.sumNonZeros(ret));
}
else {
throw new DMLRuntimeException("Unsupported federated
append: "
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index a544e966ac..60fb151280 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -24,6 +24,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.api.DMLScript;
@@ -42,6 +43,7 @@ import
org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -203,19 +205,17 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest[] fr1 =
FederationUtils.callInstruction(instStrings, output, id,
new CPOperand[] {input1}, new long[] {fedMap.getID()},
execType);
fedMap.execute(getTID(), true, tmp);
- fedMap.execute(getTID(), true, fr1, new FederatedRequest[0]);
-
- if(input1.isFrame()) {
- FrameObject out = ec.getFrameObject(output);
- out.setSchema(schema.toArray(new Types.ValueType[0]));
-
out.getDataCharacteristics().setDimension(fedMap.getMaxIndexInRange(0),
fedMap.getMaxIndexInRange(1));
- out.setFedMapping(fedMap.copyWithNewID(fr1[0].getID()));
- } else {
- MatrixObject out = ec.getMatrixObject(output);
-
out.getDataCharacteristics().set(fedMap.getMaxIndexInRange(0),
fedMap.getMaxIndexInRange(1),
- (int) ((MatrixObject)in).getBlocksize());
- out.setFedMapping(fedMap.copyWithNewID(fr1[0].getID()));
- }
+ Future<FederatedResponse>[] ret = fedMap.execute(getTID(),
true, fr1, new FederatedRequest[0]);
+
+ //set output characteristics for frames and matrices
+ CacheableData<?> out = ec.getCacheableData(output);
+ if(input1.isFrame())
+ ((FrameObject) out).setSchema(schema.toArray(new
Types.ValueType[0]));
+ out.getDataCharacteristics()
+ .setDimension(fedMap.getMaxIndexInRange(0),
fedMap.getMaxIndexInRange(1))
+ .setBlocksize(in.getBlocksize())
+ .setNonZeros(FederationUtils.sumNonZeros(ret));
+ out.setFedMapping(fedMap.copyWithNewID(fr1[0].getID()));
}
private void leftIndexing(ExecutionContext ec)
@@ -324,8 +324,8 @@ public final class IndexingFEDInstruction extends
UnaryFEDInstruction {
fedMap.execute(getTID(), true, tmp);
if(in2 != null) { // matrix, frame
- FederatedRequest[] fr1 = fedMap.broadcastSliced(in2,
DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
- input2.isFrame(), sliceIxs);
+ FederatedRequest[] fr1 = fedMap.broadcastSliced(in2,
+ DMLScript.LINEAGE ? ec.getLineageItem(input2) :
null, input2.isFrame(), sliceIxs);
FederatedRequest[] fr2 =
FederationUtils.callInstruction(instStrings, output, id, new
CPOperand[]{input1, input2},
new long[]{fedMap.getID(), fr1[0].getID()},
null);
FederatedRequest fr3 = fedMap.cleanup(getTID(),
fr1[0].getID());
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 93a7e41291..c9c6725e81 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -25,6 +25,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Future;
+import java.util.concurrent.atomic.LongAdder;
import java.util.stream.Stream;
import java.util.zip.Adler32;
import java.util.zip.Checksum;
@@ -287,12 +288,12 @@ public class
MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
public static void encodeFederatedFrames(FederationMap fedMapping,
MultiColumnEncoder globalencoder,
MatrixObject transformedMat) {
long varID = FederationUtils.getNextFedDataID();
- FederationMap transformedFedMapping =
fedMapping.mapParallel(varID, (range, data) -> {
+ LongAdder nnz = new LongAdder();
+ FederationMap tfFedMap = fedMapping.mapParallel(varID, (range,
data) -> {
// copy because we reuse it
long[] beginDims = range.getBeginDims();
long[] endDims = range.getEndDims();
- IndexRange ixRange = new IndexRange(beginDims[0],
endDims[0], beginDims[1], endDims[1]).add(1);// make
-
// 1-based
+ IndexRange ixRange = new IndexRange(beginDims[0],
endDims[0], beginDims[1], endDims[1]).add(1);
IndexRange ixRangeInv = new IndexRange(0, beginDims[0],
0, beginDims[1]);
// get the encoder segment that is relevant for this
federated worker
@@ -301,10 +302,12 @@ public class
MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
encoder.updateIndexRanges(beginDims, endDims,
globalencoder.getNumExtraCols(ixRangeInv));
try {
- FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(RequestType.EXEC_UDF,
+ FederatedResponse response =
data.executeFederatedOperation(
+ new
FederatedRequest(RequestType.EXEC_UDF,
-1, new
ExecuteFrameEncoder(data.getVarID(), varID, encoder))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
+ nnz.add((Long)response.getData()[0]);
}
catch(Exception e) {
throw new DMLRuntimeException(e);
@@ -313,9 +316,10 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
});
// construct a federated matrix with the encoded data
-
transformedMat.getDataCharacteristics().setDimension(transformedFedMapping.getMaxIndexInRange(0),
- transformedFedMapping.getMaxIndexInRange(1));
- transformedMat.setFedMapping(transformedFedMapping);
+ transformedMat.getDataCharacteristics()
+ .setDimension(tfFedMap.getMaxIndexInRange(0),
tfFedMap.getMaxIndexInRange(1))
+ .setNonZeros(nnz.longValue());
+ transformedMat.setFedMapping(tfFedMap);
}
public static class CreateFrameEncoder extends FederatedUDF {
@@ -380,7 +384,8 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
ec.setVariable(String.valueOf(_outputID), mo);
// return id handle
- return new
FederatedResponse(ResponseType.SUCCESS_EMPTY);
+ return new FederatedResponse(
+ ResponseType.SUCCESS_EMPTY,
mbout.getNonZeros());
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index b9794b413e..fe0641e512 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -175,13 +175,15 @@ public class ParameterizedBuiltinFEDInstruction extends
ComputationFEDInstructio
output,
new CPOperand[] {getTargetOperand()},
new long[] {mo.getFedMapping().getID()});
- mo.getFedMapping().execute(getTID(), true, fr1);
+ Future<FederatedResponse>[] ret =
mo.getFedMapping().execute(getTID(), true, fr1);
// derive new fed mapping for output
CacheableData<?> out = ec.getCacheableData(output);
if(mo instanceof FrameObject)
((FrameObject)out).setSchema(((FrameObject)
mo).getSchema());
-
out.getDataCharacteristics().set(mo.getDataCharacteristics());
+ out.getDataCharacteristics()
+ .set(mo.getDataCharacteristics())
+ .setNonZeros(FederationUtils.sumNonZeros(ret));
out.setFedMapping(mo.getFedMapping().copyWithNewID(fr1.getID()));
}
else if(opcode.equals("rmempty"))