This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new be4f940  [SYSTEMDS-2836] Extended update in-place for unary operators
be4f940 is described below

commit be4f9404a62b291e997ee5205db395d6ff1b2ae7
Author: Ismael Ibrahim <[email protected]>
AuthorDate: Sun Oct 31 20:08:34 2021 +0100

    [SYSTEMDS-2836] Extended update in-place for unary operators
    
    AMLS project SS2021.
    Closes #1406.
    
    Co-authored-by: Maximilian Theiner <[email protected]>
    Co-authored-by: Alexander Kropiunig <[email protected]>
    Co-authored-by: Matthias Boehm <[email protected]>
---
 src/main/java/org/apache/sysds/hops/Hop.java       |  2 +-
 .../java/org/apache/sysds/hops/OptimizerUtils.java |  8 +++
 src/main/java/org/apache/sysds/hops/UnaryOp.java   | 17 ++++-
 src/main/java/org/apache/sysds/lops/Unary.java     |  5 ++
 .../instructions/cp/UnaryCPInstruction.java        |  8 +--
 .../sysds/runtime/matrix/data/LibMatrixAgg.java    | 10 +--
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 19 +++--
 .../updateinplace/UnaryUpdateInPlaceTest.java      | 80 ++++++++++++++++++++++
 .../functions/updateinplace/UnaryUpdateInplace.dml | 36 ++++++++++
 9 files changed, 168 insertions(+), 17 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/Hop.java 
b/src/main/java/org/apache/sysds/hops/Hop.java
index a25cf10..9114c55 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -65,7 +65,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
 
 public abstract class Hop implements ParseInfo {
        private static final Log LOG =  LogFactory.getLog(Hop.class.getName());
-       
+
        public static final long CPThreshold = 2000;
 
        // static variable to assign an unique ID to every hop that is created
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java 
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index be916d9..1b94413 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -191,6 +191,14 @@ public class OptimizerUtils
        public static boolean ALLOW_LOOP_UPDATE_IN_PLACE = true;
        
        /**
+        * Enables the update-in-place for all unary operators with a single
+        * consumer. In this case we do not allocate the output, but directly
+        * write the output values back to the input block.
+        */
+       //TODO enabling it by default requires modifications in lineage-based 
reuse
+       public static boolean ALLOW_UNARY_UPDATE_IN_PLACE = false;
+       
+       /**
         * Replace eval second-order function calls with normal function call
         * if the function name is a known string (after constant propagation).
         */
diff --git a/src/main/java/org/apache/sysds/hops/UnaryOp.java 
b/src/main/java/org/apache/sysds/hops/UnaryOp.java
index 38199b2..d4e8f34 100644
--- a/src/main/java/org/apache/sysds/hops/UnaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/UnaryOp.java
@@ -43,6 +43,7 @@ import org.apache.sysds.runtime.util.UtilFunctions;
 import java.util.ArrayList;
 
 
+
 /* Unary (cell operations): e.g, b_ij = round(a_ij)
  *             Semantic: given a value, perform the operation (independent of 
other values)
  */
@@ -57,7 +58,7 @@ public class UnaryOp extends MultiThreadedHop
        private UnaryOp() {
                //default constructor for clone
        }
-       
+
        public UnaryOp(String l, DataType dt, ValueType vt, OpOp1 o, Hop inp) {
                super(l, dt, vt);
 
@@ -130,7 +131,7 @@ public class UnaryOp extends MultiThreadedHop
                try 
                {
                        Hop input = getInput().get(0);
-                       
+
                        if(    getDataType() == DataType.SCALAR //value type 
casts or matrix to scalar
                                || (_op == OpOp1.CAST_AS_MATRIX && 
getInput().get(0).getDataType()==DataType.SCALAR)
                                || (_op == OpOp1.CAST_AS_FRAME && 
getInput().get(0).getDataType()==DataType.SCALAR))
@@ -165,10 +166,20 @@ public class UnaryOp extends MultiThreadedHop
                                }
                                else //default unary 
                                {
+                                       boolean inplace = false;
+
+                                       //check in-place
+                                       if 
(OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE
+                                               && input.getParent().size() == 
1)
+                                       {
+                                               inplace = !(input instanceof 
DataOp)
+                                                       || !((DataOp) 
input).isRead();
+                                       }
+
                                        int k = isCumulativeUnaryOperation() || 
isExpensiveUnaryOperation() ?
                                                
OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ) : 1;
                                        Unary unary1 = new 
Unary(input.constructLops(),
-                                               _op, getDataType(), 
getValueType(), et, k, false);
+                                               _op, getDataType(), 
getValueType(), et, k, inplace);
                                        setOutputDimensions(unary1);
                                        setLineNumbers(unary1);
                                        setLops(unary1);
diff --git a/src/main/java/org/apache/sysds/lops/Unary.java 
b/src/main/java/org/apache/sysds/lops/Unary.java
index ad4b2b8..f0a59fa 100644
--- a/src/main/java/org/apache/sysds/lops/Unary.java
+++ b/src/main/java/org/apache/sysds/lops/Unary.java
@@ -122,6 +122,7 @@ public class Unary extends Lop
        }
        
        public static boolean isMultiThreadedOp(OpOp1 op) {
+               //TODO extend for all basic unary operations
                return op==OpOp1.CUMSUM
                        || op==OpOp1.CUMPROD
                        || op==OpOp1.CUMMIN
@@ -129,6 +130,10 @@ public class Unary extends Lop
                        || op==OpOp1.CUMSUMPROD
                        || op==OpOp1.EXP
                        || op==OpOp1.LOG
+                       || op==OpOp1.ABS
+                       || op==OpOp1.ROUND
+                       || op==OpOp1.FLOOR
+                       || op==OpOp1.CEIL
                        || op==OpOp1.SIGMOID
                        || op==OpOp1.POW2
                        || op==OpOp1.MULT2;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
index 0c98e84..8f92c07 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/UnaryCPInstruction.java
@@ -19,10 +19,10 @@
 
 package org.apache.sysds.runtime.instructions.cp;
 
-import java.util.Arrays;
-
+import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.lops.Unary;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
@@ -61,8 +61,8 @@ public abstract class UnaryCPInstruction extends 
ComputationCPInstruction {
                        in.split(parts[1]);
                        out.split(parts[2]);
                        func = Builtin.getBuiltinFnObject(opcode);
-                       
-                       if( Arrays.asList(new 
String[]{"ucumk+","ucum*","ucumk+*","ucummin","ucummax","exp","log","sigmoid"}).contains(opcode)
 ){
+                       Types.OpOp1 op_type = 
Types.OpOp1.valueOfByOpcode(opcode);
+                       if( Unary.isMultiThreadedOp(op_type)){
                                UnaryOperator op = new UnaryOperator(func, 
Integer.parseInt(parts[3]),Boolean.parseBoolean(parts[4]));
                                return new UnaryMatrixCPInstruction(op, in, 
out, opcode, str);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 22c437d..0d3c007 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -294,12 +294,14 @@ public class LibMatrixAgg
        }
        
        public static MatrixBlock cumaggregateUnaryMatrix(MatrixBlock in, 
MatrixBlock out, UnaryOperator uop, double[] agg) {
-               //prepare meta data 
+               //Check this implementation, standard case for cumagg (single 
threaded)
+
+               //prepare meta data
                AggType aggtype = getAggType(uop);
                final int m = in.rlen;
                final int m2 = out.rlen;
                final int n2 = out.clen;
-               
+
                //filter empty input blocks (incl special handling for 
sparse-unsafe operations)
                if( in.isEmpty() && (agg == null || aggtype == 
AggType.CUM_SUM_PROD) ) {
                        return aggregateUnaryMatrixEmpty(in, out, aggtype, 
null);
@@ -317,7 +319,7 @@ public class LibMatrixAgg
                }
                
                //Timing time = new Timing(true);
-               
+
                if( !in.sparse )
                        cumaggregateUnaryMatrixDense(in, out, aggtype, uop.fn, 
agg, 0, m);
                else
@@ -336,7 +338,7 @@ public class LibMatrixAgg
                AggregateUnaryOperator uaop = 
InstructionUtils.parseBasicCumulativeAggregateUnaryOperator(uop);
                
                //fall back to sequential if necessary or agg not supported
-               if(    k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 
|| in.rlen <= k
+               if( k <= 1 || (long)in.rlen*in.clen < PAR_NUMCELL_THRESHOLD1 || 
in.rlen <= k
                        || out.clen*8*k > PAR_INTERMEDIATE_SIZE_THRESHOLD || 
uaop == null || !out.isThreadSafe()) {
                        return cumaggregateUnaryMatrix(in, out, uop);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 0538dd7..2f521b5 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -2755,6 +2755,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                return ret;
        }
 
+
        @Override
        public MatrixBlock unaryOperations(UnaryOperator op, MatrixValue 
result) {
                MatrixBlock ret = checkType(result);
@@ -2769,7 +2770,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        ret = new MatrixBlock(rlen, n, sp, sp ? nonZeros : 
rlen*n);
                else
                        ret.reset(rlen, n, sp);
-               
+
                //early abort for comparisons w/ special values
                if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN, 
BuiltinCode.ISNA))
                        if( !containsValue(op.getPattern()) )
@@ -2788,7 +2789,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        //note: we apply multi-threading in a best-effort 
manner here
                        //only for expensive operators such as exp, log, 
sigmoid, because
                        //otherwise allocation, read and write anyway dominates
-                       ret.allocateDenseBlock(false);
+                       if (!op.isInplace() || isEmpty())
+                               ret.allocateDenseBlock(false);
+                       else
+                               ret = this;
+
                        DenseBlock a = getDenseBlock();
                        DenseBlock c = ret.getDenseBlock();
                        for(int bi=0; bi<a.numBlocks(); bi++) {
@@ -2797,7 +2802,11 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        }
                        ret.recomputeNonZeros();
                }
-               else {
+               else
+               {
+                       if (op.isInplace() && !isInSparseFormat() )
+                               ret = this;
+                       
                        //default execute unary operations
                        if(op.sparseSafe)
                                sparseUnaryOperations(op, ret);
@@ -2870,8 +2879,8 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                }
                else //DENSE <- DENSE
                {
-                       //allocate dense output block
-                       ret.allocateDenseBlock(false);
+                       if( this != ret ) //!in-place
+                               ret.allocateDenseBlock(false);
                        DenseBlock da = getDenseBlock();
                        DenseBlock dc = ret.getDenseBlock();
                        
diff --git 
a/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
 
b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
new file mode 100644
index 0000000..0c7c133
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/updateinplace/UnaryUpdateInPlaceTest.java
@@ -0,0 +1,80 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.updateinplace;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.builtin.BuiltinSplitTest;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+
+public class UnaryUpdateInPlaceTest extends AutomatedTestBase{
+       private final static String TEST_NAME = "UnaryUpdateInplace";
+       private final static String TEST_DIR = "functions/updateinplace/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinSplitTest.class.getSimpleName() + "/";
+       private final static double eps = 1e-3;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new String[]{"B",}));
+       }
+
+       @Test
+       public void testInPlace() {
+               runInPlaceTest(Types.ExecType.CP);
+       }
+
+
+       private void runInPlaceTest(Types.ExecType instType) {
+               Types.ExecMode platformOld = setExecMode(instType);
+               boolean oldFlag = OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE;
+               
+               try {
+                       loadTestConfiguration(getTestConfiguration(TEST_NAME));
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-explain","-nvargs","Out=" 
+ output("Out") };
+
+                       OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = true;
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> dmlfileOut1 = 
readDMLMatrixFromOutputDir("Out");
+                       OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = false;
+                       runTest(true, false, null, -1);
+                       HashMap<MatrixValue.CellIndex, Double> dmlfileOut2 = 
readDMLMatrixFromOutputDir("Out");
+
+                       //compare matrices
+                       
TestUtils.compareMatrices(dmlfileOut1,dmlfileOut2,eps,"Stat-DML1","Stat-DML2");
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+               }
+               finally {
+                       rtplatform = platformOld;
+                       OptimizerUtils.ALLOW_UNARY_UPDATE_IN_PLACE = oldFlag;
+               }
+       }
+}
diff --git a/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml 
b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
new file mode 100644
index 0000000..957ffc5
--- /dev/null
+++ b/src/test/scripts/functions/updateinplace/UnaryUpdateInplace.dml
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+#A = rand(rows = 100, cols = 100)
+#C = rand(rows = 100, cols = 100)
+
+A = matrix(1, 10, 10);
+C = matrix(1, 10, 10);
+while(FALSE){}
+A = A * seq(1.1,10.1);
+while(FALSE){}
+B = round(A) # does not apply
+C = C * seq(1.1,10.1);
+D = log(C) # applies
+while(FALSE){}
+C = A + B + D*3
+Out = C
+write(Out, $Out);
+print(as.scalar(C[2, 1]))

Reply via email to