This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 2026cff  [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
2026cff is described below

commit 2026cfff97fd992c75b6b56ac01d8d199f4f9db3
Author: sebwrede <[email protected]>
AuthorDate: Mon Oct 11 18:30:28 2021 +0200

    [SYSTEMDS-3018] Federated Reorg Operation FedOut Compilation
    
    This commit ensures that Reorg operations rdiag and rev are compiled with 
the federated output flag FOUT/LOUT.
    Additionally, it removes rshape and rsort from the FEDInstructionParser 
since the federated parsing of these
    Reorg types are not supported yet.
    Closes #1414.
---
 .../hops/rewrite/IPAPassRewriteFederatedPlan.java  |  1 -
 src/main/java/org/apache/sysds/lops/Lop.java       |  6 +-
 .../runtime/instructions/FEDInstructionParser.java |  3 +-
 .../runtime/instructions/InstructionUtils.java     | 25 ++++++-
 .../instructions/fed/ReorgFEDInstruction.java      | 81 +++++++++++++---------
 .../instructions/fed/UnaryFEDInstruction.java      | 14 ++++
 .../federated/primitives/FederatedRdiagTest.java   | 17 +++++
 .../federated/primitives/FederatedRevTest.java     | 18 +++++
 8 files changed, 125 insertions(+), 40 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
index cbc21cf..377ebb1 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/IPAPassRewriteFederatedPlan.java
@@ -252,7 +252,6 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
                if ( hopRels.isEmpty() )
                        hopRels.add(new HopRel(currentHop, 
FEDInstruction.FederatedOutput.NONE, hopRelMemo));
                hopRelMemo.put(currentHop.getHopID(), hopRels);
-               currentHop.setVisited();
        }
 
        /**
diff --git a/src/main/java/org/apache/sysds/lops/Lop.java 
b/src/main/java/org/apache/sysds/lops/Lop.java
index e014d3c..7da091f 100644
--- a/src/main/java/org/apache/sysds/lops/Lop.java
+++ b/src/main/java/org/apache/sysds/lops/Lop.java
@@ -117,9 +117,9 @@ public abstract class Lop
        protected PrivacyConstraint privacyConstraint;
 
        /**
-        * Boolean defining if the output of the operation should be federated.
-        * If it is true, the output should be kept at federated sites.
-        * If it is false, the output should be retrieved by the coordinator.
+        * Enum defining if the output of the operation should be forced 
federated, forced local or neither.
+        * If it is FOUT, the output should be kept at federated sites.
+        * If it is LOUT, the output should be retrieved by the coordinator.
         */
        protected FederatedOutput _fedOutput = null;
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 755287a..8000da7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -63,8 +63,9 @@ public class FEDInstructionParser extends InstructionParser
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );
                String2FEDInstructionType.put( "rdiag"  , FEDType.Reorg );
-               String2FEDInstructionType.put( "rshape" , FEDType.Reorg );
                String2FEDInstructionType.put( "rev"    , FEDType.Reorg );
+               //String2FEDInstructionType.put( "rshape" , FEDType.Reorg ); 
Not supported by ReorgFEDInstruction parser!
+               //String2FEDInstructionType.put( "rsort"  , FEDType.Reorg ); 
Not supported by ReorgFEDInstruction parser!
 
                // Ternary Instruction Opcodes
                String2FEDInstructionType.put( "+*" , FEDType.Ternary);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 246ed87..9991edf 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -85,6 +85,7 @@ import 
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.CPInstruction.CPType;
 import org.apache.sysds.runtime.instructions.fed.FEDInstruction.FEDType;
+import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
 import 
org.apache.sysds.runtime.instructions.gpu.GPUInstruction.GPUINSTRUCTION_TYPE;
 import org.apache.sysds.runtime.instructions.spark.SPInstruction.SPType;
 import org.apache.sysds.runtime.matrix.data.LibCommonsMath;
@@ -1144,8 +1145,30 @@ public class InstructionUtils
                return linst;
        }
 
+       /**
+        * Removes federated output flag from the end of the instruction string 
if the flag is present.
+        * @param linst instruction string
+        * @return instruction string with no federated output flag
+        */
        public static String removeFEDOutputFlag(String linst){
-               return linst.substring(0, 
linst.lastIndexOf(Lop.OPERAND_DELIMITOR));
+               int lastOperandStartIndex = 
linst.lastIndexOf(Lop.OPERAND_DELIMITOR);
+               String lastOperand = linst.substring(lastOperandStartIndex);
+               if ( containsFEDOutputFlag(lastOperand) )
+                       return linst.substring(0, lastOperandStartIndex);
+               else return linst;
+       }
+
+       /**
+        * Checks whether the given operand string contains a federated output 
flag
+        * @param operandString which is checked for federated output flag
+        * @return true if the given operand string contains a federated output 
flag
+        */
+       private static boolean containsFEDOutputFlag(String operandString){
+               for (FederatedOutput fedOutput : FederatedOutput.values()){
+                       if ( operandString.contains(fedOutput.name()) )
+                               return true;
+               }
+               return false;
        }
 
        private static String replaceOperand(String linst, CPOperand 
oldOperand, String newOperandName){
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
index c32b15b..4202498 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ReorgFEDInstruction.java
@@ -54,8 +54,6 @@ import 
org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 public class ReorgFEDInstruction extends UnaryFEDInstruction {
-       @SuppressWarnings("unused")
-       private static boolean fedoutFlagInString = false;
 
        public ReorgFEDInstruction(Operator op, CPOperand in1, CPOperand out, 
String opcode, String istr, FederatedOutput fedOut) {
                super(FEDType.Reorg, op, in1, out, opcode, istr, fedOut);
@@ -71,23 +69,25 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
 
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
+               FederatedOutput fedOut;
                if ( opcode.equalsIgnoreCase("r'") ) {
                        InstructionUtils.checkNumFields(str, 2, 3, 4);
                        in.split(parts[1]);
                        out.split(parts[2]);
                        int k = str.startsWith(Types.ExecMode.SPARK.name()) ? 0 
: Integer.parseInt(parts[3]);
-                       FederatedOutput fedOut = 
str.startsWith(Types.ExecMode.SPARK.name()) ?  
FederatedOutput.valueOf(parts[3]) :
-                               FederatedOutput.valueOf(parts[4]);
+                       fedOut = str.startsWith(Types.ExecMode.SPARK.name()) ?
+                               FederatedOutput.valueOf(parts[3]) : 
FederatedOutput.valueOf(parts[4]);
                        return new ReorgFEDInstruction(new 
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k), in, out, opcode, str, 
fedOut);
                }
                else if ( opcode.equalsIgnoreCase("rdiag") ) {
                        parseUnaryInstruction(str, in, out); //max 2 operands
-                       return new ReorgFEDInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str);
+                       fedOut = parseFedOutFlag(str, 3);
+                       return new ReorgFEDInstruction(new 
ReorgOperator(DiagIndex.getDiagIndexFnObject()), in, out, opcode, str, fedOut);
                }
                else if ( opcode.equalsIgnoreCase("rev") ) {
-                       fedoutFlagInString = parts.length > 3;
                        parseUnaryInstruction(str, in, out); //max 2 operands
-                       return new ReorgFEDInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str);
+                       fedOut = parseFedOutFlag(str, 3);
+                       return new ReorgFEDInstruction(new 
ReorgOperator(RevIndex.getRevIndexFnObject()), in, out, opcode, str, fedOut);
                }
                else {
                        throw new DMLRuntimeException("ReorgFEDInstruction: 
unsupported opcode: "+opcode);
@@ -117,7 +117,6 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                        mo1.getFedMapping().execute(getTID(), true, fr, fr1);
 
                        if (_fedOut != null && !_fedOut.isForcedLocal()){
-                               mo1.getFedMapping().execute(getTID(), true, 
fr1);
                                //drive output federated mapping
                                MatrixObject out = ec.getMatrixObject(output);
                                
out.getDataCharacteristics().set(mo1.getNumColumns(), mo1.getNumRows(), (int) 
mo1.getBlocksize(), mo1.getNnz());
@@ -146,10 +145,7 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                        out.getDataCharacteristics().set(mo1.getNumRows(), 
mo1.getNumColumns(), (int) mo1.getBlocksize(), mo1.getNnz());
                        
out.setFedMapping(mo1.getFedMapping().copyWithNewID(fr1.getID()));
 
-                       if ( _fedOut != null && _fedOut.isForcedLocal() ){
-                               out.acquireReadAndRelease();
-                               out.getFedMapping().cleanup(getTID(), 
fr1.getID());
-                       }
+                       optionalForceLocal(out);
                }
                else if (instOpcode.equals("rdiag")) {
                        RdiagResult result;
@@ -160,24 +156,7 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                                result = rdiagM2V(mo1, r_op);
                        }
 
-                       FederationMap diagFedMap = result.getFedMap();
-                       Map<FederatedRange, int[]> dcs = result.getDcs();
-
-                       //update fed ranges
-                       for(int i = 0; i < 
diagFedMap.getFederatedRanges().length; i++) {
-                               int[] newRange = 
dcs.get(diagFedMap.getFederatedRanges()[i]);
-
-                               
diagFedMap.getFederatedRanges()[i].setBeginDim(0,
-                                       
(diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
-                                               i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
-                               diagFedMap.getFederatedRanges()[i].setEndDim(0,
-                                       
diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
-                               
diagFedMap.getFederatedRanges()[i].setBeginDim(1,
-                                       
(diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
-                                               i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
-                               diagFedMap.getFederatedRanges()[i].setEndDim(1,
-                                       
diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
-                       }
+                       FederationMap diagFedMap = updateFedRanges(result);
 
                        //update output mapping and data characteristics
                        MatrixObject rdiag = ec.getMatrixObject(output);
@@ -185,10 +164,44 @@ public class ReorgFEDInstruction extends 
UnaryFEDInstruction {
                                .set(diagFedMap.getMaxIndexInRange(0), 
diagFedMap.getMaxIndexInRange(1),
                                        (int) mo1.getBlocksize());
                        rdiag.setFedMapping(diagFedMap);
-                       if ( _fedOut != null && _fedOut.isForcedLocal() ){
-                               rdiag.acquireReadAndRelease();
-                               rdiag.getFedMapping().cleanup(getTID(), 
rdiag.getFedMapping().getID());
-                       }
+                       optionalForceLocal(rdiag);
+               }
+       }
+
+       /**
+        * Update the federated ranges of result and return the updated 
federation map.
+        * @param result RdiagResult for which the fedmap is updated
+        * @return updated federation map
+        */
+       private FederationMap updateFedRanges(RdiagResult result){
+               FederationMap diagFedMap = result.getFedMap();
+               Map<FederatedRange, int[]> dcs = result.getDcs();
+
+               for(int i = 0; i < diagFedMap.getFederatedRanges().length; i++) 
{
+                       int[] newRange = 
dcs.get(diagFedMap.getFederatedRanges()[i]);
+
+                       diagFedMap.getFederatedRanges()[i].setBeginDim(0,
+                               
(diagFedMap.getFederatedRanges()[i].getBeginDims()[0] == 0 ||
+                                       i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
+                       diagFedMap.getFederatedRanges()[i].setEndDim(0,
+                               
diagFedMap.getFederatedRanges()[i].getBeginDims()[0] + newRange[0]);
+                       diagFedMap.getFederatedRanges()[i].setBeginDim(1,
+                               
(diagFedMap.getFederatedRanges()[i].getBeginDims()[1] == 0 ||
+                                       i == 0) ? 0 : 
diagFedMap.getFederatedRanges()[i - 1].getEndDims()[1]);
+                       diagFedMap.getFederatedRanges()[i].setEndDim(1,
+                               
diagFedMap.getFederatedRanges()[i].getBeginDims()[1] + newRange[1]);
+               }
+               return diagFedMap;
+       }
+
+       /**
+        * If federated output is forced local, the output will be retrieved 
and removed from federated workers.
+        * @param outputMatrixObject which will be retrieved and removed from 
federated workers
+        */
+       private void optionalForceLocal(MatrixObject outputMatrixObject){
+               if ( _fedOut != null && _fedOut.isForcedLocal() ){
+                       outputMatrixObject.acquireReadAndRelease();
+                       outputMatrixObject.getFedMapping().cleanup(getTID(), 
outputMatrixObject.getFedMapping().getID());
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
index 0ae3178..dea0acf 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/UnaryFEDInstruction.java
@@ -110,4 +110,18 @@ public abstract class UnaryFEDInstruction extends 
ComputationFEDInstruction {
                out.split(parts[parts.length - 2]);
                return opcode;
        }
+
+       /**
+        * Parse and return federated output flag from given instr string at 
given position.
+        * If the position given is greater than the length of the instruction, 
FederatedOutput.NONE is returned.
+        * @param instr instruction string to be parsed
+        * @param position of federated output flag
+        * @return parsed federated output flag or FederatedOutput.NONE
+        */
+       static FederatedOutput parseFedOutFlag(String instr, int position){
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(instr);
+               if ( parts.length > position )
+                       return FederatedOutput.valueOf(parts[position]);
+               else return FederatedOutput.NONE;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
index cda9966..e4e7a88 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRdiagTest.java
@@ -24,6 +24,7 @@ import java.util.Collection;
 
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -69,7 +70,21 @@ public class FederatedRdiagTest extends AutomatedTestBase {
        @Test
        public void federatedRdiagSP() { federatedRdiag(Types.ExecMode.SPARK); }
 
+       @Test
+       public void federatedCompilationRDiagCP(){
+               federatedRdiag(Types.ExecMode.SINGLE_NODE, true);
+       }
+
+       @Test
+       public void federatedCompilationRdiagSP(){
+               federatedRdiag(Types.ExecMode.SPARK, true);
+       }
+
        public void federatedRdiag(Types.ExecMode execMode) {
+               federatedRdiag(execMode, false);
+       }
+
+       public void federatedRdiag(Types.ExecMode execMode, boolean 
activateFedCompilation) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                Types.ExecMode platformOld = rtplatform;
 
@@ -111,6 +126,7 @@ public class FederatedRdiagTest extends AutomatedTestBase {
                        input("X1"), input("X2"), input("X3"), input("X4"), 
expected("S")};
                runTest(null);
 
+               OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
                TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
                loadTestConfiguration(config);
 
@@ -139,5 +155,6 @@ public class FederatedRdiagTest extends AutomatedTestBase {
                TestUtils.shutdownThreads(t1, t2, t3, t4);
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               OptimizerUtils.FEDERATED_COMPILATION = false;
        }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
index 847f351..66f9c2f 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedRevTest.java
@@ -23,7 +23,9 @@ import java.util.Arrays;
 import java.util.Collection;
 
 import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
 import org.apache.sysds.test.AutomatedTestBase;
@@ -77,7 +79,21 @@ public class FederatedRevTest extends AutomatedTestBase {
                runRevTest(ExecMode.SPARK);
        }
 
+       @Test
+       public void federatedCompilationRevCP(){
+               runRevTest(Types.ExecMode.SINGLE_NODE, true);
+       }
+
+       @Test
+       public void federatedCompilationRevSP(){
+               runRevTest(Types.ExecMode.SPARK, true);
+       }
+
        private void runRevTest(ExecMode execMode) {
+               runRevTest(execMode, false);
+       }
+
+       private void runRevTest(ExecMode execMode, boolean 
activateFedCompilation) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
                ExecMode platformOld = rtplatform;
 
@@ -135,6 +151,7 @@ public class FederatedRevTest extends AutomatedTestBase {
 
                runTest(null);
 
+               OptimizerUtils.FEDERATED_COMPILATION = activateFedCompilation;
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
                programArgs = new String[] {"-stats", "100", "-nvargs",
                        "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
@@ -160,6 +177,7 @@ public class FederatedRevTest extends AutomatedTestBase {
 
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               OptimizerUtils.FEDERATED_COMPILATION = false;
 
        }
 }

Reply via email to