This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit d61c3bffc677443c28d0fca27364267c1ca41111
Author: baunsgaard <[email protected]>
AuthorDate: Thu Nov 12 18:13:40 2020 +0100

    [SYSTEMDS-2724] Cast to matrix Federated
    
    Closes #1100
---
 .../instructions/fed/FEDInstructionUtils.java      |  26 ++--
 .../instructions/fed/VariableFEDInstruction.java   |  65 ++++++---
 .../primitives/FederetedCastToFrameTest.java       |   4 +-
 .../primitives/FederetedCastToMatrixTest.java      | 160 +++++++++++++++++++++
 .../test/functions/frame/DetectSchemaTest.java     |   6 +-
 .../test/functions/lineage/CacheEvictionTest.java  |   1 +
 .../primitives/FederatedCastToMatrixTest.dml       |  26 ++++
 .../FederatedCastToMatrixTestReference.dml         |  25 ++++
 8 files changed, 274 insertions(+), 39 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 68e1cee..ef66b66 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -36,8 +36,6 @@ import 
org.apache.sysds.runtime.instructions.cp.MatrixIndexingCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.MultiReturnParameterizedBuiltinCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.ParameterizedBuiltinCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.ReorgCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.UnaryCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.UnaryMatrixCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import 
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction.VariableOperationCode;
 import org.apache.sysds.runtime.instructions.spark.AggregateUnarySPInstruction;
@@ -84,15 +82,13 @@ public class FEDInstructionUtils {
                        if( mo.isFederated() )
                                fedinst = 
TsmmFEDInstruction.parseInstruction(linst.getInstructionString());
                }
-               else if(inst instanceof UnaryCPInstruction){
-                       if (inst instanceof AggregateUnaryCPInstruction) {
-                               AggregateUnaryCPInstruction instruction = 
(AggregateUnaryCPInstruction) inst;
-                               if( instruction.input1.isMatrix() && 
ec.containsVariable(instruction.input1) ) {
-                                       MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
-                                       if (mo1.isFederated() && 
instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){
-                                               LOG.debug("Federated 
UnaryAggregate");
-                                               fedinst = 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
-                                       }
+               else if (inst instanceof AggregateUnaryCPInstruction) {
+                       AggregateUnaryCPInstruction instruction = 
(AggregateUnaryCPInstruction) inst;
+                       if( instruction.input1.isMatrix() && 
ec.containsVariable(instruction.input1) ) {
+                               MatrixObject mo1 = 
ec.getMatrixObject(instruction.input1);
+                               if (mo1.isFederated() && 
instruction.getAUType() == AggregateUnaryCPInstruction.AUType.DEFAULT){
+                                       LOG.debug("Federated UnaryAggregate");
+                                       fedinst = 
AggregateUnaryFEDInstruction.parseInstruction(inst.getInstructionString());
                                }
                        }
                }
@@ -154,11 +150,13 @@ public class FEDInstructionUtils {
                                && 
ec.getCacheableData(ins.getInput1()).isFederated()){
                                fedinst = 
VariableFEDInstruction.parseInstruction(ins);
                        }
-
+                       else if(ins.getVariableOpcode() == 
VariableOperationCode.CastAsMatrixVariable 
+                               && ins.getInput1().isFrame() 
+                               && 
ec.getCacheableData(ins.getInput1()).isFederated()){
+                               fedinst = 
VariableFEDInstruction.parseInstruction(ins);
+                       }
                }
 
-
-               
                //set thread id for federated context management
                if( fedinst != null ) {
                        fedinst.setTID(ec.getTID());
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
index 7d39e9d..134a2e3 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/VariableFEDInstruction.java
@@ -87,37 +87,62 @@ public class VariableFEDInstruction extends FEDInstruction 
implements LineageTra
                _in.processInstruction(ec);
        }
 
-       private void processCastAsMatrixVariableInstruction(ExecutionContext 
ec){
-               LOG.error("Not Implemented");
-               throw new DMLRuntimeException("Not Implemented Cast as Matrix");
+       private void processCastAsMatrixVariableInstruction(ExecutionContext 
ec) {
 
+               FrameObject mo1 = ec.getFrameObject(_in.getInput1());
+
+               if(!mo1.isFederated())
+                       throw new DMLRuntimeException(
+                               "Federated Reorg: " + "Federated input 
expected, but invoked w/ " + mo1.isFederated());
+
+               // execute function at federated site.
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(_in.getInstructionString(),
+                       _in.getOutput(),
+                       new CPOperand[] {_in.getInput1()},
+                       new long[] {mo1.getFedMapping().getID()});
+               mo1.getFedMapping().execute(getTID(), true, fr1);
+
+               // Construct output local.
+
+               MatrixObject out = ec.getMatrixObject(_in.getOutput());
+               FederationMap outMap = 
mo1.getFedMapping().copyWithNewID(fr1.getID());
+               Map<FederatedRange, FederatedData> newMap = new HashMap<>();
+               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getFedMapping().entrySet()) {
+                       FederatedData om = pair.getValue();
+                       FederatedData nf = new 
FederatedData(Types.DataType.MATRIX, om.getAddress(), om.getFilepath(),
+                               om.getVarID());
+                       newMap.put(pair.getKey(), nf);
+               }
+               out.setFedMapping(outMap);
        }
 
-       private void processCastAsFrameVariableInstruction(ExecutionContext ec){
+       private void processCastAsFrameVariableInstruction(ExecutionContext ec) 
{
 
                MatrixObject mo1 = ec.getMatrixObject(_in.getInput1());
-               
-               if( !mo1.isFederated() )
-                       throw new DMLRuntimeException("Federated Reorg: "
-                               + "Federated input expected, but invoked w/ 
"+mo1.isFederated());
-       
-               //execute transpose at federated site
-               FederatedRequest fr1 = 
FederationUtils.callInstruction(_in.getInstructionString(), _in.getOutput(),
-                       new CPOperand[]{_in.getInput1()}, new 
long[]{mo1.getFedMapping().getID()});
+
+               if(!mo1.isFederated())
+                       throw new DMLRuntimeException(
+                               "Federated Reorg: " + "Federated input 
expected, but invoked w/ " + mo1.isFederated());
+
+               // execute function at federated site.
+               FederatedRequest fr1 = 
FederationUtils.callInstruction(_in.getInstructionString(),
+                       _in.getOutput(),
+                       new CPOperand[] {_in.getInput1()},
+                       new long[] {mo1.getFedMapping().getID()});
                mo1.getFedMapping().execute(getTID(), true, fr1);
-               
-               //drive output federated mapping
+
+               // Construct output local.
                FrameObject out = ec.getFrameObject(_in.getOutput());
-               out.getDataCharacteristics().set(mo1.getNumColumns(),
-                       mo1.getNumRows(), (int)mo1.getBlocksize(), 
mo1.getNnz());
-               FederationMap outMap =  
mo1.getFedMapping().copyWithNewID(fr1.getID());
+               out.getDataCharacteristics().set(mo1.getNumColumns(), 
mo1.getNumRows(), (int) mo1.getBlocksize(), mo1.getNnz());
+               FederationMap outMap = 
mo1.getFedMapping().copyWithNewID(fr1.getID());
                Map<FederatedRange, FederatedData> newMap = new HashMap<>();
-               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getFedMapping().entrySet()){
+               for(Map.Entry<FederatedRange, FederatedData> pair : 
outMap.getFedMapping().entrySet()) {
                        FederatedData om = pair.getValue();
-                       FederatedData nf = new 
FederatedData(Types.DataType.FRAME, 
om.getAddress(),om.getFilepath(),om.getVarID());
+                       FederatedData nf = new 
FederatedData(Types.DataType.FRAME, om.getAddress(), om.getFilepath(),
+                               om.getVarID());
                        newMap.put(pair.getKey(), nf);
                }
-               ValueType[] schema = new 
ValueType[(int)mo1.getDataCharacteristics().getCols()];
+               ValueType[] schema = new ValueType[(int) 
mo1.getDataCharacteristics().getCols()];
                Arrays.fill(schema, ValueType.FP64);
                out.setSchema(schema);
                out.setFedMapping(outMap);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
index bbef96e..5e05bf5 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToFrameTest.java
@@ -114,12 +114,12 @@ public class FederetedCastToFrameTest extends 
AutomatedTestBase {
                        "X2=" + TestUtils.federatedAddress(port2, input("X2")), 
"r=" + rows, "c=" + cols};
                String fedOut = runTest(null).toString();
 
-               LOG.error(fedOut);
+               LOG.debug(fedOut);
                fedOut = fedOut.split("SystemDS Statistics:")[0];
                Assert.assertTrue("Equal Printed Output", out.equals(fedOut));
                Assert.assertTrue("Contains federated Cast to frame", 
heavyHittersContainsString("fed_castdtf"));
                TestUtils.shutdownThreads(t1, t2);
-               
+
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java
new file mode 100644
index 0000000..b075e47
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederetedCastToMatrixTest.java
@@ -0,0 +1,160 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.federated.primitives;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collection;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
+import org.apache.sysds.common.Types.FileFormat;
+import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.io.FrameWriter;
+import org.apache.sysds.runtime.io.FrameWriterFactory;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
+import org.apache.sysds.runtime.meta.MatrixCharacteristics;
+import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.apache.sysds.test.functions.frame.DetectSchemaTest;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
[email protected]
+public class FederetedCastToMatrixTest extends AutomatedTestBase {
+       private static final Log LOG = 
LogFactory.getLog(FederetedCastToMatrixTest.class.getName());
+
+       private final static String TEST_DIR = 
"functions/federated/primitives/";
+       private final static String TEST_NAME = "FederatedCastToMatrixTest";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
FederetedCastToMatrixTest.class.getSimpleName() + "/";
+
+       @Parameterized.Parameter()
+       public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               // rows have to be even and > 1
+               return Arrays.asList(new Object[][] {{10, 32}});
+       }
+
+       @Test
+       public void federatedMultiplyCP() {
+               federatedMultiply(Types.ExecMode.SINGLE_NODE);
+       }
+
+       @Test
+       @Ignore
+       public void federatedMultiplySP() {
+               // TODO Fix me Spark execution error
+               federatedMultiply(Types.ExecMode.SPARK);
+       }
+
+       public void federatedMultiply(Types.ExecMode execMode) {
+               boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
+               Types.ExecMode platformOld = rtplatform;
+               rtplatform = execMode;
+               if(rtplatform == Types.ExecMode.SPARK) {
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               }
+               try {
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+
+                       ValueType[] schema = new ValueType[cols];
+                       Arrays.fill(schema, ValueType.FP64);
+                       FrameBlock frame1 = new FrameBlock(schema);
+                       FrameBlock frame2 = new FrameBlock(schema);
+                       FrameWriter writer = 
FrameWriterFactory.createFrameWriter(FileFormat.BINARY);
+
+                       // write input matrices
+                       int halfRows = rows / 2;
+                       // We have two matrices handled by a single federated 
worker
+                       double[][] X1 = getRandomMatrix(halfRows, cols, 0, 1, 
1, 42);
+                       double[][] X2 = getRandomMatrix(halfRows, cols, 0, 1, 
1, 1340);
+
+                       DetectSchemaTest.initFrameDataString(frame1, X1, 
schema, halfRows, cols);
+                       writer.writeFrameToHDFS(frame1.slice(0, halfRows - 1, 
0, schema.length - 1, new FrameBlock()),
+                               input("X1"),
+                               halfRows,
+                               schema.length);
+
+                       DetectSchemaTest.initFrameDataString(frame2, X2, 
schema, halfRows, cols);
+                       writer.writeFrameToHDFS(frame2.slice(0, halfRows - 1, 
0, schema.length - 1, new FrameBlock()),
+                               input("X2"),
+                               halfRows,
+                               schema.length);
+
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(X1.length, X1[0].length,
+                               OptimizerUtils.DEFAULT_BLOCKSIZE, -1);
+                       HDFSTool.writeMetaDataFile(input("X1") + ".mtd", null, 
schema, DataType.FRAME, mc, FileFormat.BINARY);
+                       HDFSTool.writeMetaDataFile(input("X2") + ".mtd", null, 
schema, DataType.FRAME, mc, FileFormat.BINARY);
+
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       Thread t1 = startLocalFedWorkerThread(port1);
+                       Thread t2 = startLocalFedWorkerThread(port2);
+
+                       TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+                       loadTestConfiguration(config);
+
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[] {"-nvargs", "X1=" + 
input("X1"), "X2=" + input("X2")};
+                       String out = runTest(null).toString().split("SystemDS 
Statistics:")[0];
+
+                       // Run actual dml script with federated matrix
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "100", "-nvargs",
+                               "X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "X2=" + TestUtils.federatedAddress(port2, 
input("X2")), "r=" + rows, "c=" + cols};
+                       String fedOut = runTest(null).toString();
+
+                       LOG.debug(fedOut);
+                       fedOut = fedOut.split("SystemDS Statistics:")[0];
+                       Assert.assertTrue("Equal Printed Output", 
out.equals(fedOut));
+                       Assert.assertTrue("Contains federated Cast to frame", 
heavyHittersContainsString("fed_castdtm"));
+                       TestUtils.shutdownThreads(t1, t2);
+
+                       rtplatform = platformOld;
+                       DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
+               }
+               catch(IOException e) {
+                       Assert.fail("Error writing input frame.");
+               }
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java 
b/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java
index 67d5626..69d3dc5 100644
--- a/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/frame/DetectSchemaTest.java
@@ -117,7 +117,7 @@ public class DetectSchemaTest extends AutomatedTestBase {
                        }
                        else {
                                double[][] A = getRandomMatrix(rows, 3, 
-Float.MAX_VALUE, Float.MAX_VALUE, 0.7, 2373);
-                               initFrameDataString(frame1, A, schema);
+                               initFrameDataString(frame1, A, schema, rows, 3);
                                writer.writeFrameToHDFS(frame1.slice(0, rows-1, 
0, schema.length-1, new FrameBlock()), input("A"), rows, schema.length);
                                schema[schema.length-2] = Types.ValueType.FP64;
                        }
@@ -143,8 +143,8 @@ public class DetectSchemaTest extends AutomatedTestBase {
                }
        }
 
-       private static void initFrameDataString(FrameBlock frame1, double[][] 
data, Types.ValueType[] lschema) {
-               for (int j = 0; j < 3; j++) {
+       public static void initFrameDataString(FrameBlock frame1, double[][] 
data, Types.ValueType[] lschema, int rows, int cols) {
+               for (int j = 0; j < cols; j++) {
                        Types.ValueType vt = lschema[j];
                        switch (vt) {
                                case STRING:
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java 
b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
index 4f4d4a7..ac23d2e 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/CacheEvictionTest.java
@@ -17,6 +17,7 @@
  * under the License.
  */
 
+
 package org.apache.sysds.test.functions.lineage;
 
 import java.util.ArrayList;
diff --git 
a/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml 
b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml
new file mode 100644
index 0000000..52b9889
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTest.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = federated(type="frame", addresses=list($X1, $X2),
+    ranges=list(list(0, 0), list($r / 2, $c), list($r / 2, 0), list($r, $c)))
+
+Z = as.matrix(X)
+print(toString(Z[1]))
diff --git 
a/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml
 
b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml
new file mode 100644
index 0000000..a0db27d
--- /dev/null
+++ 
b/src/test/scripts/functions/federated/primitives/FederatedCastToMatrixTestReference.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rbind(read($X1), read($X2))
+
+Z = as.matrix(X) 
+print(toString(Z[1]))

Reply via email to