This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9446bf6362 [SYSTEMDS-3500] Fix perftest regression via new 
contains-value function
9446bf6362 is described below

commit 9446bf6362cd46cc4049bcab361f7cc6388809b2
Author: Matthias Boehm <[email protected]>
AuthorDate: Thu Feb 23 22:37:19 2023 +0100

    [SYSTEMDS-3500] Fix perftest regression via new contains-value function
    
    A while ago the MLLogreg script was extended with robustness checks for
    NaN inputs. In the perftest MLogReg 1M_1K_dense (8GB), this led to a
    performance regression of unnecessary with 20GB driver because
    input and output (16GB) exceed the 70% memory budget. Given that
    sum(isNaN(X)) is likely false, we now expose an already existing block
    operations contains(X, pattern) that has only have the memory reqs.
    We added the CP, SPARK, and FED instructions as well as related tests.
---
 scripts/builtin/multiLogReg.dml                    |   6 +-
 scripts/builtin/multiLogRegPredict.dml             |   6 +-
 .../java/org/apache/sysds/common/Builtins.java     |   1 +
 src/main/java/org/apache/sysds/common/Types.java   |   2 +-
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  |   3 +-
 .../apache/sysds/lops/ParameterizedBuiltin.java    |  21 +--
 .../org/apache/sysds/parser/DMLTranslator.java     |   1 +
 .../ParameterizedBuiltinFunctionExpression.java    |  45 ++++---
 .../controlprogram/federated/FederationUtils.java  |  14 +-
 .../runtime/instructions/CPInstructionParser.java  |   1 +
 .../runtime/instructions/SPInstructionParser.java  |   3 +-
 .../cp/ParameterizedBuiltinCPInstruction.java      |  18 ++-
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  35 +++--
 .../spark/ParameterizedBuiltinSPInstruction.java   |  30 ++++-
 .../test/functions/aggregate/ContainsTest.java     | 142 +++++++++++++++++++++
 .../federated/algorithms/FederatedLogRegTest.java  |   2 +-
 src/test/scripts/functions/aggregate/Contains.dml  |  24 ++++
 17 files changed, 294 insertions(+), 60 deletions(-)

diff --git a/scripts/builtin/multiLogReg.dml b/scripts/builtin/multiLogReg.dml
index 9b7d7da79e..528931ad8e 100644
--- a/scripts/builtin/multiLogReg.dml
+++ b/scripts/builtin/multiLogReg.dml
@@ -59,10 +59,10 @@ m_multiLogReg = function(Matrix[Double] X, Matrix[Double] 
Y, Int icpt = 2,
   D = ncol (X);
 
   # Robustness for datasets with missing values (causing NaN gradients)
-  numNaNs = sum(isNaN(X))
-  if( numNaNs > 0 ) {
+  hasNaNs = contains(target=X, pattern=NaN);
+  if( hasNaNs > 0 ) {
     if(verbose)
-      print("multiLogReg: matrix X contains "+numNaNs+" missing values, 
replacing with 0.")
+      print("multiLogReg: matrix X contains "+sum(isNaN(X))+" missing values, 
replacing with 0.")
     X = replace(target=X, pattern=NaN, replacement=0);
   }
 
diff --git a/scripts/builtin/multiLogRegPredict.dml 
b/scripts/builtin/multiLogRegPredict.dml
index dc5c0332ab..16bf08316a 100644
--- a/scripts/builtin/multiLogRegPredict.dml
+++ b/scripts/builtin/multiLogRegPredict.dml
@@ -49,9 +49,9 @@ m_multiLogRegPredict = function(Matrix[Double] X, 
Matrix[Double] B, Matrix[Doubl
     stop("multiLogRegPredict: mismatching ncol(X) and nrow(B): "+ncol(X)+" 
"+nrow(B));
   
   # Robustness for datasets with missing values (causing NaN probabilities)
-  numNaNs = sum(isNaN(X))
-  if( numNaNs > 0 ) {
-    print("multiLogRegPredict: matrix X contains "+numNaNs+" missing values, 
replacing with 0.")
+  hasNaNs = contains(target=X, pattern=NaN);
+  if( hasNaNs > 0 ) {
+    print("multiLogRegPredict: matrix X contains "+sum(isNaN(X))+" missing 
values, replacing with 0.")
     X = replace(target=X, pattern=NaN, replacement=0);
   }
   accuracy = 0.0 # initialize variable 
diff --git a/src/main/java/org/apache/sysds/common/Builtins.java 
b/src/main/java/org/apache/sysds/common/Builtins.java
index f7cbb972df..e627adb286 100644
--- a/src/main/java/org/apache/sysds/common/Builtins.java
+++ b/src/main/java/org/apache/sysds/common/Builtins.java
@@ -310,6 +310,7 @@ public enum Builtins {
        // Parameterized functions with parameters
        AUTODIFF("autoDiff", false, true),
        CDF("cdf", false, true),
+       CONTAINS("contains", false, true),
        COUNT_DISTINCT("countDistinct",false, true),
        COUNT_DISTINCT_APPROX("countDistinctApprox", false, true),
        COUNT_DISTINCT_APPROX_ROW("rowCountDistinctApprox", false, true),
diff --git a/src/main/java/org/apache/sysds/common/Types.java 
b/src/main/java/org/apache/sysds/common/Types.java
index ab81ff4e31..49221cee89 100644
--- a/src/main/java/org/apache/sysds/common/Types.java
+++ b/src/main/java/org/apache/sysds/common/Types.java
@@ -569,7 +569,7 @@ public class Types
        }
        
        public enum ParamBuiltinOp {
-               AUTODIFF, INVALID, CDF, INVCDF, GROUPEDAGG, RMEMPTY, REPLACE, 
REXPAND,
+               AUTODIFF, CDF, CONTAINS, INVALID, INVCDF, GROUPEDAGG, RMEMPTY, 
REPLACE, REXPAND,
                LOWER_TRI, UPPER_TRI,
                TRANSFORMAPPLY, TRANSFORMDECODE, TRANSFORMCOLMAP, TRANSFORMMETA,
                TOKENIZE, TOSTRING, LIST, PARAMSERV
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 55e6d79c7b..4404579894 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -181,7 +181,8 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                        case REXPAND: {
                                constructLopsRExpand(inputlops, et);
                                break;
-                       } 
+                       }
+                       case CONTAINS:
                        case CDF:
                        case INVCDF: 
                        case REPLACE:
diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java 
b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
index a0f9331adf..eb8174dbca 100644
--- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
+++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java
@@ -131,13 +131,6 @@ public class ParameterizedBuiltin extends Lop
                                
                                break;
                        
-                       case REPLACE: {
-                               sb.append( "replace" );
-                               sb.append( OPERAND_DELIMITOR );
-                               sb.append(compileGenericParamMap(_inputParams));
-                               break;
-                       }
-                       
                        case LOWER_TRI: {
                                sb.append( "lowertri" );
                                sb.append( OPERAND_DELIMITOR );
@@ -174,11 +167,14 @@ public class ParameterizedBuiltin extends Lop
 
                                break;
 
+                       case CONTAINS:
+                       case REPLACE:
                        case TOKENIZE:
                        case TRANSFORMAPPLY:
                        case TRANSFORMDECODE:
                        case TRANSFORMCOLMAP:
-                       case TRANSFORMMETA:{ 
+                       case TRANSFORMMETA:
+                       case PARAMSERV: { 
                                sb.append(_operation.name().toLowerCase()); 
//opcode
                                sb.append(OPERAND_DELIMITOR);
                                sb.append(compileGenericParamMap(_inputParams));
@@ -202,14 +198,7 @@ public class ParameterizedBuiltin extends Lop
                                sb.append(compileGenericParamMap(_inputParams));
                                break;
                        }
-
-                       case PARAMSERV: {
-                               sb.append("paramserv");
-                               sb.append(OPERAND_DELIMITOR);
-                               sb.append(compileGenericParamMap(_inputParams));
-                               break;
-                       }
-                               
+                       
                        default:
                                throw new 
LopsException(this.printErrorLocation() + "In ParameterizedBuiltin Lop, Unknown 
operation: " + _operation);
                }
diff --git a/src/main/java/org/apache/sysds/parser/DMLTranslator.java 
b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
index c6ed2c5b84..98eaf0bbfb 100644
--- a/src/main/java/org/apache/sysds/parser/DMLTranslator.java
+++ b/src/main/java/org/apache/sysds/parser/DMLTranslator.java
@@ -2007,6 +2007,7 @@ public class DMLTranslator
                                        target.getValueType(), 
source.getOpCode(), paramHops);
                                break;
                        
+                       case CONTAINS:
                        case GROUPEDAGG:
                        case RMEMPTY:
                        case REPLACE:
diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 293ca7312e..1d30d13fea 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -202,6 +202,10 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        validateReplace(output, conditional);
                        break;
                
+               case CONTAINS:
+                       validateContains(output, conditional);
+                       break;
+               
                case ORDER:
                        validateOrder(output, conditional);
                        break;
@@ -725,28 +729,24 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                output.setDimensions(in.getDim1(), in.getDim2());
        }
        
+       private void validateContains(DataIdentifier output, boolean 
conditional) {
+               //check existence and correctness of arguments
+               Expression target = getVarParam("target");
+               checkTargetParam(target, conditional);
+               checkScalarParam("contains", "pattern", conditional);
+               
+               //set boolean scalar 
+               output.setBooleanProperties();
+       }
+       
        private void validateReplace(DataIdentifier output, boolean 
conditional) {
                //check existence and correctness of arguments
                Expression target = getVarParam("target");
                if( target.getOutput().getDataType() != DataType.FRAME ){
                        checkTargetParam(target, conditional);
                }
-               
-               Expression pattern = getVarParam("pattern");
-               if( pattern==null ) {
-                       raiseValidateError("Named parameter 'pattern' missing. 
Please specify the replacement pattern.", conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
-               }
-               else if( pattern.getOutput().getDataType() != DataType.SCALAR 
){                                
-                       raiseValidateError("Replacement pattern 'pattern' is of 
type '"+pattern.getOutput().getDataType()+"'. Please, specify a scalar 
replacement pattern.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
-               }       
-               
-               Expression replacement = getVarParam("replacement");
-               if( replacement==null ) {
-                       raiseValidateError("Named parameter 'replacement' 
missing. Please specify the replacement value.", conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
-               }
-               else if( replacement.getOutput().getDataType() != 
DataType.SCALAR ){    
-                       raiseValidateError("Replacement value 'replacement' is 
of type '"+replacement.getOutput().getDataType()+"'. Please, specify a scalar 
replacement value.", conditional, LanguageErrorCodes.INVALID_PARAMETERS);
-               }       
+               checkScalarParam("replace", "pattern", conditional);
+               checkScalarParam("replace", "replacement", conditional);
                
                // Output is a matrix with same dims as input
                output.setDataType(target.getOutput().getDataType());
@@ -756,6 +756,19 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                        output.setValueType(ValueType.FP64);
                output.setDimensions(target.getOutput().getDim1(), 
target.getOutput().getDim2());
        }
+       
+       private void checkScalarParam(String group, String param, boolean 
conditional) {
+               Expression eparam = getVarParam(param);
+               if( eparam==null ) {
+                       raiseValidateError("Named parameter '"+param+"' 
missing. Please specify the "+group+" pattern.",
+                               conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+               }
+               else if( eparam.getOutput().getDataType() != DataType.SCALAR ){
+                       raiseValidateError(group + " parameter '"+param+"' is 
of type '"
+                               + eparam.getOutput().getDataType()+"'. Please, 
specify a scalar "+param+".",
+                               conditional, 
LanguageErrorCodes.INVALID_PARAMETERS);
+               }
+       }
 
        private void validateOrder(DataIdentifier output, boolean conditional) {
                //check existence and correctness of arguments
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 73939117ce..cabf4887a6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -487,7 +487,19 @@ public class FederationUtils {
                        throw new DMLRuntimeException(ex);
                }
        }
-
+       
+       public static boolean aggBooleanScalar(Future<FederatedResponse>[] tmp) 
{
+               boolean ret = false;
+               try {
+                       for( Future<FederatedResponse> fr : tmp )
+                               ret |= 
((ScalarObject)fr.get().getData()[0]).getBooleanValue();
+               }
+               catch (Exception e) {
+                       throw new DMLRuntimeException(e);
+               }
+               return ret;
+       }
+       
        public static MatrixBlock aggMatrix(AggregateUnaryOperator aop, 
Future<FederatedResponse>[] ffr, FederationMap map) {
                if (aop.isRowAggregate() && map.getType() == FType.ROW)
                        return bind(ffr, false);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index 1dc7b068b8..07ce7d620b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -222,6 +222,7 @@ public class CPInstructionParser extends InstructionParser {
                
                // Parameterized Builtin Functions
                String2CPInstructionType.put( "autoDiff" ,      
CPType.ParameterizedBuiltin);
+               String2CPInstructionType.put( "contains",       
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put("paramserv",       
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "nvlist",         
CPType.ParameterizedBuiltin);
                String2CPInstructionType.put( "cdf",            
CPType.ParameterizedBuiltin);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
index 773153d6d4..06e68a63d5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/SPInstructionParser.java
@@ -275,7 +275,8 @@ public class SPInstructionParser extends InstructionParser
                String2SPInstructionType.put( "isinf", SPType.Unary);
 
                // Parameterized Builtin Functions
-               String2SPInstructionType.put( "autoDiff"   , 
SPType.ParameterizedBuiltin);
+               String2SPInstructionType.put( "autoDiff",       
SPType.ParameterizedBuiltin);
+               String2SPInstructionType.put( "contains",       
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "groupedagg",     
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "mapgroupedagg",  
SPType.ParameterizedBuiltin);
                String2SPInstructionType.put( "rmempty",        
SPType.ParameterizedBuiltin);
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index baf4f25139..d3c88fd5ff 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -138,13 +138,14 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                }
                else if(opcode.equalsIgnoreCase("rmempty") || 
opcode.equalsIgnoreCase("replace") ||
                        opcode.equalsIgnoreCase("rexpand") || 
opcode.equalsIgnoreCase("lowertri") ||
-                       opcode.equalsIgnoreCase("uppertri")) {
+                       opcode.equalsIgnoreCase("uppertri") ) {
                        func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinCPInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
-               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode") ||
-                       opcode.equals("transformcolmap") || 
opcode.equals("transformmeta") || opcode.equals("tokenize") ||
-                       opcode.equals("toString") || opcode.equals("nvlist") || 
opcode.equals("autoDiff")) {
+               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode")
+                       || opcode.equalsIgnoreCase("contains") || 
opcode.equals("transformcolmap")
+                       || opcode.equals("transformmeta") || 
opcode.equals("tokenize")
+                       || opcode.equals("toString") || opcode.equals("nvlist") 
|| opcode.equals("autoDiff")) {
                        return new ParameterizedBuiltinCPInstruction(null, 
paramsMap, out, opcode, str);
                }
                else if("paramserv".equals(opcode)) {
@@ -235,6 +236,14 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                                        
ec.releaseMatrixInput(params.get("select"));
                        }
                }
+               else if(opcode.equalsIgnoreCase("contains")) {
+                       String varName = params.get("target");
+                       MatrixBlock target = ec.getMatrixInput(varName);
+                       double pattern = 
Double.parseDouble(params.get("pattern"));
+                       boolean ret = target.containsValue(pattern);
+                       ec.releaseMatrixInput(varName);
+                       ec.setScalarOutput(output.getName(), new 
BooleanObject(ret));
+               }
                else if(opcode.equalsIgnoreCase("replace")) {
                        if(ec.isFrameObject(params.get("target"))){
                                FrameBlock target = 
ec.getFrameInput(params.get("target"));
@@ -255,7 +264,6 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                                        ec.setMatrixOutput(output.getName(), 
ret);
                                targetObj.release();
                        }
-                       
                }
                else if(opcode.equals("lowertri") || opcode.equals("uppertri")) 
{
                        MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 12f2e597ef..7654b92ecc 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -58,9 +58,11 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
+import 
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
@@ -85,8 +87,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
        protected final HashMap<String, String> params;
 
        private static final String[] PARAM_BUILTINS = new String[]{
-               "replace", "rmempty", "lowertri", "uppertri", 
"transformdecode", "transformapply", "tokenize"};
-
+               "contains", "replace", "rmempty", "lowertri", "uppertri",
+               "transformdecode", "transformapply", "tokenize"};
 
        protected ParameterizedBuiltinFEDInstruction(Operator op, 
HashMap<String, String> paramsMap, CPOperand out,
                String opcode, String istr) {
@@ -110,7 +112,8 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        ValueFunction func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
-               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode") || opcode.equals("tokenize")) {
+               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode")
+                       || opcode.equals("tokenize") || 
opcode.equals("contains") ) {
                        return new ParameterizedBuiltinFEDInstruction(null, 
paramsMap, out, opcode, str);
                }
                else {
@@ -140,15 +143,17 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                return paramMap;
        }
 
-       public static ParameterizedBuiltinFEDInstruction 
parseInstruction(ParameterizedBuiltinCPInstruction inst,
-               ExecutionContext ec) {
+       public static ParameterizedBuiltinFEDInstruction parseInstruction(
+               ParameterizedBuiltinCPInstruction inst, ExecutionContext ec)
+       {
                if(ArrayUtils.contains(PARAM_BUILTINS, inst.getOpcode()) && 
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST))
                        return 
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
                return null;
        }
 
-       public static ParameterizedBuiltinFEDInstruction 
parseInstruction(ParameterizedBuiltinSPInstruction inst,
-               ExecutionContext ec) {
+       public static ParameterizedBuiltinFEDInstruction parseInstruction(
+               ParameterizedBuiltinSPInstruction inst, ExecutionContext ec)
+       {
                if( inst.getOpcode().equalsIgnoreCase("replace") && 
inst.getTarget(ec).isFederatedExcept(FType.BROADCAST) )
                        return 
ParameterizedBuiltinFEDInstruction.parseInstruction(inst);
                return null;
@@ -167,13 +172,21 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
        @Override
        public void processInstruction(ExecutionContext ec) {
                String opcode = getOpcode();
-               if(opcode.equalsIgnoreCase("replace")) {
+               if(opcode.equalsIgnoreCase("contains")) {
+                       FederationMap map = getTarget(ec).getFedMapping();
+                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
+                               output, new CPOperand[] {getTargetOperand()}, 
new long[] {map.getID()});
+                       FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
+                       Future<FederatedResponse>[] tmp = map.execute(getTID(), 
fr1, fr2);
+                       boolean ret = FederationUtils.aggBooleanScalar(tmp);
+                       ec.setVariable(output.getName(), new 
BooleanObject(ret));
+               }
+               else if(opcode.equalsIgnoreCase("replace")) {
                        // similar to unary federated instructions, get 
federated input
                        // execute instruction, and derive federated output 
matrix
                        CacheableData<?> mo = getTarget(ec);
-                       FederatedRequest fr1 = 
FederationUtils.callInstruction(instString,
-                               output,
-                               new CPOperand[] {getTargetOperand()},
+                       FederatedRequest fr1 = FederationUtils.callInstruction(
+                               instString, output, new CPOperand[] 
{getTargetOperand()},
                                new long[] {mo.getFedMapping().getID()});
                        Future<FederatedResponse>[] ret = 
mo.getFedMapping().execute(getTID(), true, fr1);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index e5b8fea07a..cc3ce6d03f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -47,6 +47,7 @@ import org.apache.sysds.runtime.functionobjects.KahanPlus;
 import org.apache.sysds.runtime.functionobjects.ParameterizedBuiltin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.BooleanObject;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
 import org.apache.sysds.runtime.instructions.spark.data.LazyIterableIterator;
@@ -169,6 +170,9 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                                func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                                return new 
ParameterizedBuiltinSPInstruction(new SimpleOperator(func), paramsMap, out, 
opcode, str);
                        }
+                       else if(opcode.equalsIgnoreCase("contains")) {
+                               return new 
ParameterizedBuiltinSPInstruction(null, paramsMap, out, opcode, str);
+                       }
                        else {
                                throw new DMLRuntimeException("Unknown opcode 
(" + opcode + ") for ParameterizedBuiltin Instruction.");
                        }
@@ -363,6 +367,17 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                                sec.setMatrixOutput(output.getName(), out);
                        }
                }
+               else if(opcode.equalsIgnoreCase("contains")) {
+                       JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
+                               
.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
+                       
+                       // execute contains operation 
+                       double pattern = 
Double.parseDouble(params.get("pattern"));
+                       Double ret = in1.values() //num blocks containing 
pattern
+                               .map(new RDDContainsFunction(pattern))
+                               .reduce((a,b) -> a+b);
+                       ec.setScalarOutput(output.getName(), new 
BooleanObject(ret>0));
+               }
                else if(opcode.equalsIgnoreCase("replace")) {
                        if(sec.isFrameObject(params.get("target"))){
                                params.get("target");
@@ -395,7 +410,6 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                                        mcIn.getBlocksize(),
                                        (pattern != 0 && replacement != 0) ? 
mcIn.getNonZeros() : -1);
                        }
-
                }
                else if(opcode.equalsIgnoreCase("lowertri") || 
opcode.equalsIgnoreCase("uppertri")) {
                        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
@@ -566,6 +580,20 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        return arg0.replaceOperations(new MatrixBlock(), 
_pattern, _replacement);
                }
        }
+       
+       public static class RDDContainsFunction implements 
Function<MatrixBlock, Double> {
+               private static final long serialVersionUID = 
6576713401901671659L;
+               private final double _pattern;
+
+               public RDDContainsFunction(double pattern) {
+                       _pattern = pattern;
+               }
+
+               @Override
+               public Double call(MatrixBlock arg0) {
+                       return arg0.containsValue(_pattern) ? 1d : 0d;
+               }
+       }
 
        public static class RDDFrameReplaceFunction implements 
Function<FrameBlock, FrameBlock>{
                private static final long serialVersionUID = 
6576713401901671660L;
diff --git 
a/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java 
b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
new file mode 100644
index 0000000000..4ea6d917d8
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/aggregate/ContainsTest.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.aggregate;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.utils.Statistics;
+
+public class ContainsTest extends AutomatedTestBase 
+{
+       private final static String TEST_NAME = "Contains";
+
+       private final static String TEST_DIR = "functions/aggregate/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
AggregateInfTest.class.getSimpleName() + "/";
+       
+       private final static int rows = 1205;
+       private final static int cols = 1179;
+       private final static double sparsity1 = 0.1;
+       private final static double sparsity2 = 0.7;
+       
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME,
+                       new TestConfiguration(TEST_CLASS_DIR, TEST_NAME, new 
String[]{"B"})); 
+       }
+
+       
+       @Test
+       public void testNaNTrueDenseCP() {
+               runContainsTest(Double.NaN, true, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testNaNFalseDenseCP() {
+               runContainsTest(Double.NaN, false, false, ExecType.CP);
+       }
+       
+       @Test
+       public void testNaNTrueSparseCP() {
+               runContainsTest(Double.NaN, true, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testNaNFalseSpaseCP() {
+               runContainsTest(Double.NaN, false, true, ExecType.CP);
+       }
+       
+       @Test
+       public void testInfTrueDenseCP() {
+               runContainsTest(Double.POSITIVE_INFINITY, true, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testInfFalseDenseCP() {
+               runContainsTest(Double.POSITIVE_INFINITY, false, false, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testInfTrueSparseCP() {
+               runContainsTest(Double.POSITIVE_INFINITY, true, true, 
ExecType.CP);
+       }
+       
+       @Test
+       public void testInfFalseSpaseCP() {
+               runContainsTest(Double.POSITIVE_INFINITY, false, true, 
ExecType.CP);
+       }
+
+       @Test
+       public void testNaNTrueDenseSpark() {
+               runContainsTest(Double.NaN, true, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testNaNFalseDenseSpark() {
+               runContainsTest(Double.NaN, false, false, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testNaNTrueSparseSpark() {
+               runContainsTest(Double.NaN, true, true, ExecType.SPARK);
+       }
+       
+       @Test
+       public void testNaNFalseSpaseSpark() {
+               runContainsTest(Double.NaN, false, true, ExecType.SPARK);
+       }
+       
+       private void runContainsTest( double check, boolean expected, boolean 
sparse, ExecType instType)
+       {
+               ExecMode oldMode = setExecMode(instType);
+       
+               try
+               {
+                       double sparsity = (sparse) ? sparsity1 : sparsity2;
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[]{"-args",
+                               input("A"), String.valueOf(check), output("B") 
};
+                       
+                       //generate actual dataset 
+                       double[][] A = getRandomMatrix(rows, cols, -0.05, 1, 
sparsity, 7); 
+                       A[7][7] = expected ? check : 7;
+                       writeInputMatrixWithMTD("A", A, false);
+       
+                       //run test
+                       runTest(true, false, null, -1); 
+                       boolean ret = TestUtils.readDMLBoolean(output("B"));
+                       Assert.assertEquals(expected, ret);
+                       if( instType == ExecType.CP ) {
+                               
Assert.assertEquals(Statistics.getNoOfCompiledSPInst(), 1); //reblock
+                               
Assert.assertEquals(Statistics.getNoOfExecutedSPInst(), 0);
+                       }
+               }
+               finally {
+                       resetExecMode(oldMode);
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
index 7abb1a8125..a3e91ef37d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/algorithms/FederatedLogRegTest.java
@@ -123,7 +123,7 @@ public class FederatedLogRegTest extends AutomatedTestBase {
                Assert.assertTrue("contains fed_ba+*", 
heavyHittersContainsString("fed_ba+*"));
                Assert.assertTrue("contains fed_uar", 
heavyHittersContainsString("fed_uark+", "fed_uarsqk+"));
                Assert.assertTrue("contains fed_mmchain & r'", 
heavyHittersContainsString("fed_mmchain", "fed_r'"));
-               Assert.assertTrue("contains fed_isnan", 
heavyHittersContainsString("fed_isnan"));
+               Assert.assertTrue("contains fed_contains", 
heavyHittersContainsString("fed_contains"));
                
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
diff --git a/src/test/scripts/functions/aggregate/Contains.dml 
b/src/test/scripts/functions/aggregate/Contains.dml
new file mode 100644
index 0000000000..0576b6e1cd
--- /dev/null
+++ b/src/test/scripts/functions/aggregate/Contains.dml
@@ -0,0 +1,24 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+A = read($1);
+ret = contains(target=A, pattern=$2);
+write(ret, $3);
\ No newline at end of file

Reply via email to