This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit e12d9d2bdf751fd9d535c8cab774a03c0ba0dbcd Author: Sebastian Baunsgaard <[email protected]> AuthorDate: Sun Dec 29 22:06:11 2024 +0100 [MINOR] Update cost estimation in CLA Closes #2168 --- .../compress/cost/ComputationCostEstimator.java | 44 +++++++++++++- .../runtime/compress/estim/ComEstCompressed.java | 10 ++-- ...Compressed.java => ComEstCompressedSample.java} | 53 +++++++++-------- .../runtime/compress/estim/ComEstFactory.java | 20 +++++-- .../sysds/runtime/compress/estim/ComEstSample.java | 67 ++++++++++++++++------ .../compress/estim/CompressedSizeInfoColGroup.java | 2 + .../runtime/compress/estim/EstimationFactors.java | 12 ++-- .../runtime/compress/utils/ACountHashMap.java | 2 +- .../compress/utils/DoubleIntListHashMap.java | 2 +- 9 files changed, 145 insertions(+), 67 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java index 8ddc5169c0..db42f8bd6f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java +++ b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java @@ -64,7 +64,9 @@ public class ComputationCostEstimator extends ACostEstimate { else if(g.isEmpty() || g.isConst()) // const or densifying return getCost(nRows, 1, nCols, 1, 1); - if(commonFraction > cvThreshold) + else if(g.isIncompressable()) + return getCost(nRows* 3, nRows, nCols, nRows* 3, sparsity); // make incompressable very expensive. + else if(commonFraction > cvThreshold) return getCost(nRows, nRows - g.getLargestOffInstances(), nCols, nVals, sparsity); else return getCost(nRows, nRows, nCols, nVals, sparsity); @@ -142,8 +144,44 @@ public class ComputationCostEstimator extends ACostEstimate { } private double leftMultCost(double nRowsScanned, double nRows, double nCols, double nVals, double sparsity) { - // Plus nVals * 2 because of allocation of nVals array and scan of that - final double preScalingCost = Math.max(nRowsScanned, nRows / 10) + nVals * 2; + // left multiplication want more co-coding. + // therefore, increase the cost if we have few columns + double preScalingCost = Math.max(nRowsScanned, nRows) * 2; + if ((nCols == nVals || nCols == nVals +1) && nVals > 1000){ + preScalingCost = 0; + } + // if(nCols == 1) { + // nCols *= 4; + // preScalingCost *= 5.0; + // } + // else if(nCols == 2) { + // nCols *= 3; + // preScalingCost *= 3.3; + // } + // else if(nCols == 3) { + // nCols *= 2; + // preScalingCost *= 1.6; + // } + // else if(nCols == 4) { + // nCols *= 1.5; + // preScalingCost *= 1.4; + // } + // else if(nCols > 1000) + // nCols *= 1.1; // more cost if lots and lots of columns + // else if(nCols > 5) + // nCols *= 0.7; // scale down cost of columns. + + // // if the number of unique values is low increase the cost. + // if(nVals < 10) + // nVals *= 10; + // else if(nVals < 256) + // nVals *= 5; + // else if(nVals < 1024) + // nVals *= 2; + // else if(nVals > 100000)// increase the cost if the number of distinct values is high. + // nVals *= 4; + // else if(nVals > 60000)// increase the cost if the number of distinct values is high. + // nVals *= 2; final double postScalingCost = sparsity * nVals * nCols; return leftMultCost(preScalingCost, postScalingCost); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java index 5df202fcbb..1894d8489b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java @@ -50,14 +50,12 @@ public class ComEstCompressed extends AComEst { @Override public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - - // final IEncode map = - throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'"); + return null; } @Override public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - throw new UnsupportedOperationException("Unimplemented method 'getDeltaColGroupInfo'"); + return null; } @Override @@ -69,11 +67,11 @@ public class ComEstCompressed extends AComEst { } else { List<AColGroup> groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups()); - int nVals = 1; + long nVals = 1; for(AColGroup g : groups) nVals *= g.getNumValues(); - return Math.min(_data.getNumRows(), nVals); + return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE)); } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java similarity index 62% copy from src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java copy to src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java index 5df202fcbb..9f06848746 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java @@ -26,42 +26,53 @@ import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; import org.apache.sysds.runtime.compress.colgroup.AColGroup; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; -import org.apache.sysds.runtime.compress.estim.encoding.IEncode; import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups; -public class ComEstCompressed extends AComEst { +public class ComEstCompressedSample extends ComEstSample { - final CompressedMatrixBlock cData; + private static boolean loggedWarning = false; - protected ComEstCompressed(CompressedMatrixBlock data, CompressionSettings compSettings) { - super(data, compSettings); - cData = data; + + public ComEstCompressedSample(CompressedMatrixBlock sample, CompressionSettings cs, CompressedMatrixBlock full, + int k) { + super(sample, cs, full, k); + // cData = sample; } @Override protected List<CompressedSizeInfoColGroup> CompressedSizeInfoColGroup(int clen, int k) { List<CompressedSizeInfoColGroup> ret = new ArrayList<>(); - final int nRow = cData.getNumRows(); - for(AColGroup g : cData.getColGroups()) { - ret.add(g.getCompressionInfo(nRow)); + final int nRow = _data.getNumRows(); + final List<AColGroup> fg = ((CompressedMatrixBlock) _data).getColGroups(); + final List<AColGroup> sg = ((CompressedMatrixBlock) _sample).getColGroups(); + + for(int i = 0; i < fg.size(); i++) { + CompressedSizeInfoColGroup r = fg.get(i).getCompressionInfo(nRow); + r.setMap(sg.get(i).getCompressionInfo(_sampleSize).getMap()); + ret.add(r); } + return ret; } @Override public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - - // final IEncode map = - throw new UnsupportedOperationException("Unimplemented method 'getColGroupInfo'"); + if(!loggedWarning) + LOG.warn("Compressed input cannot fallback to resampling " + colIndexes); + loggedWarning = true; + return null; } @Override public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex colIndexes, int estimate, int nrUniqueUpperBound) { - throw new UnsupportedOperationException("Unimplemented method 'getDeltaColGroupInfo'"); + if(!loggedWarning) + LOG.warn("Compressed input cannot fallback to resampling " + colIndexes); + return null; } @Override protected int worstCaseUpperBound(IColIndex columns) { + CompressedMatrixBlock cData = ((CompressedMatrixBlock) _data); if(columns.size() == 1) { int id = columns.get(0); AColGroup g = cData.getColGroupForColumn(id); @@ -69,24 +80,12 @@ public class ComEstCompressed extends AComEst { } else { List<AColGroup> groups = CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups()); - int nVals = 1; + long nVals = 1; for(AColGroup g : groups) nVals *= g.getNumValues(); - return Math.min(_data.getNumRows(), nVals); + return Math.min(_data.getNumRows(), (int) Math.min(nVals, Integer.MAX_VALUE)); } } - @Override - protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1, - CompressedSizeInfoColGroup g2, int maxDistinct) { - final IEncode map = g1.getMap().combine(g2.getMap()); - return getFacts(map, combinedColumns); - } - - protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex colIndexes) { - final int _numRows = getNumRows(); - final EstimationFactors em = map.extractFacts(_numRows, _data.getSparsity(), _data.getSparsity(), _cs); - return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); - } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java index c20a481fc3..b49543455b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java @@ -23,6 +23,7 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.apache.sysds.runtime.compress.CompressedMatrixBlock; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.lib.CLALibSlice; import org.apache.sysds.runtime.matrix.data.MatrixBlock; public interface ComEstFactory { @@ -37,13 +38,13 @@ public interface ComEstFactory { * @return A new CompressionSizeEstimator used to extract information of column groups */ public static AComEst createEstimator(MatrixBlock data, CompressionSettings cs, int k) { - if(data instanceof CompressedMatrixBlock) - return createCompressedEstimator((CompressedMatrixBlock) data, cs); - final int nRows = cs.transposed ? data.getNumColumns() : data.getNumRows(); final int nCols = cs.transposed ? data.getNumRows() : data.getNumColumns(); final double sparsity = data.getSparsity(); final int sampleSize = getSampleSize(cs, nRows, nCols, sparsity); + + if(data instanceof CompressedMatrixBlock) + return createCompressedEstimator((CompressedMatrixBlock) data, cs, sampleSize, k); if(data.isEmpty()) return createExactEstimator(data, cs); @@ -76,8 +77,17 @@ public interface ComEstFactory { return new ComEstExact(data, cs); } - private static ComEstCompressed createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs) { - LOG.debug("Using Compressed Estimator"); + private static AComEst createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs, int sampleSize, + int k) { + if(sampleSize < data.getNumRows()) { + LOG.debug("Trying to sample"); + final MatrixBlock slice = CLALibSlice.sliceRowsCompressed(data, 0, sampleSize); + if(slice instanceof CompressedMatrixBlock) { + LOG.debug("Using Sampled Compressed Estimator " + sampleSize); + return new ComEstCompressedSample((CompressedMatrixBlock) slice, cs, data, k); + } + } + LOG.debug("Using Full Compressed Estimator"); return new ComEstCompressed(data, cs); } diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java index 1509575735..97b451daee 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java @@ -23,6 +23,7 @@ import java.util.Arrays; import java.util.Random; import org.apache.sysds.runtime.compress.CompressionSettings; +import org.apache.sysds.runtime.compress.DMLCompressionException; import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType; import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex; import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory; @@ -42,13 +43,22 @@ import org.apache.sysds.utils.stats.Timing; public class ComEstSample extends AComEst { /** Sample extracted from the input data */ - private final MatrixBlock _sample; + protected final MatrixBlock _sample; /** Parallelization degree */ - private final int _k; + protected final int _k; /** Sample size */ - private final int _sampleSize; + protected final int _sampleSize; /** Boolean specifying if the sample is in transposed format. */ - private boolean _transposed; + protected boolean _transposed; + + public ComEstSample(MatrixBlock sample, CompressionSettings cs, MatrixBlock full, int k) { + super(full, cs); + _k = k; + _transposed = cs.transposed; + _sample = sample; + _sampleSize = sample.getNumRows(); + + } /** * CompressedSizeEstimatorSample, samples from the input data and estimates the size of the compressed matrix. @@ -95,22 +105,44 @@ public class ComEstSample extends AComEst { @Override protected int worstCaseUpperBound(IColIndex columns) { if(getNumColumns() == columns.size()) - return Math.min(getNumRows(), (int) _data.getNonZeros()); + return Math.min(getNumRows(), (int) Math.min(_data.getNonZeros(), Integer.MAX_VALUE)); return getNumRows(); } @Override protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, CompressedSizeInfoColGroup g1, CompressedSizeInfoColGroup g2, int maxDistinct) { - final IEncode map = g1.getMap().combine(g2.getMap()); - return extractInfo(map, combinedColumns, maxDistinct); + try { + final IEncode map = g1.getMap().combine(g2.getMap()); + return extractInfo(map, combinedColumns, maxDistinct); + } + catch(Exception e) { + + String s1 = g1.toString(); + if(s1.length() > 1000) + s1 = s1.substring(0, 1000); + + String s2 = g2.toString(); + if(s2.length() > 1000) + s2 = s2.substring(0, 1000); + + throw new DMLCompressionException("Failed to combine :\n" + s1 + "\n\n" + s2, e); + } } private CompressedSizeInfoColGroup extractInfo(IEncode map, IColIndex colIndexes, int maxDistinct) { - final double spar = _data.getSparsity(); - final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs); - final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense()); - return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); + try { + final double spar = _data.getSparsity(); + final EstimationFactors sampleFacts = map.extractFacts(_sampleSize, spar, spar, _cs); + final EstimationFactors em = scaleFactors(sampleFacts, colIndexes, maxDistinct, map.isDense()); + return new CompressedSizeInfoColGroup(colIndexes, em, _cs.validCompressions, map); + } + catch(Exception e) { + String ms = map.toString(); + if(ms.length() > 1000) + ms = ms.substring(0, 1000); + throw new DMLCompressionException("Failed to extract info: \n" + ms, e); + } } private EstimationFactors scaleFactors(EstimationFactors sampleFacts, IColIndex colIndexes, int maxDistinct, @@ -125,6 +157,9 @@ public class ComEstSample extends AComEst { final long nnz = calculateNNZ(colIndexes, scalingFactor); final int numOffs = calculateOffs(sampleFacts, numRows, scalingFactor, colIndexes, (int) nnz); final int estDistinct = distinctCountScale(sampleFacts, numOffs, numRows, maxDistinct, dense, nCol); + // if(estDistinct < sampleFacts.numVals) + // throw new DMLCompressionException("Failed estimating distinct: " + estDistinct + " should have been above " + // + sampleFacts.numVals + "\n" + Arrays.toString(sampleFacts.frequencies)); // calculate the largest instance count. final int maxLargestInstanceCount = numRows - estDistinct + 1; @@ -133,11 +168,9 @@ public class ComEstSample extends AComEst { final int mostFrequentOffsetCount = Math.max(Math.min(maxLargestInstanceCount, scaledLargestInstanceCount), numRows - numOffs); - final double overallSparsity = calculateSparsity(colIndexes, nnz, scalingFactor, - sampleFacts.overAllSparsity); + final double overallSparsity = calculateSparsity(colIndexes, nnz, scalingFactor, sampleFacts.overAllSparsity); // For robustness safety add 10 percent more tuple sparsity final double tupleSparsity = Math.min(overallSparsity * 1.3, 1.0); // increase sparsity by 30%. - if(_cs.isRLEAllowed()) { final int scaledRuns = Math.max(estDistinct, calculateRuns(sampleFacts, scalingFactor, numOffs, estDistinct)); @@ -161,14 +194,14 @@ public class ComEstSample extends AComEst { final int[] freq = sampleFacts.frequencies; if(freq == null || freq.length == 0) return numOffs; // very aggressive number of distinct + maxDistinct = Math.max(maxDistinct, sampleFacts.numVals); // sampled size is smaller than actual if there was empty rows. // and the more we can reduce this value the more accurate the estimation will become. final int sampledSize = sampleFacts.numOffs; - int est = SampleEstimatorFactory.distinctCount(freq, dense ? numRows : numOffs, sampledSize, - _cs.estimationType); + int est = SampleEstimatorFactory.distinctCount(freq, dense ? numRows : numOffs, sampledSize, _cs.estimationType); if(est > 10000) est += est * 0.5; - if(nCol > 4) // Increase estimate if we get into many columns cocoding to be safe + if(nCol > 4 && est > 100) // Increase estimate if we get into many columns cocoding to be safe est += ((double) est) * ((double) nCol) / 10; // Bound the estimate with the maxDistinct. return Math.max(Math.min(est, Math.min(maxDistinct, numOffs)), 1); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java index 4fbf9b0ee4..6cc882cc2f 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java @@ -220,6 +220,8 @@ public class CompressedSizeInfoColGroup { private static EnumMap<CompressionType, Double> calculateCompressionSizes(IColIndex cols, EstimationFactors fact, Set<CompressionType> validCompressionTypes) { + if(validCompressionTypes.size() > 10 ) + throw new DMLCompressionException("Invalid big number of compression types"); EnumMap<CompressionType, Double> res = new EnumMap<>(CompressionType.class); for(CompressionType ct : validCompressionTypes) { double compSize = getCompressionSize(cols, ct, fact); diff --git a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java index 130d0f77f8..904a228ae7 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java +++ b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java @@ -87,16 +87,14 @@ public class EstimationFactors { this.tupleSparsity = tupleSparsity; if(overAllSparsity > 1 || overAllSparsity < 0) - throw new DMLCompressionException("Invalid OverAllSparsity of: " + overAllSparsity); + overAllSparsity = Math.max(0, Math.min(1, overAllSparsity)); else if(tupleSparsity > 1 || tupleSparsity < 0) - throw new DMLCompressionException("Invalid TupleSparsity of:" + tupleSparsity); + tupleSparsity = Math.max(0, Math.min(1, tupleSparsity)); else if(largestOff > numRows) - throw new DMLCompressionException( - "Invalid number of instance of most common element should be lower than number of rows. " + largestOff - + " > numRows: " + numRows); + largestOff = numRows; else if(numVals > numOffs) - throw new DMLCompressionException( - "Num vals cannot be greater than num offs: vals: " + numVals + " offs: " + numOffs); + numVals = numOffs; + if(CompressedMatrixBlock.debug && frequencies != null) { for(int i = 0; i < frequencies.length; i++) { diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java index b1d310939d..d30ffa4753 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java @@ -54,7 +54,7 @@ public abstract class ACountHashMap<T> implements Cloneable { } /** - * Increment and return the id of the incremeted index. + * Increment and return the id of the incremented index. * * @param key The key to increment * @return The id of the incremented entry. diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java index 1c9ef3082c..00fa67f6b6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java @@ -113,7 +113,7 @@ public class DoubleIntListHashMap { } else { for(DIListEntry e = _data[ix]; e != null; e = e.next) { - if(e.key == key) { + if(Util.eq(e.key , key)) { IntArrayList lstPtr = e.value; lstPtr.appendValue(value); break;
