This is an automated email from the ASF dual-hosted git repository.
ywcb00 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1581744 [SYSTEMDS-3301] Support Federated CtableExpand
1581744 is described below
commit 1581744755617824ba79bd7a571858cf7d8e2038
Author: ywcb00 <[email protected]>
AuthorDate: Wed Mar 2 20:57:20 2022 +0100
[SYSTEMDS-3301] Support Federated CtableExpand
This patch changes the federated ctable instruction to also support
the opcode ctableexpand, and adds the respective tests.
Closes #1555.
---
.../instructions/fed/CtableFEDInstruction.java | 8 ++++--
.../instructions/fed/FEDInstructionUtils.java | 2 +-
.../federated/primitives/FederatedCtableTest.java | 31 +++++++++++++++++-----
.../federated/FederatedCtableFedOutput.dml | 2 ++
.../FederatedCtableFedOutputReference.dml | 1 +
...dOutput.dml => FederatedCtableSeqVecFedOut.dml} | 15 ++++++-----
...ml => FederatedCtableSeqVecFedOutReference.dml} | 6 +++--
7 files changed, 48 insertions(+), 17 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index 8795308..e644ef1 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -72,7 +72,7 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
String opcode = parts[0];
//handle opcode
- if(!(opcode.equalsIgnoreCase("ctable"))) {
+ if(!(opcode.equalsIgnoreCase("ctable")) &&
!(opcode.equalsIgnoreCase("ctableexpand"))) {
throw new DMLRuntimeException("Unexpected opcode in
CtableFEDInstruction: " + inst);
}
@@ -380,7 +380,11 @@ public class CtableFEDInstruction extends
ComputationFEDInstruction {
}
private String constructMaxInstString(String in, String out) {
- String maxInstrString = instString.replace("ctable", "uamax");
+ String maxInstrString;
+ if(instString.contains("ctableexpand"))
+ maxInstrString = instString.replace("ctableexpand",
"uamax");
+ else
+ maxInstrString = instString.replace("ctable", "uamax");
String[] instParts =
maxInstrString.split(Lop.OPERAND_DELIMITOR);
String[] maxInstParts = new String[] {instParts[0],
instParts[1],
InstructionUtils.concatOperandParts(in,
DataType.MATRIX.name(), (ValueType.FP64).name()),
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 5915033..1192ff7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -273,7 +273,7 @@ public class FEDInstructionUtils {
}
else if(inst instanceof CtableCPInstruction) {
CtableCPInstruction cinst =
(CtableCPInstruction) inst;
- if(inst.getOpcode().equalsIgnoreCase("ctable")
+ if((inst.getOpcode().equalsIgnoreCase("ctable")
|| inst.getOpcode().equalsIgnoreCase("ctableexpand"))
&& (
ec.getCacheableData(cinst.input1).isFederated(FType.ROW)
|| (cinst.input2.isMatrix() &&
ec.getCacheableData(cinst.input2).isFederated(FType.ROW))
|| (cinst.input3.isMatrix() &&
ec.getCacheableData(cinst.input3).isFederated(FType.ROW))))
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
index ab04efc..f3d2974 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -29,6 +29,7 @@ import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
import org.junit.Assert;
+import org.junit.Ignore;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;
@@ -39,8 +40,11 @@ public class FederatedCtableTest extends AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/";
private final static String TEST_NAME1 = "FederatedCtableTest";
private final static String TEST_NAME2 = "FederatedCtableFedOutput";
+ private final static String TEST_NAME3 = "FederatedCtableSeqVecFedOut";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedCtableTest.class.getSimpleName() + "/";
+ private final static double TOLERANCE = 1e-12;
+
private final static int blocksize = 1024;
@Parameterized.Parameter()
public int rows;
@@ -60,6 +64,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"F"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"F"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"F"}));
}
@Parameterized.Parameters
@@ -88,9 +93,20 @@ public class FederatedCtableTest extends AutomatedTestBase {
@Test
public void federatedCtableMatrixInputFedOutputSingleNode() {
runCtable(Types.ExecMode.SINGLE_NODE, true, true); }
+ @Test
+ @Ignore
+ public void federatedCtableSeqVecFedOutputSingleNode() {
runCtable(Types.ExecMode.SINGLE_NODE, true, false, true); }
+
+ @Test
+ public void federatedCtableSeqVecSliceFedOutputSingleNode() {
runCtable(Types.ExecMode.SINGLE_NODE, true, true, true); }
+
public void runCtable(Types.ExecMode execMode, boolean fedOutput,
boolean matrixInput) {
- String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
+ runCtable(execMode, fedOutput, matrixInput, false);
+ }
+
+ public void runCtable(Types.ExecMode execMode, boolean fedOutput,
boolean matrixInput, boolean seqVec) {
+ String TEST_NAME = fedOutput ? (seqVec ? TEST_NAME3 :
TEST_NAME2) : TEST_NAME1;
Types.ExecMode platformOld = setExecMode(execMode);
getAndLoadTestConfiguration(TEST_NAME);
@@ -174,7 +190,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
writeInputMatrixWithMTD("X4", X4, false, mc);
//execute main test
- fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
+ fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
programArgs = new String[]{"-stats", "100", "-args",
input("X1"), input("X2"), input("X3"), input("X4"),
Boolean.toString(reversedInputs).toUpperCase(),
Boolean.toString(weighted).toUpperCase(),
Boolean.toString(matrixInput).toUpperCase(),
@@ -182,7 +198,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
runTest(true, false, null, -1);
// Run actual dml script with federated matrix
- fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+ fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[] {"-stats", "100", "-nvargs",
"in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
"in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
@@ -197,12 +213,15 @@ public class FederatedCtableTest extends
AutomatedTestBase {
void checkResults(boolean fedOutput) {
// compare via files
- compareResults(0);
+ compareResults(TOLERANCE);
// check for federated operations
- Assert.assertTrue(heavyHittersContainsString("fed_ctable"));
- if(fedOutput) // verify output is federated
+ Assert.assertTrue(heavyHittersContainsString("fed_ctable")
+ || heavyHittersContainsString("fed_ctableexpand"));
+ if(fedOutput) { // verify output is federated
Assert.assertTrue(heavyHittersContainsString("fed_uak+"));
+ Assert.assertTrue(heavyHittersContainsString("fed_*"));
+ }
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
index 448c919..3b4c8ed 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -51,6 +51,8 @@ else
X2 = table(rix, cix);
while(FALSE) { }
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
+while(FALSE) { }
Z = as.matrix(sum(X2));
write(Z, $out);
diff --git
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
index 79ed98e..fb9a2de 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -49,6 +49,7 @@ else
else
X2 = table(rix, cix);
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
Z = as.matrix(sum(X2));
write(Z, $8);
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
similarity index 82%
copy from src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
copy to src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
index 448c919..9904170 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOut.dml
@@ -21,7 +21,7 @@
X = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
- list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
m = nrow(X);
n = ncol(X);
@@ -29,14 +29,15 @@ n = ncol(X);
# prepare offset vectors and one-hot encoded X
maxs = colMaxs(X);
if($matrixInput) {
- rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
}
else {
- rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
}
+rix = seq(1, nrow(cix));
+cix = cix[ , 1] + 1; # slice row partitioned federated vector cix and add 1
+
W = rix + cix;
if($revIn)
@@ -46,11 +47,13 @@ if($revIn)
X2 = table(cix, rix);
else
if($weighted)
- X2 = table(rix, cix, W);
- else
- X2 = table(rix, cix);
+ X2 = table(rix, cix, W);
+ else
+ X2 = table(rix, cix);
while(FALSE) { }
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
+while(FALSE) { }
Z = as.matrix(sum(X2));
write(Z, $out);
diff --git
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
similarity index 90%
copy from
src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
copy to
src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
index 79ed98e..bdbea6f 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++
b/src/test/scripts/functions/federated/FederatedCtableSeqVecFedOutReference.dml
@@ -28,14 +28,15 @@ n = ncol(X);
maxs = colMaxs(X);
if($7) { # matrix input
- rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
}
else {
- rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
}
+rix = seq(1, nrow(cix));
+cix = cix[ , 1] + 1; # slice row partitioned federated vector cix and add 1
+
W = rix + cix;
if($5)
@@ -49,6 +50,7 @@ else
else
X2 = table(rix, cix);
+X2 = X2 * (seq(1, nrow(X2)) / nrow(X2));
Z = as.matrix(sum(X2));
write(Z, $8);