This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new fa81f6a40a [SYSTEMDS-1622] Fix federated left indexing with scalar 
inputs
fa81f6a40a is described below

commit fa81f6a40ae15c9f50e00cf8ec96af2626684e4f
Author: ywcb00 <[email protected]>
AuthorDate: Sun Jun 5 23:16:07 2022 +0200

    [SYSTEMDS-1622] Fix federated left indexing with scalar inputs
    
    This patch generalizes the federated left indexing instruction
    for scalar, and fixes a more general issue of replacing
    instruction operands for edge cases where the scalar matches
    federated input or output variable names.
    
    Closes #1622.
    
    Co-authored-by: Matthias Boehm <[email protected]>
---
 .../federated/FederatedLookupTable.java            |   4 +
 .../federated/FederatedWorkerHandler.java          |  10 +-
 .../controlprogram/federated/FederationMap.java    |   5 +
 .../controlprogram/federated/FederationUtils.java  |  21 ++-
 .../instructions/fed/IndexingFEDInstruction.java   | 143 +++++++++-----
 .../primitives/FederatedLeftIndexTest.java         | 205 +++++++++++----------
 .../federated/FederatedLeftIndexFrameFullTest.dml  |   2 -
 .../FederatedLeftIndexFrameFullTestReference.dml   |   2 -
 .../federated/FederatedLeftIndexFullTest.dml       |   2 -
 .../FederatedLeftIndexFullTestReference.dml        |   2 -
 ...llTest.dml => FederatedLeftIndexScalarTest.dml} |  17 +-
 ...l => FederatedLeftIndexScalarTestReference.dml} |  17 +-
 12 files changed, 252 insertions(+), 178 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
index 55ab9715e2..afba8ac42a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedLookupTable.java
@@ -47,6 +47,10 @@ public class FederatedLookupTable {
                _lookup_table = new ConcurrentHashMap<>();
        }
 
+       public void clear() {
+               _lookup_table.clear();
+       }
+       
        /**
         * Get the ExecutionContextMap corresponding to the given host and pid 
of the
         * requesting coordinator from the lookup table. Create a new
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
index 592f77ccce..47cedd739c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederatedWorkerHandler.java
@@ -220,8 +220,10 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                                containsCLEAR = true;
                }
 
-               if(containsCLEAR)
+               if(containsCLEAR) {
+                       _flt.clear();
                        printStatistics();
+               }
 
                return response;
        }
@@ -398,7 +400,7 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                checkNumParams(request.getNumParams(), 1, 2);
                final String varName = String.valueOf(request.getID());
                ExecutionContext ec = ecm.get(request.getTID());
-
+               
                if(ec.containsVariable(varName)) {
                        final Data tgtData = ec.removeVariable(varName);
                        if(tgtData != null)
@@ -450,7 +452,6 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
 
        private FederatedResponse getVariable(FederatedRequest request, 
ExecutionContextMap ecm) {
                try{
-
                        checkNumParams(request.getNumParams(), 0);
                        ExecutionContext ec = ecm.get(request.getTID());
                        
if(!ec.containsVariable(String.valueOf(request.getID())))
@@ -494,7 +495,8 @@ public class FederatedWorkerHandler extends 
ChannelInboundHandlerAdapter {
                //handle missing spark execution context
                //TODO handling of spark instructions should be under control 
of federated site not coordinator
                if(ins.getType() == IType.SPARK
-               && !(ec instanceof SparkExecutionContext) ) {
+                       && !(ec instanceof SparkExecutionContext) )
+               {
                        ecm.convertToSparkCtx();
                        return ecm.get(id);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
index 0053a8b2fe..fcef0d7984 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationMap.java
@@ -351,6 +351,11 @@ public class FederationMap {
                return ret.toArray(new Future[0]);
        }
 
+       public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRange[] fedRange1,
+               FederatedRequest elseFr, FederatedRequest frSlice1, 
FederatedRequest frSlice2, FederatedRequest fr) {
+               return execute(tid, wait, fedRange1, elseFr, new 
FederatedRequest[]{frSlice1}, new FederatedRequest[]{frSlice2}, fr);
+       }
+
        @SuppressWarnings("unchecked")
        public Future<FederatedResponse>[] execute(long tid, boolean wait, 
FederatedRange[] fedRange1, FederatedRequest elseFr, FederatedRequest[] 
frSlices1, FederatedRequest[] frSlices2, FederatedRequest... fr) {
                // executes step1[] - step 2 - ... step4 (only first step 
federated-data-specific)
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 671cd0b744..82f78fe176 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -126,19 +126,26 @@ public class FederationUtils {
                String[] linst = inst;
                FederatedRequest[] fr = new FederatedRequest[inst.length];
                for(int j=0; j<inst.length; j++) {
+                       linst[j] = InstructionUtils.replaceOperand(linst[j], 0, 
type == null ?
+                               InstructionUtils.getExecType(linst[j]).name() : 
type.name());
+                       // replace inputs before before outputs in order to 
prevent conflicts
+                       // on outputId matching input literals (due to a mix of 
input instructions,
+                       // have to apply this replacement even for literal 
inputs)
                        for(int i = 0; i < varOldIn.length; i++) {
-                               linst[j] = 
InstructionUtils.replaceOperand(linst[j], 0, type == null ? 
InstructionUtils.getExecType(linst[j]).name() : type.name());
-                               linst[j] = linst[j].replace(
-                                       Lop.OPERAND_DELIMITOR + 
varOldOut.getName() + Lop.DATATYPE_PREFIX,
-                                       Lop.OPERAND_DELIMITOR + 
String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
-
-                               if(varOldIn[i] != null) {
+                               if( varOldIn[i] != null ) {
                                        linst[j] = linst[j].replace(
                                                Lop.OPERAND_DELIMITOR + 
varOldIn[i].getName() + Lop.DATATYPE_PREFIX,
                                                Lop.OPERAND_DELIMITOR + 
String.valueOf(varNewIn[i]) + Lop.DATATYPE_PREFIX);
-                                       linst[j] = linst[j].replace("=" + 
varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i])); //parameterized
+                                       // handle parameterized builtin 
functions
+                                       linst[j] = linst[j].replace("=" + 
varOldIn[i].getName(), "=" + String.valueOf(varNewIn[i]));
                                }
                        }
+                       for(int i = 0; i < varOldIn.length; i++) {
+                               linst[j] = linst[j].replace(
+                                       Lop.OPERAND_DELIMITOR + 
varOldOut.getName() + Lop.DATATYPE_PREFIX,
+                                       Lop.OPERAND_DELIMITOR + 
String.valueOf(outputId) + Lop.DATATYPE_PREFIX);
+                       }
+                       
                        fr[j] = new FederatedRequest(RequestType.EXEC_INST, 
outputId, (Object) linst[j]);
                }
                return fr;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
index bc70b398f9..4e4448ba97 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/IndexingFEDInstruction.java
@@ -25,8 +25,10 @@ import java.util.Collections;
 import java.util.List;
 import java.util.Objects;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
+import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.hops.fedplanner.FTypes.FType;
 import org.apache.sysds.lops.LeftIndex;
@@ -44,6 +46,7 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.instructions.cp.ScalarObject;
 import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.IndexRange;
@@ -150,7 +153,7 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                List<Types.ValueType> schema = new ArrayList<>();
                // replace old reshape values for each worker
                int i = 0;
-               for(org.apache.commons.lang3.tuple.Pair<FederatedRange, 
FederatedData> e : fedMap.getMap()) {
+               for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
                        FederatedRange range = e.getKey();
                        long rs = range.getBeginDims()[0], re = 
range.getEndDims()[0],
                                cs = range.getBeginDims()[1], ce = 
range.getEndDims()[1];
@@ -204,7 +207,8 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
        {
                //get input and requested index range
                CacheableData<?> in1 = ec.getCacheableData(input1);
-               CacheableData<?> in2 = ec.getCacheableData(input2);
+               CacheableData<?> in2 = null; // either in2 or scalar is set
+               ScalarObject scalar = null;
                IndexRange ixrange = getIndexRange(ec);
 
                //check bounds
@@ -213,11 +217,21 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                        throw new DMLRuntimeException("Invalid values for 
matrix indexing: ["+(ixrange.rowStart+1)+":"+(ixrange.rowEnd+1)+","
                                + 
(ixrange.colStart+1)+":"+(ixrange.colEnd+1)+"] " + "must be within matrix 
dimensions ["+in1.getNumRows()+","+in1.getNumColumns()+"].");
                }
-               if( (ixrange.rowEnd-ixrange.rowStart+1) != in2.getNumRows() || 
(ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) {
-                       throw new DMLRuntimeException("Invalid values for 
matrix indexing: " +
-                               "dimensions of the source matrix 
["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
-                               "do not match the shape of the matrix specified 
by indices [" +
-                               (ixrange.rowStart+1) +":" + (ixrange.rowEnd+1) 
+ ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + "].");
+
+               if(input2.getDataType() == DataType.SCALAR) {
+                       if(!ixrange.isScalar())
+                               throw new DMLRuntimeException("Invalid index 
range for leftindexing with scalar: " + ixrange.toString() + ".");
+
+                       scalar = ec.getScalarInput(input2);
+               }
+               else {
+                       in2 = ec.getCacheableData(input2);
+                       if( (ixrange.rowEnd-ixrange.rowStart+1) != 
in2.getNumRows() || (ixrange.colEnd-ixrange.colStart+1) != in2.getNumColumns()) 
{
+                               throw new DMLRuntimeException("Invalid values 
for matrix indexing: " +
+                                       "dimensions of the source matrix 
["+in2.getNumRows()+"x" + in2.getNumColumns() + "] " +
+                                       "do not match the shape of the matrix 
specified by indices [" +
+                                       (ixrange.rowStart+1) +":" + 
(ixrange.rowEnd+1) + ", " + (ixrange.colStart+1) + ":" + (ixrange.colEnd+1) + 
"].");
+                       }
                }
 
                FederationMap fedMap = in1.getFedMapping();
@@ -226,9 +240,13 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                int[][] sliceIxs = new int[fedMap.getSize()][];
                FederatedRange[] ranges = new FederatedRange[fedMap.getSize()];
 
+               // instruction string for copying a partition at the federated 
site
+               int cpVarInstIx = fedMap.getSize();
+               String cpVarInstString = createCopyInstString();
+
                // replace old reshape values for each worker
                int i = 0, prev = 0, from = fedMap.getSize();
-               for(org.apache.commons.lang3.tuple.Pair<FederatedRange, 
FederatedData> e : fedMap.getMap()) {
+               for(Pair<FederatedRange, FederatedData> e : fedMap.getMap()) {
                        FederatedRange range = e.getKey();
                        long rs = range.getBeginDims()[0], re = 
range.getEndDims()[0],
                                cs = range.getBeginDims()[1], ce = 
range.getEndDims()[1];
@@ -239,29 +257,46 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
 
                        long[] newIx = new long[]{(int) rsn, (int) ren, (int) 
csn, (int) cen};
 
-                       // find ranges where to apply  leftIndex
-                       long to;
-                       if(in1.isFederated(FType.ROW) && (to = (prev + ren - 
rsn)) >= 0 &&
-                               to < in2.getNumRows() && ixrange.rowStart <= 
re) {
-                               sliceIxs[i] = new int[] { prev, (int) to, 0, 
(int) in2.getNumColumns()-1};
-                               prev = (int) (to + 1);
-
-                               instStrings[i] = modifyIndices(newIx, 4, 8);
-                               ranges[i] = range;
-                               from = Math.min(i, from);
+                       if(in2 != null) { // matrix, frame
+                               // find ranges where to apply leftIndex
+                               long to;
+                               if(in1.isFederated(FType.ROW) && (to = (prev + 
ren - rsn)) >= 0 &&
+                                       to < in2.getNumRows() && 
ixrange.rowStart <= re) {
+                                       sliceIxs[i] = new int[] { prev, (int) 
to, 0, (int) in2.getNumColumns()-1};
+                                       prev = (int) (to + 1);
+
+                                       instStrings[i] = modifyIndices(newIx, 
4, 8);
+                                       ranges[i] = range;
+                                       from = Math.min(i, from);
+                               }
+                               else if(in1.isFederated(FType.COL) && (to = 
(prev + cen - csn)) >= 0 &&
+                                       to < in2.getNumColumns() && 
ixrange.colStart <= ce) {
+                                       sliceIxs[i] = new int[] {0, (int) 
in2.getNumRows() - 1, prev, (int) to};
+                                       prev = (int) (to + 1);
+
+                                       instStrings[i] = modifyIndices(newIx, 
4, 8);
+                                       ranges[i] = range;
+                                       from = Math.min(i, from);
+                               }
+                               else {
+                                       // TODO shallow copy, add more advanced 
update in place for federated
+                                       cpVarInstIx = Math.min(i, cpVarInstIx);
+                                       instStrings[i] = cpVarInstString;
+                               }
                        }
-                       else if(in1.isFederated(FType.COL) && (to = (prev + cen 
- csn)) >= 0 &&
-                               to < in2.getNumColumns() && ixrange.colStart <= 
ce) {
-                               sliceIxs[i] = new int[] {0, (int) 
in2.getNumRows() - 1, prev, (int) to};
-                               prev = (int) (to + 1);
-
-                               instStrings[i] = modifyIndices(newIx, 4, 8);
-                               ranges[i] = range;
-                               from = Math.min(i, from);
+                       else { // scalar
+                               if(ixrange.rowStart >= rs && ixrange.rowEnd < re
+                                       && ixrange.colStart >= cs && 
ixrange.colEnd < ce) {
+                                       instStrings[i] = modifyIndices(newIx, 
4, 8);
+                                       instStrings[i] = 
changeScalarLiteralFlag(instStrings[i], 3);
+                                       ranges[i] = range;
+                                       from = Math.min(i, from);
+                               }
+                               else {
+                                       cpVarInstIx = Math.min(i, cpVarInstIx);
+                                       instStrings[i] = cpVarInstString;
+                               }
                        }
-                       else
-                               // TODO shallow copy, add more advanced update 
in place for federated
-                               instStrings[i] = createCopyInstString();
 
                        i++;
                }
@@ -269,35 +304,44 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                sliceIxs = 
Arrays.stream(sliceIxs).filter(Objects::nonNull).toArray(int[][] :: new);
 
                long id = FederationUtils.getNextFedDataID();
+               //TODO remove explicit put (unnecessary in CP, only spark which 
is about to be cleaned up)
                FederatedRequest tmp = new 
FederatedRequest(FederatedRequest.RequestType.PUT_VAR, id, new 
MatrixCharacteristics(-1, -1), in1.getDataType());
                fedMap.execute(getTID(), true, tmp);
 
-               FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, 
DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
-                       input2.isFrame(), sliceIxs);
-               FederatedRequest[] fr2 = 
FederationUtils.callInstruction(instStrings, output, id, new 
CPOperand[]{input1, input2},
-                       new long[]{fedMap.getID(), fr1[0].getID()}, null);
-               FederatedRequest fr3 = fedMap.cleanup(getTID(), fr1[0].getID());
+               if(in2 != null) { // matrix, frame
+                       FederatedRequest[] fr1 = fedMap.broadcastSliced(in2, 
DMLScript.LINEAGE ? ec.getLineageItem(input2) : null,
+                               input2.isFrame(), sliceIxs);
+                       FederatedRequest[] fr2 = 
FederationUtils.callInstruction(instStrings, output, id, new 
CPOperand[]{input1, input2},
+                               new long[]{fedMap.getID(), fr1[0].getID()}, 
null);
+                       FederatedRequest fr3 = fedMap.cleanup(getTID(), 
fr1[0].getID());
 
-               //execute federated instruction and cleanup intermediates
-               if(sliceIxs.length == fedMap.getSize())
-                       fedMap.execute(getTID(), true, fr2, fr1, fr3);
-               else {
-                       // get index of cpvar request
-                       for(i = 0; i < fr2.length; i++)
-                               if(i < from || i >= from + sliceIxs.length)
-                                       break;
-                       fedMap.execute(getTID(), true, ranges, (fr2[i]), 
Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, fr3);
+                       //execute federated instruction and cleanup 
intermediates
+                       if(sliceIxs.length == fedMap.getSize())
+                               fedMap.execute(getTID(), true, fr2, fr1, fr3);
+                       else
+                               fedMap.execute(getTID(), true, ranges, 
fr2[cpVarInstIx], Arrays.copyOfRange(fr2, from, from + sliceIxs.length), fr1, 
fr3);
+               }
+               else { // scalar
+                       FederatedRequest fr1 = fedMap.broadcast(scalar);
+                       FederatedRequest[] fr2 = 
FederationUtils.callInstruction(instStrings, output, id, new 
CPOperand[]{input1, input2},
+                               new long[]{fedMap.getID(), fr1.getID()}, null);
+                       FederatedRequest fr3 = fedMap.cleanup(getTID(), 
fr1.getID());
+
+                       if(fr2.length == 1)
+                               fedMap.execute(getTID(), true, fr2, fr1, fr3);
+                       else
+                               fedMap.execute(getTID(), true, ranges, 
fr2[cpVarInstIx], fr2[from], fr1, fr3);
                }
 
                if(input1.isFrame()) {
                        FrameObject out = ec.getFrameObject(output);
                        out.setSchema(((FrameObject) in1).getSchema());
                        
out.getDataCharacteristics().set(in1.getDataCharacteristics());
-                       out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+                       out.setFedMapping(fedMap.copyWithNewID(id));
                } else {
                        MatrixObject out = ec.getMatrixObject(output);
-                       
out.getDataCharacteristics().set(in1.getDataCharacteristics());;
-                       out.setFedMapping(fedMap.copyWithNewID(fr2[0].getID()));
+                       
out.getDataCharacteristics().set(in1.getDataCharacteristics());
+                       out.setFedMapping(fedMap.copyWithNewID(id));
                }
        }
 
@@ -309,6 +353,13 @@ public final class IndexingFEDInstruction extends 
UnaryFEDInstruction {
                return String.join(Lop.OPERAND_DELIMITOR, instParts);
        }
 
+       private String changeScalarLiteralFlag(String inst, int partIx) {
+               // change the literal flag of the broadcast scalar
+               String[] instParts = inst.split(Lop.OPERAND_DELIMITOR);
+               instParts[partIx] = instParts[partIx].replace("true", "false");
+               return String.join(Lop.OPERAND_DELIMITOR, instParts);
+       }
+
        private String createCopyInstString() {
                String[] instParts = instString.split(Lop.OPERAND_DELIMITOR);
                return 
VariableCPInstruction.prepareCopyInstruction(instParts[2], 
instParts[8]).toString();
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
index 3c337286a3..6686bdd298 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedLeftIndexTest.java
@@ -22,7 +22,6 @@ package org.apache.sysds.test.functions.federated.primitives;
 import java.util.Arrays;
 import java.util.Collection;
 
-import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.util.HDFSTool;
@@ -40,6 +39,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase 
{
 
        private final static String TEST_NAME1 = "FederatedLeftIndexFullTest";
        private final static String TEST_NAME2 = 
"FederatedLeftIndexFrameFullTest";
+       private final static String TEST_NAME3 = "FederatedLeftIndexScalarTest";
 
        private final static String TEST_DIR = "functions/federated/";
        private static final String TEST_CLASS_DIR = TEST_DIR + 
FederatedLeftIndexTest.class.getSimpleName() + "/";
@@ -81,7 +81,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase 
{
        }
 
        private enum DataType {
-               MATRIX, FRAME
+               MATRIX, FRAME, SCALAR
        }
 
        @Override
@@ -89,6 +89,7 @@ public class FederatedLeftIndexTest extends AutomatedTestBase 
{
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S"}));
        }
 
        @Test
@@ -102,108 +103,122 @@ public class FederatedLeftIndexTest extends 
AutomatedTestBase {
        }
 
        @Test
-       public void testLeftIndexFullDenseMatrixSP() { 
runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK); }
+       public void testLeftIndexFullDenseMatrixSP() {
+               runAggregateOperationTest(DataType.MATRIX, ExecMode.SPARK);
+       }
 
        @Test
        public void testLeftIndexFullDenseFrameSP() {
                runAggregateOperationTest(DataType.FRAME, ExecMode.SPARK);
        }
 
-       private void runAggregateOperationTest(DataType dataType, ExecMode 
execMode) {
-               setExecMode(execMode);
-
-               String TEST_NAME = null;
-
-               if(dataType == DataType.MATRIX)
-                       TEST_NAME = TEST_NAME1;
-               else
-                       TEST_NAME = TEST_NAME2;
-
+       @Test
+       public void testLeftIndexScalarCP() {
+               runAggregateOperationTest(DataType.SCALAR, 
ExecMode.SINGLE_NODE);
+       }
 
-               getAndLoadTestConfiguration(TEST_NAME);
-               String HOME = SCRIPT_DIR + TEST_DIR;
+       @Test
+       public void testLeftIndexScalarSP() {
+               runAggregateOperationTest(DataType.SCALAR, ExecMode.SPARK);
+       }
 
-               // write input matrices
-               int r1 = rows1;
-               int c1 = cols1 / 4;
-               if(rowPartitioned) {
-                       r1 = rows1 / 4;
-                       c1 = cols1;
+       private void runAggregateOperationTest(DataType dataType, ExecMode 
execMode) {
+               ExecMode oldPlatform = setExecMode(execMode);
+               
+               try {
+                       String TEST_NAME = null;
+       
+                       if(dataType == DataType.MATRIX)
+                               TEST_NAME = TEST_NAME1;
+                       else if(dataType == DataType.FRAME)
+                               TEST_NAME = TEST_NAME2;
+                       else
+                               TEST_NAME = TEST_NAME3;
+       
+                       getAndLoadTestConfiguration(TEST_NAME);
+                       String HOME = SCRIPT_DIR + TEST_DIR;
+       
+                       // write input matrices
+                       int r1 = rows1;
+                       int c1 = cols1 / 4;
+                       if(rowPartitioned) {
+                               r1 = rows1 / 4;
+                               c1 = cols1;
+                       }
+       
+                       double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
+                       double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
+                       double[][] X3 = getRandomMatrix(r1, c1,  1, 5, 1, 8);
+                       double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
+       
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(r1, c1,  blocksize, r1 * c1);
+                       writeInputMatrixWithMTD("X1", X1, false, mc);
+                       writeInputMatrixWithMTD("X2", X2, false, mc);
+                       writeInputMatrixWithMTD("X3", X3, false, mc);
+                       writeInputMatrixWithMTD("X4", X4, false, mc);
+       
+                       if(dataType != DataType.SCALAR) {
+                               double[][] Y = getRandomMatrix(rows2, cols2, 1, 
5, 1, 3);
+       
+                               MatrixCharacteristics mc2 = new 
MatrixCharacteristics(rows2, cols2, blocksize, rows2 * cols2);
+                               writeInputMatrixWithMTD("Y", Y, false, mc2);
+                       }
+       
+                       // empty script name because we don't execute any 
script, just start the worker
+                       fullDMLScriptName = "";
+                       int port1 = getRandomAvailablePort();
+                       int port2 = getRandomAvailablePort();
+                       int port3 = getRandomAvailablePort();
+                       int port4 = getRandomAvailablePort();
+                       Thread t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       Thread t2 = startLocalFedWorkerThread(port2, 
FED_WORKER_WAIT_S);
+                       Thread t3 = startLocalFedWorkerThread(port3, 
FED_WORKER_WAIT_S);
+                       Thread t4 = startLocalFedWorkerThread(port4);
+       
+                       TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
+                       loadTestConfiguration(config);
+       
+                       var lfrom = Math.min(from, to);
+                       var lfrom2 = Math.min(from2, to2);
+       
+                       // Run reference dml script with normal matrix
+                       fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
+                       programArgs = new String[] {"-args", input("X1"), 
input("X2"), input("X3"), input("X4"),
+                               input("Y"), String.valueOf(lfrom), 
String.valueOf(to),
+                               String.valueOf(lfrom2), String.valueOf(to2),
+                               Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
+                       runTest(null);
+                       // Run actual dml script with federated matrix
+       
+                       fullDMLScriptName = HOME + TEST_NAME + ".dml";
+                       programArgs = new String[] {"-stats", "100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
+                               "in_Y=" + input("Y"), "rows=" + rows1, "cols=" 
+ cols1,
+                               "rows2=" + rows2, "cols2=" + cols2,
+                               "from=" + from, "to=" + to,"from2=" + from2, 
"to2=" + to2,
+                               "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "out_S=" + output("S")};
+       
+                       runTest(null);
+       
+                       // compare via files
+                       compareResults(1e-9, "Stat-DML1", "Stat-DML2");
+       
+                       Assert.assertTrue(rtplatform ==ExecMode.SPARK ?
+                               heavyHittersContainsString("fed_mapLeftIndex") 
: heavyHittersContainsString("fed_leftIndex"));
+       
+                       // check that federated input files are still existing
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+       
+                       TestUtils.shutdownThreads(t1, t2, t3, t4);
                }
-
-               double[][] X1 = getRandomMatrix(r1, c1, 1, 5, 1, 3);
-               double[][] X2 = getRandomMatrix(r1, c1, 1, 5, 1, 7);
-               double[][] X3 = getRandomMatrix(r1, c1,  1, 5, 1, 8);
-               double[][] X4 = getRandomMatrix(r1, c1, 1, 5, 1, 9);
-
-               MatrixCharacteristics mc = new MatrixCharacteristics(r1, c1,  
blocksize, r1 * c1);
-               writeInputMatrixWithMTD("X1", X1, false, mc);
-               writeInputMatrixWithMTD("X2", X2, false, mc);
-               writeInputMatrixWithMTD("X3", X3, false, mc);
-               writeInputMatrixWithMTD("X4", X4, false, mc);
-
-               double[][] Y = getRandomMatrix(rows2, cols2, 1, 5, 1, 3);
-
-               MatrixCharacteristics mc2 = new MatrixCharacteristics(rows2, 
cols2, blocksize, rows2 * cols2);
-               writeInputMatrixWithMTD("Y", Y, false, mc2);
-
-               // empty script name because we don't execute any script, just 
start the worker
-               fullDMLScriptName = "";
-               int port1 = getRandomAvailablePort();
-               int port2 = getRandomAvailablePort();
-               int port3 = getRandomAvailablePort();
-               int port4 = getRandomAvailablePort();
-               Thread t1 = startLocalFedWorkerThread(port1, FED_WORKER_WAIT_S);
-               Thread t2 = startLocalFedWorkerThread(port2, FED_WORKER_WAIT_S);
-               Thread t3 = startLocalFedWorkerThread(port3, FED_WORKER_WAIT_S);
-               Thread t4 = startLocalFedWorkerThread(port4);
-
-               rtplatform = execMode;
-               if(rtplatform == ExecMode.SPARK) {
-                       System.out.println(7);
-                       DMLScript.USE_LOCAL_SPARK_CONFIG = true;
+               finally {
+                       resetExecMode(oldPlatform);
                }
-               TestConfiguration config = 
availableTestConfigurations.get(TEST_NAME);
-               loadTestConfiguration(config);
-
-               if(from > to)
-                       from = to;
-               if(from2 > to2)
-                       from2 = to2;
-
-               // Run reference dml script with normal matrix
-               fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-explain", "-args", input("X1"), 
input("X2"), input("X3"), input("X4"),
-                       input("Y"), String.valueOf(from), String.valueOf(to),
-                       String.valueOf(from2), String.valueOf(to2),
-                       Boolean.toString(rowPartitioned).toUpperCase(), 
expected("S")};
-               runTest(null);
-               // Run actual dml script with federated matrix
-
-               fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-stats", "100", "-nvargs",
-                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
-                       "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
-                       "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
-                       "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
-                       "in_Y=" + input("Y"), "rows=" + rows1, "cols=" + cols1,
-                       "rows2=" + rows2, "cols2=" + cols2,
-                       "from=" + from, "to=" + to,"from2=" + from2, "to2=" + 
to2,
-                       "rP=" + Boolean.toString(rowPartitioned).toUpperCase(), 
"out_S=" + output("S")};
-
-               runTest(null);
-
-               // compare via files
-               compareResults(1e-9, "Stat-DML1", "Stat-DML2");
-
-               Assert.assertTrue(rtplatform ==ExecMode.SPARK ? 
heavyHittersContainsString("fed_mapLeftIndex") : 
heavyHittersContainsString("fed_leftIndex"));
-
-               // check that federated input files are still existing
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
-               Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
-
-               TestUtils.shutdownThreads(t1, t2, t3, t4);
        }
 }
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml 
b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
index ca9fe81f40..a10bb72f77 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
@@ -41,5 +41,3 @@ A = as.frame(A)
 
 A[from:to, from2:to2] = B;
 write(A, $out_S);
-
-print(toString(A))
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
index 4b5a85234c..6589134273 100644
--- 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
@@ -37,5 +37,3 @@ A = as.frame(A)
 
 A[from:to, from2:to2] = B;
 write(A, $11);
-
-print(toString(A))
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml 
b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
index a201f7bfe3..c048cb77c2 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexFullTest.dml
@@ -38,5 +38,3 @@ B = read($in_Y)
 
 A[from:to, from2:to2] = B;
 write(A, $out_S);
-
-print(toString(A))
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml 
b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
index 2cc29f7ca8..ecd123254e 100644
--- 
a/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedLeftIndexFullTestReference.dml
@@ -34,5 +34,3 @@ B = read($5)
 
 A[from:to, from2:to2] = B;
 write(A, $11);
-
-print(toString(A))
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml 
b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
similarity index 90%
copy from 
src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
copy to src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
index ca9fe81f40..71a9f93490 100644
--- a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTest.dml
+++ b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTest.dml
@@ -19,10 +19,10 @@
 #
 #-------------------------------------------------------------
 
-from = $from;
-to = $to;
-from2 = $from2;
-to2 = $to2;
+row1 = $from;
+row2 = $to;
+col1 = $from2;
+col2 = $to2;
 
 if ($rP) {
   A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
@@ -34,12 +34,11 @@ if ($rP) {
     list(0,$cols/2), list($rows, 3*($cols/4)), list(0, 3*($cols/4)), 
list($rows, $cols)));
 }
 
-B = read($in_Y)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
 
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
 
-A[from:to, from2:to2] = B;
 write(A, $out_S);
 
-print(toString(A))
diff --git 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
 
b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
similarity index 88%
copy from 
src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
copy to 
src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
index 4b5a85234c..14ea17fbda 100644
--- 
a/src/test/scripts/functions/federated/FederatedLeftIndexFrameFullTestReference.dml
+++ 
b/src/test/scripts/functions/federated/FederatedLeftIndexScalarTestReference.dml
@@ -19,10 +19,10 @@
 #
 #-------------------------------------------------------------
 
-from = $6;
-to = $7;
-from2 = $8;
-to2 = $9;
+row1 = $6;
+row2 = $7;
+col1 = $8;
+col2 = $9;
 if($10) {
   A = rbind(read($1), read($2), read($3), read($4));
 }
@@ -30,12 +30,11 @@ else {
   A = cbind(read($1), read($2), read($3), read($4));
 }
 
-B = read($5)
+b = 13;
+c = as.scalar(rand(rows=1, cols=1, seed=456));
 
-B = as.frame(B)
-A = as.frame(A)
+A[row1, col1] = b;
+A[row2, col2] = c;
 
-A[from:to, from2:to2] = B;
 write(A, $11);
 
-print(toString(A))

Reply via email to