This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 72379f9ccc8ce161fc81bf2040ebe3ca8abed71c Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Thu Sep 7 13:42:43 2023 +0200 [SYSTEMDS-3490] Compressed Transform Encode Optimizations Optimized the transformation of the default case, and added a new approximate equi width scheme. Closes #1901 --- .../runtime/transform/encode/ColumnEncoder.java | 72 ++++--- .../runtime/transform/encode/ColumnEncoderBin.java | 210 +++++++++++++++------ .../transform/encode/ColumnEncoderComposite.java | 4 +- .../transform/encode/ColumnEncoderDummycode.java | 73 ++++--- .../transform/encode/ColumnEncoderFeatureHash.java | 9 +- .../transform/encode/ColumnEncoderPassThrough.java | 8 +- .../transform/encode/ColumnEncoderRecode.java | 7 +- .../runtime/transform/encode/ColumnEncoderUDF.java | 2 +- .../encode/ColumnEncoderWordEmbedding.java | 2 +- .../runtime/transform/encode/CompressedEncode.java | 9 +- .../runtime/transform/encode/EncoderFactory.java | 29 +-- .../runtime/transform/encode/LegacyEncoder.java | 1 - .../transform/encode/MultiColumnEncoder.java | 91 +++++---- .../java/org/apache/sysds/performance/Main.java | 8 +- .../performance/compression/TransformPerf.java | 130 +++++++++++++ .../performance/generators/FrameTransformFile.java | 83 ++++++++ ...ransformCompressedTestSingleColBinSpecific.java | 51 ++++- 17 files changed, 588 insertions(+), 201 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java index 356fa73eb1..ba048160ec 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java @@ -29,8 +29,8 @@ import java.io.ObjectOutput; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; -import java.util.List; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.Callable; @@ -40,10 +40,11 @@ import org.apache.commons.logging.LogFactory; import org.apache.sysds.api.DMLScript; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; -import org.apache.sysds.runtime.data.SparseRowVector; -import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; import org.apache.sysds.runtime.data.SparseBlockCSR; +import org.apache.sysds.runtime.data.SparseRowVector; +import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.util.DependencyTask; import org.apache.sysds.runtime.util.DependencyThreadPool; @@ -129,25 +130,24 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder protected abstract double getCode(CacheBlock<?>in, int row); - protected abstract double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize); - - - /*protected void applySparse(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - int index = _colID - 1; - for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) { - SparseRowVector row = (SparseRowVector) out.getSparseBlock().get(r); - row.values()[index] = getCode(in, r); - row.indexes()[index] = outputCol; - } - }*/ + /** + * Get the coded values for a given range from start to end. + * + * @param in The CacheBlock to extract the values from + * @param startInd The start Index + * @param rowEnd The end index + * @param tmp double tmp array to put the result into if valid. + * @return The encoded double values. + */ + protected abstract double[] getCodeCol(CacheBlock<?> in, int startInd, int rowEnd, double[] tmp); protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk){ boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; mcsr = false; //force CSR for transformencode int index = _colID - 1; // Apply loop tiling to exploit CPU caches - double[] codes = getCodeCol(in, rowStart, blk); int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + double[] codes = getCodeCol(in, rowStart, rowEnd, null); int B = 32; //tile size for(int i = rowStart; i < rowEnd; i+=B) { int lim = Math.min(i+B, rowEnd); @@ -168,25 +168,39 @@ public abstract class ColumnEncoder implements Encoder, Comparable<ColumnEncoder } } - /*protected void applyDense(CacheBlock in, MatrixBlock out, int outputCol, int rowStart, int blk){ - for(int i = rowStart; i < getEndIndex(in.getNumRows(), rowStart, blk); i++) { - out.quickSetValue(i, outputCol, getCode(in, i)); - } - }*/ - protected void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk){ // Apply loop tiling to exploit CPU caches - double[] codes = getCodeCol(in, rowStart, blk); - int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); - int B = 32; //tile size - for(int i = rowStart; i < rowEnd; i+=B) { - int lim = Math.min(i+B, rowEnd); - for (int ii=i; ii<lim; ii++) - out.quickSetValue(ii, outputCol, codes[ii-rowStart]); - //out.denseSuperQuickSetValue(ii, outputCol, codes[ii-rowStart]); + final int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + final int smallTile = 64; + final double[] tmp = new double[smallTile]; + final DenseBlock outB = out.getDenseBlock(); + if( outB.isContiguous(rowStart, blk)) + for(int i = rowStart; i < rowEnd; i += smallTile) + applyDenseTileContiguous(in, outB, outputCol, i, Math.min(i + smallTile, rowEnd),tmp); + else + for(int i = rowStart; i < rowEnd; i += smallTile) + applyDenseTileGeneric(in, outB, outputCol, i, Math.min(i + smallTile, rowEnd),tmp); + + } + + private void applyDenseTileContiguous(CacheBlock<?> in, DenseBlock out, int outputCol, int s, int e, double[] tmp) { + // these are codes for this block offset by rowStart + final double[] codes = getCodeCol(in, s, e, tmp); + final double[] vals = out.values(s); + int off = out.pos(s) + outputCol; + final int nCol = out.getDim(1); + for(int i = 0; i < e - s; i++, off += nCol){ + vals[off] = codes[i]; } } + private void applyDenseTileGeneric(CacheBlock<?> in, DenseBlock out, int outputCol, int s, int e, double[] tmp) { + // these are codes for this block offset by rowStart + final double[] codes = getCodeCol(in, s, e, tmp); + for(int i = s; i < e; i++) // rows + out.set(i, outputCol, codes[i - s]); + } + protected abstract TransformType getTransformType(); /** * Indicates if this encoder is applicable, i.e, if there is a column to encode. diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java index b2c530a3e4..2df54bef20 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java @@ -27,6 +27,7 @@ import java.io.ObjectOutput; import java.util.Arrays; import java.util.HashMap; import java.util.PriorityQueue; +import java.util.Random; import java.util.concurrent.Callable; import org.apache.commons.lang3.tuple.MutableTriple; @@ -34,6 +35,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.lops.Lop; import org.apache.sysds.runtime.controlprogram.caching.CacheBlock; import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.columns.Array; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.utils.stats.TransformStatistics; @@ -43,6 +45,9 @@ public class ColumnEncoderBin extends ColumnEncoder { public static final String NBINS_PREFIX = "nbins"; private static final long serialVersionUID = 1917445005206076078L; + public static final double SAMPLE_FRACTION = 0.1; + public static final int MINIMUM_SAMPLE_SIZE = 1000; + protected int _numBin = -1; private BinMethod _binMethod = BinMethod.EQUI_WIDTH; @@ -96,10 +101,14 @@ public class ColumnEncoderBin extends ColumnEncoder { } public void setBinMethod(String method) { - if (method.equalsIgnoreCase(BinMethod.EQUI_WIDTH.toString())) + if(method.equalsIgnoreCase(BinMethod.EQUI_WIDTH.toString())) _binMethod = BinMethod.EQUI_WIDTH; - if (method.equalsIgnoreCase(BinMethod.EQUI_HEIGHT.toString())) + else if(method.equalsIgnoreCase(BinMethod.EQUI_HEIGHT.toString())) _binMethod = BinMethod.EQUI_HEIGHT; + else if(method.equalsIgnoreCase(BinMethod.EQUI_HEIGHT_APPROX.toString())) + _binMethod = BinMethod.EQUI_HEIGHT_APPROX; + else + throw new RuntimeException(method + " is invalid"); } @Override @@ -107,7 +116,7 @@ public class ColumnEncoderBin extends ColumnEncoder { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if(!isApplicable()) return; - if(_binMethod == BinMethod.EQUI_WIDTH) { + else if(_binMethod == BinMethod.EQUI_WIDTH) { double[] pairMinMax = getMinMaxOfCol(in, _colID, 0, -1); computeBins(pairMinMax[0], pairMinMax[1]); } @@ -115,20 +124,26 @@ public class ColumnEncoderBin extends ColumnEncoder { double[] sortedCol = prepareDataForEqualHeightBins(in, _colID, 0, -1); computeEqualHeightBins(sortedCol, false); } + else if(_binMethod == BinMethod.EQUI_HEIGHT_APPROX){ + double[] vals = sampleDoubleColumn(in, _colID, SAMPLE_FRACTION, MINIMUM_SAMPLE_SIZE); + Arrays.sort(vals); + computeEqualHeightBins(vals, false); + } if(DMLScript.STATISTICS) TransformStatistics.incBinningBuildTime(System.nanoTime()-t0); } + public void build(CacheBlock<?> in, double[] equiHeightMaxs) { long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; if(!isApplicable()) return; - if(_binMethod == BinMethod.EQUI_WIDTH) { + else if(_binMethod == BinMethod.EQUI_WIDTH) { double[] pairMinMax = getMinMaxOfCol(in, _colID, 0, -1); computeBins(pairMinMax[0], pairMinMax[1]); } - else if(_binMethod == BinMethod.EQUI_HEIGHT) { + else if(_binMethod == BinMethod.EQUI_HEIGHT || _binMethod == BinMethod.EQUI_HEIGHT_APPROX) { computeEqualHeightBins(equiHeightMaxs, true); } @@ -148,56 +163,92 @@ public class ColumnEncoderBin extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { - // find the right bucket for a block of rows - final int endInd = getEndIndex(in.getNumRows(), startInd, blkSize); - final double[] codes = new double[endInd - startInd]; + protected final double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { + final int endLength = endInd - startInd; + final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; if (_binMins == null || _binMins.length == 0 || _binMaxs.length == 0) { LOG.warn("ColumnEncoderBin: applyValue without bucket boundaries, assign 1"); - Arrays.fill(codes, startInd, endInd, 1.0); + Arrays.fill(codes, startInd, endInd, 1.0); return codes; } - for (int i=startInd; i<endInd; i++) { - double inVal = in.getDoubleNaN(i, _colID - 1); - codes[i- startInd] = getCodeIndex(inVal); + + if(in instanceof FrameBlock) + getCodeColFrame((FrameBlock) in, startInd, endInd, codes); + else{ + for (int i=startInd; i<endInd; i++) { + double inVal = in.getDoubleNaN(i, _colID - 1); + codes[i-startInd] = getCodeIndex(inVal);; + } } return codes; } - protected double getCodeIndex(double inVal){ - // throw new NotImplementedException("Intensional"); - if (Double.isNaN(inVal) || inVal < _binMins[0] || inVal > _binMaxs[_binMaxs.length-1]){ + protected final void getCodeColFrame(FrameBlock in, int startInd, int endInd, double[] codes) { + final Array<?> c = in.getColumn(_colID - 1); + final double mi = _binMins[0]; + final double mx = _binMaxs[_binMaxs.length-1]; + if(!c.containsNull()) + for(int i = startInd; i < endInd; i++) + codes[i - startInd] = getCodeIndex(c.getAsDouble(i), mi,mx); + else + for(int i = startInd; i < endInd; i++) + codes[i - startInd] = getCodeIndex(c.getAsNaNDouble(i),mi,mx); + } + + protected final double getCodeIndex(double inVal){ + return getCodeIndex(inVal, _binMins[0],_binMaxs[_binMaxs.length-1]); + } + + protected final double getCodeIndex(double inVal, double mi, double mx){ + final boolean nan = Double.isNaN(inVal); + if(nan || (_binMethod != BinMethod.EQUI_HEIGHT_APPROX && (inVal < mi || inVal > mx))) return Double.NaN; - } - else if (_binMethod == BinMethod.EQUI_HEIGHT) { - final int ix = Arrays.binarySearch(_binMaxs, inVal); - - if(ix < 0) // somewhere in between values - // +2 because negative values are found from binary search. - // plus 2 to correct for the absolute value of that. - return Math.abs(ix + 1) + 1; - else if (ix == 0) // If first bucket boundary add it there. - return 1; - else - // precisely at boundaries default to lower bucket - // This is done to avoid using an extra bucket for max value. - return Math.min(ix + 1, _binMaxs.length); - } - else { //if (_binMethod == BinMethod.EQUI_WIDTH) { - final double max = _binMaxs[_binMaxs.length-1]; - final double min = _binMins[0]; - - if(max == min){ - return 1; - } + else if(_binMethod == BinMethod.EQUI_WIDTH) + return getEqWidth(inVal); + else // if (_binMethod == BinMethod.EQUI_HEIGHT || _binMethod == BinMethod.EQUI_HEIGHT_APPROX) + return getCodeIndexEQHeight(inVal); + } - //TODO: Skip computing bin boundaries for equi-width - double binWidth = (max - min) / _numBin; - double code = Math.ceil((inVal - min) / binWidth); - return (code == 0) ? code + 1 : code; - } + private final double getEqWidth(double inVal) { + final double max = _binMaxs[_binMaxs.length - 1]; + final double min = _binMins[0]; + + if(max == min) + return 1; + + // TODO: Skip computing bin boundaries for equi-width + double binWidth = (max - min) / _numBin; + double code = Math.ceil((inVal - min) / binWidth); + return (code == 0) ? code + 1 : code; + } + + private final double getCodeIndexEQHeight(double inVal){ + if(_binMaxs.length <= 10) + return getCodeIndexEQHeightSmall(inVal); + else + return getCodeIndexEQHeightNormal(inVal); } + private final double getCodeIndexEQHeightSmall(double inVal) { + for(int i = 0; i < _binMaxs.length-1; i++) + if(inVal <= _binMaxs[i]) + return i + 1; + return _binMaxs.length; + } + + private final double getCodeIndexEQHeightNormal(double inVal) { + final int ix = Arrays.binarySearch(_binMaxs, inVal); + if(ix < 0) // somewhere in between values + // +2 because negative values are found from binary search. + // plus 2 to correct for the absolute value of that. + return Math.abs(ix + 1) + 1; + else if(ix == 0) // If first bucket boundary add it there. + return 1; + else + // precisely at boundaries default to lower bucket + // This is done to avoid using an extra bucket for max value. + return Math.min(ix + 1, _binMaxs.length); + } @Override protected TransformType getTransformType() { @@ -208,29 +259,63 @@ public class ColumnEncoderBin extends ColumnEncoder { // derive bin boundaries from min/max per column double min = Double.POSITIVE_INFINITY; double max = Double.NEGATIVE_INFINITY; - for(int i = startRow; i < getEndIndex(in.getNumRows(), startRow, blockSize); i++) { - double inVal = in.getDouble(i, colID - 1); - if(Double.isNaN(inVal)) - continue; - min = Math.min(min, inVal); - max = Math.max(max, inVal); + final int end = getEndIndex(in.getNumRows(), startRow, blockSize); + for(int i = startRow; i < end; i++) { + final double inVal = in.getDoubleNaN(i, colID - 1); + if(!Double.isNaN(inVal)){ + min = Math.min(min, inVal); + max = Math.max(max, inVal); + } } return new double[] {min, max}; } private static double[] prepareDataForEqualHeightBins(CacheBlock<?> in, int colID, int startRow, int blockSize) { + double[] vals = extractDoubleColumn(in, colID, startRow, blockSize); + Arrays.sort(vals); + return vals; + } + + private static double[] extractDoubleColumn(CacheBlock<?> in, int colID, int startRow, int blockSize) { int endRow = getEndIndex(in.getNumRows(), startRow, blockSize); - double[] vals = new double[endRow-startRow]; - for(int i = startRow; i < endRow; i++) { - double inVal = in.getDoubleNaN(i, colID - 1); - //FIXME current NaN handling introduces 0s and thus - // impacts the computation of bin boundaries - if(Double.isNaN(inVal)) - continue; - vals[i-startRow] = inVal; + double[] vals = new double[endRow - startRow]; + final int cid = colID -1; + if(in instanceof FrameBlock) { + // FrameBlock optimization + Array<?> a = ((FrameBlock) in).getColumn(cid); + for(int i = startRow; i < endRow; i++) { + double inVal = a.getAsNaNDouble(i); + if(Double.isNaN(inVal)) + continue; + vals[i - startRow] = inVal; + } + } + else { + for(int i = startRow; i < endRow; i++) { + double inVal = in.getDoubleNaN(i, cid); + // FIXME current NaN handling introduces 0s and thus + // impacts the computation of bin boundaries + if(Double.isNaN(inVal)) + continue; + vals[i - startRow] = inVal; + } + } + return vals; + } + + private static double[] sampleDoubleColumn(CacheBlock<?> in, int colID, double sampleFraction, int minimum_sample_size){ + final int nRow = in.getNumRows(); + int elm =(int) Math.min( nRow, Math.max(minimum_sample_size, Math.ceil(nRow * sampleFraction))); + double[] vals = new double[elm]; + Array<?> a = ((FrameBlock) in).getColumn(colID - 1); + int s = DMLScript.SEED; + Random r = s == -1 ? new Random() : new Random(s); + for(int i = 0; i < elm; i++) { + double inVal = a.getAsNaNDouble(r.nextInt(nRow)); + vals[i] = inVal; } - Arrays.sort(vals); return vals; + } @Override @@ -261,12 +346,12 @@ public class ColumnEncoderBin extends ColumnEncoder { } } - private void computeEqualHeightBins(double[] sortedCol, boolean isSorted) { + private void computeEqualHeightBins(double[] sortedCol, boolean doNotTakeQuantiles) { if(_binMins == null || _binMaxs == null) { _binMins = new double[_numBin]; _binMaxs = new double[_numBin]; } - if(!isSorted) { + if(!doNotTakeQuantiles) { int n = sortedCol.length; for(int i = 0; i < _numBin; i++) { double pos = n * (i + 1d) / _numBin; @@ -410,7 +495,7 @@ public class ColumnEncoderBin extends ColumnEncoder { } public enum BinMethod { - INVALID, EQUI_WIDTH, EQUI_HEIGHT + INVALID, EQUI_WIDTH, EQUI_HEIGHT, EQUI_HEIGHT_APPROX } private static class BinSparseApplyTask extends ColumnApplyTask<ColumnEncoderBin> { @@ -469,12 +554,13 @@ public class ColumnEncoderBin extends ColumnEncoder { _partialData.put(_startRow, minMax); } } - if (_method == BinMethod.EQUI_HEIGHT) { + else if (_method == BinMethod.EQUI_HEIGHT || _method == BinMethod.EQUI_HEIGHT_APPROX) { double[] sortedVals = prepareDataForEqualHeightBins(_input, _colID, _startRow, _blockSize); synchronized(_partialData) { _partialData.put(_startRow, sortedVals); } } + if (DMLScript.STATISTICS) TransformStatistics.incBinningBuildTime(System.nanoTime()-t0); return null; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java index 84bf84b516..225f2db54c 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java @@ -228,7 +228,7 @@ public class ColumnEncoderComposite extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { throw new DMLRuntimeException("CompositeEncoder does not have a Code"); } @@ -421,7 +421,7 @@ public class ColumnEncoderComposite extends ColumnEncoder { @Override public int getDomainSize() { return _columnEncoders.stream()// - .map(ColumnEncoder::getDomainSize).reduce(Integer::max).get(); + .map(ColumnEncoder::getDomainSize).reduce((a,x) -> Integer.max(a,x)).get(); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java index 160b716496..4b24652750 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java @@ -77,38 +77,51 @@ public class ColumnEncoderDummycode extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { throw new DMLRuntimeException("DummyCoder does not have a code"); } - protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk){ - if (!(in instanceof MatrixBlock)){ - throw new DMLRuntimeException("ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() + - " and not MatrixBlock"); + + + /** + * Since the recoded values are already offset in the output matrix (same as input at this point) the dummycoding + * only needs to offset them within their column domain. Which means that the indexes in the SparseRowVector do not + * need to be sorted anymore and can be updated directly. + * <p> + * Input: Output: + * + * <pre> + * <code> + *1 | 0 | 2 | 0 1 | 0 | 0 | 1 + *2 | 0 | 1 | 0 ===> 0 | 1 | 1 | 0 + *1 | 0 | 2 | 0 1 | 0 | 0 | 1 + *1 | 0 | 1 | 0 1 | 0 | 1 | 0 + * </code> + * </pre> + * + * Example SparseRowVector Internals (1. row): + * <p> + * indexes = [0,2] ===> indexes = [0,3] values = [1,2] values = [1,1] + * + * @param in Input block to apply to + * @param out Output in sparse format + * @param outputCol The column to output to + * @param rowStart Row start + * @param blk block size. + */ + protected void applySparse(CacheBlock<?> in, MatrixBlock out, int outputCol, int rowStart, int blk) { + if(!(in instanceof MatrixBlock)) { + throw new DMLRuntimeException( + "ColumnEncoderDummycode called with: " + in.getClass().getSimpleName() + " and not MatrixBlock"); } boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == SparseBlock.Type.MCSR; - mcsr = false; //force CSR for transformencode + mcsr = false; // force CSR for transformencode ArrayList<Integer> sparseRowsWZeros = null; int index = _colID - 1; for(int r = rowStart; r < getEndIndex(in.getNumRows(), rowStart, blk); r++) { - // Since the recoded values are already offset in the output matrix (same as input at this point) - // the dummycoding only needs to offset them within their column domain. Which means that the - // indexes in the SparseRowVector do not need to be sorted anymore and can be updated directly. - // - // Input: Output: - // - // 1 | 0 | 2 | 0 1 | 0 | 0 | 1 - // 2 | 0 | 1 | 0 ===> 0 | 1 | 1 | 0 - // 1 | 0 | 2 | 0 1 | 0 | 0 | 1 - // 1 | 0 | 1 | 0 1 | 0 | 1 | 0 - // - // Example SparseRowVector Internals (1. row): - // - // indexes = [0,2] ===> indexes = [0,3] - // values = [1,2] values = [1,1] - if (mcsr) { + if(mcsr) { double val = out.getSparseBlock().get(r).values()[index]; - if(Double.isNaN(val)){ + if(Double.isNaN(val)) { if(sparseRowsWZeros == null) sparseRowsWZeros = new ArrayList<>(); sparseRowsWZeros.add(r); @@ -119,21 +132,21 @@ public class ColumnEncoderDummycode extends ColumnEncoder { out.getSparseBlock().get(r).indexes()[index] = nCol; out.getSparseBlock().get(r).values()[index] = 1; } - else { //csr - SparseBlockCSR csrblock = (SparseBlockCSR)out.getSparseBlock(); + else { // csr + SparseBlockCSR csrblock = (SparseBlockCSR) out.getSparseBlock(); int rptr[] = csrblock.rowPointers(); - double val = csrblock.values()[rptr[r]+index]; - if(Double.isNaN(val)){ + double val = csrblock.values()[rptr[r] + index]; + if(Double.isNaN(val)) { if(sparseRowsWZeros == null) sparseRowsWZeros = new ArrayList<>(); sparseRowsWZeros.add(r); - csrblock.values()[rptr[r]+index] = 0; //test + csrblock.values()[rptr[r] + index] = 0; // test continue; } // Manually fill the column-indexes and values array int nCol = outputCol + (int) val - 1; - csrblock.indexes()[rptr[r]+index] = nCol; - csrblock.values()[rptr[r]+index] = 1; + csrblock.indexes()[rptr[r] + index] = nCol; + csrblock.values()[rptr[r] + index] = 1; } } if(sparseRowsWZeros != null) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java index 467d16ae7c..c57c72f459 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java @@ -19,8 +19,6 @@ package org.apache.sysds.runtime.transform.encode; -import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex; - import java.io.IOException; import java.io.ObjectInput; import java.io.ObjectOutput; @@ -85,10 +83,9 @@ public class ColumnEncoderFeatureHash extends ColumnEncoder { return Math.abs(a.hashDouble(row) % _K + 1); } - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { - // hash a block of rows - int endInd = getEndIndex(in.getNumRows(), startInd, blkSize); - double codes[] = new double[endInd-startInd]; + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { + final int endLength = endInd - startInd; + final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; if( in instanceof FrameBlock) { Array<?> a = ((FrameBlock) in).getColumn(_colID-1); for(int i = startInd; i < endInd; i++) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java index 9d775a7a5f..d40359c484 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java @@ -67,9 +67,9 @@ public class ColumnEncoderPassThrough extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { - int endInd = getEndIndex(in.getNumRows(), startInd, blkSize); - double[] codes = new double[endInd-startInd]; + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { + final int endLength = endInd - startInd; + final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; for (int i=startInd; i<endInd; i++) { codes[i-startInd] = in.getDoubleNaN(i, _colID-1); } @@ -83,8 +83,8 @@ public class ColumnEncoderPassThrough extends ColumnEncoder { mcsr = false; //force CSR for transformencode int index = _colID - 1; // Apply loop tiling to exploit CPU caches - double[] codes = getCodeCol(in, rowStart, blk); int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk); + double[] codes = getCodeCol(in, rowStart, rowEnd, null); int B = 32; //tile size for(int i = rowStart; i < rowEnd; i+=B) { int lim = Math.min(i+B, rowEnd); diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java index cbb4f79664..e013e7ccf0 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java @@ -226,10 +226,9 @@ public class ColumnEncoderRecode extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { - // lookup for a block of rows - int endInd = getEndIndex(in.getNumRows(), startInd, blkSize); - double codes[] = new double[endInd-startInd]; + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { + final int endLength = endInd - startInd; + final double[] codes = tmp != null && tmp.length == endLength ? tmp : new double[endLength]; for (int i=startInd; i<endInd; i++) { String key = in.getString(i, _colID-1); if(key == null || key.isEmpty()) { diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java index 469414b397..756a1fdc5c 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderUDF.java @@ -162,7 +162,7 @@ public class ColumnEncoderUDF extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { throw new DMLRuntimeException("UDF encoders only support full column access."); } } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java index b6faf4d00b..d2909f3e01 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java @@ -57,7 +57,7 @@ public class ColumnEncoderWordEmbedding extends ColumnEncoder { } @Override - protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) { + protected double[] getCodeCol(CacheBlock<?> in, int startInd, int endInd, double[] tmp) { throw new NotImplementedException(); } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java index 6c8b40c405..8c359aaa83 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java @@ -197,8 +197,13 @@ public class CompressedEncode { } } else { - for(int i = 0; i < a.size(); i++) - m.set(i, (int) b.getCodeIndex(a.getAsDouble(i)) - 1); + + for(int i = 0; i < a.size(); i++){ + int idx = (int) b.getCodeIndex(a.getAsDouble(i)) - 1; + if(idx < 0) + throw new RuntimeException(a.getAsDouble(i) + " is invalid value for " + b + "\n" + idx); + m.set(i, idx); + } } return m; } diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java index 4357b35307..640cd54d58 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java @@ -49,19 +49,22 @@ import org.apache.wink.json4j.JSONObject; public interface EncoderFactory { final static Log LOG = LogFactory.getLog(EncoderFactory.class.getName()); + public static MultiColumnEncoder createEncoder(String spec, int clen) { + return createEncoder(spec, null, clen, null, null, -1, -1); + } + public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta) { - return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta); + return createEncoder(spec, colnames, clen, meta, null, -1, -1); } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, int minCol, int maxCol) { - return createEncoder(spec, colnames, UtilFunctions.nCopies(clen, ValueType.STRING), meta, minCol, maxCol); + return createEncoder(spec, colnames, clen, meta, null, minCol, maxCol); } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, int clen, FrameBlock meta) { - ValueType[] lschema = (schema == null) ? UtilFunctions.nCopies(clen, ValueType.STRING) : schema; - return createEncoder(spec, colnames, lschema, meta); + return createEncoder(spec, colnames, clen, meta); } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, @@ -70,8 +73,8 @@ public interface EncoderFactory { } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, - int minCol, int maxCol){ - return createEncoder(spec, colnames, schema, meta, null, minCol, maxCol); + int minCol, int maxCol) { + return createEncoder(spec, colnames, schema.length, meta, null, minCol, maxCol); } public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings) { @@ -80,13 +83,15 @@ public interface EncoderFactory { public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, MatrixBlock embeddings) { - return createEncoder(spec, colnames, schema, meta, embeddings, -1, -1); + return createEncoder(spec, colnames, schema.length, meta, embeddings, -1, -1); } - public static MultiColumnEncoder createEncoder(String spec, String[] colnames, ValueType[] schema, FrameBlock meta, + public static MultiColumnEncoder createEncoder(String spec, String[] colnames, int clen, FrameBlock meta, MatrixBlock embeddings, int minCol, int maxCol) { + + MultiColumnEncoder encoder; - int clen = schema.length; + // int clen = schema.length; try { // parse transform specification @@ -155,6 +160,8 @@ public interface EncoderFactory { binMethod = ColumnEncoderBin.BinMethod.EQUI_WIDTH; else if ("EQUI-HEIGHT".equals(method)) binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT; + else if ("EQUI-HEIGHT-APPROX".equals(method)) + binMethod = ColumnEncoderBin.BinMethod.EQUI_HEIGHT_APPROX; else throw new DMLRuntimeException("Unsupported binning method: " + method); ColumnEncoderBin bin = new ColumnEncoderBin(id, numBins, binMethod); @@ -177,12 +184,12 @@ public interface EncoderFactory { } encoder = new MultiColumnEncoder(lencoders); if(!oIDs.isEmpty()) { - encoder.addReplaceLegacyEncoder(new EncoderOmit(jSpec, colnames, schema.length, minCol, maxCol)); + encoder.addReplaceLegacyEncoder(new EncoderOmit(jSpec, colnames, clen, minCol, maxCol)); if(DMLScript.STATISTICS) TransformStatistics.incEncoderCount(1); } if(!mvIDs.isEmpty()) { - EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, schema.length, minCol, maxCol); + EncoderMVImpute ma = new EncoderMVImpute(jSpec, colnames, clen, minCol, maxCol); ma.initRecodeIDList(rcIDs); encoder.addReplaceLegacyEncoder(ma); if(DMLScript.STATISTICS) diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/LegacyEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/LegacyEncoder.java index 24c7227455..3fe6cef42e 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/LegacyEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/LegacyEncoder.java @@ -29,7 +29,6 @@ import java.util.HashSet; import java.util.List; import java.util.Set; - import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.DMLRuntimeException; diff --git a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java index def429d57b..43ab2492ad 100644 --- a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java +++ b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java @@ -117,7 +117,7 @@ public class MultiColumnEncoder implements Encoder { finally{ pool.shutdown(); } - outputMatrixPostProcessing(out); + outputMatrixPostProcessing(out, k); return out; } else { @@ -367,7 +367,7 @@ public class MultiColumnEncoder implements Encoder { } // Recomputing NNZ since we access the block directly // TODO set NNZ explicit count them in the encoders - outputMatrixPostProcessing(out); + outputMatrixPostProcessing(out, k); if(_legacyOmit != null) out = _legacyOmit.apply((FrameBlock) in, out); if(_legacyMVImpute != null) @@ -597,11 +597,26 @@ public class MultiColumnEncoder implements Encoder { } } - private void outputMatrixPostProcessing(MatrixBlock output){ + private void outputMatrixPostProcessing(MatrixBlock output, int k){ long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0; - int k = OptimizerUtils.getTransformNumThreads(); - if (k == 1) { - Set<Integer> indexSet = _columnEncoders.stream() + if(output.isInSparseFormat()){ + if (k == 1) + outputMatrixPostProcessingSingleThread(output); + else + outputMatrixPostProcessingParallel(output, k); + } + else { + output.recomputeNonZeros(k); + } + + + if(DMLScript.STATISTICS) + TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0); + } + + + private void outputMatrixPostProcessingSingleThread(MatrixBlock output){ + Set<Integer> indexSet = _columnEncoders.stream() .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { if(l == null) return null; @@ -612,42 +627,42 @@ public class MultiColumnEncoder implements Encoder { for(Integer row : indexSet) output.getSparseBlock().get(row).compact(); } - } - else { - ExecutorService myPool = CommonThreadPool.get(k); - try { - // Collect the row indices that need compaction - Set<Integer> indexSet = myPool.submit(() -> - _columnEncoders.stream().parallel() - .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { - if(l == null) - return null; - return l.stream(); - }).collect(Collectors.toSet()) - ).get(); - - // Check if the set is empty - boolean emptySet = myPool.submit(() -> - indexSet.stream().parallel().allMatch(Objects::isNull) - ).get(); - - // Concurrently compact the rows - if (emptySet) { - myPool.submit(() -> { - indexSet.stream().parallel().forEach(row -> { - output.getSparseBlock().get(row).compact(); - }); - }).get(); - } - } - catch(Exception ex) { - throw new DMLRuntimeException(ex); + + output.recomputeNonZeros(); + } + + + private void outputMatrixPostProcessingParallel(MatrixBlock output, int k) { + ExecutorService myPool = CommonThreadPool.get(k); + try { + // Collect the row indices that need compaction + Set<Integer> indexSet = myPool.submit(() -> _columnEncoders.stream().parallel() + .map(ColumnEncoderComposite::getSparseRowsWZeros).flatMap(l -> { + if(l == null) + return null; + return l.stream(); + }).collect(Collectors.toSet())).get(); + + // Check if the set is empty + boolean emptySet = myPool.submit(() -> indexSet.stream().parallel().allMatch(Objects::isNull)).get(); + + // Concurrently compact the rows + if(emptySet) { + myPool.submit(() -> { + indexSet.stream().parallel().forEach(row -> { + output.getSparseBlock().get(row).compact(); + }); + }).get(); } + } + catch(Exception ex) { + throw new DMLRuntimeException(ex); + } + finally { myPool.shutdown(); } + output.recomputeNonZeros(); - if(DMLScript.STATISTICS) - TransformStatistics.incOutMatrixPostProcessingTime(System.nanoTime()-t0); } @Override diff --git a/src/test/java/org/apache/sysds/performance/Main.java b/src/test/java/org/apache/sysds/performance/Main.java index 3fd2def237..4e8f566a30 100644 --- a/src/test/java/org/apache/sysds/performance/Main.java +++ b/src/test/java/org/apache/sysds/performance/Main.java @@ -23,8 +23,10 @@ import org.apache.sysds.performance.compression.IOBandwidth; import org.apache.sysds.performance.compression.SchemaTest; import org.apache.sysds.performance.compression.Serialize; import org.apache.sysds.performance.compression.StreamCompress; +import org.apache.sysds.performance.compression.TransformPerf; import org.apache.sysds.performance.generators.ConstMatrix; import org.apache.sysds.performance.generators.FrameFile; +import org.apache.sysds.performance.generators.FrameTransformFile; import org.apache.sysds.performance.generators.GenMatrices; import org.apache.sysds.performance.generators.IGenerate; import org.apache.sysds.performance.generators.MatrixFile; @@ -153,7 +155,7 @@ public class Main { String p = args[3]; // input frame String s = args[4]; // spec int id = Integer.parseInt(args[5]); - // run13A(n, FrameTransformFile.create(p, s), k, id); + run13A(n, FrameTransformFile.create(p, s), k, id); } private static void run13A(int n, IGenerate<MatrixBlock> g, int k, int id) throws Exception { @@ -171,7 +173,8 @@ public class Main { int n = Integer.parseInt(args[2]); IGenerate<FrameBlock> g = FrameFile.create(args[3]); String spec = args[4]; - // new TransformPerf(n, k, g, spec).run(); + new TransformPerf(n, k, g, spec).run(); + } private static void run16(String[] args) { @@ -180,7 +183,6 @@ public class Main { System.out.println(mb); } - public static void main(String[] args) { try { exec(Integer.parseInt(args[0]), args); diff --git a/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java new file mode 100644 index 0000000000..3edc164e63 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/compression/TransformPerf.java @@ -0,0 +1,130 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.compression; + +import org.apache.sysds.performance.PerfUtil; +import org.apache.sysds.performance.compression.Serialize.InOut; +import org.apache.sysds.performance.generators.ConstFrame; +import org.apache.sysds.performance.generators.IGenerate; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.frame.data.lib.FrameLibApplySchema; +import org.apache.sysds.runtime.frame.data.lib.FrameLibDetectSchema; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; + +public class TransformPerf extends APerfTest<Serialize.InOut, FrameBlock> { + + private final String file; + private final String spec; + private final String specPath; + private final int k; + + public TransformPerf(int n, int k, IGenerate<FrameBlock> gen, String specPath) throws Exception { + super(n, gen); + this.file = "tmp/perf-tmp.bin"; + this.k = k; + this.spec = PerfUtil.readSpec(specPath); + this.specPath = specPath; + } + + public void run() throws Exception { + System.out.println(this); + CompressedMatrixBlock.debug = true; + + // execute(() -> detectSchema(k), "Detect Schema"); + // execute(() -> detectAndApply(k), "Detect&Apply Frame Schema"); + + updateGen(); + + // execute(() -> detectAndApply(k), "Detect&Apply Frame Schema Known"); + + // execute(() -> transformEncode(k), "TransformEncode Def"); + execute(() -> transformEncodeCompressed(k), "TransformEncode Comp"); + + } + + private void updateGen() { + if(gen instanceof ConstFrame) { + FrameBlock fb = gen.take(); + FrameBlock r = FrameLibDetectSchema.detectSchema(fb, k); + FrameBlock out = FrameLibApplySchema.applySchema(fb, r, k); + ((ConstFrame) gen).change(out); + } + } + + private void detectSchema(int k) { + FrameBlock fb = gen.take(); + long in = fb.getInMemorySize(); + FrameBlock r = FrameLibDetectSchema.detectSchema(fb, k); + long out = r.getInMemorySize(); + ret.add(new InOut(in, out)); + } + + private void detectAndApply(int k) { + FrameBlock fb = gen.take(); + long in = fb.getInMemorySize(); + FrameBlock r = FrameLibDetectSchema.detectSchema(fb, k); + FrameBlock out = FrameLibApplySchema.applySchema(fb, r, k); + long outS = out.getInMemorySize(); + ret.add(new InOut(in, outS)); + } + + private void transformEncode(int k) { + FrameBlock fb = gen.take(); + long in = fb.getInMemorySize(); + MultiColumnEncoder e = EncoderFactory.createEncoder(spec, fb.getNumColumns()); + MatrixBlock r = e.encode(fb, k); + long out = r.getInMemorySize(); + ret.add(new InOut(in, out)); + } + + private void transformEncodeCompressed(int k) { + FrameBlock fb = gen.take(); + long in = fb.getInMemorySize(); + MultiColumnEncoder e = EncoderFactory.createEncoder(spec, fb.getNumColumns()); + MatrixBlock r = e.encode(fb, k, true); + long out = r.getInMemorySize(); + ret.add(new InOut(in, out)); + } + + @Override + protected String makeResString() { + throw new RuntimeException("Do not call"); + } + + @Override + protected String makeResString(double[] times) { + return Serialize.makeResString(ret, times); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(super.toString()); + sb.append(" File: "); + sb.append(file); + sb.append(" Spec: "); + sb.append(specPath); + return sb.toString(); + } + +} diff --git a/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java b/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java new file mode 100644 index 0000000000..359cbd2381 --- /dev/null +++ b/src/test/java/org/apache/sysds/performance/generators/FrameTransformFile.java @@ -0,0 +1,83 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.performance.generators; + +import java.io.IOException; + +import org.apache.sysds.common.Types.FileFormat; +import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.performance.PerfUtil; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.frame.data.FrameBlock; +import org.apache.sysds.runtime.io.FileFormatPropertiesCSV; +import org.apache.sysds.runtime.io.FrameReader; +import org.apache.sysds.runtime.io.FrameReaderFactory; +import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.transform.encode.EncoderFactory; +import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder; + +public class FrameTransformFile extends ConstMatrix { + + final private String path; + final private String specPath; + + private FrameTransformFile(String path, String specPath, MatrixBlock mb) throws IOException { + super(mb); + this.path = path; + this.specPath = specPath; + } + + // example: + // src/test/resources/datasets/titanic/tfspec.json + // src/test/resources/datasets/titanic/titanic.csv + public static FrameTransformFile create(String path, String specPath) throws IOException { + // read spec + final String spec = PerfUtil.readSpec(specPath); + + // MetaDataAll mba = new MetaDataAll(path + ".mtd", false, true); + // DataCharacteristics ds = mba.getDataCharacteristics(); + // FileFormat f = FileFormat.valueOf(mba.getFormatTypeString().toUpperCase()); + + FileFormatPropertiesCSV csvP = new FileFormatPropertiesCSV(); + csvP.setHeader(true); + FrameReader r = FrameReaderFactory.createFrameReader(FileFormat.CSV, csvP); + FrameBlock fb = r.readFrameFromHDFS(path, new ValueType[] {ValueType.STRING}, -1, -1); + + int k = InfrastructureAnalyzer.getLocalParallelism(); + FrameBlock sc = fb.detectSchema(k); + fb = fb.applySchema(sc, k); + MultiColumnEncoder encoder = EncoderFactory.createEncoder(spec, fb.getColumnNames(), fb.getNumColumns(), null); + MatrixBlock mb = encoder.encode(fb, k); + + return new FrameTransformFile(path, specPath, mb); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" From file: "); + sb.append(path); + sb.append(" -- Transformed with: "); + sb.append(specPath); + return sb.toString(); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleColBinSpecific.java b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleColBinSpecific.java index fffe66ab95..fb622a216e 100644 --- a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleColBinSpecific.java +++ b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleColBinSpecific.java @@ -27,6 +27,7 @@ import java.util.Collection; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.api.DMLScript; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.runtime.frame.data.FrameBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; @@ -95,11 +96,46 @@ public class TransformCompressedTestSingleColBinSpecific { test(4); } + @Test + public void test30() { + test(30); + } + public void test(int bin) { - test("{ids:true, bin:[{id:1, method:equi-height, numbins:" + bin + "}], dummycode:[1] }"); + test("{ids:true, bin:[{id:1, method:equi-height, numbins:" + bin + "}], dummycode:[1] }", true); + } + + @Test + public void testAP1() { + testAP(1); + } + + @Test + public void testAP2() { + testAP(2); + } + + @Test + public void testAP3() { + testAP(3); + } + + @Test + public void testAP4() { + testAP(4); } - public void test(String spec) { + @Test + public void testAP30() { + testAP(30); + } + + public void testAP(int bin) { + DMLScript.SEED = 132; + test("{ids:true, bin:[{id:1, method:equi-height-approx, numbins:" + bin + "}], dummycode:[1] }", false); + } + + public void test(String spec, boolean EQ) { try { FrameBlock meta = null; @@ -112,14 +148,15 @@ public class TransformCompressedTestSingleColBinSpecific { data.getNumColumns(), meta); MatrixBlock outNormal = encoderNormal.encode(data, k); FrameBlock outNormalMD = encoderNormal.getMetaData(null); - TestUtils.compareMatrices(outNormal, outCompressed, 0, "Not Equal after apply"); TestUtils.compareFrames(outNormalMD, outCompressedMD, true); - // Assert that each bucket has the same number of elements - MatrixBlock colSum = outNormal.colSum(); - for(int i = 0; i < colSum.getNumColumns(); i++) - assertEquals(colSum.quickGetValue(0, 0), colSum.quickGetValue(0, i), 0.001); + if(EQ){ + // Assert that each bucket has the same number of elements + MatrixBlock colSum = outNormal.colSum(); + for(int i = 0; i < colSum.getNumColumns(); i++) + assertEquals(colSum.quickGetValue(0, 0), colSum.quickGetValue(0, i), 0.001); + } } catch(Exception e) {
