This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit dadfbbbeda2eedc45a39ad2573fde22ba8ed6daa Author: baunsgaard <[email protected]> AuthorDate: Mon Jun 5 23:12:30 2023 +0200 [SYSTEMDS-3574] Recompress Compressed Matrix This commit adds support to recompress a compressed matrix via the standard compression plan of extracting statistics from groups, and using these in the sequence of co-coding and applying the co-coding plan to the compressed column groups via the CLALibCombineGroups. The implementation currently is naive in the sense that it does not guarantee that the output column groups are of the type expected and wanted since the combine algorithms are simplified to only return groups that made it easy to combine not enabling a specific plan to be followed. Closes #1836 --- .../compress/CompressedMatrixBlockFactory.java | 92 +++++++++----------- .../runtime/compress/CompressionStatistics.java | 2 + .../sysds/runtime/compress/colgroup/AColGroup.java | 17 +++- .../compress/colgroup/AMorphingMMColGroup.java | 1 - .../sysds/runtime/compress/colgroup/ASDC.java | 2 +- .../sysds/runtime/compress/colgroup/ASDCZero.java | 2 +- .../runtime/compress/colgroup/ColGroupConst.java | 8 +- .../runtime/compress/colgroup/ColGroupDDC.java | 3 +- .../runtime/compress/colgroup/ColGroupDDCFOR.java | 5 +- .../runtime/compress/colgroup/ColGroupEmpty.java | 4 +- .../runtime/compress/colgroup/ColGroupFactory.java | 14 ++-- .../runtime/compress/colgroup/ColGroupSDCFOR.java | 2 +- .../compress/colgroup/ColGroupSDCSingle.java | 7 -- .../compress/colgroup/IFrameOfReferenceGroup.java | 33 ++++++++ .../colgroup/dictionary/DictionaryFactory.java | 84 +++++++++++++++---- .../compress/colgroup/indexes/AColIndex.java | 9 ++ .../compress/colgroup/indexes/IColIndex.java | 8 ++ .../sysds/runtime/compress/estim/AComEst.java | 1 - .../runtime/compress/estim/ComEstCompressed.java | 21 ++++- .../sysds/runtime/compress/estim/ComEstExact.java | 2 +- .../compress/estim/CompressedSizeInfoColGroup.java | 17 +++- .../compress/estim/encoding/ConstEncoding.java | 8 +- .../compress/estim/encoding/DenseEncoding.java | 54 +++++++----- .../compress/estim/encoding/EmptyEncoding.java | 12 ++- .../runtime/compress/estim/encoding/IEncode.java | 7 +- .../compress/estim/encoding/SparseEncoding.java | 17 ++-- .../runtime/compress/io/ReaderCompressed.java | 3 + .../runtime/compress/lib/CLALibCombineGroups.java | 97 +++++++++++++++++++--- .../sysds/runtime/compress/lib/CLALibStack.java | 23 +++-- .../sysds/runtime/compress/lib/CLALibUtils.java | 26 +++++- .../component/compress/CompressedLoggingTests.java | 18 ++-- 31 files changed, 438 insertions(+), 161 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java index 8a6478f8c9..5074a695df 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressedMatrixBlockFactory.java @@ -267,10 +267,13 @@ public class CompressedMatrixBlockFactory { private Pair<MatrixBlock, CompressionStatistics> compressMatrix() { if(mb.getNonZeros() < 0) throw new DMLCompressionException("Invalid to compress matrices with unknown nonZeros"); - else if(mb instanceof CompressedMatrixBlock) // Redundant compression - return recompress((CompressedMatrixBlock) mb); + else if(mb instanceof CompressedMatrixBlock && ((CompressedMatrixBlock) mb).isOverlapping()) { + LOG.warn("Unsupported recompression of overlapping compression"); + return new ImmutablePair<>(mb, null); + } _stats.denseSize = MatrixBlock.estimateSizeInMemory(mb.getNumRows(), mb.getNumColumns(), 1.0); + _stats.sparseSize = MatrixBlock.estimateSizeSparseInMemory(mb.getNumRows(), mb.getNumColumns(), mb.getSparsity()); _stats.originalSize = mb.getInMemorySize(); _stats.originalCost = costEstimator.getCost(mb); @@ -427,11 +430,15 @@ public class CompressedMatrixBlockFactory { final double ratio = _stats.getRatio(); final double denseRatio = _stats.getDenseRatio(); + _stats.setColGroupsCounts(res.getColGroups()); if(ratio < 1 && denseRatio < 100.0) { LOG.info("--dense size: " + _stats.denseSize); LOG.info("--original size: " + _stats.originalSize); LOG.info("--compressed size: " + _stats.compressedSize); LOG.info("--compression ratio: " + ratio); + LOG.debug("--col groups types " + _stats.getGroupsTypesString()); + LOG.debug("--col groups sizes " + _stats.getGroupsSizesString()); + logLengths(); LOG.info("Abort block compression because compression ratio is less than 1."); res = null; setNextTimePhase(time.stop()); @@ -439,8 +446,6 @@ public class CompressedMatrixBlockFactory { return; } - _stats.setColGroupsCounts(res.getColGroups()); - if(compSettings.isInSparkInstruction) res.clearSoftReferenceToDecompressed(); @@ -454,30 +459,6 @@ public class CompressedMatrixBlockFactory { return new ImmutablePair<>(mb, _stats); } - private Pair<MatrixBlock, CompressionStatistics> recompress(CompressedMatrixBlock cmb) { - LOG.debug("Recompressing an already compressed MatrixBlock"); - LOG.warn("Not Implemented Recompress yet"); - - classifyPhase(); - // informationExtractor = ComEstFactory.createEstimator(mb, compSettings, k); - - // compressionGroups = informationExtractor.computeCompressedSizeInfos(k); - - // _stats.estimatedSizeCols = compressionGroups.memoryEstimate(); - // _stats.estimatedCostCols = costEstimator.getCost(compressionGroups); - - // logPhase(); - - - - return new ImmutablePair<>(cmb, null); - // _stats.originalSize = cmb.getInMemorySize(); - // CompressedMatrixBlock combined = CLALibCombineGroups.combine(cmb, k); - // CompressedMatrixBlock squashed = CLALibSquash.squash(combined, k); - // _stats.compressedSize = squashed.getInMemorySize(); - // return new ImmutablePair<>(squashed, _stats); - } - private void logPhase() { setNextTimePhase(time.stop()); DMLCompressionStatistics.addCompressionTime(getLastTimePhase(), phase); @@ -492,6 +473,9 @@ public class CompressedMatrixBlockFactory { LOG.debug("--Seed used for comp : " + compSettings.seed); LOG.debug("--compression phase " + phase + " Classify : " + getLastTimePhase()); LOG.debug("--Individual Columns Estimated Compression: " + _stats.estimatedSizeCols); + if(mb instanceof CompressedMatrixBlock) { + LOG.debug("--Recompressing already compressed MatrixBlock"); + } break; case 1: LOG.debug("--compression phase " + phase + " Grouping : " + getLastTimePhase()); @@ -521,7 +505,9 @@ public class CompressedMatrixBlockFactory { LOG.debug("--compression phase " + phase + " Cleanup : " + getLastTimePhase()); LOG.debug("--col groups types " + _stats.getGroupsTypesString()); LOG.debug("--col groups sizes " + _stats.getGroupsSizesString()); + LOG.debug("--input was compressed " + (mb instanceof CompressedMatrixBlock)); LOG.debug(String.format("--dense size: %16d", _stats.denseSize)); + LOG.debug(String.format("--sparse size: %16d", _stats.sparseSize)); LOG.debug(String.format("--original size: %16d", _stats.originalSize)); LOG.debug(String.format("--compressed size: %16d", _stats.compressedSize)); LOG.debug(String.format("--compression ratio: %4.3f", _stats.getRatio())); @@ -534,35 +520,39 @@ public class CompressedMatrixBlockFactory { LOG.debug( String.format("--relative cost: %1.4f", (_stats.compressedCost / _stats.originalCost))); } - if(compressionGroups != null && compressionGroups.getInfo().size() < 1000) { - int[] lengths = new int[res.getColGroups().size()]; - int i = 0; - for(AColGroup colGroup : res.getColGroups()) - lengths[i++] = colGroup.getNumValues(); - - LOG.debug("--compressed colGroup dictionary sizes: " + Arrays.toString(lengths)); - LOG.debug( - "--compressed colGroup nr columns : " + constructNrColumnString(res.getColGroups())); - } - if(LOG.isTraceEnabled()) { - for(AColGroup colGroup : res.getColGroups()) { - if(colGroup.estimateInMemorySize() < 1000) - LOG.trace(colGroup); - else { - LOG.trace("--colGroups type : " + colGroup.getClass().getSimpleName() + " size: " - + colGroup.estimateInMemorySize() - + ((colGroup instanceof AColGroupValue) ? " numValues :" - + ((AColGroupValue) colGroup).getNumValues() : "") - + " colIndexes : " + colGroup.getColIndices()); - } - } - } + logLengths(); } } } phase++; } + private void logLengths() { + if(compressionGroups != null && compressionGroups.getInfo().size() < 1000) { + int[] lengths = new int[res.getColGroups().size()]; + int i = 0; + for(AColGroup colGroup : res.getColGroups()) + lengths[i++] = colGroup.getNumValues(); + + LOG.debug("--compressed colGroup dictionary sizes: " + Arrays.toString(lengths)); + LOG.debug("--compressed colGroup nr columns : " + constructNrColumnString(res.getColGroups())); + } + if(LOG.isTraceEnabled()) { + for(AColGroup colGroup : res.getColGroups()) { + if(colGroup.estimateInMemorySize() < 1000) + LOG.trace(colGroup); + else { + LOG.trace( + "--colGroups type : " + colGroup.getClass().getSimpleName() + " size: " + + colGroup.estimateInMemorySize() + + ((colGroup instanceof AColGroupValue) ? " numValues :" + + ((AColGroupValue) colGroup).getNumValues() : "") + + " colIndexes : " + colGroup.getColIndices()); + } + } + } + } + private void setNextTimePhase(double time) { lastPhase = time; } diff --git a/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java b/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java index e9abe835b4..d54eb2c352 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java +++ b/src/main/java/org/apache/sysds/runtime/compress/CompressionStatistics.java @@ -35,6 +35,8 @@ public class CompressionStatistics { public long originalSize; /** Size if the input is dense */ public long denseSize; + /** Size if the input is sparse */ + public long sparseSize; /** Estimated size of compressing individual columns */ public long estimatedSizeCols; /** Estimated size of compressing after co-coding */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java index b3356db9da..e85b220d4b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java @@ -36,6 +36,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.data.SparseBlock; +import org.apache.sysds.runtime.functionobjects.Plus; import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; import org.apache.sysds.runtime.matrix.data.MatrixBlock; import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator; @@ -418,17 +419,29 @@ public abstract class AColGroup implements Serializable { * Perform a binary row operation. * * @param op The operation to execute - * @param v The vector of values to apply, should be same length as dictionary length. + * @param v The vector of values to apply the values contained should be at least the length of the highest + * value in the column index * @param isRowSafe True if the binary op is applied to an entire zero row and all results are zero * @return A updated column group with the new values. */ public abstract AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean isRowSafe); + /** + * Short hand add operator call on column group to add a row vector to the column group + * + * @param v The vector to add + * @return A new column group where the vector is added. + */ + public AColGroup addVector(double[] v) { + return binaryRowOpRight(new BinaryOperator(Plus.getPlusFnObject(), 1), v, false); + } + /** * Perform a binary row operation. * * @param op The operation to execute - * @param v The vector of values to apply, should be same length as dictionary length. + * @param v The vector of values to apply the values contained should be at least the length of the highest + * value in the column index * @param isRowSafe True if the binary op is applied to an entire zero row and all results are zero * @return A updated column group with the new values. */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java index 7f18843316..bb85e786e2 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMorphingMMColGroup.java @@ -197,7 +197,6 @@ public abstract class AMorphingMMColGroup extends AColGroupValue { protected abstract AColGroup allocateRightMultiplicationCommon(double[] common, IColIndex colIndexes, ADictionary preAgg); - /** extract common value from group and return non morphing group */ /** * extract common value from group and return non morphing group * diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java index 964dc083ca..96c3dda02d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDC.java @@ -60,6 +60,6 @@ public abstract class ASDC extends AMorphingMMColGroup implements AOffsetsGroup @Override public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding()); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java index 4ce3946c28..041458621d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ASDCZero.java @@ -228,6 +228,6 @@ public abstract class ASDCZero extends APreAgg implements AOffsetsGroup, IContai @Override public final CompressedSizeInfoColGroup getCompressionInfo(int nRow) { EstimationFactors ef = new EstimationFactors(getNumValues(), _numRows, getNumberOffsets(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType(),getEncoding()); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java index 87044290db..e3454886d6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupConst.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.ConstScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibLeftMultBy; @@ -576,7 +577,8 @@ public class ColGroupConst extends ADictBasedColGroup implements IContainDefault @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { - return new CompressedSizeInfoColGroup(_colIndexes, 1, nRow, CompressionType.CONST); + EstimationFactors ef = new EstimationFactors(1, 1, 1, _dict.getSparsity()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.CONST, getEncoding()); } @Override @@ -594,8 +596,8 @@ public class ColGroupConst extends ADictBasedColGroup implements IContainDefault return _dict.getValues(); } - @Override - protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering){ + @Override + protected AColGroup fixColIndexes(IColIndex newColIndex, int[] reordering) { return ColGroupConst.create(newColIndex, _dict.reorder(reordering)); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index f88579000b..78a0bdba51 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -555,8 +555,9 @@ public class ColGroupDDC extends APreAgg implements IMapToDataGroup { @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { + IEncode enc = getEncoding(); EstimationFactors ef = new EstimationFactors(getNumValues(), _data.size(), _data.size(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), getCompType(), enc); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index 39399b5854..c1f99d0202 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -54,7 +54,7 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ -public class ColGroupDDCFOR extends AMorphingMMColGroup { +public class ColGroupDDCFOR extends AMorphingMMColGroup implements IFrameOfReferenceGroup { private static final long serialVersionUID = -5769772089913918987L; /** Pointers to row indexes in the dictionary */ @@ -469,8 +469,9 @@ public class ColGroupDDCFOR extends AMorphingMMColGroup { @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { + IEncode enc = getEncoding(); EstimationFactors ef = new EstimationFactors(getNumValues(), _data.size(), _data.size(), _dict.getSparsity()); - return new CompressedSizeInfoColGroup(_colIndexes, ef, nRow, getCompType()); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), getCompType(), enc); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java index 1e1b847782..c908b267e3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupEmpty.java @@ -33,6 +33,7 @@ import org.apache.sysds.runtime.compress.colgroup.scheme.EmptyScheme; import org.apache.sysds.runtime.compress.colgroup.scheme.ICLAScheme; import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.data.DenseBlock; @@ -350,7 +351,8 @@ public class ColGroupEmpty extends AColGroupCompressed implements IContainADicti @Override public CompressedSizeInfoColGroup getCompressionInfo(int nRow) { - return new CompressedSizeInfoColGroup(_colIndexes, 0, nRow, CompressionType.CONST); + EstimationFactors ef = new EstimationFactors(getNumValues(), 1, 0, 0.0); + return new CompressedSizeInfoColGroup(_colIndexes, ef, estimateInMemorySize(), CompressionType.CONST, getEncoding()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java index bf4900e1fb..3358517b38 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java @@ -32,6 +32,7 @@ import java.util.concurrent.Future; import org.apache.commons.lang3.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.bitmap.ABitmap; @@ -53,6 +54,7 @@ import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.cost.ACostEstimate; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; +import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection; import org.apache.sysds.runtime.compress.utils.ACount.DCounts; import org.apache.sysds.runtime.compress.utils.DblArray; @@ -149,15 +151,17 @@ public class ColGroupFactory { private List<AColGroup> compress() { try { - List<AColGroup> ret = compressExecute(); - if(pool != null) - pool.shutdown(); - return ret; + if(in instanceof CompressedMatrixBlock) + return CLALibCombineGroups.combine((CompressedMatrixBlock) in, csi, pool); + else + return compressExecute(); } catch(Exception e) { + throw new DMLCompressionException("Compression Failed", e); + } + finally { if(pool != null) pool.shutdown(); - throw new DMLCompressionException("Compression Failed", e); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index d05930dd2c..d2e7b549d8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -62,7 +62,7 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; * with no modifications. * */ -public class ColGroupSDCFOR extends ASDC implements IMapToDataGroup { +public class ColGroupSDCFOR extends ASDC implements IMapToDataGroup , IFrameOfReferenceGroup{ private static final long serialVersionUID = 3883228464052204203L; diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index f4c11f2d8c..7f43df5f8f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -24,7 +24,6 @@ import java.io.DataOutput; import java.io.IOException; import java.util.Arrays; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; @@ -116,12 +115,6 @@ public class ColGroupSDCSingle extends ASDC { return _dict.getValue(colIdx); } - @Override - public ADictionary getDictionary() { - throw new NotImplementedException( - "Not implemented getting the dictionary out, and i think we should consider removing the option"); - } - @Override protected double[] preAggSumRows() { return _dict.sumAllRowsToDoubleWithDefault(_defaultTuple); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/IFrameOfReferenceGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/IFrameOfReferenceGroup.java new file mode 100644 index 0000000000..dc55b56364 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/IFrameOfReferenceGroup.java @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup; + +/** + * Interface for frame of reference groups. + */ +public interface IFrameOfReferenceGroup { + /** + * extract common value from group and return non morphing group + * + * @param constV a vector to contain all values, note length = nCols in total matrix. + * @return A non morphing column group with decompression instructions. + */ + public AColGroup extractCommon(double[] constV); +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java index da5c317874..37ed289862 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictionaryFactory.java @@ -22,6 +22,7 @@ package org.apache.sysds.runtime.compress.colgroup.dictionary; import java.io.DataInput; import java.io.IOException; import java.util.ArrayList; +import java.util.Map; import org.apache.commons.lang.NotImplementedException; import org.apache.commons.logging.Log; @@ -238,6 +239,11 @@ public interface DictionaryFactory { } public static ADictionary combineDictionaries(AColGroupCompressed a, AColGroupCompressed b) { + return combineDictionaries(a, b, null); + } + + public static ADictionary combineDictionaries(AColGroupCompressed a, AColGroupCompressed b, + Map<Integer, Integer> filter) { if(a instanceof ColGroupEmpty && b instanceof ColGroupEmpty) return null; // null return is handled elsewhere. @@ -266,23 +272,31 @@ public interface DictionaryFactory { return combineSparseConstSparseRet(ad, a.getNumCols(), bt); } else if(bc.isDense()) - return combineFullDictionaries(ad, a.getNumCols(), bd, b.getNumCols()); + return combineFullDictionaries(ad, a.getNumCols(), bd, b.getNumCols(), filter); else if(bc.isSDC()) { double[] tuple = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSDCRight(ad, a.getNumCols(), bd, tuple); + return combineSDCRight(ad, a.getNumCols(), bd, tuple, filter); } } else if(ac.isSDC()) { if(bc.isSDC()) { final double[] at = ((IContainDefaultTuple) a).getDefaultTuple(); final double[] bt = ((IContainDefaultTuple) b).getDefaultTuple(); - return combineSDC(ad, at, bd, bt); + return combineSDC(ad, at, bd, bt, filter); } } } throw new NotImplementedException("Not supporting combining dense: " + a + " " + b); } + /** + * Combine the dictionaries assuming a sparse combination where each dictionary can be a SDC containing a default + * element that have to be introduced into the combined dictionary. + * + * @param a A Dictionary can be SDC or const + * @param b A Dictionary can be Const or SDC. + * @return The combined dictionary + */ public static ADictionary combineDictionariesSparse(AColGroupCompressed a, AColGroupCompressed b) { CompressionType ac = a.getCompType(); CompressionType bc = b.getCompType(); @@ -298,9 +312,7 @@ public interface DictionaryFactory { if(a.sameIndexStructure(b)) { return ad.cbind(bd, b.getNumCols()); } - // real combine extract default and combine like dense but with default before. - } } else if(ac.isConst()) { @@ -315,7 +327,7 @@ public interface DictionaryFactory { } /** - * Combine the dictionaries as if the dictionaries contain the full spectrum of the data contained. + * Combine the dictionaries as if the dictionaries contain the full spectrum of the combined data. * * @param a Left side dictionary * @param nca Number of columns left dictionary @@ -324,6 +336,22 @@ public interface DictionaryFactory { * @return A combined dictionary */ public static ADictionary combineFullDictionaries(ADictionary a, int nca, ADictionary b, int ncb) { + return combineFullDictionaries(a, nca, b, ncb, null); + } + + /** + * Combine the dictionaries as if the dictionaries only contain the values in the specified filter. + * + * @param a Left side dictionary + * @param nca Number of columns left dictionary + * @param b Right side dictionary + * @param ncb Number of columns right dictionary + * @param filter The mapping filter to not include all possible combinations in the output, this filter is allowed to + * be null, that means the output is defaulting back to a full combine + * @return A combined dictionary + */ + public static ADictionary combineFullDictionaries(ADictionary a, int nca, ADictionary b, int ncb, + Map<Integer, Integer> filter) { final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); @@ -333,24 +361,45 @@ public interface DictionaryFactory { if(ra == 1 && rb == 1) return new MatrixBlockDictionary(ma.append(mb)); - MatrixBlock out = new MatrixBlock(ra * rb, nca + ncb, false); + MatrixBlock out = new MatrixBlock(filter != null ? filter.size() : ra * rb, nca + ncb, false); out.allocateBlock(); - for(int r = 0; r < out.getNumRows(); r++) { - int ia = r % ra; - int ib = r / ra; - for(int c = 0; c < nca; c++) - out.quickSetValue(r, c, ma.quickGetValue(ia, c)); + if(filter != null) { + for(int r : filter.keySet()) { + int o = filter.get(r); + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(o, c, ma.quickGetValue(ia, c)); - for(int c = 0; c < ncb; c++) - out.quickSetValue(r, c + nca, mb.quickGetValue(ib, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(o, c + nca, mb.quickGetValue(ib, c)); + + } + } + else { + + for(int r = 0; r < out.getNumRows(); r++) { + int ia = r % ra; + int ib = r / ra; + for(int c = 0; c < nca; c++) + out.quickSetValue(r, c, ma.quickGetValue(ia, c)); + for(int c = 0; c < ncb; c++) + out.quickSetValue(r, c + nca, mb.quickGetValue(ib, c)); + + } } return new MatrixBlockDictionary(out); } public static ADictionary combineSDCRight(ADictionary a, int nca, ADictionary b, double[] tub) { + return combineSDCRight(a, nca, b, tub, null); + } + + public static ADictionary combineSDCRight(ADictionary a, int nca, ADictionary b, double[] tub, + Map<Integer, Integer> filter) { final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); final int rb = b.getNumberOfValues(ncb); @@ -384,6 +433,13 @@ public interface DictionaryFactory { } public static ADictionary combineSDC(ADictionary a, double[] tua, ADictionary b, double[] tub) { + return combineSDC(a, tua, b, tub, null); + } + + public static ADictionary combineSDC(ADictionary a, double[] tua, ADictionary b, double[] tub, + Map<Integer, Integer> filter) { + if(filter != null) + throw new NotImplementedException(); final int nca = tua.length; final int ncb = tub.length; final int ra = a.getNumberOfValues(nca); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java index f8eb0bfedb..cf22ba0d7b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java @@ -66,4 +66,13 @@ public abstract class AColIndex implements IColIndex { res = 31 * res + it.next(); return res; } + + @Override + public boolean containsAny(IColIndex idx) { + IIterate it = idx.iterator(); + while(it.hasNext()) + if(contains(it.next())) + return true; + return false; + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java index 8b73abfa0d..5163998ef8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java @@ -178,6 +178,14 @@ public interface IColIndex { */ public boolean contains(int i); + /** + * Analyze if this column group contain any of the given column Ids. + * + * @param idx A List of indexes + * @return If it is contained + */ + public boolean containsAny(IColIndex idx); + /** A Class for slice results containing indexes for the slicing of dictionaries, and the resulting column index */ public static class SliceResult { /** Start index to slice inside the dictionary */ diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java index 145682f3f2..cf55e6180d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/AComEst.java @@ -184,7 +184,6 @@ public abstract class AComEst { final int worstCase = worstCaseUpperBound(combinedColumns); // Get max number of tuples based on the above. final long max = Math.min((long) g1V * g2V, worstCase); - if(max > 1000000) // set the max combination to a million distinct return null; // This combination is clearly not a good idea return null to indicate that. else if(g1.getMap() == null || g2.getMap() == null) diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java index 7f8ccfef73..5a1a1332d3 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java @@ -22,11 +22,12 @@ package org.apache.sysds.runtime.compress.estim; import java.util.ArrayList; import java.util.List; -import org.apache.commons.lang.NotImplementedException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.estim.encoding.IEncode; +import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; public class ComEstCompressed extends AComEst { @@ -66,14 +67,26 @@ public class ComEstCompressed extends AComEst { AColGroup g = cData.getColGroupForColumn(id); return g.getNumValues(); } - else - throw new UnsupportedOperationException("Unimplemented method 'worstCaseUpperBound'"); + else { + List<AColGroup> groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups()); + int nVals = 1; + for(AColGroup g : groups) + nVals *= g.getNumValues(); + + return Math.min(_data.getNumRows(), nVals); + } } @Override protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1, CompressedSizeInfoColGroup g2, int maxDistinct) { - throw new UnsupportedOperationException("Unimplemented method 'combine'"); + final IEncode map = g1.getMap().combine(g2.getMap()); + return getFacts(map, combinedColumns); } + protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { + final int _numRows = getNumRows(); + final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); + return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java index e9ac7a9708..63af720223 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstExact.java @@ -56,7 +56,7 @@ public class ComEstExact extends AComEst { return getFacts(map, combinedColumns); } - private CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { + protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { final int _numRows = getNumRows(); final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 8de7360e62..b9a1f2482a 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -79,6 +79,16 @@ public class CompressedSizeInfoColGroup { _sizes.put(bestCompressionType, _minSize); } + public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts, long minSize, CompressionType bestCompression, IEncode map){ + _cols = columns; + _facts = facts; + _minSize = minSize; + _bestCompressionType = bestCompression; + _sizes = new EnumMap<>(CompressionType.class); + _sizes.put(bestCompression, _minSize); + _map = map; + } + public CompressedSizeInfoColGroup(IColIndex columns, EstimationFactors facts, Set<CompressionType> validCompressionTypes, IEncode map) { _cols = columns; @@ -260,12 +270,11 @@ public class CompressedSizeInfoColGroup { public String toString() { StringBuilder sb = new StringBuilder(); sb.append(this.getClass().getSimpleName()); - sb.append("cols: " + _cols); + sb.append(" cols: " + _cols); sb.append(String.format(" common: %4.3f", getMostCommonFraction())); - sb.append(" Sizes: "); - sb.append(_sizes); + sb.append(" Sizes: " + _sizes); sb.append(" facts: " + _facts); - // sb.append("\n" + _map); + sb.append(" mapIsNull: " + (_map == null)); return sb.toString(); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java index b3839c9e95..1a120772b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/ConstEncoding.java @@ -19,6 +19,10 @@ package org.apache.sysds.runtime.compress.estim.encoding; +import java.util.Map; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.estim.EstimationFactors; @@ -37,8 +41,8 @@ public class ConstEncoding implements IEncode { } @Override - public IEncode combineNoResize(IEncode e){ - return e; + public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) { + return new ImmutablePair<>(e, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java index 5504e1c22d..f68dd3d674 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/DenseEncoding.java @@ -22,12 +22,17 @@ package org.apache.sysds.runtime.compress.estim.encoding; import java.util.HashMap; import java.util.Map; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; import org.apache.sysds.runtime.compress.colgroup.offset.AIterator; import org.apache.sysds.runtime.compress.estim.EstimationFactors; +/** + * An Encoding that contains a value on each row of the input. + */ public class DenseEncoding implements IEncode { private final AMapToData map; @@ -38,7 +43,7 @@ public class DenseEncoding implements IEncode { } @Override - public DenseEncoding combine(IEncode e) { + public IEncode combine(IEncode e) { if(e instanceof EmptyEncoding || e instanceof ConstEncoding) return this; else if(e instanceof SparseEncoding) @@ -48,16 +53,16 @@ public class DenseEncoding implements IEncode { } @Override - public IEncode combineNoResize(IEncode e) { + public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) { if(e instanceof EmptyEncoding || e instanceof ConstEncoding) - return this; + return new ImmutablePair<>(this, null); else if(e instanceof SparseEncoding) return combineSparseNoResize((SparseEncoding) e); else return combineDenseNoResize((DenseEncoding) e); } - protected DenseEncoding combineSparse(SparseEncoding e) { + protected IEncode combineSparse(SparseEncoding e) { final int maxUnique = e.getUnique() * getUnique(); final int size = map.size(); final int nVl = getUnique(); @@ -66,7 +71,7 @@ public class DenseEncoding implements IEncode { final AMapToData ret = assignSparse(e); // Iteration 2 reassign indexes. if(maxUnique + nVl > size) - return combineSparseHashMap(ret); + return combineSparseHashMap(ret).getLeft(); else return combineSparseMapToData(ret, maxUnique, nVl); } @@ -92,7 +97,7 @@ public class DenseEncoding implements IEncode { return ret; } - private final DenseEncoding combineSparseHashMap(final AMapToData ret) { + private final Pair<IEncode, Map<Integer, Integer>> combineSparseHashMap(final AMapToData ret) { final int size = ret.size(); final Map<Integer, Integer> m = new HashMap<>(size); for(int r = 0; r < size; r++) { @@ -104,7 +109,7 @@ public class DenseEncoding implements IEncode { else ret.set(r, mv); } - return new DenseEncoding(MapToFactory.resize(ret, m.size())); + return new ImmutablePair<>(new DenseEncoding(MapToFactory.resize(ret, m.size())), m); } private final DenseEncoding combineSparseMapToData(final AMapToData ret, final int maxUnique, final int nVl) { @@ -136,15 +141,20 @@ public class DenseEncoding implements IEncode { final AMapToData ret = MapToFactory.create(size, maxUnique); - if(maxUnique > size) - return combineDenseWithHashMap(lm, rm, size, nVL, ret); - else - return combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique); + if(maxUnique > size) { + // aka there is more maxUnique than rows. + final Map<Integer, Integer> m = new HashMap<>(size); + return combineDenseWithHashMap(lm, rm, size, nVL, ret, m); + } + else { + final AMapToData m = MapToFactory.create(maxUnique, maxUnique + 1); + return combineDenseWithMapToData(lm, rm, size, nVL, ret, maxUnique, m); + } } - private DenseEncoding combineDenseNoResize(final DenseEncoding other) { + private Pair<IEncode, Map<Integer, Integer>> combineDenseNoResize(final DenseEncoding other) { if(map == other.map) - return this; // same object + return new ImmutablePair<>(this, null); // same object final AMapToData lm = map; final AMapToData rm = other.map; @@ -156,20 +166,21 @@ public class DenseEncoding implements IEncode { final AMapToData ret = MapToFactory.create(size, maxUnique); - for(int r = 0; r < size; r++) - ret.set(r, lm.getIndex(r) + rm.getIndex(r) * nVL); + final Map<Integer, Integer> m = new HashMap<>(Math.min(size, maxUnique)); + return new ImmutablePair<>(combineDenseWithHashMap(lm, rm, size, nVL, ret, m), m); + // there can be less unique. - return new DenseEncoding(ret); + // return new DenseEncoding(ret); } - private DenseEncoding combineSparseNoResize(final SparseEncoding other) { - return new DenseEncoding(assignSparse(other)); + private Pair<IEncode, Map<Integer, Integer>> combineSparseNoResize(final SparseEncoding other) { + final AMapToData a = assignSparse(other); + return combineSparseHashMap(a); } protected final DenseEncoding combineDenseWithHashMap(final AMapToData lm, final AMapToData rm, final int size, - final int nVL, final AMapToData ret) { - final Map<Integer, Integer> m = new HashMap<>(size); + final int nVL, final AMapToData ret, Map<Integer, Integer> m) { for(int r = 0; r < size; r++) addValHashMap(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, ret); @@ -177,8 +188,7 @@ public class DenseEncoding implements IEncode { } protected final DenseEncoding combineDenseWithMapToData(final AMapToData lm, final AMapToData rm, final int size, - final int nVL, final AMapToData ret, final int maxUnique) { - final AMapToData m = MapToFactory.create(maxUnique, maxUnique + 1); + final int nVL, final AMapToData ret, final int maxUnique, final AMapToData m) { int newUID = 1; for(int r = 0; r < size; r++) newUID = addValMapToData(lm.getIndex(r) + rm.getIndex(r) * nVL, r, m, newUID, ret); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java index 6aa9f91ca6..9e12654c77 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EmptyEncoding.java @@ -19,10 +19,16 @@ package org.apache.sysds.runtime.compress.estim.encoding; +import java.util.Map; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.estim.EstimationFactors; -/** Empty encoding for cases where the entire group of columns is zero */ +/** + * Empty encoding for cases where the entire group of columns is zero + */ public class EmptyEncoding implements IEncode { // empty constructor @@ -35,8 +41,8 @@ public class EmptyEncoding implements IEncode { } @Override - public IEncode combineNoResize(IEncode e){ - return e; + public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) { + return new ImmutablePair<>(e, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java index dcff899217..e7202a19c4 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java @@ -19,6 +19,9 @@ package org.apache.sysds.runtime.compress.estim.encoding; +import java.util.Map; + +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressionSettings; @@ -42,13 +45,13 @@ public interface IEncode { public IEncode combine(IEncode e); /** - * Combine two encodings without resizing the output. meaning the mapping of the indexes should be consistant with + * Combine two encodings without resizing the output. meaning the mapping of the indexes should be consistent with * left hand side Dictionary indexes and right hand side indexes. * * @param e The other side to combine with * @return The combined encoding */ - public IEncode combineNoResize(IEncode e); + public Pair<IEncode, Map<Integer,Integer>> combineWithMap(IEncode e); /** * Get the number of unique values in this encoding diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java index ff3b3285af..872e512adf 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java @@ -19,6 +19,10 @@ package org.apache.sysds.runtime.compress.estim.encoding; +import java.util.Map; + +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; @@ -29,7 +33,10 @@ import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory; import org.apache.sysds.runtime.compress.estim.EstimationFactors; import org.apache.sysds.runtime.compress.utils.IntArrayList; -/** Most common is zero encoding */ +/** + * A Encoding that contain a default value that is not encoded and every other value is encoded in the map. The logic is + * similar to the SDC column group. + */ public class SparseEncoding implements IEncode { /** A map to the distinct values contained */ @@ -63,14 +70,14 @@ public class SparseEncoding implements IEncode { } @Override - public IEncode combineNoResize(IEncode e) { + public Pair<IEncode, Map<Integer, Integer>> combineWithMap(IEncode e) { if(e instanceof EmptyEncoding || e instanceof ConstEncoding) - return this; + return new ImmutablePair<>(this, null); else if(e instanceof SparseEncoding) { SparseEncoding es = (SparseEncoding) e; if(es.off == off && es.map == map) - return this; - return combineSparseNoResize(es); + return new ImmutablePair<>(this, null); + return new ImmutablePair<>(combineSparseNoResize(es), null); } else throw new DMLCompressionException("Not allowed other to be dense"); diff --git a/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java index f1f45abc98..c3493df65d 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/io/ReaderCompressed.java @@ -35,6 +35,7 @@ import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressedMatrixBlockFactory; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.lib.CLALibStack; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.io.MatrixReader; @@ -78,6 +79,8 @@ public final class ReaderCompressed extends MatrixReader { for(Path subPath : IOUtilFunctions.getSequenceFilePaths(fs, path)){ read(subPath, job, data); } + if(data.containsValue(null)) + throw new DMLCompressionException("Invalid read data contains null:"); if(data.size() == 1) return data.entrySet().iterator().next().getValue(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java index 27ea3f56b9..3d112a57b0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibCombineGroups.java @@ -19,10 +19,17 @@ package org.apache.sysds.runtime.compress.lib; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ExecutorService; + import org.apache.commons.lang.NotImplementedException; +import org.apache.commons.lang3.tuple.Pair; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; @@ -31,10 +38,13 @@ import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; import org.apache.sysds.runtime.compress.colgroup.ColGroupSDC; import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed; import org.apache.sysds.runtime.compress.colgroup.IContainDefaultTuple; +import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo; +import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup; import org.apache.sysds.runtime.compress.estim.encoding.ConstEncoding; import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding; import org.apache.sysds.runtime.compress.estim.encoding.EmptyEncoding; @@ -42,6 +52,7 @@ import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding; import org.apache.sysds.runtime.data.DenseBlock; import org.apache.sysds.runtime.matrix.data.MatrixBlock; +import org.apache.sysds.runtime.util.CommonThreadPool; /** * Library functions to combine column groups inside a compressed matrix. @@ -53,8 +64,66 @@ public final class CLALibCombineGroups { // private constructor } - public static CompressedMatrixBlock combine(CompressedMatrixBlock cmb, int k) { - throw new NotImplementedException(); + public static List<AColGroup> combine(CompressedMatrixBlock cmb, int k) { + ExecutorService pool = null; + try { + pool = (k > 1) ? CommonThreadPool.get(k) : null; + return combine(cmb, null, pool); + } + catch(Exception e) { + throw new DMLCompressionException("Compression Failed", e); + } + finally { + if(pool != null) + pool.shutdown(); + } + } + + public static List<AColGroup> combine(CompressedMatrixBlock cmb, CompressedSizeInfo csi, ExecutorService pool) { + + List<AColGroup> input = cmb.getColGroups(); + boolean filterFor = CLALibUtils.shouldFilterFOR(input); + double[] c = filterFor ? new double[cmb.getNumColumns()] : null; + if(filterFor) + input = CLALibUtils.filterFOR(input, c); + + + List<List<AColGroup>> combinations = new ArrayList<>(); + for(CompressedSizeInfoColGroup gi : csi.getInfo()) { + combinations.add(findGroupsInIndex(gi.getColumns(), input)); + } + + List<AColGroup> ret = new ArrayList<>(); + if(filterFor) + for(List<AColGroup> combine : combinations) + ret.add(combineN(combine).addVector(c)); + else + for(List<AColGroup> combine : combinations) + ret.add(combineN(combine)); + + + + return ret; + } + + public static List<AColGroup> findGroupsInIndex(IColIndex idx, List<AColGroup> groups) { + List<AColGroup> ret = new ArrayList<>(); + for(AColGroup g : groups) { + if(g.getColIndices().containsAny(idx)) + ret.add(g); + } + return ret; + } + + public static AColGroup combineN(List<AColGroup> groups) { + + AColGroup base = groups.get(0); + // Inefficient combine N but base line + for(int i = 1; i < groups.size(); i++) { + base = combine(base, groups.get(i)); + } + + return base; } /** @@ -62,11 +131,16 @@ public final class CLALibCombineGroups { * * The number of rows should be equal, and it is not verified so there will be unexpected behavior in such cases. * + * It is assumed that this method is not called with FOR groups + * * @param a The first group to combine. * @param b The second group to combine. * @return A new column group containing the two. */ public static AColGroup combine(AColGroup a, AColGroup b) { + if(a instanceof IFrameOfReferenceGroup || b instanceof IFrameOfReferenceGroup) + throw new DMLCompressionException("Invalid call with frame of reference group to combine"); + IColIndex combinedColumns = ColIndexFactory.combine(a, b); // try to recompress a and b if uncompressed @@ -96,18 +170,19 @@ public final class CLALibCombineGroups { return combineCompressed(combinedColumns, bc, ac); } - IEncode ce = ae.combineNoResize(be); - + Pair<IEncode, Map<Integer,Integer>> cec = ae.combineWithMap(be); + IEncode ce = cec.getLeft(); + Map<Integer,Integer> filter = cec.getRight(); if(ce instanceof DenseEncoding) { - DenseEncoding ced = (DenseEncoding) ce; - ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc); + DenseEncoding ced = (DenseEncoding) (ce); + ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter); return ColGroupDDC.create(combinedColumns, cd, ced.getMap(), null); } else if(ce instanceof EmptyEncoding) { return new ColGroupEmpty(combinedColumns); } else if(ce instanceof ConstEncoding) { - ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc); + ADictionary cd = DictionaryFactory.combineDictionaries(ac, bc, filter); return ColGroupConst.create(combinedColumns, cd); } else if(ce instanceof SparseEncoding) { @@ -145,12 +220,12 @@ public final class CLALibCombineGroups { public static double[] constructDefaultTuple(AColGroupCompressed ac, AColGroupCompressed bc) { double[] ret = new double[ac.getNumCols() + bc.getNumCols()]; - if(ac instanceof IContainDefaultTuple ){ - double[] defa = ((IContainDefaultTuple)ac).getDefaultTuple(); + if(ac instanceof IContainDefaultTuple) { + double[] defa = ((IContainDefaultTuple) ac).getDefaultTuple(); System.arraycopy(defa, 0, ret, 0, defa.length); } - if(bc instanceof IContainDefaultTuple){ - double[] defb = ((IContainDefaultTuple)bc).getDefaultTuple(); + if(bc instanceof IContainDefaultTuple) { + double[] defb = ((IContainDefaultTuple) bc).getDefaultTuple(); System.arraycopy(defb, 0, ret, ac.getNumCols(), defb.length); } return ret; diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java index adce71a24a..1dd76483c6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibStack.java @@ -192,6 +192,7 @@ public final class CLALibStack { final AColGroup[][] finalCols = new AColGroup[clen][]; // temp array for combining final int blocksInColumn = (rlen - 1) / blen + 1; + // Add all the blocks into linear structure. for(int br = 0; br * blen < rlen; br++) { for(int bc = 0; bc * blen < clen; bc++) { @@ -208,24 +209,28 @@ public final class CLALibStack { return combineViaDecompression(m, rlen, clen, blen, k); } finalCols[c][br] = gc; + if(br != 0 && (finalCols[c][0] == null || !finalCols[c][br].getColIndices().equals(finalCols[c][0].getColIndices()))){ + LOG.warn("Combining via decompression. There was an column with different index"); + return combineViaDecompression(m, rlen, clen, blen, k); + } } } } + + final ExecutorService pool = CommonThreadPool.get(Math.max(Math.min(clen / 500, k), 1)); try { List<AColGroup> finalGroups = pool.submit(() -> { return Arrays// .stream(finalCols)// - .filter(x -> x != null)// + .filter(x -> x != null)// filter all columns that are contained in other groups. .parallel()// .map(x -> { return combineN(x); }).collect(Collectors.toList()); }).get(); - - pool.shutdown(); if(finalGroups.contains(null)) { LOG.warn("Combining via decompression. There was a column group that did not append "); return combineViaDecompression(m, rlen, clen, blen, k); @@ -233,12 +238,20 @@ public final class CLALibStack { return new CompressedMatrixBlock(rlen, clen, -1, false, finalGroups); } catch(InterruptedException | ExecutionException e) { - pool.shutdown(); throw new DMLRuntimeException("Failed to combine column groups", e); } + finally { + pool.shutdown(); + } } private static AColGroup combineN(AColGroup[] groups) { - return AColGroup.appendN(groups); + try { + return AColGroup.appendN(groups); + + } + catch(Exception e) { + throw new DMLCompressionException("Failed to combine groups:\n" + Arrays.toString(groups), e); + } } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java index 06da63d283..20262e7437 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java +++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java @@ -34,6 +34,7 @@ import org.apache.sysds.runtime.compress.colgroup.AMorphingMMColGroup; import org.apache.sysds.runtime.compress.colgroup.APreAgg; import org.apache.sysds.runtime.compress.colgroup.ColGroupConst; import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty; +import org.apache.sysds.runtime.compress.colgroup.IFrameOfReferenceGroup; import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.colgroup.indexes.IIterate; @@ -93,6 +94,29 @@ public final class CLALibUtils { return false; } + /** + * Detect if the list of groups contains FOR. + * + * @param groups the groups + * @return If it contains FOR. + */ + protected static boolean shouldFilterFOR(List<AColGroup> groups) { + for(AColGroup g : groups) + if(g instanceof IFrameOfReferenceGroup) + return true; + return false; + } + + protected static List<AColGroup> filterFOR(List<AColGroup> groups, double[] constV) { + if(constV == null) + return groups; + final List<AColGroup> filteredGroups = new ArrayList<>(); + for(AColGroup g : groups) + if(g instanceof IFrameOfReferenceGroup) + filteredGroups.add(((IFrameOfReferenceGroup) g).extractCommon(constV)); + return filteredGroups; + } + /** * Helper method to filter out SDC Groups and remove all constant groups, to reduce computation. * @@ -166,7 +190,7 @@ public final class CLALibUtils { final double[] colVals = cg.getValues(); for(int i = 0; i < colIdx.size(); i++) { // Find the index in the result columns to add the value into. - int outId = resCols.findIndex(colIdx.get(i)); + int outId = resCols.findIndex(colIdx.get(i)); values[outId] = colVals[i]; } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java b/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java index 4373aa43df..650ad8091f 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/CompressedLoggingTests.java @@ -275,7 +275,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("--colGroups type")) return; } - fail("Log did not contain Compressed Size"); + fail("Log did not contain colgroups type"); } catch(Exception e) { e.printStackTrace(); @@ -303,7 +303,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("--colGroups type")) return; } - fail("Log did not contain Compressed Size"); + fail("Log did not contain colgroups type"); } catch(Exception e) { e.printStackTrace(); @@ -331,7 +331,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("Empty input to compress")) return; } - fail("Log Did not contain Empty"); + fail("Log did not contain Empty"); } catch(Exception e) { e.printStackTrace(); @@ -349,17 +349,15 @@ public class CompressedLoggingTests { try { Logger.getLogger(CompressedMatrixBlockFactory.class).setLevel(Level.DEBUG); - MatrixBlock mb = TestUtils.generateTestMatrixBlock(10, 1000, 1, 1, 0.0, 235); - mb = TestUtils.round(mb); + MatrixBlock mb = TestUtils.generateTestMatrixBlock(100, 3, 1, 1, 0.5, 235); MatrixBlock m2 = CompressedMatrixBlockFactory.compress(mb).getLeft(); CompressedMatrixBlockFactory.compress(m2).getLeft(); final List<LoggingEvent> log = LoggingUtils.reinsert(appender); for(LoggingEvent l : log) { - // LOG.error(l.getMessage()); if(l.getMessage().toString().contains("Recompressing")) return; } - fail("Log Did not contain Recompressing"); + fail("Log did not contain Recompressing"); } catch(Exception e) { e.printStackTrace(); @@ -390,7 +388,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("Abort block compression")) return; } - fail("Log Did not contain Recompressing"); + fail("Log did not contain abort block compression"); } catch(Exception e) { e.printStackTrace(); @@ -415,7 +413,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("CompressionSettings")) return; } - fail("failed to get Compressionsetting to log"); + fail("failed to get Compression setting to log"); } catch(Exception e) { e.printStackTrace(); @@ -439,7 +437,7 @@ public class CompressedLoggingTests { if(l.getMessage().toString().contains("Estimation Type")) return; } - fail("failed to get Compressionsetting to log"); + fail("failed to get estimation type"); } catch(Exception e) { e.printStackTrace();
