This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 1357ae1 [SYSTEMDS-3241] Federated quantiles via recursive histograms
1357ae1 is described below
commit 1357ae10fe8ffa48cf1398f740695d14258ec1d7
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Thu Mar 3 21:46:59 2022 +0100
[SYSTEMDS-3241] Federated quantiles via recursive histograms
Closes #1477.
---
.../instructions/fed/FEDInstructionUtils.java | 5 +-
.../fed/QuantilePickFEDInstruction.java | 441 ++++++++++++++++++++-
.../fed/QuantileSortFEDInstruction.java | 33 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 20 +
.../primitives/FederatedQuantileTest.java | 109 +++--
.../primitives/FederatedQuantileWeightsTest.java | 99 ++++-
.../aggregate/FederatedMeanTestReference.dml | 3 +-
.../federated/quantile/FederatedIQRTest.dml | 9 +-
.../quantile/FederatedIQRTestReference.dml | 5 +-
.../federated/quantile/FederatedMedianTest.dml | 8 +-
.../quantile/FederatedMedianTestReference.dml | 6 +-
.../quantile/FederatedMedianWeightsTest.dml | 8 +-
.../FederatedMedianWeightsTestReference.dml | 5 +-
.../federated/quantile/FederatedQuantileTest.dml | 9 +-
.../quantile/FederatedQuantileTestReference.dml | 7 +-
.../quantile/FederatedQuantileWeightsTest.dml | 8 +-
.../FederatedQuantileWeightsTestReference.dml | 8 +-
.../federated/quantile/FederatedQuantilesTest.dml | 10 +-
.../quantile/FederatedQuantilesTestReference.dml | 8 +-
19 files changed, 723 insertions(+), 78 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
index 1192ff7..e1c4587 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/FEDInstructionUtils.java
@@ -38,7 +38,6 @@ import
org.apache.sysds.runtime.instructions.cp.AggregateTernaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.AggregateUnaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.BinaryCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CentralMomentCPInstruction;
-import org.apache.sysds.runtime.instructions.cp.CovarianceCPInstruction;
import org.apache.sysds.runtime.instructions.cp.CtableCPInstruction;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.IndexingCPInstruction;
@@ -159,7 +158,7 @@ public class FEDInstructionUtils {
if(instruction.getOpcode().equalsIgnoreCase("cm"))
fedinst =
CentralMomentFEDInstruction.parseInstruction(inst.getInstructionString());
else
if(inst.getOpcode().equalsIgnoreCase("qsort")) {
-
if(mo1.getFedMapping().getFederatedRanges().length == 1)
+
if(mo1.isFederated(FType.ROW) ||
mo1.getFedMapping().getFederatedRanges().length == 1 &&
mo1.isFederated(FType.COL))
fedinst =
QuantileSortFEDInstruction.parseInstruction(inst.getInstructionString());
}
else
if(inst.getOpcode().equalsIgnoreCase("rshape"))
@@ -186,7 +185,7 @@ public class FEDInstructionUtils {
fedinst =
QuantilePickFEDInstruction.parseInstruction(inst.getInstructionString());
else
if("cov".equals(instruction.getOpcode()) &&
(ec.getMatrixObject(instruction.input1).isFederated(FType.ROW) ||
ec.getMatrixObject(instruction.input2).isFederated(FType.ROW)))
- fedinst =
CovarianceFEDInstruction.parseInstruction((CovarianceCPInstruction)inst);
+ fedinst =
CovarianceFEDInstruction.parseInstruction(inst.getInstructionString());
else
fedinst =
BinaryFEDInstruction.parseInstruction(
InstructionUtils.concatOperands(inst.getInstructionString(),FederatedOutput.NONE.name()));
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index f984967..83a1360 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -20,8 +20,16 @@
package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.stream.Collectors;
+import org.apache.commons.lang3.tuple.ImmutablePair;
+import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.lops.PickByCount.OperationTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -41,6 +49,7 @@ import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
+@SuppressWarnings("unchecked")
public class QuantilePickFEDInstruction extends BinaryFEDInstruction {
private final OperationTypes _type;
@@ -61,6 +70,10 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
this(op, in, in2, out, type, inmem, opcode, istr,
FederatedOutput.NONE);
}
+ public OperationTypes getQPickType() {
+ return _type;
+ }
+
public static QuantilePickFEDInstruction parseInstruction ( String str
) {
String[] parts =
InstructionUtils.getInstructionPartsWithValueType(str);
String opcode = parts[0];
@@ -69,7 +82,6 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
//instruction parsing
if( parts.length == 4 ) {
//instructions of length 4 originate from unary - mr-iqm
- //TODO this should be refactored to use pickvaluecount
lops
CPOperand in1 = new CPOperand(parts[1]);
CPOperand in2 = new CPOperand(parts[2]);
CPOperand out = new CPOperand(parts[3]);
@@ -97,6 +109,380 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
@Override
public void processInstruction(ExecutionContext ec) {
+
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) ||
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+ processColumnQPick(ec);
+ else
+ processRowQPick(ec);
+ }
+
+ public <T> void processRowQPick(ExecutionContext ec) {
+ MatrixObject in = ec.getMatrixObject(input1);
+ FederationMap fedMap = in.getFedMapping();
+ boolean average = _type == OperationTypes.MEDIAN;
+
+ double[] quantiles = input2 != null ? (input2.isMatrix() ?
ec.getMatrixInput(input2).getDenseBlockValues() :
+ input2.isScalar() ? new double[]
{ec.getScalarInput(input2).getDoubleValue()} : null) :
+ (average ? new double[] {0.5} : _type ==
OperationTypes.IQM ? new double[] {0.25, 0.75} : null);
+
+ if (input2 != null && input2.isMatrix())
+ ec.releaseMatrixInput(input2.getName());
+
+ // Find min and max
+ long varID = FederationUtils.getNextFedDataID();
+ List<double[]> minMax = new ArrayList<>();
+ fedMap.mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
+ FederatedRequest.RequestType.EXEC_UDF,
-1,
+ new
QuantilePickFEDInstruction.MinMax(data.getVarID()))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ double[] rangeMinMax = (double[])
response.getData()[0];
+ minMax.add(rangeMinMax);
+
+ return null;
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ });
+
+ // Find weights sum, min and max
+ double globalMin = Double.MAX_VALUE, globalMax =
Double.MIN_VALUE, vectorLength = in.getNumColumns() == 2 ? 0 : in.getNumRows(),
sumWeights = 0.0;
+ for(double[] values : minMax) {
+ globalMin = Math.min(globalMin, values[0]);
+ globalMax = Math.max(globalMax, values[1]);
+ if(in.getNumColumns() == 2)
+ vectorLength += values[2];
+ sumWeights += values[3];
+ }
+
+ // Average for median
+ average = average && (in.getNumColumns() == 2 ? sumWeights :
in.getNumRows()) % 2 == 0;
+
+ // If multiple quantiles take first histogram and reuse bins,
otherwise recursively get bin with result
+ int numBuckets = 256; // (int) Math.round(in.getNumRows() /
2.0);
+ int quantileIndex = quantiles != null && quantiles.length == 1
? (int) Math.round(vectorLength * quantiles[0]) : -1;
+
+ T ret = createHistogram(in, (int) vectorLength, globalMin,
globalMax, numBuckets, quantileIndex, average);
+
+ // Compute and set results
+ if(quantiles != null && quantiles.length > 1) {
+ computeMultipleQuantiles(ec, in,
(Map<ImmutablePair<Double, Double>, Integer>) ret, quantiles, (int)
vectorLength, varID, _type);
+ } else
+ getSingleQuantileResult(ret, ec, fedMap, varID,
average, false, (int) vectorLength);
+ }
+
+ private <T> void computeMultipleQuantiles(ExecutionContext ec,
MatrixObject in, Map<ImmutablePair<Double, Double>, Integer> buckets, double[]
quantiles, int vectorLength, long varID, OperationTypes type) {
+ MatrixBlock out = new MatrixBlock(quantiles.length, 1, false);
+ ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>>[] bucketsWithIndex = new ImmutableTriple[quantiles.length];
+
+ // Find bins with each quantile for first histogram
+ int sizeBeforeTmp = 0, sizeBefore = 0, countFoundBins = 0;
+ for(Map.Entry<ImmutablePair<Double, Double>, Integer> entry :
buckets.entrySet()) {
+ sizeBeforeTmp += entry.getValue();
+
+ for(int i = 0; i < quantiles.length; i++) {
+ int quantileIndex = (int)
Math.round(vectorLength * quantiles[i]);
+ ImmutablePair<Double, Double> bucketWithQ =
null;
+
+ if(quantileIndex > sizeBefore && quantileIndex
<= sizeBeforeTmp) {
+ bucketWithQ = entry.getKey();
+ bucketsWithIndex[i] = new
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore,
entry.getValue(), bucketWithQ);
+ countFoundBins++;
+ }
+ }
+
+ sizeBefore = sizeBeforeTmp;
+ if(countFoundBins == quantiles.length)
+ break;
+ }
+
+ // Find each quantile bin recursively
+ Map<Integer, T> retBuckets = new HashMap<>();
+
+ double left = 0, right = 0;
+ for(int i = 0; i < bucketsWithIndex.length; i++) {
+ int nextNumBuckets = bucketsWithIndex[i].middle < 100 ?
bucketsWithIndex[i].middle * 2 : (int) Math.round(bucketsWithIndex[i].middle /
2.0);
+ T hist = createHistogram(in, vectorLength,
bucketsWithIndex[i].right.left, bucketsWithIndex[i].right.right,
nextNumBuckets, bucketsWithIndex[i].left, false);
+
+ if(_type == OperationTypes.IQM) {
+ left = i == 0 ? hist instanceof ImmutablePair ?
((ImmutablePair<Double, Double>)hist).right : (Double) hist : left;
+ right = i == 1 ? hist instanceof ImmutablePair
? ((ImmutablePair<Double, Double>)hist).left : (Double) hist : right;
+ } else {
+ if(hist instanceof ImmutablePair)
+ retBuckets.put(i, hist); // set value
if returned double instead of bin
+ else
+ out.setValue(i, 0, (Double) hist);
+ }
+ }
+
+ if(type == OperationTypes.IQM) {
+ ImmutablePair<Double, Double> IQMRange = new
ImmutablePair<>(left, right);
+ getSingleQuantileResult(IQMRange, ec,
in.getFedMapping(), varID, false, true, vectorLength);
+ }
+ else {
+ if(!retBuckets.isEmpty()) {
+ // Search for values within bucket range where
it as returned
+ in.getFedMapping().mapParallel(varID, (range,
data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
+
FederatedRequest.RequestType.EXEC_UDF,
+ -1,
+ new
QuantilePickFEDInstruction.GetValuesInRanges(data.getVarID(), quantiles.length,
(HashMap<Integer, ImmutablePair<Double, Double>>) retBuckets))).get();
+ if(!response.isSuccessful())
+
response.throwExceptionFromResponse();
+
+ // Add results by row
+ MatrixBlock tmp = (MatrixBlock)
response.getData()[0];
+ synchronized(out) {
+
out.binaryOperationsInPlace(InstructionUtils.parseBinaryOperator("+"), tmp);
+ }
+ return null;
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException(e);
+ }
+ });
+ }
+
+ ec.setMatrixOutput(output.getName(), out);
+ }
+ }
+
+ private <T> void getSingleQuantileResult(T ret, ExecutionContext ec,
FederationMap fedMap, long varID, boolean average, boolean isIQM, int
vectorLength) {
+ double result = 0.0;
+ if(ret instanceof ImmutablePair) {
+ // Search for values within bucket range
+ List<Double> values = new ArrayList<>();
+ fedMap.mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF,
+ -1,
+ new
QuantilePickFEDInstruction.GetValuesInRange(data.getVarID(),
(ImmutablePair<Double, Double>) ret, isIQM))).get();
+ if(!response.isSuccessful())
+
response.throwExceptionFromResponse();
+ values.add((double)
response.getData()[0]);
+ return null;
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ });
+
+ // Sum of 1 or 2 values
+ result = values.stream().reduce(0.0, Double::sum);
+
+ } else
+ result = (Double) ret;
+
+ result /= (average ? 2 : isIQM ? ((int) Math.round(vectorLength
* 0.75) - (int) Math.round(vectorLength * 0.25)) : 1);
+
+ ec.setScalarOutput(output.getName(), new DoubleObject(result));
+ }
+
+ public <T> T createHistogram(MatrixObject in, int vectorLength, double
globalMin, double globalMax, int numBuckets, int quantileIndex, boolean
average) {
+ FederationMap fedMap = in.getFedMapping();
+
+ Map<ImmutablePair<Double, Double>, Integer> buckets = new
LinkedHashMap<>();
+ List<Map<ImmutablePair<Double, Double>, Integer>> hists = new
ArrayList<>();
+ List<Set<Double>> distincts = new ArrayList<>();
+
+ double bucketRange = (globalMax-globalMin) / numBuckets;
+ boolean isEvenNumRows = vectorLength % 2 == 0;
+
+ // Create buckets according to min and max
+ double tmpMin = globalMin, tmpMax = globalMax;
+ for(int i = 0; i < numBuckets && tmpMin <= tmpMax; i++) {
+ buckets.put(new ImmutablePair<>(tmpMin, tmpMin +
bucketRange), 0);
+ tmpMin += bucketRange;
+ }
+
+ // Create histograms
+ long varID = FederationUtils.getNextFedDataID();
+ fedMap.mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
+ FederatedRequest.RequestType.EXEC_UDF,
-1,
+ new
QuantilePickFEDInstruction.GetHistogram(data.getVarID(), buckets,
globalMax))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ Map<ImmutablePair<Double, Double>, Integer>
rangeHist = (Map<ImmutablePair<Double, Double>, Integer>) response.getData()[0];
+ hists.add(rangeHist);
+ Set<Double> rangeDistinct = (Set<Double>)
response.getData()[1];
+ distincts.add(rangeDistinct);
+ return null;
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ });
+
+ // Merge results into one histogram
+ for(ImmutablePair<Double, Double> bucket : buckets.keySet()) {
+ int value = 0;
+ for(Map<ImmutablePair<Double, Double>, Integer> hist :
hists)
+ value += hist.get(bucket);
+ buckets.put(bucket, value);
+ }
+
+ if(quantileIndex == -1)
+ return (T) buckets;
+
+ // Find bucket with quantile
+ ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> bucketWithIndex = getBucketWithIndex(buckets, quantileIndex, average,
isEvenNumRows);
+
+ // Check if can terminate
+ Set<Double> distinctValues =
distincts.stream().flatMap(Set::stream).collect(Collectors.toSet());
+ if((distinctValues.size() == 1 && !average) ||
(distinctValues.size() == 2 && average))
+ return (T) distinctValues.stream().reduce(0.0, (a, b)
-> a + b);
+
+ ImmutablePair<Double, Double> finalBucketWithQ =
bucketWithIndex.right;
+ List<Double> distinctInNewBucket =
distinctValues.stream().filter( e -> e >= finalBucketWithQ.left && e <=
finalBucketWithQ.right).collect(Collectors.toList());
+ if((distinctInNewBucket.size() == 1 && !average) || (average &&
distinctInNewBucket.size() == 2))
+ return (T) distinctInNewBucket.stream().reduce(0.0, (a,
b) -> a + b);
+
+ if(distinctValues.size() == 1 || (bucketWithIndex.middle == 1
&& !average) || (bucketWithIndex.middle == 2 && isEvenNumRows && average) ||
+ globalMin == globalMax)
+ return (T) bucketWithIndex.right;
+
+ int nextNumBuckets = bucketWithIndex.middle < 100 ?
bucketWithIndex.middle * 2 : (int) Math.round(bucketWithIndex.middle / 2.0);
+
+ // Add more bins to not stuck
+ if(numBuckets == nextNumBuckets && globalMin ==
bucketWithIndex.right.left && globalMax == bucketWithIndex.right.right) {
+ nextNumBuckets *= 2;
+ }
+
+ return createHistogram(in, vectorLength,
bucketWithIndex.right.left, bucketWithIndex.right.right, nextNumBuckets,
bucketWithIndex.left, average);
+ }
+
+ private ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>> getBucketWithIndex(Map<ImmutablePair<Double, Double>, Integer>
buckets, int quantileIndex, boolean average, boolean isEvenNumRows) {
+ int sizeBeforeTmp = 0, sizeBefore = 0, bucketWithQSize = 0;
+ ImmutablePair<Double, Double> bucketWithQ = null;
+ for(Map.Entry<ImmutablePair<Double, Double>, Integer> range :
buckets.entrySet()) {
+ sizeBeforeTmp += range.getValue();
+ if(quantileIndex <= sizeBeforeTmp && bucketWithQSize ==
0) {
+ bucketWithQ = range.getKey();
+ bucketWithQSize = range.getValue();
+ sizeBeforeTmp -= bucketWithQSize;
+ sizeBefore = sizeBeforeTmp;
+
+ if(!average || sizeBefore + bucketWithQSize >=
quantileIndex + 1)
+ break;
+ } else if(quantileIndex + 1 <= sizeBeforeTmp +
bucketWithQSize && isEvenNumRows && average) {
+ // Add right bin that contains second index
+ int bucket2Size = range.getValue();
+ if (bucket2Size != 0) {
+ bucketWithQ = new
ImmutablePair<>(bucketWithQ.left, range.getKey().right);
+ bucketWithQSize += bucket2Size;
+ break;
+ }
+ }
+ }
+ quantileIndex = quantileIndex == 1 ? 1 : quantileIndex -
sizeBefore;
+ return new ImmutableTriple<>(quantileIndex, bucketWithQSize,
bucketWithQ);
+ }
+
+ public static class GetHistogram extends FederatedUDF {
+ private static final long serialVersionUID =
5413355823424777742L;
+ private final Map<ImmutablePair<Double, Double>, Integer>
_buckets;
+ private final double _max;
+
+ private GetHistogram(long input, Map<ImmutablePair<Double,
Double>, Integer> buckets, double max) {
+ super(new long[] {input});
+ _buckets = buckets;
+ _max = max;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ double[] values = mb.getDenseBlockValues();
+ boolean isWeighted = mb.getNumColumns() == 2;
+
+ Map<ImmutablePair<Double, Double>, Integer> hist =
_buckets;
+ Set<Double> distinct = new HashSet<>();
+
+ for(int i = 0; i < values.length - (isWeighted ? 1 :
0); i += (isWeighted ? 2 : 1)) {
+ double val = values[i];
+ int weight = isWeighted ? (int) values[i+1] : 1;
+ for (Map.Entry<ImmutablePair<Double, Double>,
Integer> range : _buckets.entrySet()) {
+ if((val >= range.getKey().left && val <
range.getKey().right) || (val == _max && val == range.getKey().right)) {
+ hist.put(range.getKey(),
range.getValue() + weight);
+
+ distinct.add(val);
+ }
+ }
+ }
+
+ Object[] ret = new Object[] {hist, distinct.size() < 3
? distinct : new HashSet<>()};
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ public static class GetValuesInRanges extends FederatedUDF {
+ private static final long serialVersionUID =
8663298932616139153L;
+ private final int _numQuantiles;
+ private final HashMap<Integer, ImmutablePair<Double, Double>>
_ranges;
+
+ private GetValuesInRanges(long input,int numQuantiles,
HashMap<Integer, ImmutablePair<Double, Double>> ranges) {
+ super(new long[] {input});
+ _ranges = ranges;
+ _numQuantiles = numQuantiles;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ double[] values = mb.getDenseBlockValues();
+
+ // FIXME rewrite - see binning encode
+ MatrixBlock res = new MatrixBlock(_numQuantiles, 1,
false);
+ for(double val : values) {
+ for(Map.Entry<Integer, ImmutablePair<Double,
Double>> entry : _ranges.entrySet()) {
+ if(entry.getValue().left <= val && val
<= entry.getValue().right) {
+ res.setValue(entry.getKey(),
0,val);
+ break;
+ }
+ }
+ }
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+ }
+
+ @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ public static class MinMax extends FederatedUDF {
+ private static final long serialVersionUID =
-3906698363866500744L;
+
+ private MinMax(long input) {
+ super(new long[] {input});
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ double[] ret = new double[]{mb.getNumColumns() == 2 ?
mb.colMin().quickGetValue(0, 0) : mb.min(),
+ mb.getNumColumns() == 2 ?
mb.colMax().quickGetValue(0, 0) : mb.max(),
+ mb.getNumColumns() == 2 ?
mb.colSum().quickGetValue(0, 1) : 0,
+ mb.getNumColumns() == 2 ?
mb.sumWeightForQuantile() : 0};
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, ret);
+ }
+
+ @Override public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ public void processColumnQPick(ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap fedMapping = in.getFedMapping();
@@ -125,12 +511,12 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
response = data
.executeFederatedOperation(
new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
QuantilePickFEDInstruction.IQM(data.getVarID()))).get();
+ new
QuantilePickFEDInstruction.ColIQM(data.getVarID()))).get();
break;
case MEDIAN:
response = data
.executeFederatedOperation(new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
- new
QuantilePickFEDInstruction.Median(data.getVarID()))).get();
+ new
QuantilePickFEDInstruction.ColMedian(data.getVarID()))).get();
break;
default:
throw new
DMLRuntimeException("Unsupported qpick operation type: "+_type);
@@ -149,6 +535,9 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
assert res.size() == 1;
+ if (input2 != null && input2.isMatrix())
+ ec.releaseMatrixInput(input2.getName());
+
if(output.isScalar())
ec.setScalarOutput(output.getName(), new
DoubleObject((double) res.get(0)));
else
@@ -190,11 +579,49 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
}
- private static class IQM extends FederatedUDF {
+ public static class GetValuesInRange extends FederatedUDF {
+ private static final long serialVersionUID =
5413355823424777742L;
+ private final ImmutablePair<Double, Double> _range;
+ private final boolean _sumInRange;
+
+ private GetValuesInRange(long input, ImmutablePair<Double,
Double> range, boolean sumInRange) {
+ super(new long[] {input});
+ _range = range;
+ _sumInRange = sumInRange;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ MatrixBlock mb = ((MatrixObject)
data[0]).acquireReadAndRelease();
+ double[] values = mb.getDenseBlockValues();
+
+ double res = 0.0;
+ int i = 0;
+
+ // FIXME better search, e.g. sort in QSort and binary
search
+ for(double val : values) {
+ // different conditions for IQM and simple QPICK
+ if((!_sumInRange && _range.left <= val && val
<= _range.right) ||
+ (_sumInRange && _range.left < val &&
val <= _range.right))
+ res += val;
+ if(i++ > 2 && !_sumInRange)
+ break;
+ }
+
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS, res);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
+ private static class ColIQM extends FederatedUDF {
private static final long serialVersionUID =
2223186699111957677L;
- protected IQM(long input) {
+ protected ColIQM(long input) {
super(new long[] {input});
}
@Override
@@ -209,11 +636,11 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
}
- private static class Median extends FederatedUDF {
+ private static class ColMedian extends FederatedUDF {
private static final long serialVersionUID =
-2808597461054603816L;
- protected Median(long input) {
+ protected ColMedian(long input) {
super(new long[] {input});
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
index 0a545bc..cb76404 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantileSortFEDInstruction.java
@@ -75,9 +75,40 @@ public class QuantileSortFEDInstruction extends
UnaryFEDInstruction{
throw new DMLRuntimeException("Unknown opcode while
parsing a QuantileSortFEDInstruction: " + str);
}
}
-
@Override
public void processInstruction(ExecutionContext ec) {
+
if(ec.getMatrixObject(input1).isFederated(FederationMap.FType.COL) ||
ec.getMatrixObject(input1).isFederated(FederationMap.FType.FULL))
+ processColumnQSort(ec);
+ else
+ processRowQSort(ec);
+ }
+
+ public void processRowQSort(ExecutionContext ec) {
+ MatrixObject in = ec.getMatrixObject(input1);
+ MatrixObject out = ec.getMatrixObject(output);
+
+ // TODO make sure that qsort result is used by qpick only where
the main operation happens
+ if(input2 != null) {
+ MatrixObject weights = ec.getMatrixObject(input2);
+ String newInst =
InstructionUtils.replaceOperand(instString, 1, "append");
+ newInst = InstructionUtils.concatOperands(newInst,
"true");
+ FederatedRequest[] fr1 =
in.getFedMapping().broadcastSliced(weights, false);
+ FederatedRequest fr2 =
FederationUtils.callInstruction(newInst, output,
+ new CPOperand[]{input1, input2}, new long[]{
in.getFedMapping().getID(), fr1[0].getID()});
+ in.getFedMapping().execute(getTID(), true, fr1, fr2);
+
out.getDataCharacteristics().set(in.getDataCharacteristics());
+ out.getDataCharacteristics().setCols(2);
+
out.setFedMapping(in.getFedMapping().copyWithNewID(fr2.getID(), 2));
+ }
+ else {
+ // make a copy without sorting
+ long id = FederationUtils.getNextFedDataID();
+
out.getDataCharacteristics().set(in.getDataCharacteristics());
+
out.setFedMapping(in.getFedMapping().identCopy(getTID(), id));
+ }
+ }
+
+ public void processColumnQSort(ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap fedMapping = in.getFedMapping();
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 989c8c4..eb152ab 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -952,6 +952,26 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock, Externalizab
InstructionUtils.parseBasicAggregateUnaryOperator("uamin", 1));
return out.quickGetValue(0, 0);
}
+
+ /**
+ * Wrapper method for reduceall-colMin of a matrix.
+ *
+ * @return A new MatrixBlock containing the column mins of this matrix
+ */
+ public MatrixBlock colMin() {
+ AggregateUnaryOperator op =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmin", 1);
+ return aggregateUnaryOperations(op, null, 1000, null, true);
+ }
+
+ /**
+ * Wrapper method for reduceall-colMin of a matrix.
+ *
+ * @return A new MatrixBlock containing the column mins of this matrix
+ */
+ public MatrixBlock colMax() {
+ AggregateUnaryOperator op =
InstructionUtils.parseBasicAggregateUnaryOperator("uacmax", 1);
+ return aggregateUnaryOperations(op, null, 1000, null, true);
+ }
/**
* Wrapper method for reduceall-max of a matrix.
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
index 226ad53..598256a 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileTest.java
@@ -41,17 +41,24 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/quantile/";
private final static String TEST_NAME1 = "FederatedQuantileTest";
private final static String TEST_NAME2 = "FederatedMedianTest";
- private final static String TEST_NAME3 = "FederatedIQMTest";
+ private final static String TEST_NAME3 = "FederatedIQRTest";
private final static String TEST_NAME4 = "FederatedQuantilesTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedQuantileTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@Parameterized.Parameter()
public int rows;
+ @Parameterized.Parameter(1)
+ public int cols;
+ @Parameterized.Parameter(2)
+ public boolean rowPartitioned;
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{1000}});
+ return Arrays.asList(new Object[][] {
+ {1000, 1, false},
+ {16, 1, true}
+ });
}
@Override
@@ -76,38 +83,28 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
public void federatedMedianCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME2, -1); }
@Test
- public void federatedIQMCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+ public void federatedIQRCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME3, -1); }
@Test
- public void federatedQuantilesCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME1, -1); }
+ public void federatedQuantilesCP() {
federatedQuartile(Types.ExecMode.SINGLE_NODE, TEST_NAME4, -1); }
@Test
-// @Ignore
public void federatedQuantile1SP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.25); }
@Test
-// @Ignore
public void federatedQuantile2SP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.5); }
@Test
-// @Ignore
public void federatedQuantile3SP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, 0.75); }
@Test
-// @Ignore
public void federatedMedianSP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME2, -1); }
@Test
-// @Ignore
- public void federatedIQMSP() { federatedQuartile(Types.ExecMode.SPARK,
TEST_NAME1, -1); }
+ public void federatedIQRSP() { federatedQuartile(Types.ExecMode.SPARK,
TEST_NAME3, -1); }
@Test
-// @Ignore
- public void federatedQuantilesSP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME1, -1); }
-
-
-
-
+ public void federatedQuantilesSP() {
federatedQuartile(Types.ExecMode.SPARK, TEST_NAME4, -1); }
public void federatedQuartile(Types.ExecMode execMode, String
TEST_NAME, double p) {
boolean sparkConfigOld = DMLScript.USE_LOCAL_SPARK_CONFIG;
@@ -116,21 +113,72 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
-
- MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1,
blocksize, rows);
- writeInputMatrixWithMTD("X1", X1, false, mc);
+ double[][] X1, X2, X3, X4;
+ int port1, port2, port3, port4;
+ Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ String[] programArgs1, programArgs2;
+ if(rowPartitioned) {
+ X1 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 3);
+ X2 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 7);
+ X3 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 8);
+ X4 = getRandomMatrix(rows / 4, cols, 1, 12, 1, 9);
+
+ MatrixCharacteristics mc1 = new
MatrixCharacteristics(rows / 4, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc1);
+ writeInputMatrixWithMTD("X2", X2, false, mc1);
+ writeInputMatrixWithMTD("X3", X3, false, mc1);
+ writeInputMatrixWithMTD("X4", X4, false, mc1);
+
+ port1 = getRandomAvailablePort();
+ port2 = getRandomAvailablePort();
+ port3 = getRandomAvailablePort();
+ port4 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2,
FED_WORKER_WAIT_S);
+ t3 = startLocalFedWorkerThread(port3,
FED_WORKER_WAIT_S);
+ t4 = startLocalFedWorkerThread(port4);
+
+ programArgs1 = new String[] {"-explain", "-stats",
"100", "-args",
+ String.valueOf(p), expected("S"),
Boolean.toString(rowPartitioned).toUpperCase(),
+ input("X1"), input("X2"), input("X3"),
input("X4")};
+ programArgs2 = new String[] {"-explain","-stats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + cols,
+ "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "p=" + String.valueOf(p),
+ "out_S=" + output("S")};
+ }
+ else {
+ X1 = getRandomMatrix(rows, 1, 1, 12, 1, 3);
+ MatrixCharacteristics mc = new
MatrixCharacteristics(rows, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+
+ port1 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1);
+
+ programArgs1 = new String[] {"-explain", "-stats",
"100", "-args",
+ String.valueOf(p), expected("S"),
Boolean.toString(rowPartitioned).toUpperCase(), input("X1"),
+ input("X1"), input("X1"), input("X1")};
+ programArgs2 = new String[] {"-explain", "-stats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X3=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X4=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "rows=" + rows, "cols=" + cols, "p=" +
String.valueOf(p),
+ "out_S=" + output("S"), "rP=" +
Boolean.toString(rowPartitioned).toUpperCase()
+ };
+ }
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
- int port1 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1);
// we need the reference file to not be written to hdfs, so we
get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
// Run reference dml script with normal matrix for Row/Col
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-explain", "-stats", "100",
"-args", input("X1"), expected("S"), String.valueOf(p)};
+
+ programArgs = programArgs1;
runTest(true, false, null, -1);
// reference file should not be written to hdfs, so we set
platform here
@@ -142,11 +190,7 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
loadTestConfiguration(config);
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-explain", "-stats", "100",
"-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
- "rows=" + rows, "cols=" + 1, "p=" + String.valueOf(p),
- "out_S=" + output("S")
- };
+ programArgs = programArgs2;
runTest(true, false, null, -1);
// compare all sums via files
@@ -156,8 +200,15 @@ public class FederatedQuantileTest extends
AutomatedTestBase {
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-
TestUtils.shutdownThreads(t1);
+ if(rowPartitioned) {
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t2, t3, t4);
+ }
+
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
index 4511c6d..9ac0b5d 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/primitives/FederatedQuantileWeightsTest.java
@@ -41,17 +41,19 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
private final static String TEST_DIR = "functions/federated/quantile/";
private final static String TEST_NAME1 = "FederatedQuantileWeightsTest";
private final static String TEST_NAME2 = "FederatedMedianWeightsTest";
- private final static String TEST_NAME3 = "FederatedIQMWeightsTest";
- private final static String TEST_NAME4 =
"FederatedQuantilesWeightsTest";
private final static String TEST_CLASS_DIR = TEST_DIR +
FederatedQuantileWeightsTest.class.getSimpleName() + "/";
private final static int blocksize = 1024;
@Parameterized.Parameter()
public int rows;
+ @Parameterized.Parameter(1)
+ public boolean rowPartitioned;
@Parameterized.Parameters
public static Collection<Object[]> data() {
- return Arrays.asList(new Object[][] {{1000}});
+ return Arrays.asList(new Object[][] {
+ {1000, false},
+ {12, true}});
}
@Override
@@ -59,8 +61,6 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"S.scalar"}));
addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2, new String[] {"S.scalar"}));
- addTestConfiguration(TEST_NAME3, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3, new String[] {"S.scalar"}));
- addTestConfiguration(TEST_NAME4, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME4, new String[] {"S"}));
}
@Test
@@ -88,24 +88,81 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
getAndLoadTestConfiguration(TEST_NAME);
String HOME = SCRIPT_DIR + TEST_DIR;
- double[][] X1 = getRandomMatrix(rows, 1, 1, 5, 1, 3);
+ double[][] X1, X2, X3, X4;
+ int port1, port2, port3, port4;
+ Thread t1 = null, t2 = null, t3 = null, t4 = null;
+ String[] programArgs1, programArgs2;
- MatrixCharacteristics mc = new MatrixCharacteristics(rows, 1,
blocksize, rows);
- writeInputMatrixWithMTD("X1", X1, false, mc);
-
- double[][] W = getRandomMatrix(rows, 1, 1, 1, 1.0, 1);
+ double[][] W = getRandomMatrix(rows, 1, 1, 5, 1.0, 1);
+ for(int i = 0; i < W.length; i++){
+ for(int y = 0; y < W[0].length; y++){
+ W[i][y] = (double) Math.round(W[i][y]);
+ }
+ }
writeInputMatrixWithMTD("W", W, false);
+ if(rowPartitioned) {
+ X1 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 3);
+ X2 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 7);
+ X3 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 8);
+ X4 = getRandomMatrix(rows / 4, 1, 1, 12, 1, 9);
+
+ MatrixCharacteristics mc1 = new
MatrixCharacteristics(rows / 4, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc1);
+ writeInputMatrixWithMTD("X2", X2, false, mc1);
+ writeInputMatrixWithMTD("X3", X3, false, mc1);
+ writeInputMatrixWithMTD("X4", X4, false, mc1);
+
+ port1 = getRandomAvailablePort();
+ port2 = getRandomAvailablePort();
+ port3 = getRandomAvailablePort();
+ port4 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1,
FED_WORKER_WAIT_S);
+ t2 = startLocalFedWorkerThread(port2,
FED_WORKER_WAIT_S);
+ t3 = startLocalFedWorkerThread(port3,
FED_WORKER_WAIT_S);
+ t4 = startLocalFedWorkerThread(port4);
+
+ programArgs1 = new String[] {"-explain", "-stats",
"100", "-args",
+ String.valueOf(p), expected("S"),
Boolean.toString(rowPartitioned).toUpperCase(), input("W"),
+ input("X1"), input("X2"), input("X3"),
input("X4")};
+ programArgs2 = new String[] {"-explain","-stats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port2,
input("X2")),
+ "in_X3=" + TestUtils.federatedAddress(port3,
input("X3")),
+ "in_X4=" + TestUtils.federatedAddress(port4,
input("X4")), "rows=" + rows, "cols=" + 1,
+ "rP=" +
Boolean.toString(rowPartitioned).toUpperCase(), "p=" + String.valueOf(p), "W="
+ input("W"),
+ "out_S=" + output("S")};
+ }
+ else {
+ X1 = getRandomMatrix(rows, 1, 1, 12, 1, 3);
+ MatrixCharacteristics mc = new
MatrixCharacteristics(rows, 1, blocksize, rows);
+ writeInputMatrixWithMTD("X1", X1, false, mc);
+
+ port1 = getRandomAvailablePort();
+ t1 = startLocalFedWorkerThread(port1);
+
+ programArgs1 = new String[] {"-explain", "-stats",
"100", "-args",
+ String.valueOf(p), expected("S"),
Boolean.toString(rowPartitioned).toUpperCase(), input("W"), input("X1"),
+ input("X1"), input("X1"), input("X1")};
+ programArgs2 = new String[] {"-explain", "-stats",
"100", "-nvargs",
+ "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X2=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X3=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "in_X4=" + TestUtils.federatedAddress(port1,
input("X1")),
+ "rows=" + rows, "cols=" + 1, "p=" +
String.valueOf(p), "W=" + input("W"),
+ "out_S=" + output("S"), "rP=" +
Boolean.toString(rowPartitioned).toUpperCase()
+ };
+ }
+
// empty script name because we don't execute any script, just
start the worker
fullDMLScriptName = "";
- int port1 = getRandomAvailablePort();
- Thread t1 = startLocalFedWorkerThread(port1);
// we need the reference file to not be written to hdfs, so we
get the correct format
rtplatform = Types.ExecMode.SINGLE_NODE;
// Run reference dml script with normal matrix for Row/Col
fullDMLScriptName = HOME + TEST_NAME + "Reference.dml";
- programArgs = new String[] {"-explain", "-stats", "100",
"-args", input("X1"), expected("S"), String.valueOf(p), input("W")};
+
+ programArgs = programArgs1;
runTest(true, false, null, -1);
// reference file should not be written to hdfs, so we set
platform here
@@ -117,12 +174,7 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
loadTestConfiguration(config);
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-explain", "-stats", "100",
"-nvargs",
- "in_X1=" + TestUtils.federatedAddress(port1,
input("X1")),
- "rows=" + rows, "cols=" + 1,
- "p=" + String.valueOf(p), "W=" + input("W"),
- "out_S=" + output("S")
- };
+ programArgs = programArgs2;
runTest(true, false, null, -1);
// compare all sums via files
@@ -132,8 +184,15 @@ public class FederatedQuantileWeightsTest extends
AutomatedTestBase {
// check that federated input files are still existing
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X1")));
-
TestUtils.shutdownThreads(t1);
+ if(rowPartitioned) {
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X2")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X3")));
+
Assert.assertTrue(HDFSTool.existsFileOnHDFS(input("X4")));
+
+ TestUtils.shutdownThreads(t2, t3, t4);
+ }
+
rtplatform = platformOld;
DMLScript.USE_LOCAL_SPARK_CONFIG = sparkConfigOld;
}
diff --git
a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
index cb17fc3..bdd2518 100644
---
a/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
+++
b/src/test/scripts/functions/federated/aggregate/FederatedMeanTestReference.dml
@@ -23,7 +23,6 @@ if($6) {
A = rbind(read($1), read($2), read($3), read($4));
}
else { A = cbind(read($1), read($2), read($3), read($4)); }
-#A = read($1)
-s = mean(A);
+s = mean(A);
write(s, $5);
diff --git a/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
index f50e72e..e6bed3f 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedIQRTest.dml
@@ -19,6 +19,13 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
+
s = interQuartileMean(A);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
index 84fcace..c6ca4db 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedIQRTestReference.dml
@@ -19,6 +19,9 @@
#
#-------------------------------------------------------------
-A = read($1);
+if($3) {
+ A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
s = interQuartileMean(A);
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
index 22f2504..25800b5 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedMedianTest.dml
@@ -19,6 +19,12 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
s = median(A);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
index 1544987..ee1ede8 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedMedianTestReference.dml
@@ -19,6 +19,10 @@
#
#-------------------------------------------------------------
-A = read($1);
+print($3)
+if($3) {
+ A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
s = median(A);
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
index 58ec328..34fc30c 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTest.dml
@@ -19,7 +19,13 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
W = read($W);
s = median(A, W);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
index 6b9e3de..8f9cf3c 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedMedianWeightsTestReference.dml
@@ -19,7 +19,10 @@
#
#-------------------------------------------------------------
-A = read($1);
W = read($4);
+if($3) {
+ A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
s = median(A, W);
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
index 1c84330..ebe1f27 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantileTest.dml
@@ -19,6 +19,13 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
+
s = quantile(A, $p);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
index 7a1fb36..4e371b0 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantileTestReference.dml
@@ -19,6 +19,9 @@
#
#-------------------------------------------------------------
-A = read($1);
-s = quantile (A, $3);
+if($3) {
+ A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
+s = quantile (A, $1);
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
index c423a65..3eedb7e 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTest.dml
@@ -19,7 +19,13 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
W = read($W);
s = quantile(A, W, $p);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
index 6796757..733e68d 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantileWeightsTestReference.dml
@@ -19,7 +19,11 @@
#
#-------------------------------------------------------------
-A = read($1);
W = read($4);
-s = quantile (A, W, $3);
+if($3) {
+ A = rbind(read($5), read($6), read($7), read($8));
+}
+else { A = read($5); }
+
+s = quantile (A, W, $1);
write(s, $2);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
index f5f22fa..8f11bc6 100644
--- a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
+++ b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTest.dml
@@ -19,9 +19,15 @@
#
#-------------------------------------------------------------
-A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+if ($rP) {
+ A = federated(addresses=list($in_X1, $in_X2, $in_X3, $in_X4),
+ ranges=list(list(0, 0), list($rows/4, $cols), list($rows/4, 0),
list(2*$rows/4, $cols),
+ list(2*$rows/4, 0), list(3*$rows/4, $cols), list(3*$rows/4, 0),
list($rows, $cols)));
+} else {
+ A = federated(addresses=list($in_X1), ranges=list(list(0, 0), list($rows,
$cols)));
+}
P = matrix(0.25, 3, 1);
P[2,1] = 0.5;
P[3,1] = 0.75;
-s = quantile(X=A, P=P);
+s = quantile(A, P);
write(s, $out_S);
diff --git
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
index 96970d0..e4d6b71 100644
---
a/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
+++
b/src/test/scripts/functions/federated/quantile/FederatedQuantilesTestReference.dml
@@ -19,9 +19,13 @@
#
#-------------------------------------------------------------
-A = read($1);
+if($3) {
+ A = rbind(read($4), read($5), read($6), read($7));
+}
+else { A = read($4); }
P = matrix(0.25, 3, 1);
P[2,1] = 0.5;
P[3,1] = 0.75;
-s = quantile(X=A, P=P);
+s = quantile(A, P);
+print(toString(s))
write(s, $2);