This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit 8a52a6832cbb4b0cefab8c96d0776c9b089f17f3
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Jan 3 15:08:57 2024 +0100

    [SYSTEMDS-3661] New contains-vector parameterized built-in function
    
    We already supported a contains(target=X, pattern=s) function for
    checking if a pattern scalar s is included in the target matrix X.
    For a new external use case, we now overloaded this function to also
    allow the pattern to be a row vector (does an exact duplicate exist).
    For correctness, we use safe double comparisons on long values, and
    implement block operations with and without early abort for efficient
    distributed spark operations and multiple column blocks.
---
 .../ParameterizedBuiltinFunctionExpression.java    |   9 +-
 .../org/apache/sysds/runtime/data/DenseBlock.java  |  15 +++
 .../org/apache/sysds/runtime/data/SparseBlock.java |  22 ++++
 .../cp/ParameterizedBuiltinCPInstruction.java      |  10 +-
 .../spark/ParameterizedBuiltinSPInstruction.java   |  55 ++++++++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  17 +++
 .../test/functions/aggregate/ContainsTest.java     | 127 ++++++++++++++++++++-
 .../aggregate/{Contains.dml => ContainsVal.dml}    |   0
 .../aggregate/{Contains.dml => ContainsVect.dml}   |  13 ++-
 9 files changed, 252 insertions(+), 16 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 1906ee818e..4ee92e783b 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -737,8 +737,15 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                //check existence and correctness of arguments
                Expression target = getVarParam("target");
                checkTargetParam(target, conditional);
-               checkScalarParam("contains", "pattern", conditional);
                
+               Expression pattern = getVarParam("pattern");
+               if(pattern == null)
+                       raiseValidateError("Named parameter 'pattern' missing. 
Please specify the input matrix.",
+                               conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+               if(!(pattern.getOutput().getDataType().isScalar()
+                       ||pattern.getOutput().getDataType().isMatrix()) )
+                       raiseValidateError("Named parameter 'pattern' must be a 
scalar or matrix.",
+                               conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
                //set boolean scalar 
                output.setBooleanProperties();
        }
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
index e4a55ab10c..64e3789d4a 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlock.java
@@ -20,7 +20,9 @@
 package org.apache.sysds.runtime.data;
 
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
+import java.util.List;
 
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -677,6 +679,19 @@ public abstract class DenseBlock implements Serializable, 
Block
                return false;
        }
        
+       public List<Integer> contains(double[] pattern, boolean earlyAbort) {
+               List<Integer> ret = new ArrayList<>();
+               int clen = _odims[0];
+               for( int i=0; i<_rlen; i++ ) {
+                       //safe comparison on long representations, incl NaN
+                       if(Arrays.equals(values(i), pos(i), pos(i)+clen, 
pattern, 0, clen))
+                               ret.add(i);
+                       if(earlyAbort && ret.size()>0)
+                               return ret;
+               }
+               return ret;
+       }
+       
        @Override
        public boolean equals(Object o) {
                if(o instanceof DenseBlock)
diff --git a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
index bbddb9a178..bc6d4727d1 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlock.java
@@ -20,8 +20,10 @@
 package org.apache.sysds.runtime.data;
 
 import java.io.Serializable;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Iterator;
+import java.util.List;
 
 import org.apache.sysds.runtime.matrix.data.IJV;
 
@@ -491,6 +493,26 @@ public abstract class SparseBlock implements Serializable, 
Block
                return false;
        }
        
+       public List<Integer> contains(double[] pattern, boolean earlyAbort) {
+               List<Integer> ret = new ArrayList<>();
+               int rlen = numRows();
+               for( int i=0; i<rlen; i++ ) {
+                       int apos = pos(i);
+                       int alen = size(i);
+                       int[] aix = indexes(i);
+                       double[] avals = values(i);
+                       boolean lret = true;
+                       //safe comparison on long representations, incl NaN
+                       for(int k=apos; k<apos+alen & !lret; k++)
+                               lret &= Double.compare(avals[k], 
pattern[aix[k]]) == 0;
+                       if( lret )
+                               ret.add(i);
+                       if(earlyAbort && ret.size()>0)
+                               return ret;
+               }
+               return ret;
+       }
+       
        ////////////////////////
        //iterators
        
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index b0c86efdbd..0307fbb03b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -250,9 +250,15 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                else if(opcode.equalsIgnoreCase("contains")) {
                        String varName = params.get("target");
                        MatrixBlock target = ec.getMatrixInput(varName);
-                       double pattern = 
Double.parseDouble(params.get("pattern"));
-                       boolean ret = target.containsValue(pattern);
+                       Data pattern = ec.getVariable(params.get("pattern"));
+                       if( pattern == null ) //literal
+                               pattern = 
ScalarObjectFactory.createScalarObject(ValueType.FP64, params.get("pattern"));
+                       boolean ret = pattern.getDataType().isScalar() ?
+                               
target.containsValue(((ScalarObject)pattern).getDoubleValue()) : 
+                               
(target.containsVector(((MatrixObject)pattern).acquireRead(), true).size()>0);
                        ec.releaseMatrixInput(varName);
+                       if(!pattern.getDataType().isScalar())
+                               ec.releaseMatrixInput(params.get("pattern"));
                        ec.setScalarOutput(output.getName(), new 
BooleanObject(ret));
                }
                else if(opcode.equalsIgnoreCase("replace")) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 197283e0b5..3b61b768b0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.HashMap;
 import java.util.Iterator;
+import java.util.List;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.spark.api.java.JavaPairRDD;
@@ -51,6 +52,9 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ScalarObjectFactory;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
 import org.apache.sysds.runtime.instructions.spark.data.PartitionedBroadcast;
@@ -374,15 +378,31 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        }
                }
                else if(opcode.equalsIgnoreCase("contains")) {
+                       
                        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
                                
.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
                        
+                       Data pattern = ec.getVariable(params.get("pattern"));
+                       if( pattern == null ) //literal
+                               pattern = 
ScalarObjectFactory.createScalarObject(ValueType.FP64, params.get("pattern"));
+                       
+                       boolean ret = false;
+                       if( pattern.getDataType().isScalar() ) {
+                               double dpattern = 
((ScalarObject)pattern).getDoubleValue();
+                               ret = in1.values() //num blocks containing 
pattern
+                                       .map(new RDDContainsFunction(dpattern))
+                                       .reduce((a,b) -> a+b) > 0;
+                       }
+                       else {
+                               PartitionedBroadcast<MatrixBlock> bc = 
sec.getBroadcastForVariable(params.get("pattern"));
+                               DataCharacteristics dc = 
sec.getDataCharacteristics(params.get("target"));
+                               ret = in1.flatMapToPair(new 
RDDContainsVectFunction(bc, dc.getBlocksize()))
+                                       .reduceByKey((a,b) -> a+b)
+                                       .values().reduce((a,b) -> 
Math.max(a,b)) == dc.getNumColBlocks();
+                       }
+                       
                        // execute contains operation 
-                       double pattern = 
Double.parseDouble(params.get("pattern"));
-                       Double ret = in1.values() //num blocks containing 
pattern
-                               .map(new RDDContainsFunction(pattern))
-                               .reduce((a,b) -> a+b);
-                       ec.setScalarOutput(output.getName(), new 
BooleanObject(ret>0));
+                       ec.setScalarOutput(output.getName(), new 
BooleanObject(ret));
                }
                else if(opcode.equalsIgnoreCase("replace")) {
                        if(sec.isFrameObject(params.get("target"))){
@@ -688,6 +708,31 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        return arg0.containsValue(_pattern) ? 1d : 0d;
                }
        }
+       
+       public static class RDDContainsVectFunction implements 
PairFlatMapFunction<Tuple2<MatrixIndexes,MatrixBlock>, Long, Integer>
+       {
+               private static final long serialVersionUID = 
2228503788469700742L;
+               private final PartitionedBroadcast<MatrixBlock> _pbcPattern;
+               private final int _blocksize;
+               
+               public 
RDDContainsVectFunction(PartitionedBroadcast<MatrixBlock> bc, int blocksize) {
+                       _pbcPattern = bc;
+                       _blocksize = blocksize;
+               }
+
+               @Override
+               public Iterator<Tuple2<Long, Integer>> 
call(Tuple2<MatrixIndexes, MatrixBlock> input) throws Exception {
+                       MatrixIndexes ix = input._1();
+                       MatrixBlock pattern = _pbcPattern.getBlock(1, 
(int)ix.getColumnIndex());
+                       List<Integer> tmp = input._2().containsVector(pattern, 
false);
+                       
+                       List<Tuple2<Long,Integer>> ret = new ArrayList<>(); 
+                       ret.add(new Tuple2<>(ix.getRowIndex()*_blocksize, 0)); 
//ensure non-empty RDD
+                       for(int rix : tmp)
+                               ret.add(new 
Tuple2<>(UtilFunctions.computeCellIndex(ix.getRowIndex(), _blocksize, rix), 1));
+                       return ret.iterator();
+               }
+       }
 
        public static class RDDFrameReplaceFunction implements 
Function<FrameBlock, FrameBlock>{
                private static final long serialVersionUID = 
6576713401901671660L;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 2e821740f4..00f8bcbbfd 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -760,6 +760,23 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        getDenseBlock().contains(pattern);
        }
        
+       public List<Integer> containsVector(MatrixBlock pattern, boolean 
earlyAbort) {
+               //note: in contract to containsValue, we return the row index 
where a match 
+               //was found in order to reuse these block operations for Spark 
ops as well
+               
+               //basic error handling
+               if( clen != pattern.clen || pattern.rlen != 1 )
+                       throw new DMLRuntimeException("contains only supports 
pattern row vectors of matching "
+                               + "number of columns: " + 
getDataCharacteristics()+" vs "+pattern.getDataCharacteristics());
+               
+               //make a pass over the data to determine if it includes the
+               //pattern, with early abort as soon as the pattern is found
+               double[] dpattern = 
DataConverter.convertToDoubleVector(pattern, false, false);
+               return isInSparseFormat() ?
+                       getSparseBlock().contains(dpattern, earlyAbort) :
+                       getDenseBlock().contains(dpattern, earlyAbort);
+       }
+       
        /**
         * <p>Append value is only used when values are appended at the end of 
each row for the sparse representation</p>
         * 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java 
b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
index 4ea6d917d8..ee217bbec9 100644
--- a/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
@@ -30,8 +30,9 @@ import org.apache.sysds.utils.Statistics;
 
 public class ContainsTest extends AutomatedTestBase 
 {
-       private final static String TEST_NAME = "Contains";
-
+       private final static String TEST_NAME1 = "ContainsVal";
+       private final static String TEST_NAME2 = "ContainsVect";
+       
        private final static String TEST_DIR = "functions/aggregate/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
AggregateInfTest.class.getSimpleName() + "/";
        
@@ -42,8 +43,10 @@ public class ContainsTest extends AutomatedTestBase
        
        @Override
        public void setUp() {
-               addTestConfiguration(TEST_NAME,
-                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new 
String[]{"B"})); 
+               addTestConfiguration(TEST_NAME1,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new 
String[]{"B"}));
+               addTestConfiguration(TEST_NAME2,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new 
String[]{"B"}));
        }
 
        
@@ -107,6 +110,86 @@ public class ContainsTest extends AutomatedTestBase
                runContainsTest(Double.NaN, false, true, ExecType.SPARK);
        }
        
+       @Test
+       public void testVectTrueDenseDenseCP() {
+               runContainsVectorTest(true, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectFalseDenseDenseCP() {
+               runContainsVectorTest(false, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectTrueDenseSparseCP() {
+               runContainsVectorTest(true, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectFalseDenseSparseCP() {
+               runContainsVectorTest(false, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectTrueSparseDenseCP() {
+               runContainsVectorTest(true, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectFalseSparseDenseCP() {
+               runContainsVectorTest(false, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectTrueSparseSparseCP() {
+               runContainsVectorTest(true, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectFalseSparseSparseCP() {
+               runContainsVectorTest(false, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testVectTrueDenseDenseSpark() {
+               runContainsVectorTest(true, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectFalseDenseDenseSpark() {
+               runContainsVectorTest(false, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectTrueDenseSparseSpark() {
+               runContainsVectorTest(true, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectFalseDenseSparseSpark() {
+               runContainsVectorTest(false, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectTrueSparseDenseSpark() {
+               runContainsVectorTest(true, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectFalseSparseDenseSpark() {
+               runContainsVectorTest(false, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectTrueSparseSparseSpark() {
+               runContainsVectorTest(true, false, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testVectFalseSparseSparseSpark() {
+               runContainsVectorTest(false, false, true, ExecType.SPARK);
+       }
+       
        private void runContainsTest( double check, boolean expected, boolean 
sparse, ExecType instType)
        {
                ExecMode oldMode = setExecMode(instType);
@@ -114,10 +197,10 @@ public class ContainsTest extends AutomatedTestBase
                try
                {
                        double sparsity = (sparse) ? sparsity1 : sparsity2;
-                       getAndLoadTestConfiguration(TEST_NAME);
+                       getAndLoadTestConfiguration(TEST_NAME1);
                        
                        String HOME = SCRIPT_DIR + TEST_DIR;
-                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
                        programArgs = new String[]{"-args",
                                input("A"), String.valueOf(check), output("B") 
};
                        
@@ -139,4 +222,36 @@ public class ContainsTest extends AutomatedTestBase
                        resetExecMode(oldMode);
                }
        }
+       
+       private void runContainsVectorTest( boolean expected, boolean sparse1, 
boolean sparse2, ExecType instType)
+       {
+               ExecMode oldMode = setExecMode(instType);
+       
+               try
+               {
+                       double sparsity = (sparse1) ? sparsity1 : sparsity2;
+                       getAndLoadTestConfiguration(TEST_NAME2);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME2 + ".dml";
+                       programArgs = new String[]{"-args", input("A"),
+                               String.valueOf(expected).toUpperCase(), 
String.valueOf(sparse2).toUpperCase(), output("B") };
+                       
+                       //generate actual dataset 
+                       double[][] A = getRandomMatrix(rows, cols, -0.05, 1, 
sparsity, 7); 
+                       writeInputMatrixWithMTD("A", A, false);
+       
+                       //run test
+                       runTest(true, false, null, -1); 
+                       boolean ret = TestUtils.readDMLBoolean(output("B"));
+                       Assert.assertEquals(expected, ret);
+                       if( instType == ExecType.CP ) {
+                               
Assert.assertEquals(Statistics.getNoOfCompiledSPInst(), 1); //reblock
+                               
Assert.assertEquals(Statistics.getNoOfExecutedSPInst(), 0);
+                       }
+               }
+               finally {
+                       resetExecMode(oldMode);
+               }
+       }
 }
diff --git a/src/test/scripts/functions/aggregate/Contains.dml 
b/src/test/scripts/functions/aggregate/ContainsVal.dml
similarity index 100%
copy from src/test/scripts/functions/aggregate/Contains.dml
copy to src/test/scripts/functions/aggregate/ContainsVal.dml
diff --git a/src/test/scripts/functions/aggregate/Contains.dml 
b/src/test/scripts/functions/aggregate/ContainsVect.dml
similarity index 81%
rename from src/test/scripts/functions/aggregate/Contains.dml
rename to src/test/scripts/functions/aggregate/ContainsVect.dml
index 0576b6e1cd..ebd929aec5 100644
--- a/src/test/scripts/functions/aggregate/Contains.dml
+++ b/src/test/scripts/functions/aggregate/ContainsVect.dml
@@ -20,5 +20,14 @@
 #-------------------------------------------------------------
 
 A = read($1);
-ret = contains(target=A, pattern=$2);
-write(ret, $3);
\ No newline at end of file
+expected = $2;
+
+# generate probe vector
+sp2 = ifelse($3, 0.9, 0.01);
+b = round(rand(rows=1, cols=ncol(A), sparsity=sp2));
+if( expected )
+  A[nrow(A)-100, ] = b;
+
+ret = contains(target=A, pattern=b);
+write(ret, $4);
+

Reply via email to