This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 88341e8  [SYSTEMDS-2978] Federated frame tokenization (parameterized 
builtin)
88341e8 is described below

commit 88341e889666fc1cf3e9407886969afdfbf098e6
Author: Olga <[email protected]>
AuthorDate: Sun May 30 01:02:42 2021 +0200

    [SYSTEMDS-2978] Federated frame tokenization (parameterized builtin)
    
    Closes #1284.
---
 .../runtime/instructions/InstructionUtils.java     |   6 +
 .../instructions/fed/FEDInstructionUtils.java      |  16 +--
 .../fed/ParameterizedBuiltinFEDInstruction.java    |  50 ++++++-
 .../primitives/FederatedTokenizeTest.java          | 152 +++++++++++++++++++++
 .../functions/federated/FederatedTokenizeTest.dml  |  36 +++++
 .../federated/FederatedTokenizeTestReference.dml   |  36 +++++
 6 files changed, 286 insertions(+), 10 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java 
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 4c6d1fa..5269e79 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -1089,6 +1089,12 @@ public class InstructionUtils
                return InstructionUtils.concatOperands(parts[0], parts[1], 
createOperand(op1), createOperand(op2), createOperand(out));
        }
 
+       public static String constructUnaryInstString(String instString, 
CPOperand op1, String opcode, CPOperand out) {
+               String[] parts = instString.split(Lop.OPERAND_DELIMITOR);
+               parts[1] = opcode;
+               return InstructionUtils.concatOperands(parts[0], parts[1], 
createOperand(op1), createOperand(out));
+       }
+
        /**
         * Prepare instruction string for sending in a FederatedRequest as a CP 
instruction.
         * This involves replacing the coordinator operand names with the 
worker operand names,
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 15409d9..4ab080d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.runtime.instructions.fed;
 
+import org.apache.commons.lang3.ArrayUtils;
 import org.apache.sysds.runtime.codegen.SpoofCellwise;
 import org.apache.sysds.runtime.codegen.SpoofMultiAggregate;
 import org.apache.sysds.runtime.codegen.SpoofOuterProduct;
@@ -71,6 +72,10 @@ import 
org.apache.sysds.runtime.instructions.spark.UnarySPInstruction;
 import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
 
 public class FEDInstructionUtils {
+       
+       private static String[] PARAM_BUILTINS = new String[]{
+               "replace", "rmempty", "lowertri", "uppertri", 
"transformdecode", "transformapply", "tokenize"};
+       
        // private static final Log LOG = 
LogFactory.getLog(FEDInstructionUtils.class.getName());
 
        // This is currently a rather simplistic to our solution of replacing 
instructions with their correct federated
@@ -164,15 +169,8 @@ public class FEDInstructionUtils {
                }
                else if( inst instanceof ParameterizedBuiltinCPInstruction ) {
                        ParameterizedBuiltinCPInstruction pinst = 
(ParameterizedBuiltinCPInstruction) inst;
-                       if((pinst.getOpcode().equals("replace") || 
pinst.getOpcode().equals("rmempty")
-                               || pinst.getOpcode().equals("lowertri") || 
pinst.getOpcode().equals("uppertri"))
-                               && pinst.getTarget(ec).isFederated()) {
+                       if( ArrayUtils.contains(PARAM_BUILTINS, 
pinst.getOpcode()) && pinst.getTarget(ec).isFederated() )
                                fedinst = 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
-                       }
-                       else if((pinst.getOpcode().equals("transformdecode") || 
pinst.getOpcode().equals("transformapply")) &&
-                               pinst.getTarget(ec).isFederated()) {
-                               return 
ParameterizedBuiltinFEDInstruction.parseInstruction(pinst.getInstructionString());
-                       }
                }
                else if (inst instanceof 
MultiReturnParameterizedBuiltinCPInstruction) {
                        MultiReturnParameterizedBuiltinCPInstruction minst = 
(MultiReturnParameterizedBuiltinCPInstruction) inst;
@@ -235,7 +233,7 @@ public class FEDInstructionUtils {
                        SpoofCPInstruction instruction = (SpoofCPInstruction) 
inst;
                        Class<?> scla = 
instruction.getOperatorClass().getSuperclass();
                        if(((scla == SpoofCellwise.class || scla == 
SpoofMultiAggregate.class
-                                               || scla == 
SpoofOuterProduct.class) && instruction.isFederated(ec))
+                               || scla == SpoofOuterProduct.class) && 
instruction.isFederated(ec))
                                || (scla == SpoofRowwise.class && 
instruction.isFederated(ec, FType.ROW))) {
                                fedinst = 
SpoofFEDInstruction.parseInstruction(instruction.getInstructionString());
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
index 51b8d9b..c64817f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/ParameterizedBuiltinFEDInstruction.java
@@ -27,6 +27,7 @@ import java.util.HashMap;
 import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Future;
 import java.util.stream.Stream;
 import java.util.zip.Adler32;
 import java.util.zip.Checksum;
@@ -57,6 +58,7 @@ import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
 import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.lineage.LineageItemUtils;
 import org.apache.sysds.runtime.matrix.data.FrameBlock;
@@ -119,7 +121,7 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        ValueFunction func = 
ParameterizedBuiltin.getParameterizedBuiltinFnObject(opcode);
                        return new ParameterizedBuiltinFEDInstruction(new 
SimpleOperator(func), paramsMap, out, opcode, str);
                }
-               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode")) {
+               else if(opcode.equals("transformapply") || 
opcode.equals("transformdecode") || opcode.equals("tokenize")) {
                        return new ParameterizedBuiltinFEDInstruction(null, 
paramsMap, out, opcode, str);
                }
                else {
@@ -154,11 +156,57 @@ public class ParameterizedBuiltinFEDInstruction extends 
ComputationFEDInstructio
                        transformDecode(ec);
                else if(opcode.equalsIgnoreCase("transformapply"))
                        transformApply(ec);
+               else if(opcode.equals("tokenize"))
+                       tokenize(ec);
                else {
                        throw new DMLRuntimeException("Unknown opcode : " + 
opcode);
                }
        }
 
+       private void tokenize(ExecutionContext ec)
+       {
+               FrameObject in = ec.getFrameObject(getTargetOperand());
+               FederationMap fedMap = in.getFedMapping();
+
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(instString, output,
+                       new CPOperand[] {getTargetOperand()}, new long[] 
{fedMap.getID()});
+               fedMap.execute(getTID(), true, fr1);
+
+               FrameObject out = ec.getFrameObject(output);
+               out.setFedMapping(fedMap.copyWithNewID(fr1.getID()));
+
+               // get new dims and fed mapping
+               long ncolId = FederationUtils.getNextFedDataID();
+               CPOperand ncolOp = new CPOperand(String.valueOf(ncolId), 
ValueType.INT64, DataType.SCALAR);
+
+               String unaryString = 
InstructionUtils.constructUnaryInstString(instString, output, "ncol", ncolOp);
+               FederatedRequest fr2 = 
FederationUtils.callInstruction(unaryString, ncolOp,
+                       new CPOperand[] {output}, new long[] 
{out.getFedMapping().getID()});
+               FederatedRequest fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
+               Future<FederatedResponse>[] ffr = 
out.getFedMapping().execute(getTID(), true, fr2, fr3);
+
+               long cols = 0;
+               for(int i = 0; i < ffr.length; i++) {
+                       try {
+                               if(in.isFederated(FederationMap.FType.COL)) {
+                                       
out.getFedMapping().getFederatedRanges()[i + 1].setBeginDim(1, cols);
+                                       cols += ((ScalarObject) 
ffr[i].get().getData()[0]).getLongValue();
+                               }
+                               else if(in.isFederated(FederationMap.FType.ROW))
+                                       cols = ((ScalarObject) 
ffr[i].get().getData()[0]).getLongValue();
+                               
out.getFedMapping().getFederatedRanges()[i].setEndDim(1, cols);
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               }
+
+               Types.ValueType[] schema = new Types.ValueType[(int) cols];
+               Arrays.fill(schema, ValueType.STRING);
+               out.setSchema(schema);
+               out.getDataCharacteristics().setDimension(in.getNumRows(), 
cols);
+       }
+
        private void triangle(ExecutionContext ec, String opcode) {
                boolean lower = opcode.equals("lowertri");
                boolean diag = Boolean.parseBoolean(params.get("diag"));
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
new file mode 100644
index 0000000..e7c408d
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedTokenizeTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.runtime.io.FileFormatPropertiesCSV;
+import org.apache.sysds.runtime.io.FrameReaderFactory;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederatedTokenizeTest extends AutomatedTestBase {
+
+       private final static String TEST_NAME1 = "FederatedTokenizeTest";
+
+       private final static String TEST_DIR = "functions/federated/";
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedTokenizeTest.class.getSimpleName() + "/";
+
+       private static final String DATASET = 
"20news/20news_subset_untokenized.csv";
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               return Arrays.asList(new Object[][] {
+                       {3, 4, true},
+               });
+       }
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
+       }
+
+       @Test
+       public void testTokenizeFullDenseFrameCP() {
+               runAggregateOperationTest(ExecMode.SINGLE_NODE);
+       }
+
+       private void runAggregateOperationTest(ExecMode execMode) {
+               setExecMode(execMode);
+
+               String TEST_NAME = TEST_NAME1;
+
+               getAndLoadTestConfiguration(TEST_NAME);
+               String HOME = SCRIPT_DIR + TEST_DIR;
+
+               // empty script name because we don't execute any script, just 
start the worker
+               fullDMLScriptName = "";
+               int port1 = getRandomAvailablePort();
+               int port2 = getRandomAvailablePort();
+               int port3 = getRandomAvailablePort();
+               int port4 = getRandomAvailablePort();
+               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
+               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
+               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
+               Thread t4 = startLocalFedWorkerThread(port4);
+
+               FileFormatPropertiesCSV ffpCSV = new 
FileFormatPropertiesCSV(false, DataExpression.DEFAULT_DELIM_DELIMITER, false);
+
+               // split dataset
+               FrameBlock dataset;
+               try {
+                       dataset = 
FrameReaderFactory.createFrameReader(Types.FileFormat.CSV, ffpCSV)
+                               .readFrameFromHDFS(DATASET_DIR + DATASET, -1, 
-1);
+
+                       // default for write
+                       FrameWriter fw = 
FrameWriterFactory.createFrameWriter(Types.FileFormat.CSV, ffpCSV);
+                       writeDatasetSlice(dataset, fw, ffpCSV, "AH");
+                       writeDatasetSlice(dataset, fw, ffpCSV, "AL");
+                       writeDatasetSlice(dataset, fw, ffpCSV, "BH");
+               }
+               catch(IOException e) {
+                       e.printStackTrace();
+               }
+
+               rtplatform = execMode;
+               if(rtplatform == ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+               loadTestConfiguration(config);
+
+               // Run reference dml script with normal matrix
+               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+               programArgs = new String[] {"-explain", "-args", DATASET_DIR + 
DATASET, HOME + TEST_NAME + ".json", expected("S")};
+               runTest(null);
+               // Run actual dml script with federated matrix
+
+               fullDMLScriptName = HOME + TEST_NAME + ".dml";
+               programArgs = new String[] {"-stats", "100", "-nvargs",
+                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("AH")),
+                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("AL")),
+                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("BH")),
+                       "in_S=" + input(HOME + TEST_NAME + ".json"), "rows=" + 
rows, "cols=" + cols,
+                       "out_R=" + output("S")};
+               runTest(null);
+
+               compareResults(1e-9);
+               Assert.assertTrue(heavyHittersContainsString("fed_tokenize"));
+               TestUtils.shutdownThreads(t1, t2, t3, t4);
+       }
+
+       private void writeDatasetSlice(FrameBlock dataset, FrameWriter fw, 
FileFormatPropertiesCSV ffpCSV, String name) throws IOException {
+               fw.writeFrameToHDFS(dataset, input(name), dataset.getNumRows(), 
dataset.getNumColumns());
+               
HDFSTool.writeMetaDataFile(input(DataExpression.getMTDFileName(name)), null, 
dataset.getSchema(),
+                       Types.DataType.FRAME, new 
MatrixCharacteristics(dataset.getNumRows(), dataset.getNumColumns()),
+                       Types.FileFormat.CSV, ffpCSV);
+       }
+}
diff --git a/src/test/scripts/functions/federated/FederatedTokenizeTest.dml 
b/src/test/scripts/functions/federated/FederatedTokenizeTest.dml
new file mode 100644
index 0000000..e293306
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTokenizeTest.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F1 = federated(type="frame", addresses=list($in_X1, $in_X2, $in_X3),
+  ranges=list(list(0, 0), list(2, $cols), list(2, 0), list(4, $cols),
+  list(4, 0), list(6, $cols)));
+
+max_token = 2000;
+
+# Example spec:
+jspec = "{algo:ngram, algo_params: {min_gram: 1,max_gram: 3}, out:hash, 
out_params:
+  {num_features: 128}, format_wide: true, id_cols: [2,1], tokenize_col: 3}";
+
+F2 = tokenize(target=F1[,2:4], spec=jspec, max_tokens=max_token);
+
+jspec2 = "{ids: true, recode: [1,2]}";
+[X, M] = transformencode(target=F2, spec=jspec2);
+write(X, $out_R);
diff --git 
a/src/test/scripts/functions/federated/FederatedTokenizeTestReference.dml 
b/src/test/scripts/functions/federated/FederatedTokenizeTestReference.dml
new file mode 100644
index 0000000..f6dae76
--- /dev/null
+++ b/src/test/scripts/functions/federated/FederatedTokenizeTestReference.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+F = read($1, data_type="frame", format="csv", sep=",");
+F = F[2:3, 1:4];
+F = rbind(F, rbind(F, F));
+
+max_token = 2000;
+
+# Example spec:
+jspec = "{algo:ngram, algo_params:{min_gram:1, max_gram:3}, out:hash, 
out_params:
+  {num_features: 128},format_wide: true,id_cols: [2,1],tokenize_col: 3}";
+
+res = tokenize(target=F[,2:4], spec=jspec, max_tokens=max_token);
+
+jspec2 = "{ids: true, recode: [1,2]}";
+[X, M] = transformencode(target=res, spec=jspec2);
+write(X, $3);

Reply via email to