This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f1e877561a [SYSTEMDS-3328] Federated transform encode/apply w/
equi-height binning
f1e877561a is described below
commit f1e877561a2b19e12d726ee8b488cc2fcd1358c3
Author: OlgaOvcharenko <[email protected]>
AuthorDate: Sun May 8 19:41:07 2022 +0200
[SYSTEMDS-3328] Federated transform encode/apply w/ equi-height binning
Closes #1562.
---
.../org/apache/sysds/hops/fedplanner/FTypes.java | 1 -
...tiReturnParameterizedBuiltinFEDInstruction.java | 140 +++++++++++++---
.../fed/QuantilePickFEDInstruction.java | 179 +++++++++++++++++++--
.../runtime/transform/encode/ColumnEncoder.java | 10 ++
.../runtime/transform/encode/ColumnEncoderBin.java | 47 +++++-
.../transform/encode/ColumnEncoderComposite.java | 10 ++
.../transform/encode/MultiColumnEncoder.java | 19 ++-
.../TransformFederatedEncodeApplyTest.java | 70 +++++++-
.../transform/TransformFrameEncodeApply.dml | 1 +
9 files changed, 436 insertions(+), 41 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
index d06debb43b..ec1da8b152 100644
--- a/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
+++ b/src/main/java/org/apache/sysds/hops/fedplanner/FTypes.java
@@ -77,7 +77,6 @@ public class FTypes
OTHER(FPartitioning.MIXED, FReplication.NONE);
private final FPartitioning _partType;
- @SuppressWarnings("unused") //not yet
private final FReplication _repType;
private FType(FPartitioning ptype, FReplication rtype) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
index 0b8111646d..00e7b6fafc 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/MultiReturnParameterizedBuiltinFEDInstruction.java
@@ -21,7 +21,9 @@ package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import java.util.concurrent.Future;
import java.util.stream.Stream;
import java.util.zip.Adler32;
@@ -32,15 +34,17 @@ import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.hops.fedplanner.FTypes;
+import org.apache.sysds.lops.PickByCount;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import
org.apache.sysds.runtime.controlprogram.federated.FederatedResponse.ResponseType;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
-import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.instructions.InstructionUtils;
@@ -51,7 +55,10 @@ import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
+import org.apache.sysds.runtime.transform.encode.Encoder;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.util.IndexRange;
@@ -99,20 +106,110 @@ public class
MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
// the encoder in which the complete encoding information will
be aggregated
MultiColumnEncoder globalEncoder = new MultiColumnEncoder(new
ArrayList<>());
- // first create encoders at the federated workers, then collect
them and aggregate them to a single large
- // encoder
+ FederationMap fedMapping = fin.getFedMapping();
+
+ boolean containsEquiWidthEncoder =
!fin.isFederated(FTypes.FType.ROW) &&
spec.toLowerCase().contains("equi-height");
+ if(containsEquiWidthEncoder) {
+ EncoderColnames ret =
createGlobalEncoderWithEquiHeight(ec, fin, spec);
+ globalEncoder = ret._encoder;
+ colNames = ret._colnames;
+ } else {
+ // first create encoders at the federated workers, then
collect them and aggregate them to a single large
+ // encoder
+ MultiColumnEncoder finalGlobalEncoder = globalEncoder;
+ String[] finalColNames = colNames;
+ fedMapping.forEachParallel((range, data) -> {
+ int columnOffset = (int)
range.getBeginDims()[1];
+
+ // create an encoder with the given spec. The
columnOffset (which is 0 based) has to be used to
+ // tell the federated worker how much the
indexes in the spec have to be offset.
+ Future<FederatedResponse> responseFuture =
data.executeFederatedOperation(new FederatedRequest(
+ RequestType.EXEC_UDF,
+ -1,
+ new CreateFrameEncoder(data.getVarID(),
spec, columnOffset + 1)));
+ // collect responses with encoders
+ try {
+ FederatedResponse response =
responseFuture.get();
+ MultiColumnEncoder encoder =
(MultiColumnEncoder) response.getData()[0];
+ // merge this encoder into a composite
encoder
+ synchronized(finalGlobalEncoder) {
+
finalGlobalEncoder.mergeAt(encoder, columnOffset, (int)
(range.getBeginDims()[0] + 1));
+ }
+ // no synchronization necessary since
names should anyway match
+ String[] subRangeColNames = (String[])
response.getData()[1];
+ System.arraycopy(subRangeColNames, 0,
finalColNames, (int) range.getBeginDims()[1], subRangeColNames.length);
+ }
+ catch(Exception e) {
+ throw new
DMLRuntimeException("Federated encoder creation failed: ", e);
+ }
+ return null;
+ });
+ globalEncoder = finalGlobalEncoder;
+ colNames = finalColNames;
+ }
+
+ // sort for consistent encoding in local and federated
+ if(ColumnEncoderRecode.SORT_RECODE_MAP) {
+ globalEncoder.applyToAll(ColumnEncoderRecode.class,
ColumnEncoderRecode::sortCPRecodeMaps);
+ }
+
+ FrameBlock meta = new FrameBlock((int) fin.getNumColumns(),
Types.ValueType.STRING);
+ meta.setColumnNames(colNames);
+ globalEncoder.getMetaData(meta);
+ globalEncoder.initMetaData(meta);
+
+ encodeFederatedFrames(fedMapping, globalEncoder,
ec.getMatrixObject(getOutput(0)));
+
+ // release input and outputs
+ ec.setFrameOutput(getOutput(1).getName(), meta);
+ }
+
+ private class EncoderColnames {
+ public final MultiColumnEncoder _encoder;
+ public final String[] _colnames;
+
+ public EncoderColnames(MultiColumnEncoder encoder, String[]
colnames) {
+ _encoder = encoder;
+ _colnames = colnames;
+ }
+ }
+
+ public EncoderColnames
createGlobalEncoderWithEquiHeight(ExecutionContext ec, FrameObject fin, String
spec) {
+ // the encoder in which the complete encoding information will
be aggregated
+ MultiColumnEncoder globalEncoder = new MultiColumnEncoder(new
ArrayList<>());
+ String[] colNames = new String[(int) fin.getNumColumns()];
+
+ Map<Integer, double[]> quantilesPerColumn = new HashMap<>();
FederationMap fedMapping = fin.getFedMapping();
fedMapping.forEachParallel((range, data) -> {
int columnOffset = (int) range.getBeginDims()[1];
// create an encoder with the given spec. The
columnOffset (which is 0 based) has to be used to
// tell the federated worker how much the indexes in
the spec have to be offset.
- Future<FederatedResponse> responseFuture =
data.executeFederatedOperation(new FederatedRequest(
- RequestType.EXEC_UDF, -1, new
CreateFrameEncoder(data.getVarID(), spec, columnOffset + 1)));
+ Future<FederatedResponse> responseFuture =
data.executeFederatedOperation(
+ new FederatedRequest(RequestType.EXEC_UDF, -1,
+ new CreateFrameEncoder(data.getVarID(),
spec, columnOffset + 1)));
// collect responses with encoders
try {
FederatedResponse response =
responseFuture.get();
MultiColumnEncoder encoder =
(MultiColumnEncoder) response.getData()[0];
+
+ // put columns to equi-height
+ for(Encoder enc : encoder.getColumnEncoders()) {
+ if(enc instanceof
ColumnEncoderComposite) {
+ for(Encoder compositeEncoder :
((ColumnEncoderComposite) enc).getEncoders()) {
+ if(compositeEncoder
instanceof ColumnEncoderBin && ((ColumnEncoderBin)
compositeEncoder).getBinMethod() == ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ double
quantilrRange = (double) fin.getNumRows() / ((ColumnEncoderBin)
compositeEncoder).getNumBin();
+ double[]
quantiles = new double[((ColumnEncoderBin) compositeEncoder).getNumBin()];
+ for(int i = 0;
i < quantiles.length; i++) {
+
quantiles[i] = quantilrRange * (i + 1);
+ }
+
quantilesPerColumn.put(((ColumnEncoderBin) compositeEncoder).getColID() +
columnOffset - 1, quantiles);
+ }
+ }
+ }
+ }
+
// merge this encoder into a composite encoder
synchronized(globalEncoder) {
globalEncoder.mergeAt(encoder,
columnOffset, (int) (range.getBeginDims()[0] + 1));
@@ -127,20 +224,27 @@ public class
MultiReturnParameterizedBuiltinFEDInstruction extends ComputationFE
return null;
});
- // sort for consistent encoding in local and federated
- if(ColumnEncoderRecode.SORT_RECODE_MAP) {
- globalEncoder.applyToAll(ColumnEncoderRecode.class,
ColumnEncoderRecode::sortCPRecodeMaps);
+ // calculate all quantiles
+ Map<Integer, double[]> equiHeightBinsPerColumn = new
HashMap<>();
+ for(Map.Entry<Integer, double[]> colQuantiles :
quantilesPerColumn.entrySet()) {
+ QuantilePickFEDInstruction quantileInstr = new
QuantilePickFEDInstruction(
+ null, input1, output,
PickByCount.OperationTypes.VALUEPICK,true, "qpick", "");
+ MatrixBlock quantiles =
quantileInstr.getEquiHeightBins(ec, colQuantiles.getKey(),
colQuantiles.getValue());
+ equiHeightBinsPerColumn.put(colQuantiles.getKey(),
quantiles.getDenseBlockValues());
}
- FrameBlock meta = new FrameBlock((int) fin.getNumColumns(),
Types.ValueType.STRING);
- meta.setColumnNames(colNames);
- globalEncoder.getMetaData(meta);
- globalEncoder.initMetaData(meta);
-
- encodeFederatedFrames(fedMapping, globalEncoder,
ec.getMatrixObject(getOutput(0)));
-
- // release input and outputs
- ec.setFrameOutput(getOutput(1).getName(), meta);
+ // modify global encoder
+ for(Encoder enc : globalEncoder.getColumnEncoders()) {
+ if(enc instanceof ColumnEncoderComposite) {
+ for(Encoder compositeEncoder :
((ColumnEncoderComposite) enc).getEncoders())
+ if(compositeEncoder instanceof
ColumnEncoderBin && ((ColumnEncoderBin) compositeEncoder)
+ .getBinMethod() ==
ColumnEncoderBin.BinMethod.EQUI_HEIGHT)
+ ((ColumnEncoderBin)
compositeEncoder).buildEquiHeight(equiHeightBinsPerColumn
+
.get(((ColumnEncoderBin) compositeEncoder).getColID() - 1));
+ ((ColumnEncoderComposite)
enc).updateAllDCEncoders();
+ }
+ }
+ return new EncoderColnames(globalEncoder, colNames);
}
public static void encodeFederatedFrames(FederationMap fedMapping,
MultiColumnEncoder globalencoder,
@@ -199,7 +303,7 @@ public class MultiReturnParameterizedBuiltinFEDInstruction
extends ComputationFE
.createEncoder(_spec, colNames,
fb.getNumColumns(), null, _offset, _offset + fb.getNumColumns());
// build necessary structures for encoding
- encoder.build(fb);
+ encoder.build(fb); // FIXME skip equi-height sorting
fo.release();
// create federated response
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
index f082173547..cce571bbf7 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/fed/QuantilePickFEDInstruction.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.instructions.fed;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
@@ -28,14 +29,19 @@ import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
+import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.ImmutableTriple;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.hops.fedplanner.FTypes.FType;
import org.apache.sysds.lops.PickByCount.OperationTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.caching.FrameObject;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
+import org.apache.sysds.runtime.controlprogram.federated.FederatedRange;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
@@ -47,6 +53,7 @@ import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.DoubleObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.matrix.operators.Operator;
@@ -55,18 +62,18 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
private final OperationTypes _type;
- private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
out, OperationTypes type, boolean inmem,
+ public QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
out, OperationTypes type, boolean inmem,
String opcode, String istr) {
this(op, in, null, out, type, inmem, opcode, istr);
}
- private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
+ public QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
boolean inmem, String opcode, String istr,
FederatedOutput fedOut) {
super(FEDType.QPick, op, in, in2, out, opcode, istr, fedOut);
_type = type;
}
- private QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
+ public QuantilePickFEDInstruction(Operator op, CPOperand in, CPOperand
in2, CPOperand out, OperationTypes type,
boolean inmem, String opcode, String istr) {
this(op, in, in2, out, type, inmem, opcode, istr,
FederatedOutput.NONE);
}
@@ -112,6 +119,101 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
processRowQPick(ec);
}
+ public <T> MatrixBlock getEquiHeightBins(ExecutionContext ec, int
colID, double[] quantiles) {
+ FrameObject inFrame = ec.getFrameObject(input1);
+ FederationMap frameFedMap = inFrame.getFedMapping();
+
+ // Create vector
+ MatrixObject in = ExecutionContext.createMatrixObject(new
MatrixBlock((int) inFrame.getNumRows(), 1, false));
+ long varID = FederationUtils.getNextFedDataID();
+ ec.setVariable(String.valueOf(varID), in);
+
+ // modify map here
+ List<FederatedRange> ranges = new ArrayList<>();
+ FederationMap oldFedMap = frameFedMap.mapParallel(varID,
(range, data) -> {
+ try {
+ int colIDWorker = colID;
+ if(colID >= range.getBeginDims()[1] && colID <
range.getEndDims()[1]) {
+ if(range.getBeginDims()[1] > 1)
+ colIDWorker = colID - (int)
range.getBeginDims()[1];
+ FederatedResponse response =
data.executeFederatedOperation(
+ new
FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, -1,
+ new
QuantilePickFEDInstruction.CreateMatrixFromFrame(data.getVarID(), varID,
colIDWorker))).get();
+
+ synchronized(ranges) {
+ ranges.add(range);
+ }
+ if(!response.isSuccessful())
+
response.throwExceptionFromResponse();
+ }
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ return null;
+ });
+
+ //create one column federated object
+ List<Pair<FederatedRange, FederatedData>> newFedMapPairs = new
ArrayList<>();
+ for(Pair<FederatedRange, FederatedData> mapPair :
oldFedMap.getMap()) {
+ for(FederatedRange r : ranges) {
+ if(mapPair.getLeft().equals(r)) {
+ newFedMapPairs.add(mapPair);
+ }
+ }
+ }
+
+ FederationMap newFedMap = new FederationMap(varID,
newFedMapPairs, FType.COL);
+
+ // construct a federated matrix with the encoded data
+ in.getDataCharacteristics().setDimension(in.getNumRows(),1);
+ in.setFedMapping(newFedMap);
+
+
+ // Find min and max
+ List<double[]> minMax = new ArrayList<>();
+ newFedMap.mapParallel(varID, (range, data) -> {
+ try {
+ FederatedResponse response =
data.executeFederatedOperation(new FederatedRequest(
+ FederatedRequest.RequestType.EXEC_UDF,
-1,
+ new
QuantilePickFEDInstruction.MinMax(data.getVarID()))).get();
+ if(!response.isSuccessful())
+ response.throwExceptionFromResponse();
+ double[] rangeMinMax = (double[])
response.getData()[0];
+ minMax.add(rangeMinMax);
+
+ return null;
+ }
+ catch(Exception e) {
+ throw new DMLRuntimeException(e);
+ }
+ });
+
+ // Find weights sum, min and max
+ double globalMin = Double.MAX_VALUE, globalMax =
Double.MIN_VALUE, vectorLength = inFrame.getNumColumns() == 2 ? 0 :
inFrame.getNumRows();
+ for(double[] values : minMax) {
+ globalMin = Math.min(globalMin, values[0]);
+ globalMax = Math.max(globalMax, values[1]);
+ }
+
+ // If multiple quantiles take first histogram and reuse bins,
otherwise recursively get bin with result
+ int numBuckets = 256; // (int) Math.round(in.getNumRows() /
2.0);
+
+ T ret = createHistogram(in, (int) vectorLength, globalMin,
globalMax, numBuckets, -1, false);
+
+ // Compute and set results
+ MatrixBlock quantileValues = computeMultipleQuantiles(ec, in,
(int[]) ret, quantiles, (int) vectorLength, varID, (globalMax-globalMin) /
numBuckets, globalMin, _type, true);
+
+ ec.removeVariable(String.valueOf(varID));
+
+ // Add min to the result
+ MatrixBlock res = new MatrixBlock(quantileValues.getNumRows() +
1, 1, false);
+ res.setValue(0,0, globalMin);
+ res.copy(1, quantileValues.getNumRows(), 0, 0,
quantileValues,false);
+
+ return res;
+ }
+
public <T> void processRowQPick(ExecutionContext ec) {
MatrixObject in = ec.getMatrixObject(input1);
FederationMap fedMap = in.getFedMapping();
@@ -165,13 +267,16 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
// Compute and set results
if(quantiles != null && quantiles.length > 1) {
- computeMultipleQuantiles(ec, in, (int[]) ret,
quantiles, (int) vectorLength, varID, (globalMax-globalMin) / numBuckets,
globalMin, _type);
- } else
+ double finalVectorLength = vectorLength;
+ quantiles = Arrays.stream(quantiles).map(val -> (int)
Math.round(finalVectorLength * val)).toArray();
+ computeMultipleQuantiles(ec, in, (int[]) ret,
quantiles, (int) vectorLength, varID, (globalMax-globalMin) / numBuckets,
globalMin, _type, false);
+ }
+ else
getSingleQuantileResult(ret, ec, fedMap, varID,
average, false, (int) vectorLength, null);
}
- private <T> void computeMultipleQuantiles(ExecutionContext ec,
MatrixObject in, int[] bucketsFrequencies, double[] quantiles,
- int vectorLength, long varID, double bucketRange, double min,
OperationTypes type) {
+ private <T> MatrixBlock computeMultipleQuantiles(ExecutionContext ec,
MatrixObject in, int[] bucketsFrequencies, double[] quantiles,
+ int vectorLength, long varID, double bucketRange, double min,
OperationTypes type, boolean returnOutput) {
MatrixBlock out = new MatrixBlock(quantiles.length, 1, false);
ImmutableTriple<Integer, Integer, ImmutablePair<Double,
Double>>[] bucketsWithIndex = new ImmutableTriple[quantiles.length];
@@ -181,12 +286,12 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
sizeBeforeTmp += bucketsFrequencies[j];
for(int i = 0; i < quantiles.length; i++) {
- int quantileIndex = (int)
Math.round(vectorLength * quantiles[i]);
+
ImmutablePair<Double, Double> bucketWithQ;
- if(quantileIndex > sizeBefore && quantileIndex
<= sizeBeforeTmp) {
+ if(quantiles[i] > sizeBefore && quantiles[i] <=
sizeBeforeTmp) {
bucketWithQ = new ImmutablePair<>(min +
(j * bucketRange), min + ((j+1) * bucketRange));
- bucketsWithIndex[i] = new
ImmutableTriple<>(quantileIndex == 1 ? 1 : quantileIndex - sizeBefore,
bucketsFrequencies[j], bucketWithQ);
+ bucketsWithIndex[i] = new
ImmutableTriple<Integer, Integer, ImmutablePair<Double, Double>>(quantiles[i]
== 1 ? 1 : (int) quantiles[i] - sizeBefore, bucketsFrequencies[j], bucketWithQ);
countFoundBins++;
}
}
@@ -248,9 +353,12 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
}
});
}
-
- ec.setMatrixOutput(output.getName(), out);
+ if(returnOutput)
+ return out;
+ else
+ ec.setMatrixOutput(output.getName(), out);
}
+ return null;
}
private <T> void getSingleQuantileResult(T ret, ExecutionContext ec,
FederationMap fedMap, long varID, boolean average, boolean isIQM, int
vectorLength, ImmutablePair<Double, Double> iqmRange) {
@@ -298,7 +406,7 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
ec.setScalarOutput(output.getName(), new DoubleObject(result));
}
- public <T> T createHistogram(MatrixObject in, int vectorLength, double
globalMin, double globalMax, int numBuckets, int quantileIndex, boolean
average) {
+ public <T> T createHistogram(CacheableData<?> in, int vectorLength,
double globalMin, double globalMax, int numBuckets, int quantileIndex, boolean
average) {
FederationMap fedMap = in.getFedMapping();
List<int[]> hists = new ArrayList<>();
List<Set<Double>> distincts = new ArrayList<>();
@@ -342,7 +450,7 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
Set<Double> distinctValues =
distincts.stream().flatMap(Set::stream).collect(Collectors.toSet());
if(distinctValues.size() > quantileIndex-1 && !average)
- return (T)
distinctValues.stream().sorted().toArray()[quantileIndex-1];
+ return (T)
distinctValues.stream().sorted().toArray()[quantileIndex > 0 ? quantileIndex-1
: 0];
if(average && distinctValues.size() > quantileIndex) {
Double[] distinctsSorted =
distinctValues.stream().flatMap(Stream::of).sorted().toArray(Double[]::new);
@@ -350,7 +458,7 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
return (T) medianSum;
}
- if(average && distinctValues.size() == 2)
+ if((average && distinctValues.size() == 2) || (!average &&
distinctValues.size() == 1))
return (T) distinctValues.stream().reduce(0.0,
Double::sum);
ImmutablePair<Double, Double> finalBucketWithQ =
bucketWithIndex.right;
@@ -358,6 +466,12 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
if((distinctInNewBucket.size() == 1 && !average) || (average &&
distinctInNewBucket.size() == 2))
return (T) distinctInNewBucket.stream().reduce(0.0,
Double::sum);
+ if(!average) {
+ Set<Double> distinctsSet = new
HashSet<>(distinctInNewBucket);
+ if(distinctsSet.size() == 1)
+ return (T) distinctsSet.toArray()[0];
+ }
+
if(distinctValues.size() == 1 || (bucketWithIndex.middle == 1
&& !average) || (bucketWithIndex.middle == 2 && isEvenNumRows && average) ||
globalMin == globalMax)
return (T) bucketWithIndex.right;
@@ -402,6 +516,41 @@ public class QuantilePickFEDInstruction extends
BinaryFEDInstruction {
return new ImmutableTriple<>(quantileIndex, bucketWithQSize,
bucketWithQ);
}
+ public static class CreateMatrixFromFrame extends FederatedUDF {
+ private static final long serialVersionUID =
-6569370318237863595L;
+ private final long _outputID;
+ private final int _id;
+
+ public CreateMatrixFromFrame(long input, long output, int id) {
+ super(new long[] {input});
+ _outputID = output;
+ _id = id;
+ }
+
+ @Override
+ public FederatedResponse execute(ExecutionContext ec, Data...
data) {
+ FrameBlock fb = ((FrameObject)
data[0]).acquireReadAndRelease();
+
+ double[] colData =
ArrayUtils.toPrimitive(Arrays.stream((Object[]) fb.getColumnData(_id)).map(e ->
Double.valueOf(String.valueOf(e))).toArray(Double[] :: new));
+
+ MatrixBlock mbout = new MatrixBlock(fb.getNumRows(), 1,
colData);
+
+ // create output matrix object
+ MatrixObject mo =
ExecutionContext.createMatrixObject(mbout);
+
+ // add it to the list of variables
+ ec.setVariable(String.valueOf(_outputID), mo);
+
+ // return id handle
+ return new
FederatedResponse(FederatedResponse.ResponseType.SUCCESS_EMPTY);
+ }
+
+ @Override
+ public Pair<String, LineageItem>
getLineageItem(ExecutionContext ec) {
+ return null;
+ }
+ }
+
public static class GetHistogram extends FederatedUDF {
private static final long serialVersionUID =
5413355823424777742L;
private final double _max;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index b243c857c2..8db08b870d 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -31,6 +31,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.HashSet;
+import java.util.Map;
import java.util.Set;
import java.util.concurrent.Callable;
@@ -215,6 +216,15 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
// do nothing
}
+ public void build(CacheBlock in, double[] equiHeightMaxs) {
+ // do nothing
+ }
+
+ public void build(CacheBlock in, Map<Integer, double[]> equiHeightMaxs)
{
+ // do nothing
+ }
+
+
/**
* Merges another encoder, of a compatible type, in after a certain
position. Resizes as necessary.
* <code>ColumnEncoders</code> are compatible with themselves and
<code>EncoderComposite</code> is compatible with
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index cb6e0afada..2f5f6d4297 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -19,8 +19,6 @@
package org.apache.sysds.runtime.transform.encode;
-import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
-
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -29,6 +27,7 @@ import java.util.HashMap;
import java.util.PriorityQueue;
import java.util.concurrent.Callable;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
import org.apache.commons.lang3.tuple.MutableTriple;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.lops.Lop;
@@ -42,6 +41,11 @@ public class ColumnEncoderBin extends ColumnEncoder {
public static final String MAX_PREFIX = "max";
public static final String NBINS_PREFIX = "nbins";
private static final long serialVersionUID = 1917445005206076078L;
+
+ public int getNumBin() {
+ return _numBin;
+ }
+
protected int _numBin = -1;
private BinMethod _binMethod = BinMethod.EQUI_WIDTH;
@@ -115,6 +119,35 @@ public class ColumnEncoderBin extends ColumnEncoder {
TransformStatistics.incBinningBuildTime(System.nanoTime()-t0);
}
+ //TODO move federated things outside the location-agnostic encoder,
+ // and/or generalize to fit the existing mergeAt and similar methods
+ public void buildEquiHeight(double[] equiHeightMaxs) {
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+ if(!isApplicable())
+ return;
+ if(_binMethod == BinMethod.EQUI_HEIGHT)
+ computeFedEqualHeightBins(equiHeightMaxs);
+
+ if(DMLScript.STATISTICS)
+
TransformStatistics.incBinningBuildTime(System.nanoTime()-t0);
+ }
+
+ public void build(CacheBlock in, double[] equiHeightMaxs) {
+ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
+ if(!isApplicable())
+ return;
+ if(_binMethod == BinMethod.EQUI_WIDTH) {
+ double[] pairMinMax = getMinMaxOfCol(in, _colID, 0, -1);
+ computeBins(pairMinMax[0], pairMinMax[1]);
+ }
+ else if(_binMethod == BinMethod.EQUI_HEIGHT) {
+ computeFedEqualHeightBins(equiHeightMaxs);
+ }
+
+ if(DMLScript.STATISTICS)
+
TransformStatistics.incBinningBuildTime(System.nanoTime()-t0);
+ }
+
protected double getCode(CacheBlock in, int row){
// find the right bucket for a single row
double bin = 0;
@@ -248,6 +281,16 @@ public class ColumnEncoderBin extends ColumnEncoder {
System.arraycopy(_binMaxs, 0, _binMins, 1, _numBin - 1);
}
+ private void computeFedEqualHeightBins(double[] binMaxs) {
+ if(_binMins == null || _binMaxs == null) {
+ _binMins = new double[_numBin];
+ _binMaxs = new double[_numBin];
+ }
+ System.arraycopy(binMaxs, 1, _binMaxs, 0, _numBin);
+ _binMins[0] = binMaxs[0];
+ System.arraycopy(_binMaxs, 0, _binMins, 1, _numBin - 1);
+ }
+
public void prepareBuildPartial() {
// ensure allocated min/max arrays
_colMins = -1f;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 7194939853..243bfe7caa 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -105,6 +105,16 @@ public class ColumnEncoderComposite extends ColumnEncoder {
columnEncoder.build(in);
}
+ @Override
+ public void build(CacheBlock in, Map<Integer, double[]> equiHeightMaxs)
{
+ for(ColumnEncoder columnEncoder : _columnEncoders)
+ if(columnEncoder instanceof ColumnEncoderBin &&
((ColumnEncoderBin) columnEncoder).getBinMethod() ==
ColumnEncoderBin.BinMethod.EQUI_HEIGHT) {
+ columnEncoder.build(in,
equiHeightMaxs.get(columnEncoder.getColID()));
+ } else {
+ columnEncoder.build(in);
+ }
+ }
+
@Override
public List<DependencyTask<?>> getApplyTasks(CacheBlock in, MatrixBlock
out, int outputCol) {
List<DependencyTask<?>> tasks = new ArrayList<>();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 92cbc3fd8a..d84d00e531 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -262,6 +262,24 @@ public class MultiColumnEncoder implements Encoder {
legacyBuild((FrameBlock) in);
}
+ public void build(CacheBlock in, int k, Map<Integer, double[]>
equiHeightBinMaxs) {
+ if(hasLegacyEncoder() && !(in instanceof FrameBlock))
+ throw new DMLRuntimeException("LegacyEncoders do not
support non FrameBlock Inputs");
+ if(!_partitionDone) //happens if this method is directly called
+ deriveNumRowPartitions(in, k);
+ if(k > 1) {
+ buildMT(in, k);
+ }
+ else {
+ for(ColumnEncoderComposite columnEncoder :
_columnEncoders) {
+ columnEncoder.build(in, equiHeightBinMaxs);
+ columnEncoder.updateAllDCEncoders();
+ }
+ }
+ if(hasLegacyEncoder())
+ legacyBuild((FrameBlock) in);
+ }
+
private List<DependencyTask<?>> getBuildTasks(CacheBlock in) {
List<DependencyTask<?>> tasks = new ArrayList<>();
for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
@@ -1197,5 +1215,4 @@ public class MultiColumnEncoder implements Encoder {
return getClass().getSimpleName() + "<ColId: " +
_colEncoder._colID + ">";
}
}
-
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
index f1fa7e647f..77ea36e20b 100644
---
a/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/federated/transform/TransformFederatedEncodeApplyTest.java
@@ -58,9 +58,13 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
private final static String SPEC2b = "homes3/homes.tfspec_dummy2.json";
private final static String SPEC3 = "homes3/homes.tfspec_bin.json"; //
recode
private final static String SPEC3b = "homes3/homes.tfspec_bin2.json";
// recode
+ private final static String SPEC3c =
"homes3/homes.tfspec_bin_height.json"; //recode
+ private final static String SPEC3d =
"homes3/homes.tfspec_bin_height2.json"; //recode
private final static String SPEC6 =
"homes3/homes.tfspec_recode_dummy.json";
private final static String SPEC6b =
"homes3/homes.tfspec_recode_dummy2.json";
private final static String SPEC7 =
"homes3/homes.tfspec_binDummy.json"; // recode+dummy
+ private final static String SPEC7c =
"homes3/homes.tfspec_binHeightDummy.json"; //recode+dummy
+ private final static String SPEC7d =
"homes3/homes.tfspec_binHeightDummy2.json"; //recode+dummy
private final static String SPEC7b =
"homes3/homes.tfspec_binDummy2.json"; // recode+dummy
private final static String SPEC8 = "homes3/homes.tfspec_hash.json";
private final static String SPEC8b = "homes3/homes.tfspec_hash2.json";
@@ -77,8 +81,11 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
private static final int[] BIN_col3 = new int[] {1, 4, 2, 3, 3, 2, 4};
private static final int[] BIN_col8 = new int[] {1, 2, 2, 2, 2, 2, 3};
+ private static final int[] BIN_HEIGHT_col3 = new int[]{1,3,1,3,3,2,3};
+ private static final int[] BIN_HEIGHT_col8 = new int[]{1,2,2,3,2,2,3};
+
public enum TransformType {
- RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, IMPUTE, OMIT,
HASH, HASH_RECODE,
+ RECODE, DUMMY, RECODE_DUMMY, BIN, BIN_DUMMY, IMPUTE, OMIT,
HASH, HASH_RECODE, BIN_HEIGHT_DUMMY, BIN_HEIGHT,
}
@Override
@@ -187,6 +194,21 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
runTransformTest(TransformType.RECODE_DUMMY, false, true);
}
+ @Test
+ public void testHomesEqualHeightBinningIDsSingleNodeCSV() {
+ runTransformTest(TransformType.BIN_HEIGHT, true, false);
+ }
+
+ @Test
+ public void testHomesHeightBinningDummyIDsSingleNodeCSV() {
+ runTransformTest(TransformType.BIN_HEIGHT_DUMMY, false, false);
+ }
+
+ @Test
+ public void testHomesHeightBinningDummyColnamesSingleNodeCSV() {
+ runTransformTest(TransformType.BIN_HEIGHT_DUMMY, true, false);
+ }
+
private void runTransformTest(TransformType type, boolean colnames,
boolean lineage) {
ExecMode rtold = setExecMode(ExecMode.SINGLE_NODE);
@@ -197,10 +219,12 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
case RECODE: SPEC = colnames ? SPEC1b : SPEC1; DATASET
= DATASET1; break;
case DUMMY: SPEC = colnames ? SPEC2b : SPEC2; DATASET =
DATASET1; break;
case BIN: SPEC = colnames ? SPEC3b : SPEC3; DATASET =
DATASET1; break;
+ case BIN_HEIGHT: SPEC = colnames?SPEC3d:SPEC3c;
DATASET = DATASET1; break;
case IMPUTE: SPEC = colnames ? SPEC4b : SPEC4; DATASET
= DATASET2; break;
case OMIT: SPEC = colnames ? SPEC5b : SPEC5; DATASET =
DATASET2; break;
case RECODE_DUMMY: SPEC = colnames ? SPEC6b : SPEC6;
DATASET = DATASET1; break;
case BIN_DUMMY: SPEC = colnames ? SPEC7b : SPEC7;
DATASET = DATASET1; break;
+ case BIN_HEIGHT_DUMMY: SPEC =
colnames?SPEC7d:SPEC7c; DATASET = DATASET1; break;
case HASH: SPEC = colnames ? SPEC8b : SPEC8; DATASET =
DATASET1; break;
case HASH_RECODE: SPEC = colnames ? SPEC9b : SPEC9;
DATASET = DATASET1; break;
}
@@ -256,7 +280,7 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
fullDMLScriptName = HOME + TEST_NAME1 + ".dml";
String[] lineageArgs = new String[] {"-lineage",
"reuse_full", "-stats"};
- programArgs = new String[] {"-nvargs", "in_AH=" +
TestUtils.federatedAddress(port1, input("AH")),
+ programArgs = new String[] {"-explain", "-nvargs",
"in_AH=" + TestUtils.federatedAddress(port1, input("AH")),
"in_AL=" + TestUtils.federatedAddress(port2,
input("AL")),
"in_BH=" + TestUtils.federatedAddress(port3,
input("BH")),
"in_BL=" + TestUtils.federatedAddress(port4,
input("BL")), "rows=" + dataset.getNumRows(),
@@ -283,8 +307,12 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
Assert.assertEquals(BIN_col3[i],
R1[i][2], 1e-8);
Assert.assertEquals(BIN_col8[i],
R1[i][7], 1e-8);
}
- }
- else if(type == TransformType.BIN_DUMMY) {
+ } else if (type == TransformType.BIN_HEIGHT) {
+ for(int i=0; i<7; i++) {
+ Assert.assertEquals(BIN_HEIGHT_col3[i],
R1[i][2], 1e-8);
+ Assert.assertEquals(BIN_HEIGHT_col8[i],
R1[i][7], 1e-8);
+ }
+ } else if(type == TransformType.BIN_DUMMY) {
Assert.assertEquals(14, R1[0].length);
for(int i = 0; i < 7; i++) {
for(int j = 0; j < 4; j++) { // check
dummy coded
@@ -294,7 +322,20 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
Assert.assertEquals((j ==
BIN_col8[i] - 1) ? 1 : 0, R1[i][10 + j], 1e-8);
}
}
+ } else if (type == TransformType.BIN_HEIGHT_DUMMY) {
+ Assert.assertEquals(14, R1[0].length);
+ for(int i=0; i<7; i++) {
+ for(int j=0; j<4; j++) { //check dummy
coded
+
Assert.assertEquals((j==BIN_HEIGHT_col3[i]-1)?
+ 1:0, R1[i][2+j], 1e-8);
+ }
+ for(int j=0; j<3; j++) { //check dummy
coded
+
Assert.assertEquals((j==BIN_HEIGHT_col8[i]-1)?
+ 1:0, R1[i][10+j], 1e-8);
+ }
+ }
}
+
// assert reuse count
if (lineage)
Assert.assertTrue(LineageCacheStatistics.getInstHits() > 0);
@@ -318,3 +359,24 @@ public class TransformFederatedEncodeApplyTest extends
AutomatedTestBase {
FileFormat.CSV, ffpCSV);
}
}
+
+
+// 1,000 1,000 1,000 7,000 1,000 3,000 2,000 1,000 698,000
+// 2,000 2,000 4,000 6,000 2,000 2,000 2,000 2,000 906,000
+// 3,000 3,000 2,000 3,000 3,000 3,000 1,000 2,000 892,000
+// 1,000 4,000 3,000 6,000 2,500 2,000 1,000 2,000 932,000
+// 4,000 2,000 3,000 6,000 2,500 2,000 2,000 2,000 876,000
+// 4,000 3,000 2,000 5,000 2,500 2,000 2,000 2,000 803,000
+// 5,000 3,000 4,000 7,000 2,500 2,000 2,000 3,000 963,000
+// 4,000 1,000 1,000 7,000 1,500 2,000 1,000 2,000 760,000
+// 1,000 1,000 2,000 4,000 3,000 3,000 2,000 2,000 899,000
+// 2,000 1,000 1,000 4,000 1,000 1,000 2,000 1,000 549,000
+
+
+//Expected
+// 1,000 1,000 1,000 0,000 0,000 0,000 7,000 1,000 3,000 1,000 1,000 0,000
0,000 698,000
+// 2,000 2,000 0,000 0,000 1,000 0,000 6,000 2,000 2,000 1,000 0,000 1,000
0,000 906,000
+// 3,000 3,000 1,000 0,000 0,000 0,000 3,000 3,000 3,000 2,000 0,000 1,000
0,000 892,000
+// 1,000 4,000 0,000 0,000 1,000 0,000 6,000 2,500 2,000 2,000 0,000 0,000
1,000 932,000
+// 4,000 2,000 0,000 0,000 1,000 0,000 6,000 2,500 2,000 1,000 0,000 1,000
0,000 876,000
+// 4,000 3,000 0,000 1,000 0,000 0,000 5,000 2,500 2,000 1,000 0,000 1,000
0,000 803,000
\ No newline at end of file
diff --git a/src/test/scripts/functions/transform/TransformFrameEncodeApply.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeApply.dml
index f7be1aaaaa..01e5ed745c 100644
--- a/src/test/scripts/functions/transform/TransformFrameEncodeApply.dml
+++ b/src/test/scripts/functions/transform/TransformFrameEncodeApply.dml
@@ -28,6 +28,7 @@ jspec = read($TFSPEC, data_type="scalar",
value_type="string");
while(FALSE){}
X2 = transformapply(target=F1, spec=jspec, meta=M);
+print(toString(X))
write(X, $TFDATA1, format=$OFMT);
write(X2, $TFDATA2, format=$OFMT);