This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new be191a2 [SYSTEMDS-3313] Misc fixes federated planning (planner, lops,
rewrites)
be191a2 is described below
commit be191a244f4b0871b68daf3084ca65ec0d9fb7ec
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 13 00:18:04 2022 +0100
[SYSTEMDS-3313] Misc fixes federated planning (planner, lops, rewrites)
1) Extended supported operators during planning
2) Improved tak+ rewrites during lop construction of unary aggregates
3) Extended FED instruction parsing (relational operators)
4) Fixed FED tsmm instruction generation
---
src/main/java/org/apache/sysds/hops/AggUnaryOp.java | 9 +++++----
.../sysds/hops/fedplanner/AFederatedPlanner.java | 13 +++++++++++--
.../org/apache/sysds/hops/fedplanner/FTypes.java | 4 ++--
...lannerAllFed.java => FederatedPlannerFedAll.java} | 2 +-
...ristic.java => FederatedPlannerFedHeuristic.java} | 2 +-
src/main/java/org/apache/sysds/lops/MMTSJ.java | 2 +-
.../runtime/instructions/FEDInstructionParser.java | 20 +++++++++++++-------
7 files changed, 34 insertions(+), 18 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
index 1ad6ac7..923503b 100644
--- a/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggUnaryOp.java
@@ -38,7 +38,6 @@ import
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-
// Aggregate unary (cell) operation: Sum (aij), col_sum, row_sum
public class AggUnaryOp extends MultiThreadedHop
@@ -553,7 +552,8 @@ public class AggUnaryOp extends MultiThreadedHop
in2 = in1;
in3 = in1;
handled = true;
- } else if (input11 instanceof BinaryOp ) {
+ }
+ else if (HopRewriteUtils.isBinary(input11, OpOp2.MULT,
OpOp2.POW) ) {
BinaryOp b11 = (BinaryOp)input11;
switch( b11.getOp() ) {
case MULT: // A*B*C case
@@ -574,7 +574,8 @@ public class AggUnaryOp extends MultiThreadedHop
break;
default: break;
}
- } else if( input12 instanceof BinaryOp ) {
+ }
+ else if( HopRewriteUtils.isBinary(input12, OpOp2.MULT,
OpOp2.POW) ) {
BinaryOp b12 = (BinaryOp)input12;
switch (b12.getOp()) {
case MULT: // A*B*C case
@@ -668,7 +669,7 @@ public class AggUnaryOp extends MultiThreadedHop
if( !(that instanceof AggUnaryOp) )
return false;
- AggUnaryOp that2 = (AggUnaryOp)that;
+ AggUnaryOp that2 = (AggUnaryOp)that;
return ( _op == that2._op
&& _direction == that2._direction
&& _maxNumThreads == that2._maxNumThreads
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
index 50c5f46..97d4939 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/AFederatedPlanner.java
@@ -21,11 +21,13 @@ package org.apache.sysds.hops.fedplanner;
import java.util.Map;
+import org.apache.sysds.common.Types.AggOp;
import org.apache.sysds.common.Types.ReOrgOp;
import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.TernaryOp;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.hops.ipa.FunctionCallGraph;
import org.apache.sysds.hops.ipa.FunctionCallSizeInfo;
@@ -56,15 +58,20 @@ public abstract class AFederatedPlanner {
//handle specific operators
if( hop instanceof AggBinaryOp ) {
return (ft[0] != null && ft[1] == null)
- || (ft[0] == null && ft[1] != null);
+ || (ft[0] == null && ft[1] != null)
+ || (ft[0] == FType.COL && ft[1] == FType.ROW);
}
else if( hop instanceof BinaryOp &&
!hop.getDataType().isScalar() ) {
return (ft[0] != null && ft[1] == null)
|| (ft[0] == null && ft[1] != null)
|| (ft[0] != null && ft[0] == ft[1]);
}
+ else if( hop instanceof TernaryOp &&
!hop.getDataType().isScalar() ) {
+ return (ft[0] != null || ft[1] != null || ft[2] !=
null);
+ }
else if(ft.length==1 && ft[0] != null) {
- return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS);
+ return HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS)
+ || HopRewriteUtils.isAggUnaryOp(hop, AggOp.SUM,
AggOp.MIN, AggOp.MAX);
}
return false;
@@ -85,6 +92,8 @@ public abstract class AFederatedPlanner {
}
else if( hop instanceof BinaryOp )
return ft[0] != null ? ft[0] : ft[1];
+ else if( hop instanceof TernaryOp )
+ return ft[0] != null ? ft[0] : ft[1] != null ? ft[1] :
ft[2];
else if( HopRewriteUtils.isReorg(hop, ReOrgOp.TRANS) )
return ft[0] == FType.ROW ? FType.COL : FType.COL;
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
index 98de495..7efabc8 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -30,9 +30,9 @@ public class FTypes
public AFederatedPlanner getPlanner() {
switch( this ) {
case COMPILE_FED_ALL:
- return new FederatedPlannerAllFed();
+ return new FederatedPlannerFedAll();
case COMPILE_FED_HEURISTIC:
- return new FederatedPlannerHeuristic();
+ return new
FederatedPlannerFedHeuristic();
case COMPILE_COST_BASED:
return new FederatedPlannerCostbased();
case NONE:
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
similarity index 98%
rename from
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
rename to
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
index a35d94c..f11bbd1 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerAllFed.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedAll.java
@@ -48,7 +48,7 @@ import
org.apache.sysds.runtime.instructions.fed.FEDInstruction.FederatedOutput;
* that support federated execution on federated inputs to
* forced federated operations.
*/
-public class FederatedPlannerAllFed extends AFederatedPlanner {
+public class FederatedPlannerFedAll extends AFederatedPlanner {
@Override
public void rewriteProgram( DMLProgram prog,
diff --git
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
similarity index 94%
rename from
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
rename to
src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
index 15b12ac..4bc0b88 100644
---
a/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerHeuristic.java
+++
b/src/main/java/org/apache/sysds/hops/fedplanner/FederatedPlannerFedHeuristic.java
@@ -25,7 +25,7 @@ import org.apache.sysds.hops.AggBinaryOp;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
-public class FederatedPlannerHeuristic extends FederatedPlannerAllFed {
+public class FederatedPlannerFedHeuristic extends FederatedPlannerFedAll {
@Override
protected FType getFederatedOut(Hop hop, Map<Long, FType> fedHops) {
diff --git a/src/main/java/org/apache/sysds/lops/MMTSJ.java
b/src/main/java/org/apache/sysds/lops/MMTSJ.java
index 27876ed..45ad196 100644
--- a/src/main/java/org/apache/sysds/lops/MMTSJ.java
+++ b/src/main/java/org/apache/sysds/lops/MMTSJ.java
@@ -92,7 +92,7 @@ public class MMTSJ extends Lop
sb.append( _type );
//append degree of parallelism for matrix multiplications
- if( getExecType()==ExecType.CP ) {
+ if( getExecType()==ExecType.CP || getExecType()==ExecType.FED )
{
sb.append( OPERAND_DELIMITOR );
sb.append( _numThreads );
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
index 11ea4e0..4992426 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/FEDInstructionParser.java
@@ -43,7 +43,7 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "fedinit" , FEDType.Init );
String2FEDInstructionType.put( "tsmm" , FEDType.Tsmm );
String2FEDInstructionType.put( "ba+*" ,
FEDType.AggregateBinary );
- String2FEDInstructionType.put( "tak+*" ,
FEDType.AggregateTernary);
+ String2FEDInstructionType.put( "tak+*" ,
FEDType.AggregateTernary);
String2FEDInstructionType.put( "uak+" ,
FEDType.AggregateUnary );
String2FEDInstructionType.put( "uark+" ,
FEDType.AggregateUnary );
@@ -56,12 +56,18 @@ public class FEDInstructionParser extends InstructionParser
String2FEDInstructionType.put( "uacvar" ,
FEDType.AggregateUnary);
// Arithmetic Instruction Opcodes
- String2FEDInstructionType.put( "+" , FEDType.Binary );
- String2FEDInstructionType.put( "-" , FEDType.Binary );
- String2FEDInstructionType.put( "*" , FEDType.Binary );
- String2FEDInstructionType.put( "/" , FEDType.Binary );
- String2FEDInstructionType.put( "1-*" , FEDType.Binary);
//special * case
- String2FEDInstructionType.put( "max" , FEDType.Binary );
+ String2FEDInstructionType.put( "+" , FEDType.Binary );
+ String2FEDInstructionType.put( "-" , FEDType.Binary );
+ String2FEDInstructionType.put( "*" , FEDType.Binary );
+ String2FEDInstructionType.put( "/" , FEDType.Binary );
+ String2FEDInstructionType.put( "1-*", FEDType.Binary);
//special * case
+ String2FEDInstructionType.put( "max", FEDType.Binary );
+ String2FEDInstructionType.put( "==", FEDType.Binary);
+ String2FEDInstructionType.put( "!=", FEDType.Binary);
+ String2FEDInstructionType.put( "<", FEDType.Binary);
+ String2FEDInstructionType.put( ">", FEDType.Binary);
+ String2FEDInstructionType.put( "<=", FEDType.Binary);
+ String2FEDInstructionType.put( ">=", FEDType.Binary);
// Reorg Instruction Opcodes (repositioning of existing values)
String2FEDInstructionType.put( "r'" , FEDType.Reorg );