Repository: systemml Updated Branches: refs/heads/master b84a4933c -> 352c256a3
[SYSTEMML-1755] Fix simplification rewrite binary matrix-scalar ops This patch fixes the rewrite for simplifying matrix-scalar to scalar-scalar operations to correctly check for binary operations that are supported over scalars. Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/352c256a Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/352c256a Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/352c256a Branch: refs/heads/master Commit: 352c256a3d71bb587162120134f87e4a9a2df507 Parents: b84a493 Author: Matthias Boehm <mboe...@gmail.com> Authored: Sun Jul 9 00:32:47 2017 -0700 Committer: Matthias Boehm <mboe...@gmail.com> Committed: Sun Jul 9 00:32:47 2017 -0700 ---------------------------------------------------------------------- src/main/java/org/apache/sysml/hops/Hop.java | 92 ++++++++++---------- .../RewriteAlgebraicSimplificationStatic.java | 8 +- 2 files changed, 54 insertions(+), 46 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/Hop.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/Hop.java b/src/main/java/org/apache/sysml/hops/Hop.java index 8f8afde..80d33f1 100644 --- a/src/main/java/org/apache/sysml/hops/Hop.java +++ b/src/main/java/org/apache/sysml/hops/Hop.java @@ -28,6 +28,8 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysml.api.DMLScript; import org.apache.sysml.api.DMLScript.RUNTIME_PLATFORM; import org.apache.sysml.conf.ConfigurationManager; +import org.apache.sysml.lops.Binary; +import org.apache.sysml.lops.BinaryScalar; import org.apache.sysml.lops.CSVReBlock; import org.apache.sysml.lops.Checkpoint; import org.apache.sysml.lops.Compression; @@ -1143,53 +1145,53 @@ public abstract class Hop } - protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes> HopsOpOp2LopsB; + protected static final HashMap<Hop.OpOp2, Binary.OperationTypes> HopsOpOp2LopsB; static { - HopsOpOp2LopsB = new HashMap<Hop.OpOp2, org.apache.sysml.lops.Binary.OperationTypes>(); - HopsOpOp2LopsB.put(OpOp2.PLUS, org.apache.sysml.lops.Binary.OperationTypes.ADD); - HopsOpOp2LopsB.put(OpOp2.MINUS, org.apache.sysml.lops.Binary.OperationTypes.SUBTRACT); - HopsOpOp2LopsB.put(OpOp2.MULT, org.apache.sysml.lops.Binary.OperationTypes.MULTIPLY); - HopsOpOp2LopsB.put(OpOp2.DIV, org.apache.sysml.lops.Binary.OperationTypes.DIVIDE); - HopsOpOp2LopsB.put(OpOp2.MODULUS, org.apache.sysml.lops.Binary.OperationTypes.MODULUS); - HopsOpOp2LopsB.put(OpOp2.INTDIV, org.apache.sysml.lops.Binary.OperationTypes.INTDIV); - HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, org.apache.sysml.lops.Binary.OperationTypes.MINUS1_MULTIPLY); - HopsOpOp2LopsB.put(OpOp2.LESS, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN); - HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.Binary.OperationTypes.LESS_THAN_OR_EQUALS); - HopsOpOp2LopsB.put(OpOp2.GREATER, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN); - HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.Binary.OperationTypes.GREATER_THAN_OR_EQUALS); - HopsOpOp2LopsB.put(OpOp2.EQUAL, org.apache.sysml.lops.Binary.OperationTypes.EQUALS); - HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.Binary.OperationTypes.NOT_EQUALS); - HopsOpOp2LopsB.put(OpOp2.MIN, org.apache.sysml.lops.Binary.OperationTypes.MIN); - HopsOpOp2LopsB.put(OpOp2.MAX, org.apache.sysml.lops.Binary.OperationTypes.MAX); - HopsOpOp2LopsB.put(OpOp2.AND, org.apache.sysml.lops.Binary.OperationTypes.OR); - HopsOpOp2LopsB.put(OpOp2.OR, org.apache.sysml.lops.Binary.OperationTypes.AND); - HopsOpOp2LopsB.put(OpOp2.SOLVE, org.apache.sysml.lops.Binary.OperationTypes.SOLVE); - HopsOpOp2LopsB.put(OpOp2.POW, org.apache.sysml.lops.Binary.OperationTypes.POW); - HopsOpOp2LopsB.put(OpOp2.LOG, org.apache.sysml.lops.Binary.OperationTypes.NOTSUPPORTED); - } - - protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes> HopsOpOp2LopsBS; + HopsOpOp2LopsB = new HashMap<Hop.OpOp2, Binary.OperationTypes>(); + HopsOpOp2LopsB.put(OpOp2.PLUS, Binary.OperationTypes.ADD); + HopsOpOp2LopsB.put(OpOp2.MINUS, Binary.OperationTypes.SUBTRACT); + HopsOpOp2LopsB.put(OpOp2.MULT, Binary.OperationTypes.MULTIPLY); + HopsOpOp2LopsB.put(OpOp2.DIV, Binary.OperationTypes.DIVIDE); + HopsOpOp2LopsB.put(OpOp2.MODULUS, Binary.OperationTypes.MODULUS); + HopsOpOp2LopsB.put(OpOp2.INTDIV, Binary.OperationTypes.INTDIV); + HopsOpOp2LopsB.put(OpOp2.MINUS1_MULT, Binary.OperationTypes.MINUS1_MULTIPLY); + HopsOpOp2LopsB.put(OpOp2.LESS, Binary.OperationTypes.LESS_THAN); + HopsOpOp2LopsB.put(OpOp2.LESSEQUAL, Binary.OperationTypes.LESS_THAN_OR_EQUALS); + HopsOpOp2LopsB.put(OpOp2.GREATER, Binary.OperationTypes.GREATER_THAN); + HopsOpOp2LopsB.put(OpOp2.GREATEREQUAL, Binary.OperationTypes.GREATER_THAN_OR_EQUALS); + HopsOpOp2LopsB.put(OpOp2.EQUAL, Binary.OperationTypes.EQUALS); + HopsOpOp2LopsB.put(OpOp2.NOTEQUAL, Binary.OperationTypes.NOT_EQUALS); + HopsOpOp2LopsB.put(OpOp2.MIN, Binary.OperationTypes.MIN); + HopsOpOp2LopsB.put(OpOp2.MAX, Binary.OperationTypes.MAX); + HopsOpOp2LopsB.put(OpOp2.AND, Binary.OperationTypes.OR); + HopsOpOp2LopsB.put(OpOp2.OR, Binary.OperationTypes.AND); + HopsOpOp2LopsB.put(OpOp2.SOLVE, Binary.OperationTypes.SOLVE); + HopsOpOp2LopsB.put(OpOp2.POW, Binary.OperationTypes.POW); + HopsOpOp2LopsB.put(OpOp2.LOG, Binary.OperationTypes.NOTSUPPORTED); + } + + protected static final HashMap<Hop.OpOp2, BinaryScalar.OperationTypes> HopsOpOp2LopsBS; static { - HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, org.apache.sysml.lops.BinaryScalar.OperationTypes>(); - HopsOpOp2LopsBS.put(OpOp2.PLUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.ADD); - HopsOpOp2LopsBS.put(OpOp2.MINUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.SUBTRACT); - HopsOpOp2LopsBS.put(OpOp2.MULT, org.apache.sysml.lops.BinaryScalar.OperationTypes.MULTIPLY); - HopsOpOp2LopsBS.put(OpOp2.DIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.DIVIDE); - HopsOpOp2LopsBS.put(OpOp2.MODULUS, org.apache.sysml.lops.BinaryScalar.OperationTypes.MODULUS); - HopsOpOp2LopsBS.put(OpOp2.INTDIV, org.apache.sysml.lops.BinaryScalar.OperationTypes.INTDIV); - HopsOpOp2LopsBS.put(OpOp2.LESS, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN); - HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS); - HopsOpOp2LopsBS.put(OpOp2.GREATER, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN); - HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS); - HopsOpOp2LopsBS.put(OpOp2.EQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.EQUALS); - HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, org.apache.sysml.lops.BinaryScalar.OperationTypes.NOT_EQUALS); - HopsOpOp2LopsBS.put(OpOp2.MIN, org.apache.sysml.lops.BinaryScalar.OperationTypes.MIN); - HopsOpOp2LopsBS.put(OpOp2.MAX, org.apache.sysml.lops.BinaryScalar.OperationTypes.MAX); - HopsOpOp2LopsBS.put(OpOp2.AND, org.apache.sysml.lops.BinaryScalar.OperationTypes.AND); - HopsOpOp2LopsBS.put(OpOp2.OR, org.apache.sysml.lops.BinaryScalar.OperationTypes.OR); - HopsOpOp2LopsBS.put(OpOp2.LOG, org.apache.sysml.lops.BinaryScalar.OperationTypes.LOG); - HopsOpOp2LopsBS.put(OpOp2.POW, org.apache.sysml.lops.BinaryScalar.OperationTypes.POW); - HopsOpOp2LopsBS.put(OpOp2.PRINT, org.apache.sysml.lops.BinaryScalar.OperationTypes.PRINT); + HopsOpOp2LopsBS = new HashMap<Hop.OpOp2, BinaryScalar.OperationTypes>(); + HopsOpOp2LopsBS.put(OpOp2.PLUS, BinaryScalar.OperationTypes.ADD); + HopsOpOp2LopsBS.put(OpOp2.MINUS, BinaryScalar.OperationTypes.SUBTRACT); + HopsOpOp2LopsBS.put(OpOp2.MULT, BinaryScalar.OperationTypes.MULTIPLY); + HopsOpOp2LopsBS.put(OpOp2.DIV, BinaryScalar.OperationTypes.DIVIDE); + HopsOpOp2LopsBS.put(OpOp2.MODULUS, BinaryScalar.OperationTypes.MODULUS); + HopsOpOp2LopsBS.put(OpOp2.INTDIV, BinaryScalar.OperationTypes.INTDIV); + HopsOpOp2LopsBS.put(OpOp2.LESS, BinaryScalar.OperationTypes.LESS_THAN); + HopsOpOp2LopsBS.put(OpOp2.LESSEQUAL, BinaryScalar.OperationTypes.LESS_THAN_OR_EQUALS); + HopsOpOp2LopsBS.put(OpOp2.GREATER, BinaryScalar.OperationTypes.GREATER_THAN); + HopsOpOp2LopsBS.put(OpOp2.GREATEREQUAL, BinaryScalar.OperationTypes.GREATER_THAN_OR_EQUALS); + HopsOpOp2LopsBS.put(OpOp2.EQUAL, BinaryScalar.OperationTypes.EQUALS); + HopsOpOp2LopsBS.put(OpOp2.NOTEQUAL, BinaryScalar.OperationTypes.NOT_EQUALS); + HopsOpOp2LopsBS.put(OpOp2.MIN, BinaryScalar.OperationTypes.MIN); + HopsOpOp2LopsBS.put(OpOp2.MAX, BinaryScalar.OperationTypes.MAX); + HopsOpOp2LopsBS.put(OpOp2.AND, BinaryScalar.OperationTypes.AND); + HopsOpOp2LopsBS.put(OpOp2.OR, BinaryScalar.OperationTypes.OR); + HopsOpOp2LopsBS.put(OpOp2.LOG, BinaryScalar.OperationTypes.LOG); + HopsOpOp2LopsBS.put(OpOp2.POW, BinaryScalar.OperationTypes.POW); + HopsOpOp2LopsBS.put(OpOp2.PRINT, BinaryScalar.OperationTypes.PRINT); } protected static final HashMap<Hop.OpOp2, org.apache.sysml.lops.Unary.OperationTypes> HopsOpOp2LopsU; http://git-wip-us.apache.org/repos/asf/systemml/blob/352c256a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java index b8f9369..53359cc 100644 --- a/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java +++ b/src/main/java/org/apache/sysml/hops/rewrite/RewriteAlgebraicSimplificationStatic.java @@ -846,8 +846,14 @@ public class RewriteAlgebraicSimplificationStatic extends HopRewriteRule private Hop simplifyBinaryMatrixScalarOperation( Hop parent, Hop hi, int pos ) throws HopsException { + // Note: This rewrite is not applicable for all binary operations because some of them + // are undefined over scalars. We explicitly exclude potential conflicting matrix-scalar binary + // operations; other operations like cbind/rbind will never occur as matrix-scalar operations. + if( HopRewriteUtils.isUnary(hi, OpOp1.CAST_AS_SCALAR) - && hi.getInput().get(0) instanceof BinaryOp ) + && hi.getInput().get(0) instanceof BinaryOp + && !HopRewriteUtils.isBinary(hi.getInput().get(0), OpOp2.QUANTILE, + OpOp2.CENTRALMOMENT, OpOp2.MINUS1_MULT, OpOp2.MINUS_NZ, OpOp2.LOG_NZ)) { BinaryOp bin = (BinaryOp) hi.getInput().get(0); BinaryOp bout = null;