This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new be191a2  [SYSTEMDS-3313] Misc fixes federated planning (planner, lops, 
rewrites)
be191a2 is described below

commit be191a244f4b0871b68daf3084ca65ec0d9fb7ec
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 13 00:18:04 2022 +0100

    [SYSTEMDS-3313] Misc fixes federated planning (planner, lops, rewrites)
    
    1) Extended supported operators during planning
    2) Improved tak+ rewrites during lop construction of unary aggregates
    3) Extended FED instruction parsing (relational operators)
    4) Fixed FED tsmm instruction generation
---
 src/main/java/org/apache/sysds/hops/AggUnaryOp.java  |  9 +++++----
 .../sysds/hops/fedplanner/AFederatedPlanner.java     | 13 +++++++++++--
 .../org/apache/sysds/hops/fedplanner/FTypes.java     |  4 ++--
 ...lannerAllFed.java => FederatedPlannerFedAll.java} |  2 +-
 ...ristic.java => FederatedPlannerFedHeuristic.java} |  2 +-
 src/main/java/org/apache/sysds/lops/MMTSJ.java       |  2 +-
 .../runtime/instructions/FEDInstructionParser.java   | 20 +++++++++++++-------
 7 files changed, 34 insertions(+), 18 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java 
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 1ad6ac7..923503b 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -38,7 +38,6 @@ import 
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
-
 // Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
 
 public class AggUnaryOp extends MultiThreadedHop
@@ -553,7 +552,8 @@ public class AggUnaryOp extends MultiThreadedHop
                        in2 = in1;
                        in3 = in1;
                        handled = true;
-               } else if (input11 instanceof BinaryOp ) {
+               }
+               else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT, 
OpOp2.POW) ) {
                        BinaryOp b11 = (BinaryOp)input11;
                        switch( b11.getOp() ) {
                        case MULT: // A*B*C case
@@ -574,7 +574,8 @@ public class AggUnaryOp extends MultiThreadedHop
                                break;
                        default: break;
                        }
-               } else if( input12 instanceof BinaryOp ) {
+               }
+               else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT, 
OpOp2.POW) ) {
                        BinaryOp b12 = (BinaryOp)input12;
                        switch (b12.getOp()) {
                        case MULT: // A*B*C case
@@ -668,7 +669,7 @@ public class AggUnaryOp extends MultiThreadedHop
                if( !(that instanceof AggUnaryOp) )
                        return false;
                
-               AggUnaryOp that2 = (AggUnaryOp)that;            
+               AggUnaryOp that2 = (AggUnaryOp)that;
                return (   _op == that2._op
                                && _direction == that2._direction
                                && _maxNumThreads == that2._maxNumThreads
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index 50c5f46..97d4939 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -21,11 +21,13 @@ package org.apache.sysds.hops.fedplanner;
 
 import java.util.Map;
 
+import org.apache.sysds.common.Types.AggOp;
 import org.apache.sysds.common.Types.ReOrgOp;
 import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.BinaryOp;
 import org.apache.sysds.hops.DataOp;
 import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.TernaryOp;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.hops.ipa.FunctionCallGraph;
 import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
@@ -56,15 +58,20 @@ public abstract class AFederatedPlanner {
                //handle specific operators
                if( hop instanceof AggBinaryOp ) {
                        return (ft[0] != null && ft[1] == null)
-                               || (ft[0] == null && ft[1] != null);
+                               || (ft[0] == null && ft[1] != null)
+                               || (ft[0] == FType.COL && ft[1] == FType.ROW);
                }
                else if( hop instanceof BinaryOp && 
!hop.getDataType().isScalar() ) {
                        return (ft[0] != null && ft[1] == null)
                                || (ft[0] == null && ft[1] != null)
                                || (ft[0] != null && ft[0] == ft[1]);
                }
+               else if( hop instanceof TernaryOp && 
!hop.getDataType().isScalar() ) {
+                       return (ft[0] != null || ft[1] != null || ft[2] != 
null);
+               }
                else if(ft.length==1 && ft[0] != null) {
-                       return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS);
+                       return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS)
+                               || HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM, 
AggOp.MIN, AggOp.MAX);
                }
                
                return false;
@@ -85,6 +92,8 @@ public abstract class AFederatedPlanner {
                }
                else if( hop instanceof BinaryOp ) 
                        return ft[0] != null ? ft[0] : ft[1];
+               else if( hop instanceof TernaryOp )
+                       return ft[0] != null ? ft[0] : ft[1] != null ? ft[1] : 
ft[2];
                else if( HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS) )
                        return ft[0] == FType.ROW ? FType.COL : FType.COL;
                
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
index 98de495..7efabc8 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -30,9 +30,9 @@ public class FTypes
                public AFederatedPlanner getPlanner() {
                        switch( this ) {
                                case COMPILE_FED_ALL:
-                                       return new FederatedPlannerAllFed();
+                                       return new FederatedPlannerFedAll();
                                case COMPILE_FED_HEURISTIC:
-                                       return new FederatedPlannerHeuristic();
+                                       return new 
FederatedPlannerFedHeuristic();
                                case COMPILE_COST_BASED:
                                        return new FederatedPlannerCostbased();
                                case NONE:
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
similarity index 98%
rename from 
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
rename to 
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
index a35d94c..f11bbd1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
@@ -48,7 +48,7 @@ import 
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
  * that support federated execution on federated inputs to
  * forced federated operations.
  */
-public class FederatedPlannerAllFed extends AFederatedPlanner {
+public class FederatedPlannerFedAll extends AFederatedPlanner {
        
        @Override
        public void rewriteProgram( DMLProgram prog,
diff --git 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
similarity index 94%
rename from 
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
rename to 
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
index 15b12ac..4bc0b88 100644
--- 
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
+++ 
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
@@ -25,7 +25,7 @@ import org.apache.sysds.hops.AggBinaryOp;
 import org.apache.sysds.hops.Hop;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 
-public class FederatedPlannerHeuristic extends FederatedPlannerAllFed {
+public class FederatedPlannerFedHeuristic extends FederatedPlannerFedAll {
        
        @Override
        protected FType getFederatedOut(Hop hop, Map<Long, FType> fedHops) {
diff --git a/src/main/java/org/apache/sysds/lops/MMTSJ.java 
b/src/main/java/org/apache/sysds/lops/MMTSJ.java
index 27876ed..45ad196 100644
--- a/src/main/java/org/apache/sysds/lops/MMTSJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMTSJ.java
@@ -92,7 +92,7 @@ public class MMTSJ extends Lop
                sb.append( _type );
                
                //append degree of parallelism for matrix multiplications
-               if( getExecType()==ExecType.CP ) {
+               if( getExecType()==ExecType.CP || getExecType()==ExecType.FED ) 
{
                        sb.append( OPERAND_DELIMITOR );
                        sb.append( _numThreads );
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 11ea4e0..4992426 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -43,7 +43,7 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "fedinit"  , FEDType.Init );
                String2FEDInstructionType.put( "tsmm"     , FEDType.Tsmm );
                String2FEDInstructionType.put( "ba+*"     , 
FEDType.AggregateBinary );
-               String2FEDInstructionType.put( "tak+*"    , 
FEDType.AggregateTernary);
+               String2FEDInstructionType.put( "tak+*"    , 
FEDType.AggregateTernary);
 
                String2FEDInstructionType.put( "uak+"    , 
FEDType.AggregateUnary );
                String2FEDInstructionType.put( "uark+"   , 
FEDType.AggregateUnary );
@@ -56,12 +56,18 @@ public class FEDInstructionParser extends InstructionParser
                String2FEDInstructionType.put( "uacvar"  , 
FEDType.AggregateUnary);
 
                // Arithmetic Instruction Opcodes
-               String2FEDInstructionType.put( "+" , FEDType.Binary );
-               String2FEDInstructionType.put( "-" , FEDType.Binary );
-               String2FEDInstructionType.put( "*" , FEDType.Binary );
-               String2FEDInstructionType.put( "/" , FEDType.Binary );
-               String2FEDInstructionType.put( "1-*" , FEDType.Binary); 
//special * case
-               String2FEDInstructionType.put( "max" , FEDType.Binary );
+               String2FEDInstructionType.put( "+" ,  FEDType.Binary );
+               String2FEDInstructionType.put( "-" ,  FEDType.Binary );
+               String2FEDInstructionType.put( "*" ,  FEDType.Binary );
+               String2FEDInstructionType.put( "/" ,  FEDType.Binary );
+               String2FEDInstructionType.put( "1-*", FEDType.Binary); 
//special * case
+               String2FEDInstructionType.put( "max", FEDType.Binary );
+               String2FEDInstructionType.put( "==",  FEDType.Binary);
+               String2FEDInstructionType.put( "!=",  FEDType.Binary);
+               String2FEDInstructionType.put( "<",   FEDType.Binary);
+               String2FEDInstructionType.put( ">",   FEDType.Binary);
+               String2FEDInstructionType.put( "<=",  FEDType.Binary);
+               String2FEDInstructionType.put( ">=",  FEDType.Binary);
 
                // Reorg Instruction Opcodes (repositioning of existing values)
                String2FEDInstructionType.put( "r'"     , FEDType.Reorg );

Reply via email to