This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f04b2fd21 [SYSTEMDS-3679] Multi-threaded contains-value operations
3f04b2fd21 is described below

commit 3f04b2fd21f33fd5281bae71bc5ca76298ec5169
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 23 20:01:46 2024 +0100

    [SYSTEMDS-3679] Multi-threaded contains-value operations
    
    This patch extends the compilation and runtime of contains-value
    parameterized builtin operations for multi-threading because it is
    called in a number of algorithms and primitives to ensure valid
    input data. For an 8GB dense input matrix (100 repetitions, tested
    on two-socket Xeon Gold 6338 w/ 128 vcore), this patch improved
    performance from 2.027s to 0.052s (which is > 150GB/s).
---
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |  5 ++++-
 .../apache/sysds/hops/ipa/FunctionCallGraph.java   |  2 +-
 .../apache/sysds/hops/rewrite/HopRewriteUtils.java |  8 +++----
 .../RewriteSplitDagDataDependentOperators.java     |  2 +-
 .../apache/sysds/lops/ParameterizedBuiltin.java    |  9 ++++++--
 .../org/apache/sysds/runtime/data/DenseBlock.java  | 13 ++++++-----
 .../org/apache/sysds/runtime/data/SparseBlock.java |  7 +++---
 .../apache/sysds/runtime/data/SparseBlockCSR.java  |  6 +++---
 .../apache/sysds/runtime/data/SparseBlockDCSR.java |  6 +++---
 .../cp/ParameterizedBuiltinCPInstruction.java      |  3 ++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 25 +++++++++++++++++++---
 .../apache/sysds/runtime/util/UtilFunctions.java   | 21 ++++++++++++++++++
 12 files changed, 80 insertions(+), 27 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index a2bd1f188a..3fa9afed3d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -157,7 +157,8 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
        @Override
        public boolean isMultiThreadedOpType() {
                return HopRewriteUtils.isValidOp(_op, 
-                       ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND, 
ParamBuiltinOp.PARAMSERV);
+                       ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND,
+                       ParamBuiltinOp.PARAMSERV, ParamBuiltinOp.CONTAINS);
        }
        
        @Override
@@ -203,6 +204,8 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        case AUTODIFF:{
                                ParameterizedBuiltin pbilop = new 
ParameterizedBuiltin(
                                        inputlops, _op, getDataType(), 
getValueType(), et);
+                               if( isMultiThreadedOpType() )
+                                       
pbilop.setNumThreads(OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
                                setOutputDimensions(pbilop);
                                setLineNumbers(pbilop);
                                setLops(pbilop);
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java 
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index feeafe83e1..177e13bb6c 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -438,7 +438,7 @@ public class FunctionCallGraph
                for( Hop h : hop.getInput() )
                        rConstructFunctionCallGraph(h, fkey, sb, fstack, lfset);
                
-               if( HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV)
+               if( HopRewriteUtils.isParameterizedBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV)
                        && HopRewriteUtils.knownParamservFunctions(hop, 
sb.getDMLProg()))
                {
                        ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp) 
hop;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java 
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 61bd0921ce..144b331327 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1204,18 +1204,18 @@ public class HopRewriteUtils {
                return (hop instanceof AggUnaryOp && 
((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ);
        }
 
-       public static boolean isParameterBuiltinOp(Hop hop, ParamBuiltinOp 
type) {
+       public static boolean isParameterizedBuiltinOp(Hop hop, ParamBuiltinOp 
type) {
                return hop instanceof ParameterizedBuiltinOp && 
((ParameterizedBuiltinOp) hop).getOp().equals(type);
        }
        
        public static boolean isRemoveEmpty(Hop hop, boolean rows) {
-               return isParameterBuiltinOp(hop, ParamBuiltinOp.RMEMPTY)
+               return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY)
                        && HopRewriteUtils.isLiteralOfValue(
                                
((ParameterizedBuiltinOp)hop).getParameterHop("margin"), rows?"rows":"cols");
        }
 
        public static boolean isRemoveEmpty(Hop hop) {
-               return isParameterBuiltinOp(hop, ParamBuiltinOp.RMEMPTY);
+               return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY);
        }
        
        public static boolean isNary(Hop hop, OpOpN type) {
@@ -1660,7 +1660,7 @@ public class HopRewriteUtils {
                if( hop.isVisited() ) return false;
                hop.setVisited();
                return HopRewriteUtils.isNary(hop, OpOpN.EVAL)
-                       || (HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV) 
+                       || (HopRewriteUtils.isParameterizedBuiltinOp(hop, 
ParamBuiltinOp.PARAMSERV) 
                                && !knownParamservFunctions(hop))
                        || hop.getInput().stream().anyMatch(c -> 
containsSecondOrderBuiltin(c));
        }
diff --git 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index 0c2801da97..3eaed10792 100644
--- 
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++ 
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -319,7 +319,7 @@ public class RewriteSplitDagDataDependentOperators extends 
StatementBlockRewrite
        private static boolean isBasicDataDependentOperator(Hop hop, boolean 
noSplitRequired) {
                return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) & 
!noSplitRequired)
                        || (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) & 
!noSplitRequired)
-                       || (HopRewriteUtils.isParameterBuiltinOp(hop, 
ParamBuiltinOp.GROUPEDAGG) 
+                       || (HopRewriteUtils.isParameterizedBuiltinOp(hop, 
ParamBuiltinOp.GROUPEDAGG) 
                                && 
!((ParameterizedBuiltinOp)hop).isKnownNGroups() && !noSplitRequired)
                        || ((HopRewriteUtils.isUnary(hop, OpOp1.COMPRESS) || 
hop.requiresCompression()) &&
                                (!HopRewriteUtils.hasOnlyWriteParents(hop, 
true, true)))
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java 
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index fc9e60419c..62f073a06a 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -72,6 +72,10 @@ public class ParameterizedBuiltin extends Lop
                return _operation; 
        }
        
+       public void setNumThreads(int k) {
+               _numThreads = k;
+       }
+       
        public int getInputIndex(String name) { 
                Lop n = _inputParams.get(name);
                for(int i=0; i<getInputs().size(); i++) 
@@ -211,10 +215,11 @@ public class ParameterizedBuiltin extends Lop
                        sb.append(OPERAND_DELIMITOR);
                }
                
-               if( getExecType()==ExecType.CP && _operation == 
ParamBuiltinOp.REXPAND ) {
+               if( getExecType()==ExecType.CP 
+                       && (_operation==ParamBuiltinOp.REXPAND || 
_operation==ParamBuiltinOp.CONTAINS ) ) {
                        sb.append( "k" );
                        sb.append( Lop.NAME_VALUE_SEPARATOR );
-                       sb.append( _numThreads );       
+                       sb.append( _numThreads );
                        sb.append(OPERAND_DELIMITOR);
                }
                
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
index 037231fa0e..0a30d79250 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
@@ -665,14 +665,17 @@ public abstract class DenseBlock implements Serializable, 
Block
         * (note that NaN==NaN yields false).
         * 
         * @param pattern checked pattern
+        * @param rl row lower bound (inclusive)
+        * @param ru row upper bound (exclusive)
         * @return true if pattern appears at least once, otherwise false
         */
-       public boolean contains(double pattern) {
+       public boolean contains(double pattern, int rl, int ru) {
                boolean NaNpattern = Double.isNaN(pattern);
-               for(int i=0; i<numBlocks(); i++) {
-                       double[] vals = valuesAt(i);
-                       int len = size(i);
-                       for(int j=0; j<len; j++)
+               int clen = _odims[0];
+               for(int i=rl; i<ru; i++) {
+                       double[] vals = values(i);
+                       int pos = pos(i);
+                       for(int j=pos; j<pos+clen; j++)
                                if(vals[j]==pattern || (NaNpattern && 
Double.isNaN(vals[j])))
                                        return true;
                }
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index b19d132503..bd3468531d 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -474,12 +474,13 @@ public abstract class SparseBlock implements 
Serializable, Block
         * (note that NaN==NaN yields false).
         * 
         * @param pattern checked pattern
+        * @param rl row lower bound (inclusive)
+        * @param ru row upper bound (exclusive)
         * @return true if pattern appears at least once, otherwise false
         */
-       public boolean contains(double pattern) {
+       public boolean contains(double pattern, int rl, int ru) {
                boolean NaNpattern = Double.isNaN(pattern);
-               int rlen = numRows();
-               for(int i=0; i<rlen; i++) {
+               for(int i=rl; i<ru; i++) {
                        if( isEmpty(i) ) continue;
                        int apos = pos(i);
                        int alen = size(i);
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index 0aff4e08bb..ed00f15564 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -980,11 +980,11 @@ public class SparseBlockCSR extends SparseBlock
        }
        
        @Override //specialized for CSR
-       public boolean contains(double pattern) {
+       public boolean contains(double pattern, int rl, int ru) {
                boolean NaNpattern = Double.isNaN(pattern);
                double[] vals = _values;
-               int len = _size;
-               for(int i=0; i<len; i++)
+               int prl = pos(rl), pru = pos(ru);
+               for(int i=prl; i<pru; i++)
                        if(vals[i]==pattern || (NaNpattern && 
Double.isNaN(vals[i])))
                                return true;
                return false;
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 447781a520..c5d4717e11 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -765,11 +765,11 @@ public class SparseBlockDCSR extends SparseBlock
        }
 
        @Override //specialized for CSR
-       public boolean contains(double pattern) {
+       public boolean contains(double pattern, int rl, int ru) {
                boolean NaNpattern = Double.isNaN(pattern);
                double[] vals = _values;
-               int len = _size;
-               for(int i=0; i<len; i++)
+               int prl = pos(rl), pru = pos(ru);
+               for(int i=prl; i<pru; i++)
                        if(vals[i]==pattern || (NaNpattern && 
Double.isNaN(vals[i])))
                                return true;
                return false;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index d0aea7bce9..6bfaf1f96f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -249,12 +249,13 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                }
                else if(opcode.equalsIgnoreCase("contains")) {
                        String varName = params.get("target");
+                       int k = Integer.parseInt(params.get("k")); //num threads
                        MatrixBlock target = ec.getMatrixInput(varName);
                        Data pattern = ec.getVariable(params.get("pattern"));
                        if( pattern == null ) //literal
                                pattern = 
ScalarObjectFactory.createScalarObject(ValueType.FP64, params.get("pattern"));
                        boolean ret = pattern.getDataType().isScalar() ?
-                               
target.containsValue(((ScalarObject)pattern).getDoubleValue()) : 
+                               
target.containsValue(((ScalarObject)pattern).getDoubleValue(), k) : 
                                
(target.containsVector(((MatrixObject)pattern).acquireRead(), true).size()>0);
                        ec.releaseMatrixInput(varName);
                        if(!pattern.getDataType().isScalar())
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 701abd1c20..ea78af5759 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,6 +35,7 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
+import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
 import org.apache.commons.lang3.ArrayUtils;
@@ -766,6 +767,10 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        }
        
        public boolean containsValue(double pattern) {
+               return containsValue(pattern, 1);
+       }
+       
+       public boolean containsValue(double pattern, int k) {
                //fast paths: infer from meta data only
                if(isEmptyBlock(true))
                        return pattern==0;
@@ -774,9 +779,23 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                
                //make a pass over the data to determine if it includes the
                //pattern, with early abort as soon as the pattern is found
+               if( k == 1 ) {
+                       return containsValue(pattern, 0, rlen);
+               }
+               else {
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       List<Future<Boolean>> tasks = 
UtilFunctions.getTaskRangesDefault(rlen, k).stream()
+                               .map(p -> pool.submit(() -> 
containsValue(pattern, p.getKey(), p.getValue())))
+                               .collect(Collectors.toList()); //submit all 
before waiting
+                       pool.shutdown();
+                       return tasks.stream().anyMatch(t -> 
UtilFunctions.getSafe(t));
+               }
+       }
+       
+       private boolean containsValue(double pattern, int rl, int ru) {
                return isInSparseFormat() ?
-                       getSparseBlock().contains(pattern) :
-                       getDenseBlock().contains(pattern);
+                       getSparseBlock().contains(pattern, rl, ru) :
+                       getDenseBlock().contains(pattern, rl, ru);
        }
        
        public List<Integer> containsVector(MatrixBlock pattern, boolean 
earlyAbort) {
@@ -3040,7 +3059,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                
                //early abort for comparisons w/ special values
                if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN, 
BuiltinCode.ISNA))
-                       if( !containsValue(op.getPattern()) ) {
+                       if( !containsValue(op.getPattern(), op.getNumThreads()) 
) {
                                return new MatrixBlock(rlen, clen, true); 
//avoid unnecessary allocation
                        }
                
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java 
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index b46792da02..619dd4467e 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -33,6 +33,8 @@ import java.util.Map;
 import java.util.Random;
 import java.util.Set;
 import java.util.TimeZone;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
 import java.util.regex.Matcher;
 import java.util.regex.Pattern;
 
@@ -247,6 +249,17 @@ public class UtilFunctions {
                return pos;
        }
 
+       public static List<Pair<Integer,Integer>> getTaskRangesDefault(int len, 
int k) {
+               List<Pair<Integer,Integer>> ret = new ArrayList<>();
+               int nk = roundToNext(Math.min(8*k,len/32), k);
+               int beg = 0;
+               for(Integer blen : getBalancedBlockSizes(len, nk)) {
+                       ret.add(new Pair<>(beg, beg+blen)); 
+                       beg = beg+blen;
+               }
+               return ret;
+       }
+       
        public static ArrayList<Integer> getBalancedBlockSizesDefault(int len, 
int k, boolean constK) {
                int nk = constK ? k : roundToNext(Math.min(8*k,len/32), k);
                return getBalancedBlockSizes(len, nk);
@@ -1327,6 +1340,14 @@ public class UtilFunctions {
                        result[i] = String.valueOf(original[i]);
                return result;
        }
+       
+       public static <T> T getSafe(Future<T> task) {
+               try {
+                       return task.get();
+               } catch (InterruptedException | ExecutionException e) {
+                       throw new DMLRuntimeException(e);
+               }
+       }
 
        public static double[] convertStringToDoubleArray(String[] original) {
 //             double[] ret = new double[original.length];

Reply via email to