This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 1357ae1  [SYSTEMDS-3241] Federated quantiles via recursive histograms
1357ae1 is described below

commit 1357ae10fe8ffa48cf1398f740695d14258ec1d7
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Thu Mar 3 21:46:59 2022 +0100

    [SYSTEMDS-3241] Federated quantiles via recursive histograms
    
    Closes #1477.
---
 .../instructions/fed/FEDInstructionUtils.java      |   5 +-
 .../fed/QuantilePickFEDInstruction.java            | 441 ++++++++++++++++++++-
 .../fed/QuantileSortFEDInstruction.java            |  33 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  20 +
 .../primitives/FederatedQuantileTest.java          | 109 +++--
 .../primitives/FederatedQuantileWeightsTest.java   |  99 ++++-
 .../aggregate/FederatedMeanTestReference.dml       |   3 +-
 .../federated/quantile/FederatedIQRTest.dml        |   9 +-
 .../quantile/FederatedIQRTestReference.dml         |   5 +-
 .../federated/quantile/FederatedMedianTest.dml     |   8 +-
 .../quantile/FederatedMedianTestReference.dml      |   6 +-
 .../quantile/FederatedMedianWeightsTest.dml        |   8 +-
 .../FederatedMedianWeightsTestReference.dml        |   5 +-
 .../federated/quantile/FederatedQuantileTest.dml   |   9 +-
 .../quantile/FederatedQuantileTestReference.dml    |   7 +-
 .../quantile/FederatedQuantileWeightsTest.dml      |   8 +-
 .../FederatedQuantileWeightsTestReference.dml      |   8 +-
 .../federated/quantile/FederatedQuantilesTest.dml  |  10 +-
 .../quantile/FederatedQuantilesTestReference.dml   |   8 +-
 19 files changed, 723 insertions(+), 78 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1192ff7..e1c4587 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -38,7 +38,6 @@ import 
org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
 import org.apache.sysds.runtime.instructions.cp.Data;
 import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
@@ -159,7 +158,7 @@ public class FEDInstructionUtils {
                                                
if(instruction.getOpcode().equalsIgnoreCase("cm"))
                                                        fedinst = 
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
                                                else 
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
-                                                       
if(mo1.getFedMapping().getFederatedRanges().length == 1)
+                                                       
if(mo1.isFederated(FType.ROW) || 
mo1.getFedMapping().getFederatedRanges().length == 1 && 
mo1.isFederated(FType.COL))
                                                                fedinst = 
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
                                                }
                                                else 
if(inst.getOpcode().equalsIgnoreCase("rshape"))
@@ -186,7 +185,7 @@ public class FEDInstructionUtils {
                                                fedinst = 
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
                                        else 
if("cov".equals(instruction.getOpcode()) && 
(ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
                                                
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
-                                               fedinst = 
CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction)inst);
+                                               fedinst = 
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
                                        else
                                                fedinst = 
BinaryFEDInstruction.parseInstruction(
                                                        
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index f984967..83a1360 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -20,8 +20,16 @@
 package org.apache.sysds.runtime.instructions.fed;
 
 import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
 import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
 
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.ImmutableTriple;
 import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.lops.PickByCount.OperationTypes;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -41,6 +49,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
+@SuppressWarnings("unchecked")
 public class QuantilePickFEDInstruction extends BinaryFEDInstruction {
 
        private final OperationTypes _type;
@@ -61,6 +70,10 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                this(op, in, in2, out, type, inmem, opcode, istr, 
FederatedOutput.NONE);
        }
 
+       public OperationTypes getQPickType() {
+               return _type;
+       }
+
        public static QuantilePickFEDInstruction parseInstruction ( String str 
) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
@@ -69,7 +82,6 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                //instruction parsing
                if( parts.length == 4 ) {
                        //instructions of length 4 originate from unary - mr-iqm
-                       //TODO this should be refactored to use pickvaluecount 
lops
                        CPOperand in1 = new CPOperand(parts[1]);
                        CPOperand in2 = new CPOperand(parts[2]);
                        CPOperand out = new CPOperand(parts[3]);
@@ -97,6 +109,380 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
 
        @Override
        public void processInstruction(ExecutionContext ec) {
+               
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) || 
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+                       processColumnQPick(ec);
+               else
+                       processRowQPick(ec);
+       }
+
+       public <T> void processRowQPick(ExecutionContext ec) {
+               MatrixObject in = ec.getMatrixObject(input1);
+               FederationMap fedMap = in.getFedMapping();
+               boolean average = _type == OperationTypes.MEDIAN;
+
+               double[] quantiles = input2 != null ? (input2.isMatrix() ? 
ec.getMatrixInput(input2).getDenseBlockValues() :
+                       input2.isScalar() ? new double[] 
{ec.getScalarInput(input2).getDoubleValue()} : null) :
+                       (average ? new double[] {0.5} : _type == 
OperationTypes.IQM ? new double[] {0.25, 0.75} : null);
+
+               if (input2 != null && input2.isMatrix())
+                       ec.releaseMatrixInput(input2.getName());
+
+               // Find min and max
+               long varID = FederationUtils.getNextFedDataID();
+               List<double[]> minMax = new ArrayList<>();
+               fedMap.mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new 
QuantilePickFEDInstruction.MinMax(data.getVarID()))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               double[] rangeMinMax = (double[]) 
response.getData()[0];
+                               minMax.add(rangeMinMax);
+
+                               return null;
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               });
+
+               // Find weights sum, min and max
+               double globalMin = Double.MAX_VALUE, globalMax = 
Double.MIN_VALUE, vectorLength = in.getNumColumns() == 2 ? 0 : in.getNumRows(), 
sumWeights = 0.0;
+               for(double[] values : minMax) {
+                       globalMin = Math.min(globalMin, values[0]);
+                       globalMax = Math.max(globalMax, values[1]);
+                       if(in.getNumColumns() == 2)
+                               vectorLength += values[2];
+                       sumWeights += values[3];
+               }
+
+               // Average for median
+               average = average && (in.getNumColumns() == 2 ? sumWeights : 
in.getNumRows()) % 2 == 0;
+
+               // If multiple quantiles take first histogram and reuse bins, 
otherwise recursively get bin with result
+               int numBuckets = 256; // (int) Math.round(in.getNumRows() / 
2.0);
+               int quantileIndex = quantiles != null && quantiles.length == 1 
? (int) Math.round(vectorLength * quantiles[0]) : -1;
+
+               T ret = createHistogram(in, (int) vectorLength, globalMin, 
globalMax, numBuckets, quantileIndex, average);
+
+               // Compute and set results
+               if(quantiles != null && quantiles.length > 1) {
+                       computeMultipleQuantiles(ec, in, 
(Map<ImmutablePair<Double, Double>, Integer>) ret, quantiles, (int) 
vectorLength, varID, _type);
+               } else
+                       getSingleQuantileResult(ret, ec, fedMap, varID, 
average, false, (int) vectorLength);
+       }
+
+       private <T> void computeMultipleQuantiles(ExecutionContext ec, 
MatrixObject in, Map<ImmutablePair<Double, Double>, Integer> buckets, double[] 
quantiles, int vectorLength, long varID, OperationTypes type) {
+               MatrixBlock out = new MatrixBlock(quantiles.length, 1, false);
+               ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>>[] bucketsWithIndex = new ImmutableTriple[quantiles.length];
+
+               // Find bins with each quantile for first histogram
+               int sizeBeforeTmp = 0, sizeBefore = 0, countFoundBins = 0;
+               for(Map.Entry<ImmutablePair<Double, Double>, Integer> entry : 
buckets.entrySet()) {
+                       sizeBeforeTmp += entry.getValue();
+
+                       for(int i = 0; i < quantiles.length; i++) {
+                               int quantileIndex = (int) 
Math.round(vectorLength * quantiles[i]);
+                               ImmutablePair<Double, Double> bucketWithQ = 
null;
+
+                               if(quantileIndex > sizeBefore && quantileIndex 
<= sizeBeforeTmp) {
+                                       bucketWithQ = entry.getKey();
+                                       bucketsWithIndex[i] = new 
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore, 
entry.getValue(), bucketWithQ);
+                                       countFoundBins++;
+                               }
+                       }
+
+                       sizeBefore = sizeBeforeTmp;
+                       if(countFoundBins == quantiles.length)
+                               break;
+               }
+
+               // Find each quantile bin recursively
+               Map<Integer, T> retBuckets = new HashMap<>();
+
+               double left = 0, right = 0;
+               for(int i = 0; i < bucketsWithIndex.length; i++) {
+                       int nextNumBuckets = bucketsWithIndex[i].middle < 100 ? 
bucketsWithIndex[i].middle * 2 : (int) Math.round(bucketsWithIndex[i].middle / 
2.0);
+                       T hist = createHistogram(in, vectorLength, 
bucketsWithIndex[i].right.left, bucketsWithIndex[i].right.right, 
nextNumBuckets, bucketsWithIndex[i].left, false);
+
+                       if(_type == OperationTypes.IQM) {
+                               left = i == 0 ? hist instanceof ImmutablePair ? 
 ((ImmutablePair<Double, Double>)hist).right : (Double) hist : left;
+                               right = i == 1 ? hist instanceof ImmutablePair 
? ((ImmutablePair<Double, Double>)hist).left : (Double) hist : right;
+                       } else {
+                               if(hist instanceof ImmutablePair)
+                                       retBuckets.put(i, hist); // set value 
if returned double instead of bin
+                               else
+                                       out.setValue(i, 0, (Double) hist);
+                       }
+               }
+
+               if(type == OperationTypes.IQM) {
+                       ImmutablePair<Double, Double> IQMRange = new 
ImmutablePair<>(left, right);
+                       getSingleQuantileResult(IQMRange, ec, 
in.getFedMapping(), varID, false, true, vectorLength);
+               }
+               else {
+                       if(!retBuckets.isEmpty()) {
+                               // Search for values within bucket range where 
it as returned
+                               in.getFedMapping().mapParallel(varID, (range, 
data) -> {
+                                       try {
+                                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                                       
FederatedRequest.RequestType.EXEC_UDF,
+                                                       -1,
+                                                       new 
QuantilePickFEDInstruction.GetValuesInRanges(data.getVarID(), quantiles.length, 
(HashMap<Integer, ImmutablePair<Double, Double>>) retBuckets))).get();
+                                               if(!response.isSuccessful())
+                                                       
response.throwExceptionFromResponse();
+
+                                               // Add results by row
+                                               MatrixBlock tmp = (MatrixBlock) 
response.getData()[0];
+                                               synchronized(out) {
+                                                       
out.binaryOperationsInPlace(InstructionUtils.parseBinaryOperator("+"), tmp);
+                                               }
+                                               return null;
+                                       }
+                                       catch(Exception e) {
+                                               throw new 
DMLRuntimeException(e);
+                                       }
+                               });
+                       }
+
+                       ec.setMatrixOutput(output.getName(), out);
+               }
+       }
+
+       private <T> void getSingleQuantileResult(T ret, ExecutionContext ec, 
FederationMap fedMap, long varID, boolean average, boolean isIQM, int 
vectorLength) {
+               double result = 0.0;
+               if(ret instanceof ImmutablePair) {
+                       // Search for values within bucket range
+                       List<Double> values = new ArrayList<>();
+                       fedMap.mapParallel(varID, (range, data) -> {
+                               try {
+                                       FederatedResponse response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+                                               -1,
+                                               new 
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(), 
(ImmutablePair<Double, Double>) ret, isIQM))).get();
+                                       if(!response.isSuccessful())
+                                               
response.throwExceptionFromResponse();
+                                       values.add((double) 
response.getData()[0]);
+                                       return null;
+                               }
+                               catch(Exception e) {
+                                       throw new DMLRuntimeException(e);
+                               }
+                       });
+
+                       // Sum of 1 or 2 values
+                       result = values.stream().reduce(0.0, Double::sum);
+
+               } else
+                       result = (Double) ret;
+
+               result /= (average ? 2 : isIQM ? ((int) Math.round(vectorLength 
* 0.75) - (int) Math.round(vectorLength * 0.25)) : 1);
+
+               ec.setScalarOutput(output.getName(), new DoubleObject(result));
+       }
+
+       public <T> T createHistogram(MatrixObject in, int vectorLength,  double 
globalMin, double globalMax, int numBuckets, int quantileIndex, boolean 
average) {
+               FederationMap fedMap = in.getFedMapping();
+
+               Map<ImmutablePair<Double, Double>, Integer> buckets = new 
LinkedHashMap<>();
+               List<Map<ImmutablePair<Double, Double>, Integer>> hists = new 
ArrayList<>();
+               List<Set<Double>> distincts = new ArrayList<>();
+
+               double bucketRange = (globalMax-globalMin) / numBuckets;
+               boolean isEvenNumRows = vectorLength % 2 == 0;
+
+               // Create buckets according to min and max
+               double tmpMin = globalMin, tmpMax = globalMax;
+               for(int i = 0; i < numBuckets && tmpMin <= tmpMax; i++) {
+                       buckets.put(new ImmutablePair<>(tmpMin, tmpMin + 
bucketRange), 0);
+                       tmpMin += bucketRange;
+               }
+
+               // Create histograms
+               long varID = FederationUtils.getNextFedDataID();
+               fedMap.mapParallel(varID, (range, data) -> {
+                       try {
+                               FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
+                                       FederatedRequest.RequestType.EXEC_UDF, 
-1,
+                                       new 
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), buckets, 
globalMax))).get();
+                               if(!response.isSuccessful())
+                                       response.throwExceptionFromResponse();
+                               Map<ImmutablePair<Double, Double>, Integer> 
rangeHist = (Map<ImmutablePair<Double, Double>, Integer>) response.getData()[0];
+                               hists.add(rangeHist);
+                               Set<Double> rangeDistinct = (Set<Double>) 
response.getData()[1];
+                               distincts.add(rangeDistinct);
+                               return null;
+                       }
+                       catch(Exception e) {
+                               throw new DMLRuntimeException(e);
+                       }
+               });
+
+               // Merge results into one histogram
+               for(ImmutablePair<Double, Double> bucket : buckets.keySet()) {
+                       int value = 0;
+                       for(Map<ImmutablePair<Double, Double>, Integer> hist : 
hists)
+                               value += hist.get(bucket);
+                       buckets.put(bucket, value);
+               }
+
+               if(quantileIndex == -1)
+                       return (T) buckets;
+
+               // Find bucket with quantile
+               ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> bucketWithIndex = getBucketWithIndex(buckets, quantileIndex, average, 
isEvenNumRows);
+
+               // Check if can terminate
+               Set<Double> distinctValues = 
distincts.stream().flatMap(Set::stream).collect(Collectors.toSet());
+               if((distinctValues.size() == 1 && !average) || 
(distinctValues.size() == 2 && average))
+                       return (T) distinctValues.stream().reduce(0.0, (a, b) 
-> a + b);
+
+               ImmutablePair<Double, Double> finalBucketWithQ = 
bucketWithIndex.right;
+               List<Double> distinctInNewBucket = 
distinctValues.stream().filter( e -> e >= finalBucketWithQ.left && e <= 
finalBucketWithQ.right).collect(Collectors.toList());
+               if((distinctInNewBucket.size() == 1 && !average) || (average && 
distinctInNewBucket.size() == 2))
+                       return (T) distinctInNewBucket.stream().reduce(0.0, (a, 
b) -> a + b);
+
+               if(distinctValues.size() == 1 || (bucketWithIndex.middle == 1 
&& !average) || (bucketWithIndex.middle == 2 && isEvenNumRows && average) ||
+                       globalMin == globalMax)
+                       return (T) bucketWithIndex.right;
+
+               int nextNumBuckets = bucketWithIndex.middle < 100 ? 
bucketWithIndex.middle * 2 : (int) Math.round(bucketWithIndex.middle / 2.0);
+
+               // Add more bins to not stuck
+               if(numBuckets == nextNumBuckets && globalMin == 
bucketWithIndex.right.left && globalMax == bucketWithIndex.right.right) {
+                       nextNumBuckets *= 2;
+               }
+
+               return createHistogram(in, vectorLength, 
bucketWithIndex.right.left, bucketWithIndex.right.right, nextNumBuckets, 
bucketWithIndex.left, average);
+       }
+
+       private ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> getBucketWithIndex(Map<ImmutablePair<Double, Double>, Integer> 
buckets, int quantileIndex, boolean average, boolean isEvenNumRows) {
+               int sizeBeforeTmp = 0, sizeBefore = 0, bucketWithQSize = 0;
+               ImmutablePair<Double, Double> bucketWithQ = null;
+               for(Map.Entry<ImmutablePair<Double, Double>, Integer> range : 
buckets.entrySet()) {
+                       sizeBeforeTmp += range.getValue();
+                       if(quantileIndex <= sizeBeforeTmp && bucketWithQSize == 
0) {
+                               bucketWithQ = range.getKey();
+                               bucketWithQSize = range.getValue();
+                               sizeBeforeTmp -= bucketWithQSize;
+                               sizeBefore = sizeBeforeTmp;
+
+                               if(!average || sizeBefore + bucketWithQSize >= 
quantileIndex + 1)
+                                       break;
+                       } else if(quantileIndex + 1 <= sizeBeforeTmp + 
bucketWithQSize && isEvenNumRows && average) {
+                               // Add right bin that contains second index
+                               int bucket2Size = range.getValue();
+                               if (bucket2Size != 0) {
+                                       bucketWithQ = new 
ImmutablePair<>(bucketWithQ.left, range.getKey().right);
+                                       bucketWithQSize += bucket2Size;
+                                       break;
+                               }
+                       }
+               }
+               quantileIndex = quantileIndex == 1 ? 1 : quantileIndex - 
sizeBefore;
+               return new ImmutableTriple<>(quantileIndex, bucketWithQSize, 
bucketWithQ);
+       }
+
+       public static class GetHistogram extends FederatedUDF {
+               private static final long serialVersionUID = 
5413355823424777742L;
+               private final Map<ImmutablePair<Double, Double>, Integer> 
_buckets;
+               private final double _max;
+
+               private GetHistogram(long input, Map<ImmutablePair<Double, 
Double>, Integer> buckets, double max) {
+                       super(new long[] {input});
+                       _buckets = buckets;
+                       _max = max;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       double[] values = mb.getDenseBlockValues();
+                       boolean isWeighted  = mb.getNumColumns() == 2;
+
+                       Map<ImmutablePair<Double, Double>, Integer> hist = 
_buckets;
+                       Set<Double> distinct = new HashSet<>();
+
+                       for(int i = 0; i < values.length - (isWeighted ? 1 : 
0); i += (isWeighted ? 2 : 1)) {
+                               double val = values[i];
+                               int weight = isWeighted ? (int) values[i+1] : 1;
+                               for (Map.Entry<ImmutablePair<Double, Double>, 
Integer> range : _buckets.entrySet()) {
+                                       if((val >= range.getKey().left && val < 
range.getKey().right) || (val == _max && val == range.getKey().right)) {
+                                               hist.put(range.getKey(), 
range.getValue() + weight);
+
+                                               distinct.add(val);
+                                       }
+                               }
+                       }
+
+                       Object[] ret = new Object[] {hist, distinct.size() < 3 
? distinct : new HashSet<>()};
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       public static class GetValuesInRanges extends FederatedUDF {
+               private static final long serialVersionUID = 
8663298932616139153L;
+               private final int _numQuantiles;
+               private final HashMap<Integer, ImmutablePair<Double, Double>> 
_ranges;
+
+               private GetValuesInRanges(long input,int numQuantiles, 
HashMap<Integer, ImmutablePair<Double, Double>> ranges) {
+                       super(new long[] {input});
+                       _ranges = ranges;
+                       _numQuantiles = numQuantiles;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       double[] values = mb.getDenseBlockValues();
+
+                       // FIXME rewrite - see binning encode
+                       MatrixBlock res = new MatrixBlock(_numQuantiles, 1, 
false);
+                       for(double val : values) {
+                               for(Map.Entry<Integer, ImmutablePair<Double, 
Double>> entry : _ranges.entrySet()) {
+                                       if(entry.getValue().left <= val && val 
<= entry.getValue().right) {
+                                               res.setValue(entry.getKey(), 
0,val);
+                                               break;
+                                       }
+                               }
+                       }
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+               }
+
+               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       public static class MinMax extends FederatedUDF {
+               private static final long serialVersionUID = 
-3906698363866500744L;
+
+               private MinMax(long input) {
+                       super(new long[] {input});
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       double[] ret = new double[]{mb.getNumColumns() == 2 ? 
mb.colMin().quickGetValue(0, 0) : mb.min(),
+                               mb.getNumColumns() == 2 ? 
mb.colMax().quickGetValue(0, 0) : mb.max(),
+                               mb.getNumColumns() == 2 ? 
mb.colSum().quickGetValue(0, 1) : 0,
+                               mb.getNumColumns() == 2 ? 
mb.sumWeightForQuantile() : 0};
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
+               }
+
+               @Override public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       public void processColumnQPick(ExecutionContext ec) {
                MatrixObject in = ec.getMatrixObject(input1);
                FederationMap fedMapping = in.getFedMapping();
 
@@ -125,12 +511,12 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                                                response = data
                                                        
.executeFederatedOperation(
                                                                new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                                               new 
QuantilePickFEDInstruction.IQM(data.getVarID()))).get();
+                                                               new 
QuantilePickFEDInstruction.ColIQM(data.getVarID()))).get();
                                                break;
                                        case MEDIAN:
                                                response = data
                                                        
.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
-                                                               new 
QuantilePickFEDInstruction.Median(data.getVarID()))).get();
+                                                               new 
QuantilePickFEDInstruction.ColMedian(data.getVarID()))).get();
                                                break;
                                        default:
                                                throw new 
DMLRuntimeException("Unsupported qpick operation type: "+_type);
@@ -149,6 +535,9 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
 
                assert res.size() == 1;
 
+               if (input2 != null && input2.isMatrix())
+                       ec.releaseMatrixInput(input2.getName());
+
                if(output.isScalar())
                        ec.setScalarOutput(output.getName(), new 
DoubleObject((double) res.get(0)));
                else
@@ -190,11 +579,49 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                }
        }
 
-       private static class IQM extends FederatedUDF {
+       public static class GetValuesInRange extends FederatedUDF {
+               private static final long serialVersionUID = 
5413355823424777742L;
+               private final ImmutablePair<Double, Double> _range;
+               private final boolean _sumInRange;
+
+               private GetValuesInRange(long input, ImmutablePair<Double, 
Double> range, boolean sumInRange) {
+                       super(new long[] {input});
+                       _range = range;
+                       _sumInRange = sumInRange;
+               }
+
+               @Override
+               public FederatedResponse execute(ExecutionContext ec, Data... 
data) {
+                       MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
+                       double[] values = mb.getDenseBlockValues();
+
+                       double res = 0.0;
+                       int i = 0;
+
+                       // FIXME better search, e.g. sort in QSort and binary 
search
+                       for(double val : values) {
+                               // different conditions for IQM and simple QPICK
+                               if((!_sumInRange && _range.left <= val && val 
<= _range.right) ||
+                                       (_sumInRange && _range.left < val && 
val <= _range.right))
+                                       res += val;
+                               if(i++ > 2 && !_sumInRange)
+                                       break;
+                       }
+
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+               }
+
+               @Override
+               public Pair<String, LineageItem> 
getLineageItem(ExecutionContext ec) {
+                       return null;
+               }
+       }
+
+       private static class ColIQM extends FederatedUDF {
 
                private static final long serialVersionUID = 
2223186699111957677L;
 
-               protected IQM(long input) {
+               protected ColIQM(long input) {
                        super(new long[] {input});
                }
                @Override
@@ -209,11 +636,11 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                }
        }
 
-       private static class Median extends FederatedUDF {
+       private static class ColMedian extends FederatedUDF {
 
                private static final long serialVersionUID = 
-2808597461054603816L;
 
-               protected Median(long input) {
+               protected ColMedian(long input) {
                        super(new long[] {input});
                }
                @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index 0a545bc..cb76404 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -75,9 +75,40 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
                        throw new DMLRuntimeException("Unknown opcode while 
parsing a QuantileSortFEDInstruction: " + str);
                }
        }
-
        @Override
        public void processInstruction(ExecutionContext ec) {
+               
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) || 
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+                       processColumnQSort(ec);
+               else
+                       processRowQSort(ec);
+       }
+
+       public void processRowQSort(ExecutionContext ec) {
+               MatrixObject in = ec.getMatrixObject(input1);
+               MatrixObject out = ec.getMatrixObject(output);
+
+               // TODO make sure that qsort result is used by qpick only where 
the main operation happens
+               if(input2 != null) {
+                       MatrixObject weights = ec.getMatrixObject(input2);
+                       String newInst = 
InstructionUtils.replaceOperand(instString, 1, "append");
+                       newInst = InstructionUtils.concatOperands(newInst, 
"true");
+                       FederatedRequest[] fr1 = 
in.getFedMapping().broadcastSliced(weights, false);
+                       FederatedRequest fr2 = 
FederationUtils.callInstruction(newInst, output,
+                               new CPOperand[]{input1, input2}, new long[]{ 
in.getFedMapping().getID(), fr1[0].getID()});
+                       in.getFedMapping().execute(getTID(), true, fr1, fr2);
+                       
out.getDataCharacteristics().set(in.getDataCharacteristics());
+                       out.getDataCharacteristics().setCols(2);
+                       
out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID(), 2));
+               }
+               else {
+                       // make a copy without sorting
+                       long id = FederationUtils.getNextFedDataID();
+                       
out.getDataCharacteristics().set(in.getDataCharacteristics());
+                       
out.setFedMapping(in.getFedMapping().identCopy(getTID(), id));
+               }
+       }
+
+       public void processColumnQSort(ExecutionContext ec) {
                MatrixObject in = ec.getMatrixObject(input1);
                FederationMap fedMapping = in.getFedMapping();
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 989c8c4..eb152ab 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -952,6 +952,26 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock, Externalizab
                        
InstructionUtils.parseBasicAggregateUnaryOperator("uamin", 1));
                return out.quickGetValue(0, 0);
        }
+
+       /**
+        * Wrapper method for reduceall-colMin of a matrix.
+        *
+        * @return A new MatrixBlock containing the column mins of this matrix
+        */
+       public MatrixBlock colMin() {
+               AggregateUnaryOperator op = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin", 1);
+               return aggregateUnaryOperations(op, null, 1000, null, true);
+       }
+
+       /**
+        * Wrapper method for reduceall-colMin of a matrix.
+        *
+        * @return A new MatrixBlock containing the column mins of this matrix
+        */
+       public MatrixBlock colMax() {
+               AggregateUnaryOperator op = 
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax", 1);
+               return aggregateUnaryOperations(op, null, 1000, null, true);
+       }
        
        /**
         * Wrapper method for reduceall-max of a matrix.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
index 226ad53..598256a 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -41,17 +41,24 @@ public class FederatedQuantileTest extends 
AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/quantile/";
        private final static String TEST_NAME1 = "FederatedQuantileTest";
        private final static String TEST_NAME2 = "FederatedMedianTest";
-       private final static String TEST_NAME3 = "FederatedIQMTest";
+       private final static String TEST_NAME3 = "FederatedIQRTest";
        private final static String TEST_NAME4 = "FederatedQuantilesTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedQuantileTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
        @Parameterized.Parameter()
        public int rows;
+       @Parameterized.Parameter(1)
+       public int cols;
+       @Parameterized.Parameter(2)
+       public boolean rowPartitioned;
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{1000}});
+               return Arrays.asList(new Object[][] {
+                       {1000, 1, false},
+                       {16, 1, true}
+               });
        }
 
        @Override
@@ -76,38 +83,28 @@ public class FederatedQuantileTest extends 
AutomatedTestBase {
        public void federatedMedianCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
 
        @Test
-       public void federatedIQMCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+       public void federatedIQRCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME3, -1); }
 
        @Test
-       public void federatedQuantilesCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+       public void federatedQuantilesCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME4, -1); }
 
        @Test
-//     @Ignore
        public void federatedQuantile1SP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.25); }
 
        @Test
-//     @Ignore
        public void federatedQuantile2SP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.5); }
 
        @Test
-//     @Ignore
        public void federatedQuantile3SP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.75); }
 
        @Test
-//     @Ignore
        public void federatedMedianSP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME2, -1); }
 
        @Test
-//     @Ignore
-       public void federatedIQMSP() { federatedQuartile(Types.ExecMode.SPARK, 
TEST_NAME1, -1); }
+       public void federatedIQRSP() { federatedQuartile(Types.ExecMode.SPARK, 
TEST_NAME3, -1); }
 
        @Test
-//     @Ignore
-       public void federatedQuantilesSP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
-
-
-
-
+       public void federatedQuantilesSP() { 
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME4, -1); }
 
        public void federatedQuartile(Types.ExecMode execMode, String 
TEST_NAME, double p) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
@@ -116,21 +113,72 @@ public class FederatedQuantileTest extends 
AutomatedTestBase {
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
-               double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
-
-               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, 
blocksize, rows);
-               writeInputMatrixWithMTD("X1", X1, false, mc);
+               double[][] X1, X2, X3, X4;
+               int port1, port2, port3, port4;
+               Thread t1 = null, t2 = null, t3 = null, t4 = null;
+               String[] programArgs1, programArgs2;
+               if(rowPartitioned) {
+                       X1 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 3);
+                       X2 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 7);
+                       X3 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 8);
+                       X4 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 9);
+
+                       MatrixCharacteristics mc1 = new 
MatrixCharacteristics(rows / 4, 1, blocksize, rows);
+                       writeInputMatrixWithMTD("X1", X1, false, mc1);
+                       writeInputMatrixWithMTD("X2", X2, false, mc1);
+                       writeInputMatrixWithMTD("X3", X3, false, mc1);
+                       writeInputMatrixWithMTD("X4", X4, false, mc1);
+
+                       port1 = getRandomAvailablePort();
+                       port2 = getRandomAvailablePort();
+                       port3 = getRandomAvailablePort();
+                       port4 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2, 
FED_WORKER_WAIT_S);
+                       t3 = startLocalFedWorkerThread(port3, 
FED_WORKER_WAIT_S);
+                       t4 = startLocalFedWorkerThread(port4);
+
+                       programArgs1 = new String[] {"-explain", "-stats", 
"100", "-args",
+                               String.valueOf(p), expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase(),
+                               input("X1"), input("X2"), input("X3"), 
input("X4")};
+                       programArgs2 = new String[] {"-explain","-stats", 
"100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + cols,
+                               "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "p=" + String.valueOf(p),
+                               "out_S=" + output("S")};
+               }
+               else {
+                       X1 = getRandomMatrix(rows, 1, 1, 12, 1, 3);
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(rows, 1, blocksize, rows);
+                       writeInputMatrixWithMTD("X1", X1, false, mc);
+
+                       port1 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1);
+
+                       programArgs1 = new String[] {"-explain", "-stats", 
"100", "-args",
+                               String.valueOf(p), expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase(), input("X1"),
+                               input("X1"), input("X1"), input("X1")};
+                       programArgs2 = new String[] {"-explain", "-stats", 
"100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X3=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X4=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "rows=" + rows, "cols=" + cols, "p=" + 
String.valueOf(p),
+                               "out_S=" + output("S"), "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase()
+                       };
+               }
 
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
-               int port1 = getRandomAvailablePort();
-               Thread t1 = startLocalFedWorkerThread(port1);
 
                // we need the reference file to not be written to hdfs, so we 
get the correct format
                rtplatform = Types.ExecMode.SINGLE_NODE;
                // Run reference dml script with normal matrix for Row/Col
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-explain", "-stats", "100", 
"-args", input("X1"), expected("S"), String.valueOf(p)};
+
+               programArgs = programArgs1;
                runTest(true, false, null, -1);
 
                // reference file should not be written to hdfs, so we set 
platform here
@@ -142,11 +190,7 @@ public class FederatedQuantileTest extends 
AutomatedTestBase {
                loadTestConfiguration(config);
 
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-explain", "-stats", "100", 
"-nvargs",
-                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
-                       "rows=" + rows, "cols=" + 1, "p=" + String.valueOf(p),
-                       "out_S=" + output("S")
-               };
+               programArgs = programArgs2;
                runTest(true, false, null, -1);
 
                // compare all sums via files
@@ -156,8 +200,15 @@ public class FederatedQuantileTest extends 
AutomatedTestBase {
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-
                TestUtils.shutdownThreads(t1);
+               if(rowPartitioned) {
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+                       TestUtils.shutdownThreads(t2, t3, t4);
+               }
+
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
index 4511c6d..9ac0b5d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
@@ -41,17 +41,19 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/quantile/";
        private final static String TEST_NAME1 = "FederatedQuantileWeightsTest";
        private final static String TEST_NAME2 = "FederatedMedianWeightsTest";
-       private final static String TEST_NAME3 = "FederatedIQMWeightsTest";
-       private final static String TEST_NAME4 = 
"FederatedQuantilesWeightsTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedQuantileWeightsTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
        @Parameterized.Parameter()
        public int rows;
+       @Parameterized.Parameter(1)
+       public boolean rowPartitioned;
 
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
-               return Arrays.asList(new Object[][] {{1000}});
+               return Arrays.asList(new Object[][] {
+                       {1000, false},
+                       {12, true}});
        }
 
        @Override
@@ -59,8 +61,6 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
-               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
-               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
        }
 
        @Test
@@ -88,24 +88,81 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
                getAndLoadTestConfiguration(TEST_NAME);
                String HOME = SCRIPT_DIR + TEST_DIR;
 
-               double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
+               double[][] X1, X2, X3, X4;
+               int port1, port2, port3, port4;
+               Thread t1 = null, t2 = null, t3 = null, t4 = null;
+               String[] programArgs1, programArgs2;
 
-               MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1, 
blocksize, rows);
-               writeInputMatrixWithMTD("X1", X1, false, mc);
-
-               double[][] W = getRandomMatrix(rows, 1, 1, 1, 1.0, 1);
+               double[][] W = getRandomMatrix(rows, 1, 1, 5, 1.0, 1);
+               for(int i = 0; i < W.length; i++){
+                       for(int y = 0; y < W[0].length; y++){
+                               W[i][y] = (double) Math.round(W[i][y]);
+                       }
+               }
                writeInputMatrixWithMTD("W", W, false);
 
+               if(rowPartitioned) {
+                       X1 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 3);
+                       X2 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 7);
+                       X3 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 8);
+                       X4 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 9);
+
+                       MatrixCharacteristics mc1 = new 
MatrixCharacteristics(rows / 4, 1, blocksize, rows);
+                       writeInputMatrixWithMTD("X1", X1, false, mc1);
+                       writeInputMatrixWithMTD("X2", X2, false, mc1);
+                       writeInputMatrixWithMTD("X3", X3, false, mc1);
+                       writeInputMatrixWithMTD("X4", X4, false, mc1);
+
+                       port1 = getRandomAvailablePort();
+                       port2 = getRandomAvailablePort();
+                       port3 = getRandomAvailablePort();
+                       port4 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1, 
FED_WORKER_WAIT_S);
+                       t2 = startLocalFedWorkerThread(port2, 
FED_WORKER_WAIT_S);
+                       t3 = startLocalFedWorkerThread(port3, 
FED_WORKER_WAIT_S);
+                       t4 = startLocalFedWorkerThread(port4);
+
+                       programArgs1 = new String[] {"-explain", "-stats", 
"100", "-args",
+                               String.valueOf(p), expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase(), input("W"),
+                               input("X1"), input("X2"), input("X3"), 
input("X4")};
+                       programArgs2 = new String[] {"-explain","-stats", 
"100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port2, 
input("X2")),
+                               "in_X3=" + TestUtils.federatedAddress(port3, 
input("X3")),
+                               "in_X4=" + TestUtils.federatedAddress(port4, 
input("X4")), "rows=" + rows, "cols=" + 1,
+                               "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase(), "p=" + String.valueOf(p), "W=" 
+ input("W"),
+                               "out_S=" + output("S")};
+               }
+               else {
+                       X1 = getRandomMatrix(rows, 1, 1, 12, 1, 3);
+                       MatrixCharacteristics mc = new 
MatrixCharacteristics(rows, 1, blocksize, rows);
+                       writeInputMatrixWithMTD("X1", X1, false, mc);
+
+                       port1 = getRandomAvailablePort();
+                       t1 = startLocalFedWorkerThread(port1);
+
+                       programArgs1 = new String[] {"-explain", "-stats", 
"100", "-args",
+                               String.valueOf(p), expected("S"), 
Boolean.toString(rowPartitioned).toUpperCase(), input("W"), input("X1"),
+                               input("X1"), input("X1"), input("X1")};
+                       programArgs2 = new String[] {"-explain", "-stats", 
"100", "-nvargs",
+                               "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X2=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X3=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "in_X4=" + TestUtils.federatedAddress(port1, 
input("X1")),
+                               "rows=" + rows, "cols=" + 1, "p=" + 
String.valueOf(p), "W=" + input("W"),
+                               "out_S=" + output("S"), "rP=" + 
Boolean.toString(rowPartitioned).toUpperCase()
+                       };
+               }
+
                // empty script name because we don't execute any script, just 
start the worker
                fullDMLScriptName = "";
-               int port1 = getRandomAvailablePort();
-               Thread t1 = startLocalFedWorkerThread(port1);
 
                // we need the reference file to not be written to hdfs, so we 
get the correct format
                rtplatform = Types.ExecMode.SINGLE_NODE;
                // Run reference dml script with normal matrix for Row/Col
                fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
-               programArgs = new String[] {"-explain", "-stats", "100", 
"-args", input("X1"), expected("S"), String.valueOf(p), input("W")};
+
+               programArgs = programArgs1;
                runTest(true, false, null, -1);
 
                // reference file should not be written to hdfs, so we set 
platform here
@@ -117,12 +174,7 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
                loadTestConfiguration(config);
 
                fullDMLScriptName = HOME + TEST_NAME + ".dml";
-               programArgs = new String[] {"-explain", "-stats", "100", 
"-nvargs",
-                       "in_X1=" + TestUtils.federatedAddress(port1, 
input("X1")),
-                       "rows=" + rows, "cols=" + 1,
-                       "p=" + String.valueOf(p), "W=" + input("W"),
-                       "out_S=" + output("S")
-               };
+               programArgs = programArgs2;
                runTest(true, false, null, -1);
 
                // compare all sums via files
@@ -132,8 +184,15 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
 
                // check that federated input files are still existing
                Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-
                TestUtils.shutdownThreads(t1);
+               if(rowPartitioned) {
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+                       
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+                       TestUtils.shutdownThreads(t2, t3, t4);
+               }
+
                rtplatform = platformOld;
                DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
        }
diff --git 
a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml 
b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
index cb17fc3..bdd2518 100644
--- 
a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
+++ 
b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
@@ -23,7 +23,6 @@ if($6) {
   A = rbind(read($1), read($2), read($3), read($4));
 }
 else { A = cbind(read($1), read($2), read($3), read($4)); }
-#A = read($1)
-s = mean(A);
 
+s = mean(A);
 write(s, $5);
diff --git a/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
index f50e72e..e6bed3f 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
@@ -19,6 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
+
 s = interQuartileMean(A);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml 
b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
index 84fcace..c6ca4db 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
@@ -19,6 +19,9 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
+if($3) {
+       A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
 s = interQuartileMean(A);
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
index 22f2504..25800b5 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
@@ -19,6 +19,12 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 s = median(A);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
index 1544987..ee1ede8 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
@@ -19,6 +19,10 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
+print($3)
+if($3) {
+       A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
 s = median(A);
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
index 58ec328..34fc30c 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
@@ -19,7 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 W = read($W);
 s = median(A, W);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
index 6b9e3de..8f9cf3c 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
@@ -19,7 +19,10 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
 W = read($4);
+if($3) {
+       A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
 s = median(A, W);
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
index 1c84330..ebe1f27 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
@@ -19,6 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+  A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+    ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+    list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+  A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
+
 s = quantile(A, $p);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
index 7a1fb36..4e371b0 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
@@ -19,6 +19,9 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
-s = quantile (A, $3);
+if($3) {
+  A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
+s = quantile (A, $1);
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
index c423a65..3eedb7e 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
@@ -19,7 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 W = read($W);
 s = quantile(A, W, $p);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
index 6796757..733e68d 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
 W = read($4);
-s = quantile (A, W, $3);
+if($3) {
+       A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
+
+s = quantile (A, W, $1);
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
index f5f22fa..8f11bc6 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
@@ -19,9 +19,15 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 P = matrix(0.25, 3, 1);
 P[2,1] = 0.5;
 P[3,1] = 0.75;
-s = quantile(X=A, P=P);
+s = quantile(A, P);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
index 96970d0..e4d6b71 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
@@ -19,9 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
+if($3) {
+  A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
 P = matrix(0.25, 3, 1);
 P[2,1] = 0.5;
 P[3,1] = 0.75;
-s = quantile(X=A, P=P);
+s = quantile(A, P);
+print(toString(s))
 write(s, $2);

Reply via email to