This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3f04b2fd21 [SYSTEMDS-3679] Multi-threaded contains-value operations
3f04b2fd21 is described below
commit 3f04b2fd21f33fd5281bae71bc5ca76298ec5169
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Mar 23 20:01:46 2024 +0100
[SYSTEMDS-3679] Multi-threaded contains-value operations
This patch extends the compilation and runtime of contains-value
parameterized builtin operations for multi-threading because it is
called in a number of algorithms and primitives to ensure valid
input data. For an 8GB dense input matrix (100 repetitions, tested
on two-socket Xeon Gold 6338 w/ 128 vcore), this patch improved
performance from 2.027s to 0.052s (which is > 150GB/s).
---
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 5 ++++-
.../apache/sysds/hops/ipa/FunctionCallGraph.java | 2 +-
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 8 +++----
.../RewriteSplitDagDataDependentOperators.java | 2 +-
.../apache/sysds/lops/ParameterizedBuiltin.java | 9 ++++++--
.../org/apache/sysds/runtime/data/DenseBlock.java | 13 ++++++-----
.../org/apache/sysds/runtime/data/SparseBlock.java | 7 +++---
.../apache/sysds/runtime/data/SparseBlockCSR.java | 6 +++---
.../apache/sysds/runtime/data/SparseBlockDCSR.java | 6 +++---
.../cp/ParameterizedBuiltinCPInstruction.java | 3 ++-
.../sysds/runtime/matrix/data/MatrixBlock.java | 25 +++++++++++++++++++---
.../apache/sysds/runtime/util/UtilFunctions.java | 21 ++++++++++++++++++
12 files changed, 80 insertions(+), 27 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index a2bd1f188a..3fa9afed3d 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -157,7 +157,8 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
@Override
public boolean isMultiThreadedOpType() {
return HopRewriteUtils.isValidOp(_op,
- ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND,
ParamBuiltinOp.PARAMSERV);
+ ParamBuiltinOp.GROUPEDAGG, ParamBuiltinOp.REXPAND,
+ ParamBuiltinOp.PARAMSERV, ParamBuiltinOp.CONTAINS);
}
@Override
@@ -203,6 +204,8 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
case AUTODIFF:{
ParameterizedBuiltin pbilop = new
ParameterizedBuiltin(
inputlops, _op, getDataType(),
getValueType(), et);
+ if( isMultiThreadedOpType() )
+
pbilop.setNumThreads(OptimizerUtils.getConstrainedNumThreads(_maxNumThreads));
setOutputDimensions(pbilop);
setLineNumbers(pbilop);
setLops(pbilop);
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index feeafe83e1..177e13bb6c 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -438,7 +438,7 @@ public class FunctionCallGraph
for( Hop h : hop.getInput() )
rConstructFunctionCallGraph(h, fkey, sb, fstack, lfset);
- if( HopRewriteUtils.isParameterBuiltinOp(hop,
ParamBuiltinOp.PARAMSERV)
+ if( HopRewriteUtils.isParameterizedBuiltinOp(hop,
ParamBuiltinOp.PARAMSERV)
&& HopRewriteUtils.knownParamservFunctions(hop,
sb.getDMLProg()))
{
ParameterizedBuiltinOp pop = (ParameterizedBuiltinOp)
hop;
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index 61bd0921ce..144b331327 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -1204,18 +1204,18 @@ public class HopRewriteUtils {
return (hop instanceof AggUnaryOp &&
((AggUnaryOp)hop).getOp()==AggOp.SUM_SQ);
}
- public static boolean isParameterBuiltinOp(Hop hop, ParamBuiltinOp
type) {
+ public static boolean isParameterizedBuiltinOp(Hop hop, ParamBuiltinOp
type) {
return hop instanceof ParameterizedBuiltinOp &&
((ParameterizedBuiltinOp) hop).getOp().equals(type);
}
public static boolean isRemoveEmpty(Hop hop, boolean rows) {
- return isParameterBuiltinOp(hop, ParamBuiltinOp.RMEMPTY)
+ return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY)
&& HopRewriteUtils.isLiteralOfValue(
((ParameterizedBuiltinOp)hop).getParameterHop("margin"), rows?"rows":"cols");
}
public static boolean isRemoveEmpty(Hop hop) {
- return isParameterBuiltinOp(hop, ParamBuiltinOp.RMEMPTY);
+ return isParameterizedBuiltinOp(hop, ParamBuiltinOp.RMEMPTY);
}
public static boolean isNary(Hop hop, OpOpN type) {
@@ -1660,7 +1660,7 @@ public class HopRewriteUtils {
if( hop.isVisited() ) return false;
hop.setVisited();
return HopRewriteUtils.isNary(hop, OpOpN.EVAL)
- || (HopRewriteUtils.isParameterBuiltinOp(hop,
ParamBuiltinOp.PARAMSERV)
+ || (HopRewriteUtils.isParameterizedBuiltinOp(hop,
ParamBuiltinOp.PARAMSERV)
&& !knownParamservFunctions(hop))
|| hop.getInput().stream().anyMatch(c ->
containsSecondOrderBuiltin(c));
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
index 0c2801da97..3eaed10792 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteSplitDagDataDependentOperators.java
@@ -319,7 +319,7 @@ public class RewriteSplitDagDataDependentOperators extends
StatementBlockRewrite
private static boolean isBasicDataDependentOperator(Hop hop, boolean
noSplitRequired) {
return (HopRewriteUtils.isNary(hop, OpOpN.EVAL) &
!noSplitRequired)
|| (HopRewriteUtils.isData(hop, OpOpData.SQLREAD) &
!noSplitRequired)
- || (HopRewriteUtils.isParameterBuiltinOp(hop,
ParamBuiltinOp.GROUPEDAGG)
+ || (HopRewriteUtils.isParameterizedBuiltinOp(hop,
ParamBuiltinOp.GROUPEDAGG)
&&
!((ParameterizedBuiltinOp)hop).isKnownNGroups() && !noSplitRequired)
|| ((HopRewriteUtils.isUnary(hop, OpOp1.COMPRESS) ||
hop.requiresCompression()) &&
(!HopRewriteUtils.hasOnlyWriteParents(hop,
true, true)))
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index fc9e60419c..62f073a06a 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -72,6 +72,10 @@ public class ParameterizedBuiltin extends Lop
return _operation;
}
+ public void setNumThreads(int k) {
+ _numThreads = k;
+ }
+
public int getInputIndex(String name) {
Lop n = _inputParams.get(name);
for(int i=0; i<getInputs().size(); i++)
@@ -211,10 +215,11 @@ public class ParameterizedBuiltin extends Lop
sb.append(OPERAND_DELIMITOR);
}
- if( getExecType()==ExecType.CP && _operation ==
ParamBuiltinOp.REXPAND ) {
+ if( getExecType()==ExecType.CP
+ && (_operation==ParamBuiltinOp.REXPAND ||
_operation==ParamBuiltinOp.CONTAINS ) ) {
sb.append( "k" );
sb.append( Lop.NAME_VALUE_SEPARATOR );
- sb.append( _numThreads );
+ sb.append( _numThreads );
sb.append(OPERAND_DELIMITOR);
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
index 037231fa0e..0a30d79250 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
@@ -665,14 +665,17 @@ public abstract class DenseBlock implements Serializable,
Block
* (note that NaN==NaN yields false).
*
* @param pattern checked pattern
+ * @param rl row lower bound (inclusive)
+ * @param ru row upper bound (exclusive)
* @return true if pattern appears at least once, otherwise false
*/
- public boolean contains(double pattern) {
+ public boolean contains(double pattern, int rl, int ru) {
boolean NaNpattern = Double.isNaN(pattern);
- for(int i=0; i<numBlocks(); i++) {
- double[] vals = valuesAt(i);
- int len = size(i);
- for(int j=0; j<len; j++)
+ int clen = _odims[0];
+ for(int i=rl; i<ru; i++) {
+ double[] vals = values(i);
+ int pos = pos(i);
+ for(int j=pos; j<pos+clen; j++)
if(vals[j]==pattern || (NaNpattern &&
Double.isNaN(vals[j])))
return true;
}
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index b19d132503..bd3468531d 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -474,12 +474,13 @@ public abstract class SparseBlock implements
Serializable, Block
* (note that NaN==NaN yields false).
*
* @param pattern checked pattern
+ * @param rl row lower bound (inclusive)
+ * @param ru row upper bound (exclusive)
* @return true if pattern appears at least once, otherwise false
*/
- public boolean contains(double pattern) {
+ public boolean contains(double pattern, int rl, int ru) {
boolean NaNpattern = Double.isNaN(pattern);
- int rlen = numRows();
- for(int i=0; i<rlen; i++) {
+ for(int i=rl; i<ru; i++) {
if( isEmpty(i) ) continue;
int apos = pos(i);
int alen = size(i);
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
index 0aff4e08bb..ed00f15564 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockCSR.java
@@ -980,11 +980,11 @@ public class SparseBlockCSR extends SparseBlock
}
@Override //specialized for CSR
- public boolean contains(double pattern) {
+ public boolean contains(double pattern, int rl, int ru) {
boolean NaNpattern = Double.isNaN(pattern);
double[] vals = _values;
- int len = _size;
- for(int i=0; i<len; i++)
+ int prl = pos(rl), pru = pos(ru);
+ for(int i=prl; i<pru; i++)
if(vals[i]==pattern || (NaNpattern &&
Double.isNaN(vals[i])))
return true;
return false;
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
index 447781a520..c5d4717e11 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockDCSR.java
@@ -765,11 +765,11 @@ public class SparseBlockDCSR extends SparseBlock
}
@Override //specialized for CSR
- public boolean contains(double pattern) {
+ public boolean contains(double pattern, int rl, int ru) {
boolean NaNpattern = Double.isNaN(pattern);
double[] vals = _values;
- int len = _size;
- for(int i=0; i<len; i++)
+ int prl = pos(rl), pru = pos(ru);
+ for(int i=prl; i<pru; i++)
if(vals[i]==pattern || (NaNpattern &&
Double.isNaN(vals[i])))
return true;
return false;
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index d0aea7bce9..6bfaf1f96f 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -249,12 +249,13 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
}
else if(opcode.equalsIgnoreCase("contains")) {
String varName = params.get("target");
+ int k = Integer.parseInt(params.get("k")); //num threads
MatrixBlock target = ec.getMatrixInput(varName);
Data pattern = ec.getVariable(params.get("pattern"));
if( pattern == null ) //literal
pattern =
ScalarObjectFactory.createScalarObject(ValueType.FP64, params.get("pattern"));
boolean ret = pattern.getDataType().isScalar() ?
-
target.containsValue(((ScalarObject)pattern).getDoubleValue()) :
+
target.containsValue(((ScalarObject)pattern).getDoubleValue(), k) :
(target.containsVector(((MatrixObject)pattern).acquireRead(), true).size()>0);
ec.releaseMatrixInput(varName);
if(!pattern.getDataType().isScalar())
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 701abd1c20..ea78af5759 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,6 +35,7 @@ import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
+import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.commons.lang3.ArrayUtils;
@@ -766,6 +767,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
public boolean containsValue(double pattern) {
+ return containsValue(pattern, 1);
+ }
+
+ public boolean containsValue(double pattern, int k) {
//fast paths: infer from meta data only
if(isEmptyBlock(true))
return pattern==0;
@@ -774,9 +779,23 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
//make a pass over the data to determine if it includes the
//pattern, with early abort as soon as the pattern is found
+ if( k == 1 ) {
+ return containsValue(pattern, 0, rlen);
+ }
+ else {
+ ExecutorService pool = CommonThreadPool.get(k);
+ List<Future<Boolean>> tasks =
UtilFunctions.getTaskRangesDefault(rlen, k).stream()
+ .map(p -> pool.submit(() ->
containsValue(pattern, p.getKey(), p.getValue())))
+ .collect(Collectors.toList()); //submit all
before waiting
+ pool.shutdown();
+ return tasks.stream().anyMatch(t ->
UtilFunctions.getSafe(t));
+ }
+ }
+
+ private boolean containsValue(double pattern, int rl, int ru) {
return isInSparseFormat() ?
- getSparseBlock().contains(pattern) :
- getDenseBlock().contains(pattern);
+ getSparseBlock().contains(pattern, rl, ru) :
+ getDenseBlock().contains(pattern, rl, ru);
}
public List<Integer> containsVector(MatrixBlock pattern, boolean
earlyAbort) {
@@ -3040,7 +3059,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
//early abort for comparisons w/ special values
if( Builtin.isBuiltinCode(op.fn, BuiltinCode.ISNAN,
BuiltinCode.ISNA))
- if( !containsValue(op.getPattern()) ) {
+ if( !containsValue(op.getPattern(), op.getNumThreads())
) {
return new MatrixBlock(rlen, clen, true);
//avoid unnecessary allocation
}
diff --git a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
index b46792da02..619dd4467e 100644
--- a/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
+++ b/src/main/java/org/apache/sysds/runtime/util/UtilFunctions.java
@@ -33,6 +33,8 @@ import java.util.Map;
import java.util.Random;
import java.util.Set;
import java.util.TimeZone;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
@@ -247,6 +249,17 @@ public class UtilFunctions {
return pos;
}
+ public static List<Pair<Integer,Integer>> getTaskRangesDefault(int len,
int k) {
+ List<Pair<Integer,Integer>> ret = new ArrayList<>();
+ int nk = roundToNext(Math.min(8*k,len/32), k);
+ int beg = 0;
+ for(Integer blen : getBalancedBlockSizes(len, nk)) {
+ ret.add(new Pair<>(beg, beg+blen));
+ beg = beg+blen;
+ }
+ return ret;
+ }
+
public static ArrayList<Integer> getBalancedBlockSizesDefault(int len,
int k, boolean constK) {
int nk = constK ? k : roundToNext(Math.min(8*k,len/32), k);
return getBalancedBlockSizes(len, nk);
@@ -1327,6 +1340,14 @@ public class UtilFunctions {
result[i] = String.valueOf(original[i]);
return result;
}
+
+ public static <T> T getSafe(Future<T> task) {
+ try {
+ return task.get();
+ } catch (InterruptedException | ExecutionException e) {
+ throw new DMLRuntimeException(e);
+ }
+ }
public static double[] convertStringToDoubleArray(String[] original) {
// double[] ret = new double[original.length];