This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 998d82e27b8add5a0ca55ac687f0bfd9abe54c8b Author: Matthias Boehm <[email protected]> AuthorDate: Sat Oct 31 20:25:59 2020 +0100 [SYSTEMDS-2549] Extended federated binary element-wise operations This patch generalizes the existing federated binary element-wise operations to avoid unsupported scenarios. Specifically, if the right-hand-side matrix (instead of left-hand-side) matrix is federated and the operation is commutative (e.g., mult/add) we canonicalize the inputs accordingly. --- .../fed/BinaryMatrixMatrixFEDInstruction.java | 17 +++++++++++++---- .../sysds/runtime/matrix/operators/BinaryOperator.java | 7 +++++++ .../apache/sysds/runtime/meta/DataCharacteristics.java | 4 +++- .../sysds/runtime/meta/MatrixCharacteristics.java | 12 +++++++++++- .../sysds/runtime/meta/TensorCharacteristics.java | 9 +++++++++ .../federated/algorithms/FederatedGLMTest.java | 2 +- .../federated/algorithms/FederatedKmeansTest.java | 4 +++- 7 files changed, 47 insertions(+), 8 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java index bceb6ae..ea34df1 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/fed/BinaryMatrixMatrixFEDInstruction.java @@ -25,6 +25,7 @@ import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest; import org.apache.sysds.runtime.controlprogram.federated.FederationUtils; import org.apache.sysds.runtime.instructions.cp.CPOperand; +import org.apache.sysds.runtime.matrix.operators.BinaryOperator; import org.apache.sysds.runtime.matrix.operators.Operator; public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction @@ -39,8 +40,16 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction MatrixObject mo1 = ec.getMatrixObject(input1); MatrixObject mo2 = ec.getMatrixObject(input2); + //canonicalization for federated lhs + if( !mo1.isFederated() && mo2.isFederated() + && mo1.getDataCharacteristics().equalDims(mo2.getDataCharacteristics()) + && ((BinaryOperator)_optr).isCommutative() ) { + mo1 = ec.getMatrixObject(input2); + mo2 = ec.getMatrixObject(input1); + } + + //execute federated operation on mo1 or mo2 FederatedRequest fr2 = null; - if( mo2.isFederated() ) { if(mo1.isFederated() && mo1.getFedMapping().isAligned(mo2.getFedMapping(), false)) { fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, @@ -48,12 +57,12 @@ public class BinaryMatrixMatrixFEDInstruction extends BinaryFEDInstruction mo1.getFedMapping().execute(getTID(), true, fr2); } else { - throw new DMLRuntimeException("Matrix-matrix binary operations " - + " with a federated right input are not supported yet."); + throw new DMLRuntimeException("Matrix-matrix binary operations with a " + + "federated right input are only supported for special cases yet."); } } else { - //matrix-matrix binary oFederatedRequest fr2 = null;perations -> lhs fed input -> fed output + //matrix-matrix binary operations -> lhs fed input -> fed output if(mo2.getNumRows() > 1 && mo2.getNumColumns() == 1 ) { //MV row vector FederatedRequest[] fr1 = mo1.getFedMapping().broadcastSliced(mo2, false); fr2 = FederationUtils.callInstruction(instString, output, new CPOperand[]{input1, input2}, diff --git a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java index beca629..bc4cdd0 100644 --- a/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java +++ b/src/main/java/org/apache/sysds/runtime/matrix/operators/BinaryOperator.java @@ -56,6 +56,7 @@ public class BinaryOperator extends Operator implements Serializable private static final long serialVersionUID = -2547950181558989209L; public final ValueFunction fn; + public final boolean commutative; public BinaryOperator(ValueFunction p) { //binaryop is sparse-safe iff (0 op 0) == 0 @@ -65,6 +66,8 @@ public class BinaryOperator extends Operator implements Serializable || p instanceof BitwAnd || p instanceof BitwOr || p instanceof BitwXor || p instanceof BitwShiftL || p instanceof BitwShiftR); fn = p; + commutative = p instanceof Plus || p instanceof Multiply + || p instanceof And || p instanceof Or || p instanceof Xor; } /** @@ -111,6 +114,10 @@ public class BinaryOperator extends Operator implements Serializable return null; } + public boolean isCommutative() { + return commutative; + } + @Override public String toString() { return "BinaryOperator("+fn.getClass().getSimpleName()+")"; diff --git a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java index d71ce9d..a28d98d 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java +++ b/src/main/java/org/apache/sysds/runtime/meta/DataCharacteristics.java @@ -188,9 +188,11 @@ public abstract class DataCharacteristics implements Serializable { dimOut.set(dim1.getRows(), dim2.getCols(), dim1.getBlocksize()); } + public abstract boolean equalDims(Object anObject); + @Override public abstract boolean equals(Object anObject); - + @Override public abstract int hashCode(); } diff --git a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java index 0b29cce..bdc4b21 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java +++ b/src/main/java/org/apache/sysds/runtime/meta/MatrixCharacteristics.java @@ -229,7 +229,17 @@ public class MatrixCharacteristics extends DataCharacteristics return !nnzKnown() || numRows==0 || numColumns==0 || (nonZero < numRows*numColumns - singleBlk); } - + + @Override + public boolean equalDims(Object anObject) { + if( !(anObject instanceof MatrixCharacteristics) ) + return false; + MatrixCharacteristics mc = (MatrixCharacteristics) anObject; + return dimsKnown() && mc.dimsKnown() + && numRows == mc.numRows + && numColumns == mc.numColumns; + } + @Override public boolean equals (Object anObject) { if( !(anObject instanceof MatrixCharacteristics) ) diff --git a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java index 449cc2d..2b554a2 100644 --- a/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java +++ b/src/main/java/org/apache/sysds/runtime/meta/TensorCharacteristics.java @@ -157,6 +157,15 @@ public class TensorCharacteristics extends DataCharacteristics } @Override + public boolean equalDims(Object anObject) { + if( !(anObject instanceof TensorCharacteristics) ) + return false; + TensorCharacteristics tc = (TensorCharacteristics) anObject; + return dimsKnown() && tc.dimsKnown() + && Arrays.equals(_dims, tc._dims); + } + + @Override public boolean equals (Object anObject) { if( !(anObject instanceof TensorCharacteristics) ) return false; diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java index 2b9d287..44de28f 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedGLMTest.java @@ -123,7 +123,7 @@ public class FederatedGLMTest extends AutomatedTestBase { Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); Assert.assertTrue(heavyHittersContainsString("fed_uark+","fed_uarsqk+")); Assert.assertTrue(heavyHittersContainsString("fed_uack+")); - Assert.assertTrue(heavyHittersContainsString("fed_uak+")); + //Assert.assertTrue(heavyHittersContainsString("fed_uak+")); Assert.assertTrue(heavyHittersContainsString("fed_mmchain")); //check that federated input files are still existing diff --git a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java index eb70a4b..0dd339f 100644 --- a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java +++ b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedKmeansTest.java @@ -128,8 +128,10 @@ public class FederatedKmeansTest extends AutomatedTestBase { // check for federated operations Assert.assertTrue(heavyHittersContainsString("fed_ba+*")); - Assert.assertTrue(heavyHittersContainsString("fed_uasqk+")); + //Assert.assertTrue(heavyHittersContainsString("fed_uasqk+")); Assert.assertTrue(heavyHittersContainsString("fed_uarmin")); + Assert.assertTrue(heavyHittersContainsString("fed_uark+")); + Assert.assertTrue(heavyHittersContainsString("fed_uack+")); Assert.assertTrue(heavyHittersContainsString("fed_*")); Assert.assertTrue(heavyHittersContainsString("fed_+")); Assert.assertTrue(heavyHittersContainsString("fed_<="));
