This is an automated email from the ASF dual-hosted git repository.

ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1581744  [SYSTEMDS-3301] Support Federated CtableExpand
1581744 is described below

commit 1581744755617824ba79bd7a571858cf7d8e2038
Author: ywcb00 <[email protected]>
AuthorDate: Wed Mar 2 20:57:20 2022 +0100

    [SYSTEMDS-3301] Support Federated CtableExpand
    
    This patch changes the federated ctable instruction to also support
    the opcode ctableexpand, and adds the respective tests.
    
    Closes #1555.
---
 .../instructions/fed/CtableFEDInstruction.java     |  8 ++++--
 .../instructions/fed/FEDInstructionUtils.java      |  2 +-
 .../federated/primitives/FederatedCtableTest.java  | 31 +++++++++++++++++-----
 .../federated/FederatedCtableFedOutput.dml         |  2 ++
 .../FederatedCtableFedOutputReference.dml          |  1 +
 ...dOutput.dml => FederatedCtableSeqVecFedOut.dml} | 15 ++++++-----
 ...ml => FederatedCtableSeqVecFedOutReference.dml} |  6 +++--
 7 files changed, 48 insertions(+), 17 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 8795308..e644ef1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -72,7 +72,7 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                String opcode = parts[0];
 
                //handle opcode
-               if(!(opcode.equalsIgnoreCase("ctable"))) {
+               if(!(opcode.equalsIgnoreCase("ctable")) && 
!(opcode.equalsIgnoreCase("ctableexpand"))) {
                        throw new DMLRuntimeException("Unexpected opcode in 
CtableFEDInstruction: " + inst);
                }
 
@@ -380,7 +380,11 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
        }
 
        private String constructMaxInstString(String in, String out) {
-               String maxInstrString = instString.replace("ctable", "uamax");
+               String maxInstrString;
+               if(instString.contains("ctableexpand"))
+                       maxInstrString = instString.replace("ctableexpand", 
"uamax");
+               else
+                       maxInstrString = instString.replace("ctable", "uamax");
                String[] instParts = 
maxInstrString.split(Lop.OPERAND_DELIMITOR);
                String[] maxInstParts = new String[] {instParts[0], 
instParts[1],
                        InstructionUtils.concatOperandParts(in, 
DataType.MATRIX.name(), (ValueType.FP64).name()),
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5915033..1192ff7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -273,7 +273,7 @@ public class FEDInstructionUtils {
                        }
                        else if(inst instanceof CtableCPInstruction) {
                                CtableCPInstruction cinst = 
(CtableCPInstruction) inst;
-                               if(inst.getOpcode().equalsIgnoreCase("ctable")
+                               if((inst.getOpcode().equalsIgnoreCase("ctable") 
|| inst.getOpcode().equalsIgnoreCase("ctableexpand"))
                                        && ( 
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
                                        || (cinst.input2.isMatrix() && 
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
                                        || (cinst.input3.isMatrix() && 
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
index ab04efc..f3d2974 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Assert;
+import org.junit.Ignore;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -39,8 +40,11 @@ public class FederatedCtableTest extends AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/";
        private final static String TEST_NAME1 = "FederatedCtableTest";
        private final static String TEST_NAME2 = "FederatedCtableFedOutput";
+       private final static String TEST_NAME3 = "FederatedCtableSeqVecFedOut";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedCtableTest.class.getSimpleName() + "/";
 
+       private final static double TOLERANCE = 1e-12;
+
        private final static int blocksize = 1024;
        @Parameterized.Parameter()
        public int rows;
@@ -60,6 +64,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"F"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"F"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"F"}));
        }
 
        @Parameterized.Parameters
@@ -88,9 +93,20 @@ public class FederatedCtableTest extends AutomatedTestBase {
        @Test
        public void federatedCtableMatrixInputFedOutputSingleNode() { 
runCtable(Types.ExecMode.SINGLE_NODE, true, true); }
 
+       @Test
+       @Ignore
+       public void federatedCtableSeqVecFedOutputSingleNode() { 
runCtable(Types.ExecMode.SINGLE_NODE, true, false, true); }
+
+       @Test
+       public void federatedCtableSeqVecSliceFedOutputSingleNode() { 
runCtable(Types.ExecMode.SINGLE_NODE, true, true, true); }
+
 
        public void runCtable(Types.ExecMode execMode, boolean fedOutput, 
boolean matrixInput) {
-               String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
+               runCtable(execMode, fedOutput, matrixInput, false);
+       }
+
+       public void runCtable(Types.ExecMode execMode, boolean fedOutput, 
boolean matrixInput, boolean seqVec) {
+               String TEST_NAME = fedOutput ? (seqVec ? TEST_NAME3 : 
TEST_NAME2) : TEST_NAME1;
                Types.ExecMode platformOld = setExecMode(execMode);
 
                getAndLoadTestConfiguration(TEST_NAME);
@@ -174,7 +190,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                writeInputMatrixWithMTD("X4", X4, false, mc);
 
                //execute main test
-               fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
                programArgs = new String[]{"-stats", "100", "-args",
                        input("X1"), input("X2"), input("X3"), input("X4"), 
Boolean.toString(reversedInputs).toUpperCase(),
                        Boolean.toString(weighted).toUpperCase(), 
Boolean.toString(matrixInput).toUpperCase(),
@@ -182,7 +198,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
-               fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats", "100", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
                        "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
@@ -197,12 +213,15 @@ public class FederatedCtableTest extends 
AutomatedTestBase {
 
        void checkResults(boolean fedOutput) {
                // compare via files
-               compareResults(0);
+               compareResults(TOLERANCE);
 
                // check for federated operations
-               Assert.assertTrue(heavyHittersContainsString("fed_ctable"));
-               if(fedOutput) // verify output is federated
+               Assert.assertTrue(heavyHittersContainsString("fed_ctable")
+                       || heavyHittersContainsString("fed_ctableexpand"));
+               if(fedOutput) { // verify output is federated
                        
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+                       Assert.assertTrue(heavyHittersContainsString("fed_*"));
+               }
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
index 448c919..3b4c8ed 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -51,6 +51,8 @@ else
       X2 = table(rix, cix);
 
 while(FALSE) { }
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
+while(FALSE) { }
 Z = as.matrix(sum(X2));
 
 write(Z, $out);
diff --git 
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
index 79ed98e..fb9a2de 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -49,6 +49,7 @@ else
   else
     X2 = table(rix, cix);
 
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
 Z = as.matrix(sum(X2));
 
 write(Z, $8);
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml 
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
copy to src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
index 448c919..9904170 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
@@ -21,7 +21,7 @@
 
 X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
         ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
-               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+        list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
 
 m = nrow(X);
 n = ncol(X);
@@ -29,14 +29,15 @@ n = ncol(X);
 # prepare offset vectors and one-hot encoded X
 maxs = colMaxs(X);
 if($matrixInput) {
-  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
   cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
 }
 else {
-  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
   cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
 }
 
+rix = seq(1, nrow(cix));
+cix = cix[ , 1] + 1; # slice row partitioned federated vector cix and add 1
+
 W = rix + cix;
 
 if($revIn)
@@ -46,11 +47,13 @@ if($revIn)
     X2 = table(cix, rix);
 else
   if($weighted)
-      X2 = table(rix, cix, W);
-    else
-      X2 = table(rix, cix);
+    X2 = table(rix, cix, W);
+  else
+    X2 = table(rix, cix);
 
 while(FALSE) { }
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
+while(FALSE) { }
 Z = as.matrix(sum(X2));
 
 write(Z, $out);
diff --git 
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml 
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
similarity index 90%
copy from 
src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
copy to 
src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
index 79ed98e..bdbea6f 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
@@ -28,14 +28,15 @@ n = ncol(X);
 maxs = colMaxs(X);
 
 if($7) { # matrix input
-  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
   cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
 }
 else {
-  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
   cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
 }
 
+rix = seq(1, nrow(cix));
+cix = cix[ , 1] + 1; # slice row partitioned federated vector cix and add 1
+
 W = rix + cix;
 
 if($5)
@@ -49,6 +50,7 @@ else
   else
     X2 = table(rix, cix);
 
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
 Z = as.matrix(sum(X2));
 
 write(Z, $8);

Reply via email to