This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new b0586d4a00 [SYSTEMDS-3029] Fix memory estimates of row-wise codegen 
operators
b0586d4a00 is described below

commit b0586d4a003accbe6173b845b5b46a16d24010c4
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Jul 26 18:34:16 2023 +0200

    [SYSTEMDS-3029] Fix memory estimates of row-wise codegen operators
    
    This patch adds the missing memory estimates of intermediates in
    row-wise codegen operators, which allocate a number of vectors per
    thread for row intermediates. The decision on local vs distributed,
    and local CPU vs GPU must take this into account.
---
 .../org/apache/sysds/hops/codegen/SpoofFusedOp.java     | 10 ++++++++++
 .../org/apache/sysds/runtime/codegen/SpoofRowwise.java  | 17 ++++++++++++++++-
 2 files changed, 26 insertions(+), 1 deletion(-)

diff --git a/src/main/java/org/apache/sysds/hops/codegen/SpoofFusedOp.java 
b/src/main/java/org/apache/sysds/hops/codegen/SpoofFusedOp.java
index a6dcb199af..d5bb1f0e72 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/SpoofFusedOp.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/SpoofFusedOp.java
@@ -30,11 +30,14 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.SpoofFused;
+import org.apache.sysds.runtime.codegen.CodegenUtils;
+import org.apache.sysds.runtime.codegen.SpoofOperator;
 import org.apache.sysds.runtime.codegen.SpoofRowwise;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Objects;
 
 public class SpoofFusedOp extends MultiThreadedHop
@@ -119,6 +122,13 @@ public class SpoofFusedOp extends MultiThreadedHop
 
        @Override
        protected double computeIntermediateMemEstimate(long dim1, long dim2, 
long nnz) {
+               if( _class.getGenericSuperclass().equals(SpoofRowwise.class) ) {
+                       long[] cols = new long[getInput().size()];
+                       Arrays.setAll(cols, i -> getInput(i).getDim2());
+                       SpoofOperator op = CodegenUtils.createInstance(_class);
+                       int k = 
OptimizerUtils.getConstrainedNumThreads(_maxNumThreads);
+                       return ((SpoofRowwise)op).getTmpMemoryReq(k, cols[0], 
cols);
+               }
                return 0;
        }
        
diff --git a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java 
b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
index 02cb4f9dff..0b316959d8 100644
--- a/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
+++ b/src/main/java/org/apache/sysds/runtime/codegen/SpoofRowwise.java
@@ -42,6 +42,7 @@ import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.utils.MemoryEstimates;
 
 
 public abstract class SpoofRowwise extends SpoofOperator
@@ -126,13 +127,27 @@ public abstract class SpoofRowwise extends SpoofOperator
        public int getNumIntermediates() {
                return _reqVectMem;
        }
+       
+       public long getTmpMemoryReq(int k, long cols, long... cols2) {
+               boolean hasMatrixSideInputs = IntStream.range(1, cols2.length)
+                       .mapToLong(i -> cols2[i]).anyMatch(n -> n > 1);
+               long minCols = IntStream.range(1, cols2.length)
+                       .mapToLong(i -> cols2[i]).filter(c -> c > 
1).min().orElse(1);
+               long n = cols;
+               long n2 = _type.isConstDim2(_constDim2) ? (int)_constDim2 : 
+                       _type.isRowTypeB1() || hasMatrixSideInputs ? minCols : 
-1;
+               return (long)(k * _reqVectMem * ((n2>0 && n!=n2) ?
+                       (MemoryEstimates.doubleArrayCost(n) + 
MemoryEstimates.doubleArrayCost(n2)) :
+                       MemoryEstimates.doubleArrayCost(n)));
+       }
 
        @Override
        public String getSpoofType() {
                return "RA" +  getClass().getName().split("\\.")[1];
        }
        
-       @Override public SpoofCUDAOperator createCUDAInstrcution(Integer opID, 
SpoofCUDAOperator.PrecisionProxy ep) {
+       @Override
+       public SpoofCUDAOperator createCUDAInstrcution(Integer opID, 
SpoofCUDAOperator.PrecisionProxy ep) {
                return new SpoofCUDARowwise(_type, _constDim2, _tB1, 
_reqVectMem, opID, ep);
        }
        

Reply via email to