This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 4fc8691  [SYSTEMDS-2955] Fix federated binary matrix-vector operators
4fc8691 is described below

commit 4fc8691aa6b8cb986cc8457eeec4ab6e5ba6da1a
Author: Matthias Boehm <[email protected]>
AuthorDate: Tue Apr 27 22:28:07 2021 +0200

    [SYSTEMDS-2955] Fix federated binary matrix-vector operators
    
    This patch fixes a special case of federated matrix-vector operators
    with row/column vector broadcasting of 1x1 vectors and adds a related
    test. Furthermore, this also includes a cleanup for various warnings and
    a fix for correct federated csv read (using the new metadata handling).
---
 .../runtime/compress/CompressedMatrixBlockFactory.java   |  6 +++---
 .../sysds/runtime/compress/cocode/PlanningCoCoder.java   |  2 +-
 .../controlprogram/federated/FederatedWorkerHandler.java |  4 +---
 .../fed/BinaryMatrixMatrixFEDInstruction.java            |  7 +++----
 .../runtime/instructions/fed/CtableFEDInstruction.java   | 16 ++++++++--------
 .../runtime/instructions/fed/TernaryFEDInstruction.java  |  2 +-
 .../java/org/apache/sysds/runtime/meta/MetaDataAll.java  |  1 +
 .../java/org/apache/sysds/test/AutomatedTestBase.java    |  1 -
 .../federated/primitives/FederatedBinaryVectorTest.java  |  6 ++----
 .../sysds/test/functions/privacy/ReadWriteTest.java      |  3 ---
 .../test/functions/privacy/ScalarPropagationTest.java    |  1 -
 11 files changed, 20 insertions(+), 29 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
index 3da325a..bb0cf8a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java
@@ -232,11 +232,11 @@ public class CompressedMatrixBlockFactory {
                logPhase();
        }
 
-       private AColGroup combineEmpty(List<AColGroup> e) {
+       private static AColGroup combineEmpty(List<AColGroup> e) {
                return new ColGroupEmpty(combineColIndexes(e), 
e.get(0).getNumRows());
        }
 
-       private AColGroup combineConst(List<AColGroup> c) {
+       private static AColGroup combineConst(List<AColGroup> c) {
                int[] resCols = combineColIndexes(c);
 
                double[] values = new double[resCols.length];
@@ -257,7 +257,7 @@ public class CompressedMatrixBlockFactory {
                return new ColGroupConst(resCols, c.get(0).getNumRows(), dict);
        }
 
-       private int[] combineColIndexes(List<AColGroup> gs) {
+       private static int[] combineColIndexes(List<AColGroup> gs) {
                int numCols = 0;
                for(AColGroup g : gs)
                        numCols += g.getNumCols();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
index 45f72ce..46bf988 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/PlanningCoCoder.java
@@ -200,7 +200,7 @@ public class PlanningCoCoder {
                private int st1 = 0, st2 = 0, st3 = 0, st4 = 0;
 
                public Memorizer() {
-                       mem = new HashMap<ColIndexes, 
CompressedSizeInfoColGroup>();
+                       mem = new HashMap<>();
                }
 
                public void put(CompressedSizeInfoColGroup g) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index b0cc075..b3acf18 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -227,10 +227,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
                // put meta data object in symbol table, read on first operation
                cd.setMetaData(new MetaDataFormat(mc, fmt));
-               // TODO send FileFormatProperties with request and use them for 
CSV, this is currently a workaround so reading
-               // of CSV files works
                if(fmt == FileFormat.CSV)
-                       cd.setFileFormatProperties(new 
FileFormatPropertiesCSV(header, DataExpression.DEFAULT_DELIM_DELIMITER,
+                       cd.setFileFormatProperties(new 
FileFormatPropertiesCSV(header, delim,
                                DataExpression.DEFAULT_DELIM_SPARSE));
                cd.enableCleanup(false); // guard against deletion
                _ecm.get(tid).setVariable(String.valueOf(id), cd);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
index 77ade9a..e67a353 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java
@@ -29,7 +29,6 @@ import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
-
 public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction
 {
        protected BinaryMatrixMatrixFEDInstruction(Operator op,
@@ -41,7 +40,7 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
        public void processInstruction(ExecutionContext ec) {
                MatrixObject mo1 = ec.getMatrixObject(input1);
                MatrixObject mo2 = ec.getMatrixObject(input2);
-
+               
                //canonicalization for federated lhs
                if( !mo1.isFederated() && mo2.isFederated()
                        && 
mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics())
@@ -85,8 +84,8 @@ public class BinaryMatrixMatrixFEDInstruction extends 
BinaryFEDInstruction
                                        throw new 
DMLRuntimeException("Matrix-matrix binary operations with a full partitioned 
federated input with multiple partitions are not supported yet.");
                                }
                        }
-                       else if((mo1.isFederated(FType.ROW) && mo2.getNumRows() 
== 1 && mo2.getNumColumns() > 1)
-                               || (mo1.isFederated(FType.COL) && 
mo2.getNumRows() > 1 && mo2.getNumColumns() == 1)) {
+                       else if((mo1.isFederated(FType.ROW) && mo2.getNumRows() 
== 1)      //matrix-rowVect
+                               || (mo1.isFederated(FType.COL) && 
mo2.getNumColumns() == 1)) { //matrix-colVect
                                // MV row partitioned row vector, MV col 
partitioned col vector
                                FederatedRequest fr1 = 
mo1.getFedMapping().broadcast(mo2);
                                fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, 
input2},
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 2681759..00137e8 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -50,16 +50,16 @@ import 
org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 public class CtableFEDInstruction extends ComputationFEDInstruction {
        private final CPOperand _outDim1;
        private final CPOperand _outDim2;
-       private final boolean _isExpand;
-       private final boolean _ignoreZeros;
+       //private final boolean _isExpand;
+       //private final boolean _ignoreZeros;
 
        private CtableFEDInstruction(CPOperand in1, CPOperand in2, CPOperand 
in3, CPOperand out, String outputDim1, boolean dim1Literal, String outputDim2, 
boolean dim2Literal, boolean isExpand,
                boolean ignoreZeros, String opcode, String istr) {
                super(FEDType.Ctable, null, in1, in2, in3, out, opcode, istr);
                _outDim1 = new CPOperand(outputDim1, ValueType.FP64, 
DataType.SCALAR, dim1Literal);
                _outDim2 = new CPOperand(outputDim2, ValueType.FP64, 
DataType.SCALAR, dim2Literal);
-               _isExpand = isExpand;
-               _ignoreZeros = ignoreZeros;
+               //_isExpand = isExpand;
+               //_ignoreZeros = ignoreZeros;
        }
 
        public static CtableFEDInstruction parseInstruction(String inst) {
@@ -199,7 +199,7 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
        }
 
 
-       private void setFedOutput(MatrixObject mo1, MatrixObject out, 
FederationMap fedMap, Long[] dims1, long outId) {
+       private static void setFedOutput(MatrixObject mo1, MatrixObject out, 
FederationMap fedMap, Long[] dims1, long outId) {
                long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / dims1.length;
 
                long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
@@ -225,7 +225,7 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                });
        }
 
-       private MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
+       private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
                MatrixBlock resultBlock = new MatrixBlock(1, 1, 0);
                int dim1 = 0, dim2 = 0;
                for(int i = 0; i < ffr.length; i++) {
@@ -252,7 +252,7 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                return resultBlock;
        }
 
-       private FederationMap modifyFedRanges(FederationMap fedMap, Long[] 
dims1, Long[] dims2) {
+       private static FederationMap modifyFedRanges(FederationMap fedMap, 
Long[] dims1, Long[] dims2) {
                IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
                        fedMap.getFederatedRanges()[i]
                                .setBeginDim(0, i == 0 ? 0 : 
fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
@@ -291,7 +291,7 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                return computeOutputDims(tmp);
        }
 
-       private Long[] computeOutputDims(Future<FederatedResponse>[] tmp) {
+       private static Long[] computeOutputDims(Future<FederatedResponse>[] 
tmp) {
                Long[] fedDims = new Long[tmp.length];
                for(int i = 0; i < tmp.length; i ++)
                        try {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
index 5b52dbe..2805107 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/TernaryFEDInstruction.java
@@ -199,7 +199,7 @@ public class TernaryFEDInstruction extends 
ComputationFEDInstruction {
         * @param fedRequest1 federated request to occur after array
         * @return federated requests collected in a single array
         */
-       private FederatedRequest[] collectRequests(FederatedRequest[] 
fedRequests, FederatedRequest fedRequest1){
+       private static FederatedRequest[] collectRequests(FederatedRequest[] 
fedRequests, FederatedRequest fedRequest1){
                FederatedRequest[] allRequests = new 
FederatedRequest[fedRequests.length + 1];
                for ( int i = 0; i < fedRequests.length; i++ )
                        allRequests[i] = fedRequests[i];
diff --git a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java 
b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
index df25887..bca861a 100644
--- a/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
+++ b/src/main/java/org/apache/sysds/runtime/meta/MetaDataAll.java
@@ -139,6 +139,7 @@ public class MetaDataAll extends DataIdentifier {
                return retVal;
        }
 
+       @SuppressWarnings("unchecked")
        private void parseMetaDataParams()
        {
                for( Object obj : _metaObj.entrySet() ){
diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 5929fc6..c7821af 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -87,7 +87,6 @@ import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.utils.ParameterBuilder;
 import org.apache.sysds.utils.Statistics;
 import org.apache.wink.json4j.JSONException;
-import org.apache.wink.json4j.JSONObject;
 import org.junit.After;
 import org.junit.Assert;
 import org.junit.Before;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
index 089c23e..b652228 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedBinaryVectorTest.java
@@ -58,10 +58,8 @@ public class FederatedBinaryVectorTest extends 
AutomatedTestBase {
        public static Collection<Object[]> data() {
                // rows have to be even and > 1
                return Arrays.asList(new Object[][] {
-                       // {2, 1000}, 
-                       // {10, 100}, 
-                       {100, 10}, 
-                       // {1000, 1}, {10, 2000}, {2000, 10}
+                       {100, 10},
+                       {100, 1},
                });
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java 
b/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
index 850032f..6fc57ef 100644
--- a/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/privacy/ReadWriteTest.java
@@ -21,7 +21,6 @@ package org.apache.sysds.test.functions.privacy;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
-import static org.junit.Assert.assertTrue;
 
 import java.io.FileInputStream;
 import java.io.FileOutputStream;
@@ -29,7 +28,6 @@ import java.io.IOException;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
 
-import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.meta.MetaDataAll;
 import org.apache.sysds.runtime.privacy.PrivacyConstraint;
@@ -39,7 +37,6 @@ import 
org.apache.sysds.runtime.privacy.finegrained.FineGrainedPrivacy;
 import org.apache.sysds.runtime.privacy.finegrained.FineGrainedPrivacyList;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
-import org.apache.wink.json4j.JSONObject;
 import org.junit.Assert;
 import org.junit.Test;
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
index a48c6d8..ec79b57 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/privacy/ScalarPropagationTest.java
@@ -33,7 +33,6 @@ import 
org.apache.sysds.runtime.privacy.PrivacyConstraint.PrivacyLevel;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-import org.apache.wink.json4j.JSONObject;
 
 public class ScalarPropagationTest extends AutomatedTestBase 
 {

Reply via email to