This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f33be5b1d9 [MINOR] Support Non-literals in Federated Reshape 
Instructions
f33be5b1d9 is described below

commit f33be5b1d9100e781dfa8f0ebf63390817b606bb
Author: ywcb00 <[email protected]>
AuthorDate: Sat Jul 15 15:32:56 2023 +0200

    [MINOR] Support Non-literals in Federated Reshape Instructions
    
    AMLS project SoSe'23, part I
    Closes #1862.
---
 .../instructions/fed/ReshapeFEDInstruction.java        | 18 ++++++++++--------
 .../functions/federated/io/FederatedReaderTest.java    | 12 +++++-------
 .../federated/primitives/FederatedMisAlignedTest.java  |  8 ++++----
 .../functions/federated/FederatedReshapeTest.dml       |  7 ++++++-
 4 files changed, 25 insertions(+), 20 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
index 3d355cd1dd..521dbe8e51 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReshapeFEDInstruction.java
@@ -22,6 +22,7 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.Arrays;
 import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.lops.Lop;
@@ -119,7 +120,7 @@ public class ReshapeFEDInstruction extends 
UnaryFEDInstruction {
                        mo1.getFedMapping().execute(getTID(), true, fr1, new 
FederatedRequest[0]);
 
                        // set new fed map
-                       FederationMap reshapedFedMap = mo1.getFedMapping();
+                       FederationMap reshapedFedMap = 
mo1.getFedMapping().copyWithNewID(fr1[0].getID());
                        for(int i = 0; i < 
reshapedFedMap.getFederatedRanges().length; i++) {
                                long cells = 
reshapedFedMap.getFederatedRanges()[i].getSize();
                                long row = byRow.getBooleanValue() ? cells / 
cols : rows;
@@ -140,7 +141,7 @@ public class ReshapeFEDInstruction extends 
UnaryFEDInstruction {
                        //derive output federated mapping
                        MatrixObject out = ec.getMatrixObject(output);
                        out.getDataCharacteristics().set(rows, cols, (int) 
mo1.getBlocksize(), mo1.getNnz());
-                       
out.setFedMapping(reshapedFedMap.copyWithNewID(fr1[0].getID()));
+                       out.setFedMapping(reshapedFedMap);
                }
                else {
                        // TODO support tensor out, frame and list
@@ -156,14 +157,15 @@ public class ReshapeFEDInstruction extends 
UnaryFEDInstruction {
                        .collect(Collectors.toSet()).size();
                sameFedSize = sameFedSize == 1 ? 1 : 
mo1.getFedMapping().getSize();
 
+               String execTypeName = 
InstructionUtils.getExecType(instString).name();
+               String[] instParts = 
InstructionUtils.getInstructionPartsWithValueType(instString);
                for(int i = 0; i < sameFedSize; i++) {
-                       String[] instParts = 
instString.split(Lop.OPERAND_DELIMITOR);
                        long size = 
mo1.getFedMapping().getFederatedRanges()[i].getSize();
-                       String oldInstStringPart = byRow ? instParts[3] : 
instParts[4];
-                       String newInstStringPart = byRow ? 
-                               oldInstStringPart.replace(String.valueOf(rows), 
String.valueOf(size/cols)) :
-                               oldInstStringPart.replace(String.valueOf(cols), 
String.valueOf(size/rows));
-                       instStrings[i] = instString.replace(oldInstStringPart, 
newInstStringPart);
+                       instParts[2] = InstructionUtils.createLiteralOperand(
+                               String.valueOf((int)(byRow ? size/cols : 
rows)), Types.ValueType.INT64);
+                       instParts[3] = InstructionUtils.createLiteralOperand(
+                               String.valueOf((int)(byRow ? cols : 
size/rows)), Types.ValueType.INT64);
+                       instStrings[i] = 
InstructionUtils.concatOperands(ArrayUtils.addFirst(instParts, execTypeName));
                }
 
                if(sameFedSize == 1)
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
index ff68c8328e..295fe54770 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/io/FederatedReaderTest.java
@@ -40,7 +40,7 @@ import org.junit.runners.Parameterized;
 public class FederatedReaderTest extends AutomatedTestBase {
 
        private static final Log LOG = 
LogFactory.getLog(FederatedReaderTest.class.getName());
-       private final static String TEST_DIR = "functions/federated/ioR/";
+       private final static String TEST_DIR = "functions/federated/io/";
        private final static String TEST_NAME = "FederatedReaderTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedReaderTest.class.getSimpleName() + "/";
        private final static int blocksize = 1024;
@@ -50,8 +50,6 @@ public class FederatedReaderTest extends AutomatedTestBase {
        public int cols;
        @Parameterized.Parameter(2)
        public boolean rowPartitioned;
-       @Parameterized.Parameter(3)
-       public int fedCount;
 
        @Override
        public void setUp() {
@@ -62,7 +60,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
                // number of rows or cols has to be >= number of federated 
locations.
-               return Arrays.asList(new Object[][] {{10, 13, true, 2}});
+               return Arrays.asList(new Object[][] {{10, 13, true}});
        }
 
        @Test
@@ -111,11 +109,11 @@ public class FederatedReaderTest extends 
AutomatedTestBase {
                        // Run reference dml script with normal matrix
 
                        if(workerCount == 1) {
-                               fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/" + TEST_NAME + "1Reference.dml";
+                               fullDMLScriptName = SCRIPT_DIR + TEST_DIR + 
TEST_NAME + "1Reference.dml";
                                programArgs = new String[] {"-stats", "-args", 
input("X1")};
                        }
                        else {
-                               fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/" + TEST_NAME
+                               fullDMLScriptName = SCRIPT_DIR + TEST_DIR + 
TEST_NAME
                                        + (rowPartitioned ? "Row" : "Col") + 
"2Reference.dml";
                                programArgs = new String[] {"-stats", "-args", 
input("X1"), input("X2")};
                        }
@@ -125,7 +123,7 @@ public class FederatedReaderTest extends AutomatedTestBase {
                        LOG.debug(refOut);
                        
                        // Run federated
-                       fullDMLScriptName = SCRIPT_DIR + 
"functions/federated/io/" + TEST_NAME + ".dml";
+                       fullDMLScriptName = SCRIPT_DIR + TEST_DIR + TEST_NAME + 
".dml";
                        programArgs = new String[] {"-stats", "-args", 
input("X.json")};
                        String out = runTest(null).toString();
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
index ecc8a7b90f..5b4b350b08 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedMisAlignedTest.java
@@ -205,10 +205,10 @@ public class FederatedMisAlignedTest extends 
AutomatedTestBase {
                        c = cols;
                }
 
-               double[][] X1 = getRandomMatrix(r, c, 3, 3, 1, 3);
-               double[][] X2 = getRandomMatrix(r, c, 3, 3, 1, 7);
-               double[][] X3 = getRandomMatrix(r, c, 3, 3, 1, 8);
-               double[][] X4 = getRandomMatrix(r, c, 3, 3, 1, 9);
+               double[][] X1 = getRandomMatrix(r, c, 3, 4, 1, 3);
+               double[][] X2 = getRandomMatrix(r, c, 3, 4, 1, 7);
+               double[][] X3 = getRandomMatrix(r, c, 3, 4, 1, 8);
+               double[][] X4 = getRandomMatrix(r, c, 3, 4, 1, 9);
 
                MatrixCharacteristics mc = new MatrixCharacteristics(r, c, 
blocksize, r * c);
                writeInputMatrixWithMTD("X1", X1, false, mc);
diff --git a/src/test/scripts/functions/federated/FederatedReshapeTest.dml 
b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
index 6aa8a165b5..f133bcff17 100644
--- a/src/test/scripts/functions/federated/FederatedReshapeTest.dml
+++ b/src/test/scripts/functions/federated/FederatedReshapeTest.dml
@@ -27,5 +27,10 @@ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
     ranges=list(list(0, 0), list(2, 12), list(2, 0), list(4, $cols),
     list(4, 0), list(10, $cols), list(10, 0), list(12, $cols)));
 
-s = matrix(A, rows=$r_rows, cols=$r_cols);
+# materialize the scalar input (non-literal)
+reshape_cols = $r_cols;
+while(FALSE) {}
+reshape_cols = reshape_cols;
+
+s = matrix(A, rows=$r_rows, cols=reshape_cols);
 write(s, $out_S);

Reply via email to