This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 282e20d  [SYSTEMDS-3055] Frame replace support
282e20d is described below

commit 282e20d73fadc98b4afef1959870f728072b8418
Author: baunsgaard <[email protected]>
AuthorDate: Mon Jul 12 18:12:02 2021 +0200

    [SYSTEMDS-3055] Frame replace support
    
    Add support for replace on a frame both for CP and SP instructions.
    simply provide a frame and string target and replacement:
    
    X = replace(target=X, pattern ="REPLACE_ME", replacement = "SOMETHING_ELSE")
    
    Closes #1344
---
 .../ParameterizedBuiltinFunctionExpression.java    | 11 ++-
 .../cp/ParameterizedBuiltinCPInstruction.java      | 22 ++++--
 .../spark/ParameterizedBuiltinSPInstruction.java   | 65 ++++++++++-----
 .../sysds/runtime/matrix/data/FrameBlock.java      | 13 +++
 .../test/functions/frame/FrameReplaceTest.java     | 92 ++++++++++++++++++++++
 src/test/scripts/functions/frame/ReplaceTest.dml   | 28 +++++++
 6 files changed, 204 insertions(+), 27 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index ec731e6..d074d0d 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -478,7 +478,9 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        private void validateReplace(DataIdentifier output, boolean 
conditional) {
                //check existence and correctness of arguments
                Expression target = getVarParam("target");
-               checkTargetParam(target, conditional);
+               if( target.getOutput().getDataType() != DataType.FRAME ){
+                       checkTargetParam(target, conditional);
+               }
                
                Expression pattern = getVarParam("pattern");
                if( pattern==null ) {
@@ -497,8 +499,11 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                }       
                
                // Output is a matrix with same dims as input
-               output.setDataType(DataType.MATRIX);
-               output.setValueType(ValueType.FP64);
+               output.setDataType(target.getOutput().getDataType());
+               if(target.getOutput().getDataType() == DataType.FRAME)
+                       output.setValueType(ValueType.STRING);
+               else
+                       output.setValueType(ValueType.FP64);
                output.setDimensions(target.getOutput().getDim1(), 
target.getOutput().getDim2());
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 54a5339..f115b52 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -225,12 +225,22 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                                ec.releaseMatrixInput(params.get("select"));
                }
                else if(opcode.equalsIgnoreCase("replace")) {
-                       MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
-                       double pattern = 
Double.parseDouble(params.get("pattern"));
-                       double replacement = 
Double.parseDouble(params.get("replacement"));
-                       MatrixBlock ret = target.replaceOperations(new 
MatrixBlock(), pattern, replacement);
-                       ec.setMatrixOutput(output.getName(), ret);
-                       ec.releaseMatrixInput(params.get("target"));
+                       if(ec.isFrameObject(params.get("target"))){
+                               FrameBlock target = 
ec.getFrameInput(params.get("target"));
+                               String pattern = params.get("pattern");
+                               String replacement = params.get("replacement");
+                               FrameBlock ret = 
target.replaceOperations(pattern, replacement);
+                               ec.setFrameOutput(output.getName(), ret);
+                               ec.releaseFrameInput(params.get("target"));
+                       }else{
+                               MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
+                               double pattern = 
Double.parseDouble(params.get("pattern"));
+                               double replacement = 
Double.parseDouble(params.get("replacement"));
+                               MatrixBlock ret = target.replaceOperations(new 
MatrixBlock(), pattern, replacement);
+                               ec.setMatrixOutput(output.getName(), ret);
+                               ec.releaseMatrixInput(params.get("target"));
+                       }
+                       
                }
                else if(opcode.equals("lowertri") || opcode.equals("uppertri")) 
{
                        MatrixBlock target = 
ec.getMatrixInput(params.get("target"));
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 9975925..40e152f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -358,25 +358,38 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                        }
                }
                else if(opcode.equalsIgnoreCase("replace")) {
-                       JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
-                               
.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
-                       DataCharacteristics mcIn = 
sec.getDataCharacteristics(params.get("target"));
-
-                       // execute replace operation
-                       double pattern = 
Double.parseDouble(params.get("pattern"));
-                       double replacement = 
Double.parseDouble(params.get("replacement"));
-                       JavaPairRDD<MatrixIndexes, MatrixBlock> out = 
in1.mapValues(new RDDReplaceFunction(pattern, replacement));
-
-                       // store output rdd handle
-                       sec.setRDDHandleForVariable(output.getName(), out);
-                       sec.addLineageRDD(output.getName(), 
params.get("target"));
+                       if(sec.isFrameObject(params.get("target"))){
+                               params.get("target");
+                               JavaPairRDD<Long, FrameBlock> in1 = 
sec.getFrameBinaryBlockRDDHandleForVariable(params.get("target"));
+                               DataCharacteristics mcIn = 
sec.getDataCharacteristics(params.get("target"));
+                               String pattern = params.get("pattern");
+                               String replacement = params.get("replacement");
+                               JavaPairRDD<Long, FrameBlock> out = 
in1.mapValues(new RDDFrameReplaceFunction(pattern, replacement));
+                               sec.setRDDHandleForVariable(output.getName(), 
out);
+                               sec.addLineageRDD(output.getName(), 
params.get("target"));
+                               
sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(), 
mcIn.getCols(), mcIn.getBlocksize(), mcIn.getNonZeros());
+                       }
+                       else {
+                               JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = 
sec
+                                       
.getBinaryMatrixBlockRDDHandleForVariable(params.get("target"));
+                               DataCharacteristics mcIn = 
sec.getDataCharacteristics(params.get("target"));
+       
+                               // execute replace operation
+                               double pattern = 
Double.parseDouble(params.get("pattern"));
+                               double replacement = 
Double.parseDouble(params.get("replacement"));
+                               JavaPairRDD<MatrixIndexes, MatrixBlock> out = 
in1.mapValues(new RDDReplaceFunction(pattern, replacement));
+       
+                               // store output rdd handle
+                               sec.setRDDHandleForVariable(output.getName(), 
out);
+                               sec.addLineageRDD(output.getName(), 
params.get("target"));
+       
+                               // update output statistics (required for 
correctness)
+                               
sec.getDataCharacteristics(output.getName()).set(mcIn.getRows(),
+                                       mcIn.getCols(),
+                                       mcIn.getBlocksize(),
+                                       (pattern != 0 && replacement != 0) ? 
mcIn.getNonZeros() : -1);
+                       }
 
-                       // update output statistics (required for correctness)
-                       DataCharacteristics mcOut = 
sec.getDataCharacteristics(output.getName());
-                       mcOut.set(mcIn.getRows(),
-                               mcIn.getCols(),
-                               mcIn.getBlocksize(),
-                               (pattern != 0 && replacement != 0) ? 
mcIn.getNonZeros() : -1);
                }
                else if(opcode.equalsIgnoreCase("lowertri") || 
opcode.equalsIgnoreCase("uppertri")) {
                        JavaPairRDD<MatrixIndexes, MatrixBlock> in1 = sec
@@ -544,6 +557,22 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                }
        }
 
+       public static class RDDFrameReplaceFunction implements 
Function<FrameBlock, FrameBlock>{
+               private static final long serialVersionUID = 
6576713401901671660L;
+               private final String _pattern;
+               private final String _replacement;
+
+               public RDDFrameReplaceFunction(String pattern, String 
replacement){
+                       _pattern = pattern;
+                       _replacement = replacement;
+               }
+
+               @Override 
+               public FrameBlock call(FrameBlock arg0){
+                       return arg0.replaceOperations(_pattern, _replacement);
+               }
+       }
+
        private static class RDDExtractTriangularFunction
                implements PairFlatMapFunction<Iterator<Tuple2<MatrixIndexes, 
MatrixBlock>>, MatrixIndexes, MatrixBlock> {
                private static final long serialVersionUID = 
2754868819184155702L;
diff --git a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
index 322cfad..8ee6f33 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/FrameBlock.java
@@ -2237,4 +2237,17 @@ public class FrameBlock implements CacheBlock, 
Externalizable  {
                public String apply(String input) {return null;}
                public String apply(String input1, String input2) {     return 
null;}
        }
+
+       public FrameBlock replaceOperations(String pattern, String replacement){
+               FrameBlock ret = new FrameBlock(this);
+               for(int i = 0; i < ret.getNumColumns(); i++){
+                       Array colData = ret._coldata[i];
+                       for(int j = 0; j < colData._size; j++){
+                               Object ent = colData.get(j);
+                               if(ent != null && ent.equals(pattern))
+                                       colData.set(j,replacement); 
+                       }
+               }
+               return ret;
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java 
b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
new file mode 100644
index 0000000..73868e3
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/frame/FrameReplaceTest.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.frame;
+
+import static org.junit.Assert.assertTrue;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.common.Types.ExecType;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+public class FrameReplaceTest extends AutomatedTestBase {
+    // private static final Log LOG = 
LogFactory.getLog(FrameReplaceTest.class.getName());
+    private final static String TEST_DIR = "functions/frame/";
+    private final static String TEST_NAME = "ReplaceTest";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
FrameReplaceTest.class.getSimpleName() + "/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
+    }
+
+    @Test
+    public void testParforFrameIntermediatesCP() {
+        runReplaceTest(ExecType.CP);
+    }
+
+    @Test
+    public void testParforFrameIntermediatesSpark() {
+        runReplaceTest(ExecType.SPARK);
+    }
+
+    private void runReplaceTest(ExecType et) {
+        ExecMode platformOld = rtplatform;
+        switch(et) {
+            case SPARK:
+                rtplatform = ExecMode.SPARK;
+                break;
+            default:
+                rtplatform = ExecMode.HYBRID;
+                break;
+        }
+
+        boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+        if(rtplatform == ExecMode.SPARK || rtplatform == ExecMode.HYBRID)
+            DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+
+        try {
+            // setup testcase
+            getAndLoadTestConfiguration(TEST_NAME);
+            String HOME = SCRIPT_DIR + TEST_DIR;
+            fullDMLScriptName = HOME + TEST_NAME + ".dml";
+            programArgs = new String[] {};
+
+            // run test
+            String out = runTest(null).toString();
+
+            assertTrue(out.contains("south"));
+            assertTrue(!out.contains("north"));
+
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+        }
+        finally {
+            rtplatform = platformOld;
+            DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+        }
+    }
+
+}
diff --git a/src/test/scripts/functions/frame/ReplaceTest.dml 
b/src/test/scripts/functions/frame/ReplaceTest.dml
new file mode 100644
index 0000000..2a12b48
--- /dev/null
+++ b/src/test/scripts/functions/frame/ReplaceTest.dml
@@ -0,0 +1,28 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+# 
+#   http://www.apache.org/licenses/LICENSE-2.0
+# 
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = read("src/test/resources/datasets/homes/homes.csv") 
+
+X = replace(target = X, pattern="north", replacement="south")
+X = replace(target = X, pattern="east", replacement="south")
+X = replace(target = X, pattern="west", replacement="south")
+
+print(toString(X))
\ No newline at end of file

Reply via email to