This is an automated email from the ASF dual-hosted git repository.

sebwrede pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 0fa4463  [SYSTEMDS-3085] FederatedCTable - Keep Output Federated 
Closes #1371.
0fa4463 is described below

commit 0fa4463b42dce5a65bdeb3f9d7a1c422db174cda
Author: ywcb00 <[email protected]>
AuthorDate: Wed Aug 11 12:50:36 2021 +0200

    [SYSTEMDS-3085] FederatedCTable - Keep Output Federated
    Closes #1371.
---
 .../instructions/fed/CtableFEDInstruction.java     | 306 ++++++++++++++-------
 .../federated/primitives/FederatedCtableTest.java  |  11 +-
 .../federated/FederatedCtableFedOutput.dml         |  10 +-
 .../FederatedCtableFedOutputReference.dml          |  12 +-
 4 files changed, 234 insertions(+), 105 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
index a7d4cb6..2a1cbb1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/CtableFEDInstruction.java
@@ -22,7 +22,10 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.Arrays;
 import java.util.Collections;
 import java.util.concurrent.Future;
+import java.util.Iterator;
+import java.util.SortedMap;
 import java.util.stream.IntStream;
+import java.util.TreeMap;
 
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types.DataType;
@@ -38,7 +41,6 @@ import 
org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
 import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
 import 
org.apache.sysds.runtime.controlprogram.federated.FederationMap.AlignType;
 import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
-import org.apache.sysds.runtime.functionobjects.And;
 import org.apache.sysds.runtime.instructions.Instruction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
@@ -105,8 +107,10 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                }
 
                // get new output dims
-               Long[] dims1 = getOutputDimension(mo1, input1, _outDim1, 
mo1.getFedMapping().getFederatedRanges());
-               Long[] dims2 = getOutputDimension(mo2, input2, _outDim2, 
mo1.getFedMapping().getFederatedRanges());
+               Long[] dims1 = getOutputDimension(mo1, reversed ? input2 : 
input1, reversed ? _outDim2 : _outDim1,
+                       mo1.getFedMapping().getFederatedRanges());
+               Long[] dims2 = getOutputDimension(mo2, reversed ? input1 : 
input2, reversed ? _outDim1 : _outDim2,
+                       mo1.getFedMapping().getFederatedRanges());
 
                MatrixObject mo3 = input3 != null && input3.isMatrix() ? 
ec.getMatrixObject(input3) : null;
 
@@ -116,119 +120,157 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                        mo1 = ec.getMatrixObject(input3);
                }
 
-               long dim1 = Collections.max(Arrays.asList(dims1), 
Long::compare);
-               boolean fedOutput = dim1 % mo1.getFedMapping().getSize() == 0 
&& dims1.length == Arrays.stream(dims1).distinct().count();
+               // static non-partitioned output dimension (same for all 
federated partitions)
+               long staticDim = Collections.max(Arrays.asList(dims1), 
Long::compare);
+               boolean fedOutput = isFedOutput(mo1.getFedMapping(), mo2);
 
-               processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, 
fedOutput, dims1, dims2);
+               processRequest(ec, mo1, mo2, mo3, reversed, reversedWeights, 
fedOutput, staticDim, dims2);
        }
 
+       /**
+        * Broadcast, execute, and finalize the federated instruction according 
to
+        * the specified inputs.
+        *
+        * @param ec execution context
+        * @param mo1 input matrix object 1
+        * @param mo2 input matrix object 2
+        * @param mo3 input matrix object 3 or null
+        * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+        * @param reversedWeights boolean indicating if inputs mo1 and mo3 are 
reversed
+        * @param fedOutput boolean indicating if output can be kept federated
+        * @param staticDim static non-partitioned dimension of the output
+        * @param dims2 dimensions of the partial outputs along the federated 
partitioning
+        */
        private void processRequest(ExecutionContext ec, MatrixObject mo1, 
MatrixObject mo2, MatrixObject mo3,
-               boolean reversed, boolean reversedWeights, boolean fedOutput, 
Long[] dims1, Long[] dims2) {
+               boolean reversed, boolean reversedWeights, boolean fedOutput, 
long staticDim, Long[] dims2) {
+
+               FederationMap fedMap = mo1.getFedMapping();
+
+               FederatedRequest[] fr1 = fedMap.broadcastSliced(mo2, false);
+               FederatedRequest[] fr2 = null;
+               FederatedRequest fr3, fr4, fr5;
                Future<FederatedResponse>[] ffr;
 
-               FederatedRequest[] fr1 = 
mo1.getFedMapping().broadcastSliced(mo2, false);
-               FederatedRequest fr2, fr3;
                if(mo3 != null && mo1.isFederated() && mo3.isFederated()
-               && mo1.getFedMapping().isAligned(mo3.getFedMapping(), 
AlignType.FULL)) { // mo1 and mo3 federated and aligned
+                       && fedMap.isAligned(mo3.getFedMapping(), 
AlignType.FULL)) { // mo1 and mo3 federated and aligned
                        if(!reversed)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID(), mo3.getFedMapping().getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID(), mo3.getFedMapping().getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID(), mo3.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr2, fr3);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID(), mo3.getFedMapping().getID()});
                }
                else if(mo3 == null) {
                        if(!reversed)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr2, fr3);
-
-               } else {
-                       FederatedRequest[] fr4 = 
mo1.getFedMapping().broadcastSliced(mo3, false);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID()});
+               }
+               else {
+                       fr2 = fedMap.broadcastSliced(mo3, false);
                        if(!reversed && !reversedWeights)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] 
{mo1.getFedMapping().getID(), fr1[0].getID(), fr4[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fedMap.getID(), 
fr1[0].getID(), fr2[0].getID()});
                        else if(reversed && !reversedWeights)
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
mo1.getFedMapping().getID(), fr4[0].getID()});
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fedMap.getID(), fr2[0].getID()});
                        else
-                               fr2 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
-                                       new long[] {fr1[0].getID(), 
fr4[0].getID(), mo1.getFedMapping().getID()});
-
-                       fr3 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr2.getID());
-                       ffr = mo1.getFedMapping().execute(getTID(), true, fr1, 
fr4, fr2, fr3);
+                               fr3 = 
FederationUtils.callInstruction(instString, output, new CPOperand[] {input1, 
input2, input3},
+                                       new long[] {fr1[0].getID(), 
fr2[0].getID(), fedMap.getID()});
                }
 
-               if(fedOutput && isFedOutput(ffr, dims1)) {
+               if(fedOutput) {
+                       if(fr2 != null) // broadcasted mo3
+                               fedMap.execute(getTID(), true, fr1, fr2, fr3);
+                       else
+                               fedMap.execute(getTID(), true, fr1, fr3);
+
                        MatrixObject out = ec.getMatrixObject(output);
-                       FederationMap newFedMap = 
modifyFedRanges(mo1.getFedMapping(), dims1, dims2);
-                       setFedOutput(mo1, out, newFedMap, dims1, fr2.getID());
+                       FederationMap newFedMap = 
modifyFedRanges(fedMap.copyWithNewID(fr3.getID()),
+                               staticDim, dims2, reversed);
+                       setFedOutput(mo1, out, newFedMap, staticDim, dims2, 
reversed);
                } else {
+                       fr4 = new 
FederatedRequest(FederatedRequest.RequestType.GET_VAR, fr3.getID());
+                       fr5 = fedMap.cleanup(getTID(), fr3.getID());
+                       if(fr2 != null) // broadcasted mo3
+                               ffr = fedMap.execute(getTID(), true, fr1, fr2, 
fr3, fr4, fr5);
+                       else
+                               ffr = fedMap.execute(getTID(), true, fr1, fr3, 
fr4, fr5);
+
                        ec.setMatrixOutput(output.getName(), aggResult(ffr));
                }
        }
 
-       boolean isFedOutput(Future<FederatedResponse>[] ffr,  Long[] dims1) {
-               boolean fedOutput = true;
-
-               long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / ffr.length;
-               try {
-                       MatrixBlock curr;
-                       MatrixBlock prev =(MatrixBlock) 
ffr[0].get().getData()[0];
-                       for(int i = 1; i < ffr.length && fedOutput; i++) {
-                               curr = (MatrixBlock) ffr[i].get().getData()[0];
-                               MatrixBlock sliced = curr.slice((int) 
(curr.getNumRows() - fedSize), curr.getNumRows() - 1);
-
-                               if(curr.getNumColumns() != prev.getNumColumns())
-                                       return false;
-
-                               // no intersection
-                               if(curr.getNumRows() == (i+1) * 
prev.getNumRows() && curr.getNonZeros() <= prev.getLength()
-                                       && (curr.getNumRows() - 
sliced.getNumRows()) == i * prev.getNumRows()
-                                       && curr.getNonZeros() - 
sliced.getNonZeros() == 0)
-                                       continue;
-
-                               // check intersect with AND and compare number 
of nnz
-                               MatrixBlock prevExtend = new 
MatrixBlock(curr.getNumRows(), curr.getNumColumns(), true, 0);
-                               prevExtend.copy(0, prev.getNumRows()-1, 0, 
prev.getNumColumns()-1, prev, true);
-
-                               MatrixBlock  intersect = 
curr.binaryOperationsInPlace(new BinaryOperator(And.getAndFnObject()), 
prevExtend);
-                               if(intersect.getNonZeros() != 0)
-                                       fedOutput = false;
-                               prev = sliced;
-                       }
-               }
-               catch(Exception e) {
-                       e.printStackTrace();
-               }
-               return fedOutput;
-       }
+       /**
+        * Evaluate if the output can be kept federated on the different 
federated
+        * sites or if the output needs to be aggregated on the coordinator, 
based
+        * on the output ranges of mo2.
+        * The output can be kept federated if the slices of mo2, sliced 
corresponding
+        * to the federated ranges of mo1, have strict separable and ascending 
value
+        * ranges. From this property it follows that the partial outputs can 
also
+        * be separated, and hence the overall output can be created by a simple
+        * binding through a federated mapping.
+        *
+        * @param fedMap the federation map of the federated matrix input mo1
+        * @param mo2 input matrix object mo2
+        * @return boolean indicating if the output can be kept on the 
federated sites
+        */
+       private boolean isFedOutput(FederationMap fedMap, MatrixObject mo2) {
+               MatrixBlock mb = mo2.acquireReadAndRelease();
+               FederatedRange[] fedRanges = fedMap.getFederatedRanges(); // 
federated ranges of mo1
+               SortedMap<Double, Double> fedDims = new TreeMap<Double, 
Double>(); // <beginDim, endDim>
+
+               // collect min and max of the corresponding slices of mo2
+               IntStream.range(0, fedRanges.length).forEach(i -> {
+                       MatrixBlock sliced = mb.slice(
+                               fedRanges[i].getBeginDimsInt()[0], 
fedRanges[i].getEndDimsInt()[0] - 1,
+                               fedRanges[i].getBeginDimsInt()[1], 
fedRanges[i].getEndDimsInt()[1] - 1);
+                       fedDims.put(sliced.min(), sliced.max());
+               });
 
+               boolean retVal = (fedDims.size() == fedRanges.length); // no 
duplicate begin dimension entries
 
-       private static void setFedOutput(MatrixObject mo1, MatrixObject out, 
FederationMap fedMap, Long[] dims1, long outId) {
-               long fedSize = Collections.max(Arrays.asList(dims1), 
Long::compare) / dims1.length;
+               Iterator<SortedMap.Entry<Double, Double>> iter = 
fedDims.entrySet().iterator();
+               SortedMap.Entry<Double, Double> entry = iter.next(); // first 
entry does not have to be checked
+               double prevEndDim = entry.getValue();
+               while(iter.hasNext() && retVal) {
+                       entry = iter.next();
+                       // previous end dimension must be less than current 
begin dimension (no overlaps of ranges)
+                       retVal &= (prevEndDim < entry.getKey());
+                       prevEndDim = entry.getValue();
+               }
 
-               long d1 = Collections.max(Arrays.asList(dims1), Long::compare);
-               long d2 = Collections.max(Arrays.asList(dims1), Long::compare);
+               return retVal;
+       }
+
+       /**
+        * Set the output and its data characteristics on the federated sites.
+        *
+        * @param mo1 input matrix object mo1
+        * @param out input matrix object of the output
+        * @param fedMap the federation map of the federated matrix input mo1
+        * @param staticDim static non-partitioned dimension of the output
+        * @param dims2 dimensions of the partial outputs along the federated 
partitioning
+        * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+        */
+       private static void setFedOutput(MatrixObject mo1, MatrixObject out, 
FederationMap fedMap,
+               long staticDim, Long[] dims2, boolean reversed) {
+               // get the final output dimensions
+               final long d1 = (reversed ? 
Collections.max(Arrays.asList(dims2)) : staticDim);
+               final long d2 = (reversed ? staticDim : 
Collections.max(Arrays.asList(dims2)));
 
                // set output
                out.getDataCharacteristics().set(d1, d2, (int) 
mo1.getBlocksize(), mo1.getNnz());
-               out.setFedMapping(fedMap.copyWithNewID(outId));
+               out.setFedMapping(fedMap);
 
                long varID = FederationUtils.getNextFedDataID();
-               out.getFedMapping().mapParallel(varID, (range, data) -> {
+               fedMap.mapParallel(varID, (range, data) -> {
                        try {
                                FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
                                        FederatedRequest.RequestType.EXEC_UDF, 
-1,
-                                       new SliceOutput(data.getVarID(), 
fedSize))).get();
+                                       new SliceOutput(data.getVarID(), 
staticDim, dims2, reversed))).get();
                                if(!response.isSuccessful())
                                        response.throwExceptionFromResponse();
                        }
@@ -239,6 +281,9 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                });
        }
 
+       /**
+        * Aggregate the partial outputs locally.
+        */
        private static MatrixBlock aggResult(Future<FederatedResponse>[] ffr) {
                MatrixBlock resultBlock = new MatrixBlock(1, 1, true, 0);
                int dim1 = 0, dim2 = 0;
@@ -266,27 +311,44 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                return resultBlock;
        }
 
-       private static FederationMap modifyFedRanges(FederationMap fedMap, 
Long[] dims1, Long[] dims2) {
-               IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(i -> {
-                       fedMap.getFederatedRanges()[i]
-                               .setBeginDim(0, i == 0 ? 0 : 
fedMap.getFederatedRanges()[i - 1].getEndDims()[0]);
-                       fedMap.getFederatedRanges()[i].setEndDim(0, dims1[i]);
-                       fedMap.getFederatedRanges()[i]
-                               .setBeginDim(1, i == 0 ? 0 : 
fedMap.getFederatedRanges()[i - 1].getBeginDims()[1]);
-                       fedMap.getFederatedRanges()[i].setEndDim(1, dims2[i]);
+       /**
+        * Set the ranges of the federation map according to the static 
dimension and
+        * the individual dimensions of the partial output matrices.
+        *
+        * @param fedMap the federation map of the federated matrix input mo1
+        * @param staticDim static non-partitioned dimension of the output
+        * @param dims2 dimensions of the partial outputs along the federated 
partitioning
+        * @param reversed boolean indicating if inputs mo1 and mo2 are reversed
+        * @return FederationMap the modified federation map
+        */
+       private static FederationMap modifyFedRanges(FederationMap fedMap, long 
staticDim,
+               Long[] dims2, boolean reversed) {
+               // set the federated ranges to the individual partition sizes
+               IntStream.range(0, 
fedMap.getFederatedRanges().length).forEach(counter -> {
+                       FederatedRange fedRange = 
fedMap.getFederatedRanges()[counter];
+                       fedRange.setBeginDim(reversed ? 1 : 0, 0);
+                       fedRange.setEndDim(reversed ? 1 : 0, staticDim);
+                       fedRange.setBeginDim(reversed ? 0 : 1, counter == 0 ? 0 
: dims2[counter-1]);
+                       fedRange.setEndDim(reversed ? 0 : 1, dims2[counter]);
                });
                return fedMap;
        }
 
-       private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, 
CPOperand outOp, FederatedRange[] federatedRanges) {
+       /**
+        * Compute the output dimensions of the partial outputs according to the
+        * federated ranges.
+        */
+       private Long[] getOutputDimension(MatrixObject in, CPOperand inOp, 
CPOperand outOp,
+               FederatedRange[] federatedRanges) {
                Long[] fedDims = new Long[federatedRanges.length];
 
                if(!in.isFederated()) {
                        //slice
                        MatrixBlock mb = in.acquireReadAndRelease();
                        IntStream.range(0, federatedRanges.length).forEach(i -> 
{
-                               MatrixBlock sliced = mb
-                                       
.slice(federatedRanges[i].getBeginDimsInt()[0], 
federatedRanges[i].getEndDimsInt()[0] - 1);
+                               MatrixBlock sliced = mb.slice(
+                                       
federatedRanges[i].getBeginDimsInt()[0], federatedRanges[i].getEndDimsInt()[0] 
- 1,
+                                       
federatedRanges[i].getBeginDimsInt()[1], federatedRanges[i].getEndDimsInt()[1] 
- 1);
                                fedDims[i] = (long) sliced.max();
                        });
                        return fedDims;
@@ -326,29 +388,79 @@ public class CtableFEDInstruction extends 
ComputationFEDInstruction {
                return String.join(Lop.OPERAND_DELIMITOR, maxInstParts);
        }
 
+       /**
+        * Static class which extends FederatedUDF to modify the partial 
outputs on
+        * the federated sites such that they can be bound without any local
+        * aggregation.
+        */
        private static class SliceOutput extends FederatedUDF {
 
                private static final long serialVersionUID = 
-2808597461054603816L;
-               private final long _fedSize;
+               private final int _staticDim;
+               private final Long[] _fedDims;
+               private final boolean _reversed;
 
-               protected SliceOutput(long input, long fedSize) {
+               protected SliceOutput(long input, long staticDim, Long[] 
fedDims, boolean reversed) {
                        super(new long[] {input});
-                       _fedSize = fedSize;
+                       _staticDim = (int)staticDim;
+                       _fedDims = fedDims;
+                       _reversed = reversed;
                }
 
+               /**
+                * Find the dimensions of the partial output matrix and expand 
it to the
+                * global static dimension along the non-partitioned axis and 
crop it
+                * along the paritioned axis.
+                *
+                * @param ec the execution context
+                * @param data
+                * @return FederatedResponse with status SUCCESS and an empty 
object
+                */
                public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
                        MatrixObject mo = (MatrixObject) data[0];
                        MatrixBlock mb = mo.acquireReadAndRelease();
 
-                       MatrixBlock sliced = mb.slice((int) 
(mb.getNumRows()-_fedSize), mb.getNumRows()-1);
+                       int beginDim = 0;
+                       int endDim = (_reversed ? mb.getNumRows() : 
mb.getNumColumns());
+                       int localStaticDim = (_reversed ? mb.getNumColumns() : 
mb.getNumRows());
+                       for(int counter = 0; counter < _fedDims.length; 
counter++) {
+                               if(_fedDims[counter] == endDim) {
+                                       beginDim = (counter == 0 ? 0 : 
_fedDims[counter - 1].intValue());
+                                       break;
+                               }
+                       }
+
+                       mb = expandMatrix(mb, localStaticDim);
+
+                       // crop the output
+                       MatrixBlock sliced = _reversed ? mb.slice(beginDim, 
endDim - 1, 0, _staticDim - 1)
+                               : mb.slice(0, _staticDim - 1, beginDim, endDim 
- 1);
                        mo.acquireModify(sliced);
                        mo.release();
 
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, new Object[] {});
                }
+
+               /**
+                * Expand the matrix with zeros up to the specified static 
dimension.
+                *
+                * @param mb the matrix block of the partial output
+                * @param localStaticDim the static dimension of the output 
matrix block
+                * @return MatrixBlock the output matrix block expanded to the 
global static dimension
+                */
+               private MatrixBlock expandMatrix(MatrixBlock mb, int 
localStaticDim) {
+                       int diff = _staticDim - localStaticDim;
+                       if(diff > 0) {
+                               MatrixBlock tmpMb = (_reversed ? new 
MatrixBlock(mb.getNumRows(), diff, (double) 0)
+                                       : new MatrixBlock(diff, 
mb.getNumColumns(), (double) 0));
+                               mb = mb.append(tmpMb, null, _reversed);
+                       }
+                       return mb;
+               }
+
                @Override
                public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
                        return null;
                }
        }
-}
\ No newline at end of file
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
index a5793b5..9aeb776 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedCtableTest.java
@@ -85,6 +85,9 @@ public class FederatedCtableTest extends AutomatedTestBase {
        @Test
        public void federatedCtableMatrixInputSinglenode() { 
runCtable(Types.ExecMode.SINGLE_NODE, false, true); }
 
+       @Test
+       public void federatedCtableMatrixInputFedOutputSingleNode() { 
runCtable(Types.ExecMode.SINGLE_NODE, true, true); }
+
 
        public void runCtable(Types.ExecMode execMode, boolean fedOutput, 
boolean matrixInput) {
                String TEST_NAME = fedOutput ? TEST_NAME2 : TEST_NAME1;
@@ -108,7 +111,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                loadTestConfiguration(config);
 
                if(fedOutput)
-                       runFedCtable(HOME, TEST_NAME, port1, port2, port3, 
port4);
+                       runFedCtable(HOME, TEST_NAME, matrixInput, port1, 
port2, port3, port4);
                else
                        runNonFedCtable(HOME, TEST_NAME, matrixInput, port1, 
port2, port3, port4);
                checkResults();
@@ -155,7 +158,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                runTest(true, false, null, -1);
        }
 
-       private void runFedCtable(String HOME, String TEST_NAME, int port1, int 
port2, int port3, int port4) {
+       private void runFedCtable(String HOME, String TEST_NAME, boolean 
matrixInput, int port1, int port2, int port3, int port4) {
                int r = rows / 4;
                int c = cols;
 
@@ -174,7 +177,8 @@ public class FederatedCtableTest extends AutomatedTestBase {
                fullDMLScriptName = HOME + TEST_NAME2 + "Reference.dml";
                programArgs = new String[]{"-stats", "100", "-args",
                        input("X1"), input("X2"), input("X3"), input("X4"), 
Boolean.toString(reversedInputs).toUpperCase(),
-                       Boolean.toString(weighted).toUpperCase(), 
expected("F")};
+                       Boolean.toString(weighted).toUpperCase(), 
Boolean.toString(matrixInput).toUpperCase(),
+                       expected("F")};
                runTest(true, false, null, -1);
 
                // Run actual dml script with federated matrix
@@ -185,6 +189,7 @@ public class FederatedCtableTest extends AutomatedTestBase {
                        "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
                        "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")),
                        "rows=" + rows, "cols=" + cols, "revIn=" + 
Boolean.toString(reversedInputs).toUpperCase(),
+                       "matrixInput=" + 
Boolean.toString(matrixInput).toUpperCase(),
                        "weighted=" + Boolean.toString(weighted).toUpperCase(), 
"out=" + output("F")
                };
                runTest(true, false, null, -1);
diff --git a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
index 9c21ed5..a2eda9d 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutput.dml
@@ -28,8 +28,14 @@ n = ncol(X);
 
 # prepare offset vectors and one-hot encoded X
 maxs = colMaxs(X);
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($matrixInput) {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
 
 W = rix + cix;
 
diff --git 
a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml 
b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
index e0721df..4fc6852 100644
--- a/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
+++ b/src/test/scripts/functions/federated/FederatedCtableFedOutputReference.dml
@@ -27,8 +27,14 @@ n = ncol(X);
 # prepare offset vectors and one-hot encoded X
 maxs = colMaxs(X);
 
-rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1)
-cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+if($7) { # matrix input
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m, n);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m, n);
+}
+else {
+  rix = matrix(seq(1,m)%*%matrix(1,1,n), m*n, 1);
+  cix = matrix(X + (t(cumsum(t(maxs))) - maxs), m*n, 1);
+}
 
 W = rix + cix;
 
@@ -43,4 +49,4 @@ else
   else
     X2 = table(rix, cix);
 
-write(X2, $7);
+write(X2, $8);

Reply via email to