This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new e3c60ff  [SYSTEMDS-3305] Performance federated quantiles
e3c60ff is described below

commit e3c60ffab8f4f327f9f7c3af0c12ca777526f32e
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Tue Mar 22 12:53:26 2022 +0100

    [SYSTEMDS-3305] Performance federated quantiles
    
    Closes #1558.
---
 .../controlprogram/federated/FederationUtils.java  |  21 ++-
 .../fed/AggregateUnaryFEDInstruction.java          |   4 +-
 .../fed/QuantilePickFEDInstruction.java            | 201 +++++++++++++--------
 .../fed/QuantileSortFEDInstruction.java            |   1 -
 .../primitives/FederatedQuantileTest.java          |   4 +-
 .../primitives/FederatedQuantileWeightsTest.java   |  10 +-
 .../federated/quantile/FederatedIQRWeightsTest.dml |   9 +-
 .../quantile/FederatedIQRWeightsTestReference.dml  |   6 +-
 .../quantile/FederatedQuantilesWeightsTest.dml     |  10 +-
 .../FederatedQuantilesWeightsTestReference.dml     |   9 +-
 10 files changed, 174 insertions(+), 101 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 606b6c0..a31d8be 100644
--- 
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -371,16 +371,19 @@ public class FederationUtils {
                                return new DoubleObject(aggMean(ffr, 
map).getValue(0,0));
                        }
                        else if(aop.aggOp.increOp.fn instanceof CM) {
-                               double var = ((ScalarObject) 
ffr[0].get().getData()[0]).getDoubleValue();
-                               double mean = ((ScalarObject) 
meanFfr[0].get().getData()[0]).getDoubleValue();
-                               long size = 
map.getFederatedRanges()[0].getSize();
-                               for(int i = 0; i < ffr.length - 1; i++) {
-                                       long l = size + 
map.getFederatedRanges()[i+1].getSize();
-                                       double k = ((size * var) + 
(map.getFederatedRanges()[i+1].getSize() * ((ScalarObject) 
ffr[i+1].get().getData()[0]).getDoubleValue())) / l;
-                                       var = k + (size * 
map.getFederatedRanges()[i+1].getSize()) * Math.pow((mean - ((ScalarObject) 
meanFfr[i+1].get().getData()[0]).getDoubleValue()) / l, 2);
-                                       mean = (mean *  size + ((ScalarObject) 
meanFfr[i+1].get().getData()[0]).getDoubleValue() * 
(map.getFederatedRanges()[i+1].getSize())) / l;
-                                       size = l;
+                               long size1 = 
map.getFederatedRanges()[0].getSize();
+                               double mean1 = ((ScalarObject) 
meanFfr[0].get().getData()[0]).getDoubleValue();
+                               double squaredM1 = ((ScalarObject) 
ffr[0].get().getData()[0]).getDoubleValue() * (size1 - 1);
+                               for(int i = 1; i < ffr.length; i++) {
+                                       long size2 = 
map.getFederatedRanges()[i].getSize();
+                                       double delta = ((ScalarObject) 
meanFfr[i].get().getData()[0]).getDoubleValue() - mean1;
+                                       double squaredM2 =  ((ScalarObject) 
ffr[i].get().getData()[0]).getDoubleValue() * (size2 - 1);
+                                       squaredM1 = squaredM1 + squaredM2 + 
(Math.pow(delta, 2) * size1 * size2 / (size1 + size2));
+
+                                       size1 += size2;
+                                       mean1 = mean1 + delta * size2 / size1;
                                }
+                               double var = squaredM1 / (size1 - 1);
                                return new DoubleObject(var);
 
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 668fd0b..7e2ca2a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -250,7 +250,7 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                        FederatedRequest meanFr1 =  
FederationUtils.callInstruction(meanInstr, output, id,
                                new CPOperand[]{input1}, new 
long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP, 
isSpark);
                        FederatedRequest meanFr2 = new 
FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
-                       meanTmp = map.execute(getTID(), isSpark ?
+                       meanTmp = map.execute(getTID(), true, isSpark ?
                                new FederatedRequest[] {tmpRequest, meanFr1, 
meanFr2} :
                                new FederatedRequest[] {meanFr1, meanFr2});
                }
@@ -261,7 +261,7 @@ public class AggregateUnaryFEDInstruction extends 
UnaryFEDInstruction {
                FederatedRequest fr2 = new 
FederatedRequest(RequestType.GET_VAR, fr1.getID());
                
                //execute federated commands and cleanups
-               Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark 
?
+               Future<FederatedResponse>[] tmp = map.execute(getTID(), true, 
isSpark ?
                        new FederatedRequest[] {tmpRequest, fr1, fr2} :
                        new FederatedRequest[] { fr1, fr2});
                if( output.isScalar() )
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index c6a5b08..f082173 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -22,11 +22,11 @@ package org.apache.sysds.runtime.instructions.fed;
 import java.util.ArrayList;
 import java.util.HashMap;
 import java.util.HashSet;
-import java.util.LinkedHashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.Set;
 import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 import org.apache.commons.lang3.tuple.ImmutablePair;
 import org.apache.commons.lang3.tuple.ImmutableTriple;
@@ -71,10 +71,6 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                this(op, in, in2, out, type, inmem, opcode, istr, 
FederatedOutput.NONE);
        }
 
-       public OperationTypes getQPickType() {
-               return _type;
-       }
-
        public static QuantilePickFEDInstruction parseInstruction ( String str 
) {
                String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
                String opcode = parts[0];
@@ -169,27 +165,28 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
 
                // Compute and set results
                if(quantiles != null && quantiles.length > 1) {
-                       computeMultipleQuantiles(ec, in, 
(Map<ImmutablePair<Double, Double>, Integer>) ret, quantiles, (int) 
vectorLength, varID, _type);
+                       computeMultipleQuantiles(ec, in, (int[]) ret, 
quantiles, (int) vectorLength, varID, (globalMax-globalMin) / numBuckets, 
globalMin, _type);
                } else
-                       getSingleQuantileResult(ret, ec, fedMap, varID, 
average, false, (int) vectorLength);
+                       getSingleQuantileResult(ret, ec, fedMap, varID, 
average, false, (int) vectorLength, null);
        }
 
-       private <T> void computeMultipleQuantiles(ExecutionContext ec, 
MatrixObject in, Map<ImmutablePair<Double, Double>, Integer> buckets, double[] 
quantiles, int vectorLength, long varID, OperationTypes type) {
+       private <T> void computeMultipleQuantiles(ExecutionContext ec, 
MatrixObject in, int[] bucketsFrequencies, double[] quantiles,
+               int vectorLength, long varID, double bucketRange, double min, 
OperationTypes type) {
                MatrixBlock out = new MatrixBlock(quantiles.length, 1, false);
                ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>>[] bucketsWithIndex = new ImmutableTriple[quantiles.length];
 
                // Find bins with each quantile for first histogram
                int sizeBeforeTmp = 0, sizeBefore = 0, countFoundBins = 0;
-               for(Map.Entry<ImmutablePair<Double, Double>, Integer> entry : 
buckets.entrySet()) {
-                       sizeBeforeTmp += entry.getValue();
+               for(int j = 0; j < bucketsFrequencies.length; j++) {
+                       sizeBeforeTmp += bucketsFrequencies[j];
 
                        for(int i = 0; i < quantiles.length; i++) {
                                int quantileIndex = (int) 
Math.round(vectorLength * quantiles[i]);
-                               ImmutablePair<Double, Double> bucketWithQ = 
null;
+                               ImmutablePair<Double, Double> bucketWithQ;
 
                                if(quantileIndex > sizeBefore && quantileIndex 
<= sizeBeforeTmp) {
-                                       bucketWithQ = entry.getKey();
-                                       bucketsWithIndex[i] = new 
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore, 
entry.getValue(), bucketWithQ);
+                                       bucketWithQ = new ImmutablePair<>(min + 
(j * bucketRange), min + ((j+1) * bucketRange));
+                                       bucketsWithIndex[i] = new 
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore, 
bucketsFrequencies[j], bucketWithQ);
                                        countFoundBins++;
                                }
                        }
@@ -202,14 +199,16 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                // Find each quantile bin recursively
                Map<Integer, T> retBuckets = new HashMap<>();
 
-               double left = 0, right = 0;
+               double q25Left = 0, q25Right = 0, q75Left = 0, q75Right = 0;
                for(int i = 0; i < bucketsWithIndex.length; i++) {
                        int nextNumBuckets = bucketsWithIndex[i].middle < 100 ? 
bucketsWithIndex[i].middle * 2 : (int) Math.round(bucketsWithIndex[i].middle / 
2.0);
                        T hist = createHistogram(in, vectorLength, 
bucketsWithIndex[i].right.left, bucketsWithIndex[i].right.right, 
nextNumBuckets, bucketsWithIndex[i].left, false);
 
                        if(_type == OperationTypes.IQM) {
-                               left = i == 0 ? hist instanceof ImmutablePair ? 
 ((ImmutablePair<Double, Double>)hist).right : (Double) hist : left;
-                               right = i == 1 ? hist instanceof ImmutablePair 
? ((ImmutablePair<Double, Double>)hist).left : (Double) hist : right;
+                               q25Right = i == 0 ? hist instanceof 
ImmutablePair ?  ((ImmutablePair<Double, Double>)hist).right : (Double) hist : 
q25Right;
+                               q25Left = i == 0 ? hist instanceof 
ImmutablePair ?  ((ImmutablePair<Double, Double>)hist).left : (Double) hist : 
q25Left;
+                               q75Right = i == 1 ? hist instanceof 
ImmutablePair ? ((ImmutablePair<Double, Double>)hist).right : (Double) hist : 
q75Right;
+                               q75Left = i == 1 ? hist instanceof 
ImmutablePair ?  ((ImmutablePair<Double, Double>)hist).left : (Double) hist : 
q75Left;
                        } else {
                                if(hist instanceof ImmutablePair)
                                        retBuckets.put(i, hist); // set value 
if returned double instead of bin
@@ -219,8 +218,11 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                }
 
                if(type == OperationTypes.IQM) {
-                       ImmutablePair<Double, Double> IQMRange = new 
ImmutablePair<>(left, right);
-                       getSingleQuantileResult(IQMRange, ec, 
in.getFedMapping(), varID, false, true, vectorLength);
+                       ImmutablePair<Double, Double> IQMRange = new 
ImmutablePair<>(q25Right, q75Right);
+                       if(q25Right == q75Right)
+                               ec.setScalarOutput(output.getName(), new 
DoubleObject(q25Left));
+                       else
+                               getSingleQuantileResult(IQMRange, ec, 
in.getFedMapping(), varID, false, true, vectorLength, new 
ImmutablePair<>(q25Left, q75Left));
                }
                else {
                        if(!retBuckets.isEmpty()) {
@@ -251,19 +253,22 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                }
        }
 
-       private <T> void getSingleQuantileResult(T ret, ExecutionContext ec, 
FederationMap fedMap, long varID, boolean average, boolean isIQM, int 
vectorLength) {
-               double result = 0.0;
+       private <T> void getSingleQuantileResult(T ret, ExecutionContext ec, 
FederationMap fedMap, long varID, boolean average, boolean isIQM, int 
vectorLength, ImmutablePair<Double, Double> iqmRange) {
+               double result = 0.0, q25Part = 0, q25Val = 0, q75Val = 0, 
q75Part = 0;
                if(ret instanceof ImmutablePair) {
                        // Search for values within bucket range
                        List<Double> values = new ArrayList<>();
+                       List<double[]> iqmValues = new ArrayList<>();
                        fedMap.mapParallel(varID, (range, data) -> {
                                try {
-                                       FederatedResponse response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
-                                               -1,
-                                               new 
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(), 
(ImmutablePair<Double, Double>) ret, isIQM))).get();
+                                       FederatedResponse response = 
data.executeFederatedOperation(new 
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+                                               new 
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(), 
(ImmutablePair<Double, Double>) ret, isIQM, iqmRange))).get();
                                        if(!response.isSuccessful())
                                                
response.throwExceptionFromResponse();
-                                       values.add((double) 
response.getData()[0]);
+                                       if(isIQM)
+                                               iqmValues.add((double[]) 
response.getData()[0]);
+                                       else
+                                               values.add((double) 
response.getData()[0]);
                                        return null;
                                }
                                catch(Exception e) {
@@ -271,44 +276,46 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                                }
                        });
 
-                       // Sum of 1 or 2 values
-                       result = values.stream().reduce(0.0, Double::sum);
+
+                       if(isIQM) {
+                               for(double[] vals : iqmValues) {
+                                       result += vals[0];
+                                       q25Part += vals[1];
+                                       q25Val += vals[2];
+                                       q75Part += vals[3];
+                                       q75Val += vals[4];
+                               }
+                               q25Part -= (0.25 * vectorLength);
+                               q75Part -= (0.75 * vectorLength);
+                       } else
+                               result = values.stream().reduce(0.0, 
Double::sum);
 
                } else
                        result = (Double) ret;
 
-               result /= (average ? 2 : isIQM ? ((int) Math.round(vectorLength 
* 0.75) - (int) Math.round(vectorLength * 0.25)) : 1);
+               result = average ? result / 2 : (isIQM ? ((result + 
q25Part*q25Val - q75Part*q75Val) / (vectorLength * 0.5)) : result);
 
                ec.setScalarOutput(output.getName(), new DoubleObject(result));
        }
 
        public <T> T createHistogram(MatrixObject in, int vectorLength,  double 
globalMin, double globalMax, int numBuckets, int quantileIndex, boolean 
average) {
                FederationMap fedMap = in.getFedMapping();
-
-               Map<ImmutablePair<Double, Double>, Integer> buckets = new 
LinkedHashMap<>();
-               List<Map<ImmutablePair<Double, Double>, Integer>> hists = new 
ArrayList<>();
+               List<int[]> hists = new ArrayList<>();
                List<Set<Double>> distincts = new ArrayList<>();
 
                double bucketRange = (globalMax-globalMin) / numBuckets;
                boolean isEvenNumRows = vectorLength % 2 == 0;
 
-               // Create buckets according to min and max
-               double tmpMin = globalMin, tmpMax = globalMax;
-               for(int i = 0; i < numBuckets && tmpMin <= tmpMax; i++) {
-                       buckets.put(new ImmutablePair<>(tmpMin, tmpMin + 
bucketRange), 0);
-                       tmpMin += bucketRange;
-               }
-
                // Create histograms
                long varID = FederationUtils.getNextFedDataID();
                fedMap.mapParallel(varID, (range, data) -> {
                        try {
                                FederatedResponse response = 
data.executeFederatedOperation(new FederatedRequest(
                                        FederatedRequest.RequestType.EXEC_UDF, 
-1,
-                                       new 
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), buckets, 
globalMax))).get();
+                                       new 
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), globalMin, globalMax, 
bucketRange, numBuckets))).get();
                                if(!response.isSuccessful())
                                        response.throwExceptionFromResponse();
-                               Map<ImmutablePair<Double, Double>, Integer> 
rangeHist = (Map<ImmutablePair<Double, Double>, Integer>) response.getData()[0];
+                               int[] rangeHist = (int[]) response.getData()[0];
                                hists.add(rangeHist);
                                Set<Double> rangeDistinct = (Set<Double>) 
response.getData()[1];
                                distincts.add(rangeDistinct);
@@ -320,28 +327,36 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                });
 
                // Merge results into one histogram
-               for(ImmutablePair<Double, Double> bucket : buckets.keySet()) {
-                       int value = 0;
-                       for(Map<ImmutablePair<Double, Double>, Integer> hist : 
hists)
-                               value += hist.get(bucket);
-                       buckets.put(bucket, value);
-               }
+               int[] bucketsFrequencies = new int[numBuckets];
+               for(int[] hist : hists)
+                       for(int i = 0; i < hist.length; i++)
+                               bucketsFrequencies[i] += hist[i];
 
                if(quantileIndex == -1)
-                       return (T) buckets;
+                       return (T) bucketsFrequencies;
 
                // Find bucket with quantile
-               ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> bucketWithIndex = getBucketWithIndex(buckets, quantileIndex, average, 
isEvenNumRows);
+               ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> bucketWithIndex = getBucketWithIndex(bucketsFrequencies, globalMin, 
quantileIndex, average, isEvenNumRows, bucketRange);
 
                // Check if can terminate
                Set<Double> distinctValues = 
distincts.stream().flatMap(Set::stream).collect(Collectors.toSet());
-               if((distinctValues.size() == 1 && !average) || 
(distinctValues.size() == 2 && average))
-                       return (T) distinctValues.stream().reduce(0.0, (a, b) 
-> a + b);
+
+               if(distinctValues.size() > quantileIndex-1 && !average)
+                       return (T) 
distinctValues.stream().sorted().toArray()[quantileIndex-1];
+
+               if(average && distinctValues.size() > quantileIndex) {
+                       Double[] distinctsSorted = 
distinctValues.stream().flatMap(Stream::of).sorted().toArray(Double[]::new);
+                       Double medianSum = 
Double.sum(distinctsSorted[quantileIndex-1], distinctsSorted[quantileIndex]);
+                       return (T) medianSum;
+               }
+
+               if(average && distinctValues.size() == 2)
+                       return (T) distinctValues.stream().reduce(0.0, 
Double::sum);
 
                ImmutablePair<Double, Double> finalBucketWithQ = 
bucketWithIndex.right;
                List<Double> distinctInNewBucket = 
distinctValues.stream().filter( e -> e >= finalBucketWithQ.left && e <= 
finalBucketWithQ.right).collect(Collectors.toList());
                if((distinctInNewBucket.size() == 1 && !average) || (average && 
distinctInNewBucket.size() == 2))
-                       return (T) distinctInNewBucket.stream().reduce(0.0, (a, 
b) -> a + b);
+                       return (T) distinctInNewBucket.stream().reduce(0.0, 
Double::sum);
 
                if(distinctValues.size() == 1 || (bucketWithIndex.middle == 1 
&& !average) || (bucketWithIndex.middle == 2 && isEvenNumRows && average) ||
                        globalMin == globalMax)
@@ -357,14 +372,16 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                return createHistogram(in, vectorLength, 
bucketWithIndex.right.left, bucketWithIndex.right.right, nextNumBuckets, 
bucketWithIndex.left, average);
        }
 
-       private ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> getBucketWithIndex(Map<ImmutablePair<Double, Double>, Integer> 
buckets, int quantileIndex, boolean average, boolean isEvenNumRows) {
+       private ImmutableTriple<Integer, Integer, ImmutablePair<Double, 
Double>> getBucketWithIndex(int[] bucketFrequencies, double min, int 
quantileIndex, boolean average, boolean isEvenNumRows, double bucketRange) {
                int sizeBeforeTmp = 0, sizeBefore = 0, bucketWithQSize = 0;
                ImmutablePair<Double, Double> bucketWithQ = null;
-               for(Map.Entry<ImmutablePair<Double, Double>, Integer> range : 
buckets.entrySet()) {
-                       sizeBeforeTmp += range.getValue();
+
+               double tmpBinLeft = min;
+               for(int i = 0; i < bucketFrequencies.length; i++) {
+                       sizeBeforeTmp += bucketFrequencies[i];
                        if(quantileIndex <= sizeBeforeTmp && bucketWithQSize == 
0) {
-                               bucketWithQ = range.getKey();
-                               bucketWithQSize = range.getValue();
+                               bucketWithQ = new ImmutablePair<>(tmpBinLeft, 
tmpBinLeft + bucketRange);
+                               bucketWithQSize = bucketFrequencies[i];
                                sizeBeforeTmp -= bucketWithQSize;
                                sizeBefore = sizeBeforeTmp;
 
@@ -372,13 +389,14 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                                        break;
                        } else if(quantileIndex + 1 <= sizeBeforeTmp + 
bucketWithQSize && isEvenNumRows && average) {
                                // Add right bin that contains second index
-                               int bucket2Size = range.getValue();
+                               int bucket2Size = bucketFrequencies[i];
                                if (bucket2Size != 0) {
-                                       bucketWithQ = new 
ImmutablePair<>(bucketWithQ.left, range.getKey().right);
+                                       bucketWithQ = new 
ImmutablePair<>(bucketWithQ.left, tmpBinLeft + bucketRange);
                                        bucketWithQSize += bucket2Size;
                                        break;
                                }
                        }
+                       tmpBinLeft += bucketRange;
                }
                quantileIndex = quantileIndex == 1 ? 1 : quantileIndex - 
sizeBefore;
                return new ImmutableTriple<>(quantileIndex, bucketWithQSize, 
bucketWithQ);
@@ -386,13 +404,17 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
 
        public static class GetHistogram extends FederatedUDF {
                private static final long serialVersionUID = 
5413355823424777742L;
-               private final Map<ImmutablePair<Double, Double>, Integer> 
_buckets;
                private final double _max;
+               private final double _min;
+               private final double _range;
+               private final int _numBuckets;
 
-               private GetHistogram(long input, Map<ImmutablePair<Double, 
Double>, Integer> buckets, double max) {
+               private GetHistogram(long input, double min, double max, double 
range, int numBuckets) {
                        super(new long[] {input});
-                       _buckets = buckets;
                        _max = max;
+                       _min = min;
+                       _range = range;
+                       _numBuckets = numBuckets;
                }
 
                @Override
@@ -401,22 +423,23 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                        double[] values = mb.getDenseBlockValues();
                        boolean isWeighted  = mb.getNumColumns() == 2;
 
-                       Map<ImmutablePair<Double, Double>, Integer> hist = 
_buckets;
                        Set<Double> distinct = new HashSet<>();
 
+                       int[] frequencies = new int[_numBuckets];
+
+                       // binning
                        for(int i = 0; i < values.length - (isWeighted ? 1 : 
0); i += (isWeighted ? 2 : 1)) {
                                double val = values[i];
                                int weight = isWeighted ? (int) values[i+1] : 1;
-                               for (Map.Entry<ImmutablePair<Double, Double>, 
Integer> range : _buckets.entrySet()) {
-                                       if((val >= range.getKey().left && val < 
range.getKey().right) || (val == _max && val == range.getKey().right)) {
-                                               hist.put(range.getKey(), 
range.getValue() + weight);
-
-                                               distinct.add(val);
-                                       }
+                               int index = (int) (Math.ceil((val - _min) / 
_range));
+                               index = index == 0 ? 0 : index - 1;
+                               if (val >= _min && val <= _max) {
+                                       frequencies[index] += weight;
+                                       distinct.add(val);
                                }
                        }
 
-                       Object[] ret = new Object[] {hist, distinct.size() < 3 
? distinct : new HashSet<>()};
+                       Object[] ret = new Object[] {frequencies, 
distinct.size() < 3 ? distinct : new HashSet<>()};
                        return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
                }
 
@@ -442,10 +465,10 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                        MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
                        double[] values = mb.getDenseBlockValues();
 
-                       // FIXME rewrite - see binning encode
                        MatrixBlock res = new MatrixBlock(_numQuantiles, 1, 
false);
                        for(double val : values) {
                                for(Map.Entry<Integer, ImmutablePair<Double, 
Double>> entry : _ranges.entrySet()) {
+                                       // Find value within computed bin
                                        if(entry.getValue().left <= val && val 
<= entry.getValue().right) {
                                                res.setValue(entry.getKey(), 
0,val);
                                                break;
@@ -583,12 +606,14 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
        public static class GetValuesInRange extends FederatedUDF {
                private static final long serialVersionUID = 
5413355823424777742L;
                private final ImmutablePair<Double, Double> _range;
+               private final ImmutablePair<Double, Double> _iqmRange;
                private final boolean _sumInRange;
 
-               private GetValuesInRange(long input, ImmutablePair<Double, 
Double> range, boolean sumInRange) {
+               private GetValuesInRange(long input, ImmutablePair<Double, 
Double> range, boolean sumInRange, ImmutablePair<Double, Double> iqmRange) {
                        super(new long[] {input});
                        _range = range;
                        _sumInRange = sumInRange;
+                       _iqmRange = iqmRange;
                }
 
                @Override
@@ -596,20 +621,42 @@ public class QuantilePickFEDInstruction extends 
BinaryFEDInstruction {
                        MatrixBlock mb = ((MatrixObject) 
data[0]).acquireReadAndRelease();
                        double[] values = mb.getDenseBlockValues();
 
+                       boolean isWeighted  = mb.getNumColumns() == 2;
+
                        double res = 0.0;
-                       int i = 0;
+                       int counter = 0;
 
-                       // FIXME better search, e.g. sort in QSort and binary 
search
-                       for(double val : values) {
+                       double q25Part = 0, q25Val = 0, q75Val = 0, q75Part = 0;
+                       for(int i = 0; i < values.length - (isWeighted ? 1 : 
0); i += (isWeighted ? 2 : 1)) {
+                               // get value within computed bin
                                // different conditions for IQM and simple QPICK
+                               double val = values[i];
+                               int weight = isWeighted ? (int) values[i+1] : 1;
+
+                               if(_iqmRange != null && val <= _iqmRange.left) {
+                                       q25Part += weight;
+                               }
+
+                               if(_iqmRange != null && val >= _iqmRange.left 
&& val <= _range.left) {
+                                       q25Val = val;
+                               }
+                               else if(_iqmRange != null && val <= 
_iqmRange.right && val >= _range.right)
+                                       q75Val = val;
+
                                if((!_sumInRange && _range.left <= val && val 
<= _range.right) ||
-                                       (_sumInRange && _range.left < val && 
val <= _range.right))
-                                       res += val;
-                               if(i++ > 2 && !_sumInRange)
+                                       (_sumInRange && _range.left < val && 
val <= _range.right)) {
+                                       res += (val * (!_sumInRange && weight > 
1 ? 2 : weight));
+                                       counter += weight;
+                               }
+
+                               if(_iqmRange != null && val <= _range.right)
+                                       q75Part += weight;
+
+                               if(!_sumInRange && counter > 2)
                                        break;
                        }
 
-                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+                       return new 
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,!_sumInRange ? res : 
new double[]{res, q25Part, q25Val, q75Part, q75Val});
                }
 
                @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index ded83e7..f84be32 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -133,7 +133,6 @@ public class QuantileSortFEDInstruction extends 
UnaryFEDInstruction{
                        return null;
                });
 
-
                MatrixObject sorted = ec.getMatrixObject(output);
                
sorted.getDataCharacteristics().set(in.getDataCharacteristics());
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
index 598256a..3b34454 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -56,8 +56,8 @@ public class FederatedQuantileTest extends AutomatedTestBase {
        @Parameterized.Parameters
        public static Collection<Object[]> data() {
                return Arrays.asList(new Object[][] {
-                       {1000, 1, false},
-                       {16, 1, true}
+//                     {1000, 1, false},
+                       {128, 1, true}
                });
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
index 9ac0b5d..d42f641 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
@@ -41,6 +41,8 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
        private final static String TEST_DIR = "functions/federated/quantile/";
        private final static String TEST_NAME1 = "FederatedQuantileWeightsTest";
        private final static String TEST_NAME2 = "FederatedMedianWeightsTest";
+       private final static String TEST_NAME3 = "FederatedIQRWeightsTest";
+       private final static String TEST_NAME4 = 
"FederatedQuantilesWeightsTest";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
FederatedQuantileWeightsTest.class.getSimpleName() + "/";
 
        private final static int blocksize = 1024;
@@ -53,7 +55,7 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
        public static Collection<Object[]> data() {
                return Arrays.asList(new Object[][] {
                        {1000, false},
-                       {12, true}});
+                       {128, true}});
        }
 
        @Override
@@ -61,6 +63,8 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
                addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME3, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+               addTestConfiguration(TEST_NAME4, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
        }
 
        @Test
@@ -76,10 +80,10 @@ public class FederatedQuantileWeightsTest extends 
AutomatedTestBase {
        public void federatedMedianCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
 
        @Test
-       public void federatedIQMCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+       public void federatedIQMCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME3, -1); }
 
        @Test
-       public void federatedQuantilesCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+       public void federatedQuantilesCP() { 
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME4, -1); }
 
        public void federatedQuartile(Types.ExecMode execMode, String 
TEST_NAME, double p) {
                boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml 
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
index 99b48c6..e2cffc7 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
@@ -19,7 +19,14 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 W = read($W);
 s = interQuartileMean(A, W);
+
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
index afc9a1f..d45f70c 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
@@ -19,7 +19,11 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
+if($3) {
+  A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
 W = read($4);
 s = interQuartileMean(A, W);
+
 write(s, $2);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
index 86f9611..025cd7e 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
@@ -19,10 +19,16 @@
 #
 #-------------------------------------------------------------
 
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+if ($rP) {
+    A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+        ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0), 
list(2*$rows/4, $cols),
+               list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0), 
list($rows, $cols)));
+} else {
+    A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows, 
$cols)));
+}
 P = matrix(0.25, 3, 1);
 P[2,1] = 0.5;
 P[3,1] = 0.75;
 W = read($W);
-s = quantile(X=A, W=W, P=P);
+s = quantile(A, W, P);
 write(s, $out_S);
diff --git 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
index 3c6d7fd..7b0023a 100644
--- 
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
+++ 
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
@@ -19,10 +19,13 @@
 #
 #-------------------------------------------------------------
 
-A = read($1);
+if($3) {
+  A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
 P = matrix(0.25, 3, 1);
 P[2,1] = 0.5;
 P[3,1] = 0.75;
-W = read($W);
-s = quantile(X=A, W=W, P=P);
+W = read($4);
+s = quantile(A, W, P);
 write(s, $2);

Reply via email to