This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/master by this push:
new 4fc8691 [SYSTEMDS-2955] Fix federated binary matrix-vector operators
4fc8691 is described below
commit 4fc8691aa6b8cb986cc8457eeec4ab6e5ba6da1a
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 27 22:28:07 2021 +0200
[SYSTEMDS-2955] Fix federated binary matrix-vector operators
This patch fixes a special case of federated matrix-vector operators
with row/column vector broadcasting of 1x1 vectors and adds a related
test. Furthermore, this also includes a cleanup for various warnings and
a fix for correct federated csv read (using the new metadata handling).
---
.../runtime/compress/CompressedMatrixBlockFactory.java | 6 +++---
.../sysds/runtime/compress/cocode/PlanningCoCoder.java | 2 +-
.../controlprogram/federated/FederatedWorkerHandler.java | 4 +---
.../fed/BinaryMatrixMatrixFEDInstruction.java | 7 +++----
.../runtime/instructions/fed/CtableFEDInstruction.java | 16 ++++++++--------
.../runtime/instructions/fed/TernaryFEDInstruction.java | 2 +-
.../java/org/apache/sysds/runtime/meta/MetaDataAll.java | 1 +
.../java/org/apache/sysds/test/AutomatedTestBase.java | 1 -
.../federated/primitives/FederatedBinaryVectorTest.java | 6 ++----
.../sysds/test/functions/privacy/ReadWriteTest.java | 3 ---
.../test/functions/privacy/ScalarPropagationTest.java | 1 -
11 files changed, 20 insertions(+), 29 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 3da325a..bb0cf8a 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -232,11 +232,11 @@ public class CompressedMatrixBlockFactory {
logPhase();
}
- private AColGroup combineEmpty(List<AColGroup> e) {
+ private static AColGroup combineEmpty(List<AColGroup> e) {
return new ColGroupEmpty(combineColIndexes(e),
e.get(0).getNumRows());
}
- private AColGroup combineConst(List<AColGroup> c) {
+ private static AColGroup combineConst(List<AColGroup> c) {
int[] resCols = combineColIndexes(c);
double[] values = new double[resCols.length];
@@ -257,7 +257,7 @@ public class CompressedMatrixBlockFactory {
return new ColGroupConst(resCols, c.get(0).getNumRows(), dict);
}
- private int[] combineColIndexes(List<AColGroup> gs) {
+ private static int[] combineColIndexes(List<AColGroup> gs) {
int numCols = 0;
for(AColGroup g : gs)
numCols += g.getNumCols();
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
b/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
index 45f72ce..46bf988 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
@@ -200,7 +200,7 @@ public class PlanningCoCoder {
private int st1 = 0, st2 = 0, st3 = 0, st4 = 0;
public Memorizer() {
- mem = new HashMap<ColIndexes,
CompressedSizeInfoColGroup>();
+ mem = new HashMap<>();
}
public void put(CompressedSizeInfoColGroup g) {
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b0cc075..b3acf18 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -227,10 +227,8 @@ public class FederatedWorkerHandler extends
ChannelInboundHandlerAdapter {
// put meta data object in symbol table, read on first operation
cd.setMetaData(new MetaDataFormat(mc, fmt));
- // TODO send FileFormatProperties with request and use them for
CSV, this is currently a workaround so reading
- // of CSV files works
if(fmt == FileFormat.CSV)
- cd.setFileFormatProperties(new
FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER,
+ cd.setFileFormatProperties(new
FileFormatPropertiesCSV(header, delim,
DataExpression.DEFAULT_DELIM_SPARSE));
cd.enableCleanup(false); // guard against deletion
_ecm.get(tid).setVariable(String.valueOf(id), cd);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 77ade9a..e67a353 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -29,7 +29,6 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
import org.apache.sysds.runtime.matrix.operators.Operator;
-
public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
{
protected BinaryMatrixMatrixFEDInstruction(Operator op,
@@ -41,7 +40,7 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
public void processInstruction(ExecutionContext ec) {
MatrixObject mo1 = ec.getMatrixObject(input1);
MatrixObject mo2 = ec.getMatrixObject(input2);
-
+
//canonicalization for federated lhs
if( !mo1.isFederated() && mo2.isFederated()
&&
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
@@ -85,8 +84,8 @@ public class BinaryMatrixMatrixFEDInstruction extends
BinaryFEDInstruction
throw new
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned
federated input with multiple partitions are not supported yet.");
}
}
- else if((mo1.isFederated(FType.ROW) && mo2.getNumRows()
== 1 && mo2.getNumColumns() > 1)
- || (mo1.isFederated(FType.COL) &&
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
+ else if((mo1.isFederated(FType.ROW) && mo2.getNumRows()
== 1) //matrix-rowVect
+ || (mo1.isFederated(FType.COL) &&
mo2.getNumColumns() == 1)) { //matrix-colVect
// MV row partitioned row vector, MV col
partitioned col vector
FederatedRequest fr1 =
mo1.getFedMapping().broadcast(mo2);
fr2 =
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1,
input2},
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2681759..00137e8 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -50,16 +50,16 @@ import
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
public class CtableFEDInstruction extends ComputationFEDInstruction {
private final CPOperand _outDim1;
private final CPOperand _outDim2;
- private final boolean _isExpand;
- private final boolean _ignoreZeros;
+ //private final boolean _isExpand;
+ //private final boolean _ignoreZeros;
private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand
in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2,
boolean dim2Literal, boolean isExpand,
boolean ignoreZeros, String opcode, String istr) {
super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
_outDim1 = new CPOperand(outputDim1, ValueType.FP64,
DataType.SCALAR, dim1Literal);
_outDim2 = new CPOperand(outputDim2, ValueType.FP64,
DataType.SCALAR, dim2Literal);
- _isExpand = isExpand;
- _ignoreZeros = ignoreZeros;
+ //_isExpand = isExpand;
+ //_ignoreZeros = ignoreZeros;
}
public static CtableFEDInstruction parseInstruction(String inst) {
@@ -199,7 +199,7 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
}
- private void setFedOutput(MatrixObject mo1, MatrixObject out,
FederationMap fedMap, Long[] dims1, long outId) {
+ private static void setFedOutput(MatrixObject mo1, MatrixObject out,
FederationMap fedMap, Long[] dims1, long outId) {
long fedSize = Collections.max(Arrays.asList(dims1),
Long::compare) / dims1.length;
long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
@@ -225,7 +225,7 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
});
}
- private MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
+ private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
MatrixBlock resultBlock = new MatrixBlock(1, 1, 0);
int dim1 = 0, dim2 = 0;
for(int i = 0; i < ffr.length; i++) {
@@ -252,7 +252,7 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
return resultBlock;
}
- private FederationMap modifyFedRanges(FederationMap fedMap, Long[]
dims1, Long[] dims2) {
+ private static FederationMap modifyFedRanges(FederationMap fedMap,
Long[] dims1, Long[] dims2) {
IntStream.range(0,
fedMap.getFederatedRanges().length).forEach(i -> {
fedMap.getFederatedRanges()[i]
.setBeginDim(0, i == 0 ? 0 :
fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
@@ -291,7 +291,7 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
return computeOutputDims(tmp);
}
- private Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
+ private static Long[] computeOutputDims(Future<FederatedResponse>[]
tmp) {
Long[] fedDims = new Long[tmp.length];
for(int i = 0; i < tmp.length; i ++)
try {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 5b52dbe..2805107 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -199,7 +199,7 @@ public class TernaryFEDInstruction extends
ComputationFEDInstruction {
* @param fedRequest1 federated request to occur after array
* @return federated requests collected in a single array
*/
- private FederatedRequest[] collectRequests(FederatedRequest[]
fedRequests, FederatedRequest fedRequest1){
+ private static FederatedRequest[] collectRequests(FederatedRequest[]
fedRequests, FederatedRequest fedRequest1){
FederatedRequest[] allRequests = new
FederatedRequest[fedRequests.length + 1];
for ( int i = 0; i < fedRequests.length; i++ )
allRequests[i] = fedRequests[i];
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
index df25887..bca861a 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
@@ -139,6 +139,7 @@ public class MetaDataAll extends DataIdentifier {
return retVal;
}
+ @SuppressWarnings("unchecked")
private void parseMetaDataParams()
{
for( Object obj : _metaObj.entrySet() ){
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 5929fc6..c7821af 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -87,7 +87,6 @@ import org.apache.sysds.runtime.util.HDFSTool;
import org.apache.sysds.utils.ParameterBuilder;
import org.apache.sysds.utils.Statistics;
import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index 089c23e..b652228 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -58,10 +58,8 @@ public class FederatedBinaryVectorTest extends
AutomatedTestBase {
public static Collection<Object[]> data() {
// rows have to be even and > 1
return Arrays.asList(new Object[][] {
- // {2, 1000},
- // {10, 100},
- {100, 10},
- // {1000, 1}, {10, 2000}, {2000, 10}
+ {100, 10},
+ {100, 1},
});
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
index 850032f..6fc57ef 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
@@ -21,7 +21,6 @@ package org.apache.sysds.test.functions.privacy;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
import java.io.FileInputStream;
import java.io.FileOutputStream;
@@ -29,7 +28,6 @@ import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
-import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.meta.MetaDataAll;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -39,7 +37,6 @@ import
org.apache.sysds.runtime.privacy.finegrained.FineGrainedPrivacy;
import org.apache.sysds.runtime.privacy.finegrained.FineGrainedPrivacyList;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
-import org.apache.wink.json4j.JSONObject;
import org.junit.Assert;
import org.junit.Test;
diff --git
a/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
b/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
index a48c6d8..ec79b57 100644
---
a/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
@@ -33,7 +33,6 @@ import
org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
-import org.apache.wink.json4j.JSONObject;
public class ScalarPropagationTest extends AutomatedTestBase
{