This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new e3c60ff [SYSTEMDS-3305] Performance federated quantiles
e3c60ff is described below
commit e3c60ffab8f4f327f9f7c3af0c12ca777526f32e
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Tue Mar 22 12:53:26 2022 +0100
[SYSTEMDS-3305] Performance federated quantiles
Closes #1558.
---
.../controlprogram/federated/FederationUtils.java | 21 ++-
.../fed/AggregateUnaryFEDInstruction.java | 4 +-
.../fed/QuantilePickFEDInstruction.java | 201 +++++++++++++--------
.../fed/QuantileSortFEDInstruction.java | 1 -
.../primitives/FederatedQuantileTest.java | 4 +-
.../primitives/FederatedQuantileWeightsTest.java | 10 +-
.../federated/quantile/FederatedIQRWeightsTest.dml | 9 +-
.../quantile/FederatedIQRWeightsTestReference.dml | 6 +-
.../quantile/FederatedQuantilesWeightsTest.dml | 10 +-
.../FederatedQuantilesWeightsTestReference.dml | 9 +-
10 files changed, 174 insertions(+), 101 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
index 606b6c0..a31d8be 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/federated/FederationUtils.java
@@ -371,16 +371,19 @@ public class FederationUtils {
return new DoubleObject(aggMean(ffr,
map).getValue(0,0));
}
else if(aop.aggOp.increOp.fn instanceof CM) {
- double var = ((ScalarObject)
ffr[0].get().getData()[0]).getDoubleValue();
- double mean = ((ScalarObject)
meanFfr[0].get().getData()[0]).getDoubleValue();
- long size =
map.getFederatedRanges()[0].getSize();
- for(int i = 0; i < ffr.length - 1; i++) {
- long l = size +
map.getFederatedRanges()[i+1].getSize();
- double k = ((size * var) +
(map.getFederatedRanges()[i+1].getSize() * ((ScalarObject)
ffr[i+1].get().getData()[0]).getDoubleValue())) / l;
- var = k + (size *
map.getFederatedRanges()[i+1].getSize()) * Math.pow((mean - ((ScalarObject)
meanFfr[i+1].get().getData()[0]).getDoubleValue()) / l, 2);
- mean = (mean * size + ((ScalarObject)
meanFfr[i+1].get().getData()[0]).getDoubleValue() *
(map.getFederatedRanges()[i+1].getSize())) / l;
- size = l;
+ long size1 =
map.getFederatedRanges()[0].getSize();
+ double mean1 = ((ScalarObject)
meanFfr[0].get().getData()[0]).getDoubleValue();
+ double squaredM1 = ((ScalarObject)
ffr[0].get().getData()[0]).getDoubleValue() * (size1 - 1);
+ for(int i = 1; i < ffr.length; i++) {
+ long size2 =
map.getFederatedRanges()[i].getSize();
+ double delta = ((ScalarObject)
meanFfr[i].get().getData()[0]).getDoubleValue() - mean1;
+ double squaredM2 = ((ScalarObject)
ffr[i].get().getData()[0]).getDoubleValue() * (size2 - 1);
+ squaredM1 = squaredM1 + squaredM2 +
(Math.pow(delta, 2) * size1 * size2 / (size1 + size2));
+
+ size1 += size2;
+ mean1 = mean1 + delta * size2 / size1;
}
+ double var = squaredM1 / (size1 - 1);
return new DoubleObject(var);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
index 668fd0b..7e2ca2a 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/AggregateUnaryFEDInstruction.java
@@ -250,7 +250,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest meanFr1 =
FederationUtils.callInstruction(meanInstr, output, id,
new CPOperand[]{input1}, new
long[]{in.getFedMapping().getID()}, isSpark ? ExecType.SPARK : ExecType.CP,
isSpark);
FederatedRequest meanFr2 = new
FederatedRequest(RequestType.GET_VAR, meanFr1.getID());
- meanTmp = map.execute(getTID(), isSpark ?
+ meanTmp = map.execute(getTID(), true, isSpark ?
new FederatedRequest[] {tmpRequest, meanFr1,
meanFr2} :
new FederatedRequest[] {meanFr1, meanFr2});
}
@@ -261,7 +261,7 @@ public class AggregateUnaryFEDInstruction extends
UnaryFEDInstruction {
FederatedRequest fr2 = new
FederatedRequest(RequestType.GET_VAR, fr1.getID());
//execute federated commands and cleanups
- Future<FederatedResponse>[] tmp = map.execute(getTID(), isSpark
?
+ Future<FederatedResponse>[] tmp = map.execute(getTID(), true,
isSpark ?
new FederatedRequest[] {tmpRequest, fr1, fr2} :
new FederatedRequest[] { fr1, fr2});
if( output.isScalar() )
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index c6a5b08..f082173 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -22,11 +22,11 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
-import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
+import java.util.stream.Stream;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.ImmutableTriple;
@@ -71,10 +71,6 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
this(op, in, in2, out, type, inmem, opcode, istr,
FederatedOutput.NONE);
}
- public OperationTypes getQPickType() {
- return _type;
- }
-
public static QuantilePickFEDInstruction parseInstruction ( String str
) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -169,27 +165,28 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
// Compute and set results
if(quantiles != null && quantiles.length > 1) {
- computeMultipleQuantiles(ec, in,
(Map<ImmutablePair<Double, Double>, Integer>) ret, quantiles, (int)
vectorLength, varID, _type);
+ computeMultipleQuantiles(ec, in, (int[]) ret,
quantiles, (int) vectorLength, varID, (globalMax-globalMin) / numBuckets,
globalMin, _type);
} else
- getSingleQuantileResult(ret, ec, fedMap, varID,
average, false, (int) vectorLength);
+ getSingleQuantileResult(ret, ec, fedMap, varID,
average, false, (int) vectorLength, null);
}
- private <T> void computeMultipleQuantiles(ExecutionContext ec,
MatrixObject in, Map<ImmutablePair<Double, Double>, Integer> buckets, double[]
quantiles, int vectorLength, long varID, OperationTypes type) {
+ private <T> void computeMultipleQuantiles(ExecutionContext ec,
MatrixObject in, int[] bucketsFrequencies, double[] quantiles,
+ int vectorLength, long varID, double bucketRange, double min,
OperationTypes type) {
MatrixBlock out = new MatrixBlock(quantiles.length, 1, false);
ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>>[] bucketsWithIndex = new ImmutableTriple[quantiles.length];
// Find bins with each quantile for first histogram
int sizeBeforeTmp = 0, sizeBefore = 0, countFoundBins = 0;
- for(Map.Entry<ImmutablePair<Double, Double>, Integer> entry :
buckets.entrySet()) {
- sizeBeforeTmp += entry.getValue();
+ for(int j = 0; j < bucketsFrequencies.length; j++) {
+ sizeBeforeTmp += bucketsFrequencies[j];
for(int i = 0; i < quantiles.length; i++) {
int quantileIndex = (int)
Math.round(vectorLength * quantiles[i]);
- ImmutablePair<Double, Double> bucketWithQ =
null;
+ ImmutablePair<Double, Double> bucketWithQ;
if(quantileIndex > sizeBefore && quantileIndex
<= sizeBeforeTmp) {
- bucketWithQ = entry.getKey();
- bucketsWithIndex[i] = new
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore,
entry.getValue(), bucketWithQ);
+ bucketWithQ = new ImmutablePair<>(min +
(j * bucketRange), min + ((j+1) * bucketRange));
+ bucketsWithIndex[i] = new
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore,
bucketsFrequencies[j], bucketWithQ);
countFoundBins++;
}
}
@@ -202,14 +199,16 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
// Find each quantile bin recursively
Map<Integer, T> retBuckets = new HashMap<>();
- double left = 0, right = 0;
+ double q25Left = 0, q25Right = 0, q75Left = 0, q75Right = 0;
for(int i = 0; i < bucketsWithIndex.length; i++) {
int nextNumBuckets = bucketsWithIndex[i].middle < 100 ?
bucketsWithIndex[i].middle * 2 : (int) Math.round(bucketsWithIndex[i].middle /
2.0);
T hist = createHistogram(in, vectorLength,
bucketsWithIndex[i].right.left, bucketsWithIndex[i].right.right,
nextNumBuckets, bucketsWithIndex[i].left, false);
if(_type == OperationTypes.IQM) {
- left = i == 0 ? hist instanceof ImmutablePair ?
((ImmutablePair<Double, Double>)hist).right : (Double) hist : left;
- right = i == 1 ? hist instanceof ImmutablePair
? ((ImmutablePair<Double, Double>)hist).left : (Double) hist : right;
+ q25Right = i == 0 ? hist instanceof
ImmutablePair ? ((ImmutablePair<Double, Double>)hist).right : (Double) hist :
q25Right;
+ q25Left = i == 0 ? hist instanceof
ImmutablePair ? ((ImmutablePair<Double, Double>)hist).left : (Double) hist :
q25Left;
+ q75Right = i == 1 ? hist instanceof
ImmutablePair ? ((ImmutablePair<Double, Double>)hist).right : (Double) hist :
q75Right;
+ q75Left = i == 1 ? hist instanceof
ImmutablePair ? ((ImmutablePair<Double, Double>)hist).left : (Double) hist :
q75Left;
} else {
if(hist instanceof ImmutablePair)
retBuckets.put(i, hist); // set value
if returned double instead of bin
@@ -219,8 +218,11 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
if(type == OperationTypes.IQM) {
- ImmutablePair<Double, Double> IQMRange = new
ImmutablePair<>(left, right);
- getSingleQuantileResult(IQMRange, ec,
in.getFedMapping(), varID, false, true, vectorLength);
+ ImmutablePair<Double, Double> IQMRange = new
ImmutablePair<>(q25Right, q75Right);
+ if(q25Right == q75Right)
+ ec.setScalarOutput(output.getName(), new
DoubleObject(q25Left));
+ else
+ getSingleQuantileResult(IQMRange, ec,
in.getFedMapping(), varID, false, true, vectorLength, new
ImmutablePair<>(q25Left, q75Left));
}
else {
if(!retBuckets.isEmpty()) {
@@ -251,19 +253,22 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
}
- private <T> void getSingleQuantileResult(T ret, ExecutionContext ec,
FederationMap fedMap, long varID, boolean average, boolean isIQM, int
vectorLength) {
- double result = 0.0;
+ private <T> void getSingleQuantileResult(T ret, ExecutionContext ec,
FederationMap fedMap, long varID, boolean average, boolean isIQM, int
vectorLength, ImmutablePair<Double, Double> iqmRange) {
+ double result = 0.0, q25Part = 0, q25Val = 0, q75Val = 0,
q75Part = 0;
if(ret instanceof ImmutablePair) {
// Search for values within bucket range
List<Double> values = new ArrayList<>();
+ List<double[]> iqmValues = new ArrayList<>();
fedMap.mapParallel(varID, (range, data) -> {
try {
- FederatedResponse response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
- -1,
- new
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(),
(ImmutablePair<Double, Double>) ret, isIQM))).get();
+ FederatedResponse response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(),
(ImmutablePair<Double, Double>) ret, isIQM, iqmRange))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
- values.add((double)
response.getData()[0]);
+ if(isIQM)
+ iqmValues.add((double[])
response.getData()[0]);
+ else
+ values.add((double)
response.getData()[0]);
return null;
}
catch(Exception e) {
@@ -271,44 +276,46 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
});
- // Sum of 1 or 2 values
- result = values.stream().reduce(0.0, Double::sum);
+
+ if(isIQM) {
+ for(double[] vals : iqmValues) {
+ result += vals[0];
+ q25Part += vals[1];
+ q25Val += vals[2];
+ q75Part += vals[3];
+ q75Val += vals[4];
+ }
+ q25Part -= (0.25 * vectorLength);
+ q75Part -= (0.75 * vectorLength);
+ } else
+ result = values.stream().reduce(0.0,
Double::sum);
} else
result = (Double) ret;
- result /= (average ? 2 : isIQM ? ((int) Math.round(vectorLength
* 0.75) - (int) Math.round(vectorLength * 0.25)) : 1);
+ result = average ? result / 2 : (isIQM ? ((result +
q25Part*q25Val - q75Part*q75Val) / (vectorLength * 0.5)) : result);
ec.setScalarOutput(output.getName(), new DoubleObject(result));
}
public <T> T createHistogram(MatrixObject in, int vectorLength, double
globalMin, double globalMax, int numBuckets, int quantileIndex, boolean
average) {
FederationMap fedMap = in.getFedMapping();
-
- Map<ImmutablePair<Double, Double>, Integer> buckets = new
LinkedHashMap<>();
- List<Map<ImmutablePair<Double, Double>, Integer>> hists = new
ArrayList<>();
+ List<int[]> hists = new ArrayList<>();
List<Set<Double>> distincts = new ArrayList<>();
double bucketRange = (globalMax-globalMin) / numBuckets;
boolean isEvenNumRows = vectorLength % 2 == 0;
- // Create buckets according to min and max
- double tmpMin = globalMin, tmpMax = globalMax;
- for(int i = 0; i < numBuckets && tmpMin <= tmpMax; i++) {
- buckets.put(new ImmutablePair<>(tmpMin, tmpMin +
bucketRange), 0);
- tmpMin += bucketRange;
- }
-
// Create histograms
long varID = FederationUtils.getNextFedDataID();
fedMap.mapParallel(varID, (range, data) -> {
try {
FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
FederatedRequest.RequestType.EXEC_UDF,
-1,
- new
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), buckets,
globalMax))).get();
+ new
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), globalMin, globalMax,
bucketRange, numBuckets))).get();
if(!response.isSuccessful())
response.throwExceptionFromResponse();
- Map<ImmutablePair<Double, Double>, Integer>
rangeHist = (Map<ImmutablePair<Double, Double>, Integer>) response.getData()[0];
+ int[] rangeHist = (int[]) response.getData()[0];
hists.add(rangeHist);
Set<Double> rangeDistinct = (Set<Double>)
response.getData()[1];
distincts.add(rangeDistinct);
@@ -320,28 +327,36 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
});
// Merge results into one histogram
- for(ImmutablePair<Double, Double> bucket : buckets.keySet()) {
- int value = 0;
- for(Map<ImmutablePair<Double, Double>, Integer> hist :
hists)
- value += hist.get(bucket);
- buckets.put(bucket, value);
- }
+ int[] bucketsFrequencies = new int[numBuckets];
+ for(int[] hist : hists)
+ for(int i = 0; i < hist.length; i++)
+ bucketsFrequencies[i] += hist[i];
if(quantileIndex == -1)
- return (T) buckets;
+ return (T) bucketsFrequencies;
// Find bucket with quantile
- ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> bucketWithIndex = getBucketWithIndex(buckets, quantileIndex, average,
isEvenNumRows);
+ ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> bucketWithIndex = getBucketWithIndex(bucketsFrequencies, globalMin,
quantileIndex, average, isEvenNumRows, bucketRange);
// Check if can terminate
Set<Double> distinctValues =
distincts.stream().flatMap(Set::stream).collect(Collectors.toSet());
- if((distinctValues.size() == 1 && !average) ||
(distinctValues.size() == 2 && average))
- return (T) distinctValues.stream().reduce(0.0, (a, b)
-> a + b);
+
+ if(distinctValues.size() > quantileIndex-1 && !average)
+ return (T)
distinctValues.stream().sorted().toArray()[quantileIndex-1];
+
+ if(average && distinctValues.size() > quantileIndex) {
+ Double[] distinctsSorted =
distinctValues.stream().flatMap(Stream::of).sorted().toArray(Double[]::new);
+ Double medianSum =
Double.sum(distinctsSorted[quantileIndex-1], distinctsSorted[quantileIndex]);
+ return (T) medianSum;
+ }
+
+ if(average && distinctValues.size() == 2)
+ return (T) distinctValues.stream().reduce(0.0,
Double::sum);
ImmutablePair<Double, Double> finalBucketWithQ =
bucketWithIndex.right;
List<Double> distinctInNewBucket =
distinctValues.stream().filter( e -> e >= finalBucketWithQ.left && e <=
finalBucketWithQ.right).collect(Collectors.toList());
if((distinctInNewBucket.size() == 1 && !average) || (average &&
distinctInNewBucket.size() == 2))
- return (T) distinctInNewBucket.stream().reduce(0.0, (a,
b) -> a + b);
+ return (T) distinctInNewBucket.stream().reduce(0.0,
Double::sum);
if(distinctValues.size() == 1 || (bucketWithIndex.middle == 1
&& !average) || (bucketWithIndex.middle == 2 && isEvenNumRows && average) ||
globalMin == globalMax)
@@ -357,14 +372,16 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
return createHistogram(in, vectorLength,
bucketWithIndex.right.left, bucketWithIndex.right.right, nextNumBuckets,
bucketWithIndex.left, average);
}
- private ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> getBucketWithIndex(Map<ImmutablePair<Double, Double>, Integer>
buckets, int quantileIndex, boolean average, boolean isEvenNumRows) {
+ private ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> getBucketWithIndex(int[] bucketFrequencies, double min, int
quantileIndex, boolean average, boolean isEvenNumRows, double bucketRange) {
int sizeBeforeTmp = 0, sizeBefore = 0, bucketWithQSize = 0;
ImmutablePair<Double, Double> bucketWithQ = null;
- for(Map.Entry<ImmutablePair<Double, Double>, Integer> range :
buckets.entrySet()) {
- sizeBeforeTmp += range.getValue();
+
+ double tmpBinLeft = min;
+ for(int i = 0; i < bucketFrequencies.length; i++) {
+ sizeBeforeTmp += bucketFrequencies[i];
if(quantileIndex <= sizeBeforeTmp && bucketWithQSize ==
0) {
- bucketWithQ = range.getKey();
- bucketWithQSize = range.getValue();
+ bucketWithQ = new ImmutablePair<>(tmpBinLeft,
tmpBinLeft + bucketRange);
+ bucketWithQSize = bucketFrequencies[i];
sizeBeforeTmp -= bucketWithQSize;
sizeBefore = sizeBeforeTmp;
@@ -372,13 +389,14 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
break;
} else if(quantileIndex + 1 <= sizeBeforeTmp +
bucketWithQSize && isEvenNumRows && average) {
// Add right bin that contains second index
- int bucket2Size = range.getValue();
+ int bucket2Size = bucketFrequencies[i];
if (bucket2Size != 0) {
- bucketWithQ = new
ImmutablePair<>(bucketWithQ.left, range.getKey().right);
+ bucketWithQ = new
ImmutablePair<>(bucketWithQ.left, tmpBinLeft + bucketRange);
bucketWithQSize += bucket2Size;
break;
}
}
+ tmpBinLeft += bucketRange;
}
quantileIndex = quantileIndex == 1 ? 1 : quantileIndex -
sizeBefore;
return new ImmutableTriple<>(quantileIndex, bucketWithQSize,
bucketWithQ);
@@ -386,13 +404,17 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
public static class GetHistogram extends FederatedUDF {
private static final long serialVersionUID =
5413355823424777742L;
- private final Map<ImmutablePair<Double, Double>, Integer>
_buckets;
private final double _max;
+ private final double _min;
+ private final double _range;
+ private final int _numBuckets;
- private GetHistogram(long input, Map<ImmutablePair<Double,
Double>, Integer> buckets, double max) {
+ private GetHistogram(long input, double min, double max, double
range, int numBuckets) {
super(new long[] {input});
- _buckets = buckets;
_max = max;
+ _min = min;
+ _range = range;
+ _numBuckets = numBuckets;
}
@Override
@@ -401,22 +423,23 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
double[] values = mb.getDenseBlockValues();
boolean isWeighted = mb.getNumColumns() == 2;
- Map<ImmutablePair<Double, Double>, Integer> hist =
_buckets;
Set<Double> distinct = new HashSet<>();
+ int[] frequencies = new int[_numBuckets];
+
+ // binning
for(int i = 0; i < values.length - (isWeighted ? 1 :
0); i += (isWeighted ? 2 : 1)) {
double val = values[i];
int weight = isWeighted ? (int) values[i+1] : 1;
- for (Map.Entry<ImmutablePair<Double, Double>,
Integer> range : _buckets.entrySet()) {
- if((val >= range.getKey().left && val <
range.getKey().right) || (val == _max && val == range.getKey().right)) {
- hist.put(range.getKey(),
range.getValue() + weight);
-
- distinct.add(val);
- }
+ int index = (int) (Math.ceil((val - _min) /
_range));
+ index = index == 0 ? 0 : index - 1;
+ if (val >= _min && val <= _max) {
+ frequencies[index] += weight;
+ distinct.add(val);
}
}
- Object[] ret = new Object[] {hist, distinct.size() < 3
? distinct : new HashSet<>()};
+ Object[] ret = new Object[] {frequencies,
distinct.size() < 3 ? distinct : new HashSet<>()};
return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
}
@@ -442,10 +465,10 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
double[] values = mb.getDenseBlockValues();
- // FIXME rewrite - see binning encode
MatrixBlock res = new MatrixBlock(_numQuantiles, 1,
false);
for(double val : values) {
for(Map.Entry<Integer, ImmutablePair<Double,
Double>> entry : _ranges.entrySet()) {
+ // Find value within computed bin
if(entry.getValue().left <= val && val
<= entry.getValue().right) {
res.setValue(entry.getKey(),
0,val);
break;
@@ -583,12 +606,14 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
public static class GetValuesInRange extends FederatedUDF {
private static final long serialVersionUID =
5413355823424777742L;
private final ImmutablePair<Double, Double> _range;
+ private final ImmutablePair<Double, Double> _iqmRange;
private final boolean _sumInRange;
- private GetValuesInRange(long input, ImmutablePair<Double,
Double> range, boolean sumInRange) {
+ private GetValuesInRange(long input, ImmutablePair<Double,
Double> range, boolean sumInRange, ImmutablePair<Double, Double> iqmRange) {
super(new long[] {input});
_range = range;
_sumInRange = sumInRange;
+ _iqmRange = iqmRange;
}
@Override
@@ -596,20 +621,42 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
double[] values = mb.getDenseBlockValues();
+ boolean isWeighted = mb.getNumColumns() == 2;
+
double res = 0.0;
- int i = 0;
+ int counter = 0;
- // FIXME better search, e.g. sort in QSort and binary
search
- for(double val : values) {
+ double q25Part = 0, q25Val = 0, q75Val = 0, q75Part = 0;
+ for(int i = 0; i < values.length - (isWeighted ? 1 :
0); i += (isWeighted ? 2 : 1)) {
+ // get value within computed bin
// different conditions for IQM and simple QPICK
+ double val = values[i];
+ int weight = isWeighted ? (int) values[i+1] : 1;
+
+ if(_iqmRange != null && val <= _iqmRange.left) {
+ q25Part += weight;
+ }
+
+ if(_iqmRange != null && val >= _iqmRange.left
&& val <= _range.left) {
+ q25Val = val;
+ }
+ else if(_iqmRange != null && val <=
_iqmRange.right && val >= _range.right)
+ q75Val = val;
+
if((!_sumInRange && _range.left <= val && val
<= _range.right) ||
- (_sumInRange && _range.left < val &&
val <= _range.right))
- res += val;
- if(i++ > 2 && !_sumInRange)
+ (_sumInRange && _range.left < val &&
val <= _range.right)) {
+ res += (val * (!_sumInRange && weight >
1 ? 2 : weight));
+ counter += weight;
+ }
+
+ if(_iqmRange != null && val <= _range.right)
+ q75Part += weight;
+
+ if(!_sumInRange && counter > 2)
break;
}
- return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS,!_sumInRange ? res :
new double[]{res, q25Part, q25Val, q75Part, q75Val});
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index ded83e7..f84be32 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -133,7 +133,6 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction{
return null;
});
-
MatrixObject sorted = ec.getMatrixObject(output);
sorted.getDataCharacteristics().set(in.getDataCharacteristics());
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
index 598256a..3b34454 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -56,8 +56,8 @@ public class FederatedQuantileTest extends AutomatedTestBase {
@Parameterized.Parameters
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
- {1000, 1, false},
- {16, 1, true}
+// {1000, 1, false},
+ {128, 1, true}
});
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
index 9ac0b5d..d42f641 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
@@ -41,6 +41,8 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/quantile/";
private final static String TEST_NAME1 = "FederatedQuantileWeightsTest";
private final static String TEST_NAME2 = "FederatedMedianWeightsTest";
+ private final static String TEST_NAME3 = "FederatedIQRWeightsTest";
+ private final static String TEST_NAME4 =
"FederatedQuantilesWeightsTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedQuantileWeightsTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@@ -53,7 +55,7 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
public static Collection<Object[]> data() {
return Arrays.asList(new Object[][] {
{1000, false},
- {12, true}});
+ {128, true}});
}
@Override
@@ -61,6 +63,8 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
+ addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
}
@Test
@@ -76,10 +80,10 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
public void federatedMedianCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
@Test
- public void federatedIQMCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+ public void federatedIQMCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME3, -1); }
@Test
- public void federatedQuantilesCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+ public void federatedQuantilesCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME4, -1); }
public void federatedQuartile(Types.ExecMode execMode, String
TEST_NAME, double p) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
index 99b48c6..e2cffc7 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTest.dml
@@ -19,7 +19,14 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
W = read($W);
s = interQuartileMean(A, W);
+
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
index afc9a1f..d45f70c 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedIQRWeightsTestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-A = read($1);
+if($3) {
+ A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
W = read($4);
s = interQuartileMean(A, W);
+
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
index 86f9611..025cd7e 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTest.dml
@@ -19,10 +19,16 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
P = matrix(0.25, 3, 1);
P[2,1] = 0.5;
P[3,1] = 0.75;
W = read($W);
-s = quantile(X=A, W=W, P=P);
+s = quantile(A, W, P);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
index 3c6d7fd..7b0023a 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesWeightsTestReference.dml
@@ -19,10 +19,13 @@
#
#-------------------------------------------------------------
-A = read($1);
+if($3) {
+ A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
P = matrix(0.25, 3, 1);
P[2,1] = 0.5;
P[3,1] = 0.75;
-W = read($W);
-s = quantile(X=A, W=W, P=P);
+W = read($4);
+s = quantile(A, W, P);
write(s, $2);