This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git

commit e12d9d2bdf751fd9d535c8cab774a03c0ba0dbcd
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Dec 29 22:06:11 2024 +0100

    [MINOR] Update cost estimation in CLA
    
    Closes #2168
---
 .../compress/cost/ComputationCostEstimator.java    | 44 +++++++++++++-
 .../runtime/compress/estim/ComEstCompressed.java   | 10 ++--
 ...Compressed.java => ComEstCompressedSample.java} | 53 +++++++++--------
 .../runtime/compress/estim/ComEstFactory.java      | 20 +++++--
 .../sysds/runtime/compress/estim/ComEstSample.java | 67 ++++++++++++++++------
 .../compress/estim/CompressedSizeInfoColGroup.java |  2 +
 .../runtime/compress/estim/EstimationFactors.java  | 12 ++--
 .../runtime/compress/utils/ACountHashMap.java      |  2 +-
 .../compress/utils/DoubleIntListHashMap.java       |  2 +-
 9 files changed, 145 insertions(+), 67 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
index 8ddc5169c0..db42f8bd6f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/cost/ComputationCostEstimator.java
@@ -64,7 +64,9 @@ public class ComputationCostEstimator extends ACostEstimate {
                else if(g.isEmpty() || g.isConst())
                        // const or densifying
                        return getCost(nRows, 1, nCols, 1, 1);
-               if(commonFraction > cvThreshold)
+               else if(g.isIncompressable())
+                       return getCost(nRows* 3, nRows, nCols, nRows* 3, 
sparsity); // make incompressable very expensive.
+               else if(commonFraction > cvThreshold)
                        return getCost(nRows, nRows - 
g.getLargestOffInstances(), nCols, nVals, sparsity);
                else
                        return getCost(nRows, nRows, nCols, nVals, sparsity);
@@ -142,8 +144,44 @@ public class ComputationCostEstimator extends 
ACostEstimate {
        }
 
        private double leftMultCost(double nRowsScanned, double nRows, double 
nCols, double nVals, double sparsity) {
-               // Plus nVals * 2 because of allocation of nVals array and scan 
of that
-               final double preScalingCost = Math.max(nRowsScanned, nRows / 
10) + nVals * 2;
+               // left multiplication want more co-coding.
+               // therefore, increase the cost if we have few columns
+               double preScalingCost = Math.max(nRowsScanned, nRows) * 2;
+               if ((nCols == nVals || nCols == nVals +1) && nVals > 1000){
+                               preScalingCost = 0;
+               }
+               // if(nCols == 1) {
+               //      nCols *= 4;
+               //      preScalingCost *= 5.0;
+               // }
+               // else if(nCols == 2) {
+               //      nCols *= 3;
+               //      preScalingCost *= 3.3;
+               // }
+               // else if(nCols == 3) {
+               //      nCols *= 2;
+               //      preScalingCost *= 1.6;
+               // }
+               // else if(nCols == 4) {
+               //      nCols *= 1.5;
+               //      preScalingCost *= 1.4;
+               // }
+               // else if(nCols > 1000)
+               //      nCols *= 1.1; // more cost if lots and lots of columns
+               // else if(nCols > 5)
+               //      nCols *= 0.7; // scale down cost of columns.
+
+               // // if the number of unique values is low increase the cost.
+               // if(nVals < 10)
+               //      nVals *= 10;
+               // else if(nVals < 256)
+               //      nVals *= 5;
+               // else if(nVals < 1024)
+               //      nVals *= 2;
+               // else if(nVals > 100000)// increase the cost if the number of 
distinct values is high.
+               //      nVals *= 4;
+               // else if(nVals > 60000)// increase the cost if the number of 
distinct values is high.
+               //      nVals *= 2;
                final double postScalingCost = sparsity * nVals * nCols;
                return leftMultCost(preScalingCost, postScalingCost);
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
index 5df202fcbb..1894d8489b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
@@ -50,14 +50,12 @@ public class ComEstCompressed extends AComEst {
 
        @Override
        public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, 
int estimate, int nrUniqueUpperBound) {
-
-               // final IEncode map =
-               throw new UnsupportedOperationException("Unimplemented method 
'getColGroupInfo'");
+               return null;
        }
 
        @Override
        public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex 
colIndexes, int estimate, int nrUniqueUpperBound) {
-               throw new UnsupportedOperationException("Unimplemented method 
'getDeltaColGroupInfo'");
+               return null;
        }
 
        @Override
@@ -69,11 +67,11 @@ public class ComEstCompressed extends AComEst {
                }
                else {
                        List<AColGroup> groups = 
CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups());
-                       int nVals = 1;
+                       long nVals = 1;
                        for(AColGroup g : groups)
                                nVals *= g.getNumValues();
 
-                       return Math.min(_data.getNumRows(), nVals);
+                       return Math.min(_data.getNumRows(), (int) 
Math.min(nVals, Integer.MAX_VALUE));
                }
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java
similarity index 62%
copy from 
src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
copy to 
src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java
index 5df202fcbb..9f06848746 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstCompressedSample.java
@@ -26,42 +26,53 @@ import 
org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.CompressionSettings;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
-import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
 import org.apache.sysds.runtime.compress.lib.CLALibCombineGroups;
 
-public class ComEstCompressed extends AComEst {
+public class ComEstCompressedSample extends ComEstSample {
 
-       final CompressedMatrixBlock cData;
+       private static boolean loggedWarning = false;
 
-       protected ComEstCompressed(CompressedMatrixBlock data, 
CompressionSettings compSettings) {
-               super(data, compSettings);
-               cData = data;
+
+       public ComEstCompressedSample(CompressedMatrixBlock sample, 
CompressionSettings cs, CompressedMatrixBlock full,
+               int k) {
+               super(sample, cs, full, k);
+               // cData = sample;
        }
 
        @Override
        protected List<CompressedSizeInfoColGroup> 
CompressedSizeInfoColGroup(int clen, int k) {
                List<CompressedSizeInfoColGroup> ret = new ArrayList<>();
-               final int nRow = cData.getNumRows();
-               for(AColGroup g : cData.getColGroups()) {
-                       ret.add(g.getCompressionInfo(nRow));
+               final int nRow = _data.getNumRows();
+               final List<AColGroup> fg = ((CompressedMatrixBlock) 
_data).getColGroups();
+               final List<AColGroup> sg = ((CompressedMatrixBlock) 
_sample).getColGroups();
+
+               for(int i = 0; i < fg.size(); i++) {
+                       CompressedSizeInfoColGroup r = 
fg.get(i).getCompressionInfo(nRow);
+                       
r.setMap(sg.get(i).getCompressionInfo(_sampleSize).getMap());
+                       ret.add(r);
                }
+
                return ret;
        }
 
        @Override
        public CompressedSizeInfoColGroup getColGroupInfo(IColIndex colIndexes, 
int estimate, int nrUniqueUpperBound) {
-
-               // final IEncode map =
-               throw new UnsupportedOperationException("Unimplemented method 
'getColGroupInfo'");
+               if(!loggedWarning)
+                       LOG.warn("Compressed input cannot fallback to 
resampling " + colIndexes);
+               loggedWarning = true;
+               return null;
        }
 
        @Override
        public CompressedSizeInfoColGroup getDeltaColGroupInfo(IColIndex 
colIndexes, int estimate, int nrUniqueUpperBound) {
-               throw new UnsupportedOperationException("Unimplemented method 
'getDeltaColGroupInfo'");
+               if(!loggedWarning)
+                       LOG.warn("Compressed input cannot fallback to 
resampling " + colIndexes);
+               return null;
        }
 
        @Override
        protected int worstCaseUpperBound(IColIndex columns) {
+               CompressedMatrixBlock cData = ((CompressedMatrixBlock) _data);
                if(columns.size() == 1) {
                        int id = columns.get(0);
                        AColGroup g = cData.getColGroupForColumn(id);
@@ -69,24 +80,12 @@ public class ComEstCompressed extends AComEst {
                }
                else {
                        List<AColGroup> groups = 
CLALibCombineGroups.findGroupsInIndex(columns, cData.getColGroups());
-                       int nVals = 1;
+                       long nVals = 1;
                        for(AColGroup g : groups)
                                nVals *= g.getNumValues();
 
-                       return Math.min(_data.getNumRows(), nVals);
+                       return Math.min(_data.getNumRows(), (int) 
Math.min(nVals, Integer.MAX_VALUE));
                }
        }
 
-       @Override
-       protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, 
CompressedSizeInfoColGroup g1,
-               CompressedSizeInfoColGroup g2, int maxDistinct) {
-               final IEncode map = g1.getMap().combine(g2.getMap());
-               return getFacts(map, combinedColumns);
-       }
-
-       protected CompressedSizeInfoColGroup getFacts(IEncode map, IColIndex 
colIndexes) {
-               final int _numRows = getNumRows();
-               final EstimationFactors em = map.extractFacts(_numRows, 
_data.getSparsity(), _data.getSparsity(), _cs);
-               return new CompressedSizeInfoColGroup(colIndexes, em, 
_cs.validCompressions, map);
-       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
index c20a481fc3..b49543455b 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstFactory.java
@@ -23,6 +23,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.lib.CLALibSlice;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
 public interface ComEstFactory {
@@ -37,13 +38,13 @@ public interface ComEstFactory {
         * @return A new CompressionSizeEstimator used to extract information 
of column groups
         */
        public static AComEst createEstimator(MatrixBlock data, 
CompressionSettings cs, int k) {
-               if(data instanceof CompressedMatrixBlock)
-                       return 
createCompressedEstimator((CompressedMatrixBlock) data, cs);
-
                final int nRows = cs.transposed ? data.getNumColumns() : 
data.getNumRows();
                final int nCols = cs.transposed ? data.getNumRows() : 
data.getNumColumns();
                final double sparsity = data.getSparsity();
                final int sampleSize = getSampleSize(cs, nRows, nCols, 
sparsity);
+               
+               if(data instanceof CompressedMatrixBlock)
+                       return 
createCompressedEstimator((CompressedMatrixBlock) data, cs, sampleSize, k);
 
                if(data.isEmpty())
                        return createExactEstimator(data, cs);
@@ -76,8 +77,17 @@ public interface ComEstFactory {
                return new ComEstExact(data, cs);
        }
 
-       private static ComEstCompressed 
createCompressedEstimator(CompressedMatrixBlock data, CompressionSettings cs) {
-               LOG.debug("Using Compressed Estimator");
+       private static AComEst createCompressedEstimator(CompressedMatrixBlock 
data, CompressionSettings cs, int sampleSize,
+               int k) {
+               if(sampleSize < data.getNumRows()) {
+                       LOG.debug("Trying to sample");
+                       final MatrixBlock slice = 
CLALibSlice.sliceRowsCompressed(data, 0, sampleSize);
+                       if(slice instanceof CompressedMatrixBlock) {
+                               LOG.debug("Using Sampled Compressed Estimator " 
+ sampleSize);
+                               return new 
ComEstCompressedSample((CompressedMatrixBlock) slice, cs, data, k);
+                       }
+               }
+               LOG.debug("Using Full Compressed Estimator");
                return new ComEstCompressed(data, cs);
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
index 1509575735..97b451daee 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/estim/ComEstSample.java
@@ -23,6 +23,7 @@ import java.util.Arrays;
 import java.util.Random;
 
 import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup.CompressionType;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
 import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
@@ -42,13 +43,22 @@ import org.apache.sysds.utils.stats.Timing;
 public class ComEstSample extends AComEst {
 
        /** Sample extracted from the input data */
-       private final MatrixBlock _sample;
+       protected final MatrixBlock _sample;
        /** Parallelization degree */
-       private final int _k;
+       protected final int _k;
        /** Sample size */
-       private final int _sampleSize;
+       protected final int _sampleSize;
        /** Boolean specifying if the sample is in transposed format. */
-       private boolean _transposed;
+       protected boolean _transposed;
+
+       public ComEstSample(MatrixBlock sample, CompressionSettings cs, 
MatrixBlock full, int k) {
+               super(full, cs);
+               _k = k;
+               _transposed = cs.transposed;
+               _sample = sample;
+               _sampleSize = sample.getNumRows();
+
+       }
 
        /**
         * CompressedSizeEstimatorSample, samples from the input data and 
estimates the size of the compressed matrix.
@@ -95,22 +105,44 @@ public class ComEstSample extends AComEst {
        @Override
        protected int worstCaseUpperBound(IColIndex columns) {
                if(getNumColumns() == columns.size())
-                       return Math.min(getNumRows(), (int) 
_data.getNonZeros());
+                       return Math.min(getNumRows(), (int) 
Math.min(_data.getNonZeros(), Integer.MAX_VALUE));
                return getNumRows();
        }
 
        @Override
        protected CompressedSizeInfoColGroup combine(IColIndex combinedColumns, 
CompressedSizeInfoColGroup g1,
                CompressedSizeInfoColGroup g2, int maxDistinct) {
-               final IEncode map = g1.getMap().combine(g2.getMap());
-               return extractInfo(map, combinedColumns, maxDistinct);
+               try {
+                       final IEncode map = g1.getMap().combine(g2.getMap());
+                       return extractInfo(map, combinedColumns, maxDistinct);
+               }
+               catch(Exception e) {
+
+                       String s1 = g1.toString();
+                       if(s1.length() > 1000)
+                               s1 = s1.substring(0, 1000);
+
+                       String s2 = g2.toString();
+                       if(s2.length() > 1000)
+                               s2 = s2.substring(0, 1000);
+
+                       throw new DMLCompressionException("Failed to combine 
:\n" + s1 + "\n\n" + s2, e);
+               }
        }
 
        private CompressedSizeInfoColGroup extractInfo(IEncode map, IColIndex 
colIndexes, int maxDistinct) {
-               final double spar = _data.getSparsity();
-               final EstimationFactors sampleFacts = 
map.extractFacts(_sampleSize, spar, spar, _cs);
-               final EstimationFactors em = scaleFactors(sampleFacts, 
colIndexes, maxDistinct, map.isDense());
-               return new CompressedSizeInfoColGroup(colIndexes, em, 
_cs.validCompressions, map);
+               try {
+                       final double spar = _data.getSparsity();
+                       final EstimationFactors sampleFacts = 
map.extractFacts(_sampleSize, spar, spar, _cs);
+                       final EstimationFactors em = scaleFactors(sampleFacts, 
colIndexes, maxDistinct, map.isDense());
+                       return new CompressedSizeInfoColGroup(colIndexes, em, 
_cs.validCompressions, map);
+               }
+               catch(Exception e) {
+                       String ms = map.toString();
+                       if(ms.length() > 1000)
+                               ms = ms.substring(0, 1000);
+                       throw new DMLCompressionException("Failed to extract 
info: \n" + ms, e);
+               }
        }
 
        private EstimationFactors scaleFactors(EstimationFactors sampleFacts, 
IColIndex colIndexes, int maxDistinct,
@@ -125,6 +157,9 @@ public class ComEstSample extends AComEst {
                        final long nnz = calculateNNZ(colIndexes, 
scalingFactor);
                        final int numOffs = calculateOffs(sampleFacts, numRows, 
scalingFactor, colIndexes, (int) nnz);
                        final int estDistinct = distinctCountScale(sampleFacts, 
numOffs, numRows, maxDistinct, dense, nCol);
+                       // if(estDistinct < sampleFacts.numVals)
+                       // throw new DMLCompressionException("Failed estimating 
distinct: " + estDistinct + " should have been above "
+                       // + sampleFacts.numVals + "\n" + 
Arrays.toString(sampleFacts.frequencies));
 
                        // calculate the largest instance count.
                        final int maxLargestInstanceCount = numRows - 
estDistinct + 1;
@@ -133,11 +168,9 @@ public class ComEstSample extends AComEst {
                        final int mostFrequentOffsetCount = 
Math.max(Math.min(maxLargestInstanceCount, scaledLargestInstanceCount),
                                numRows - numOffs);
 
-                       final double overallSparsity = 
calculateSparsity(colIndexes, nnz, scalingFactor,
-                               sampleFacts.overAllSparsity);
+                       final double overallSparsity = 
calculateSparsity(colIndexes, nnz, scalingFactor, sampleFacts.overAllSparsity);
                        // For robustness safety add 10 percent more tuple 
sparsity
                        final double tupleSparsity = Math.min(overallSparsity * 
1.3, 1.0); // increase sparsity by 30%.
-
                        if(_cs.isRLEAllowed()) {
                                final int scaledRuns = Math.max(estDistinct,
                                        calculateRuns(sampleFacts, 
scalingFactor, numOffs, estDistinct));
@@ -161,14 +194,14 @@ public class ComEstSample extends AComEst {
                final int[] freq = sampleFacts.frequencies;
                if(freq == null || freq.length == 0)
                        return numOffs; // very aggressive number of distinct
+               maxDistinct = Math.max(maxDistinct, sampleFacts.numVals);
                // sampled size is smaller than actual if there was empty rows.
                // and the more we can reduce this value the more accurate the 
estimation will become.
                final int sampledSize = sampleFacts.numOffs;
-               int est = SampleEstimatorFactory.distinctCount(freq, dense ? 
numRows : numOffs, sampledSize,
-                       _cs.estimationType);
+               int est = SampleEstimatorFactory.distinctCount(freq, dense ? 
numRows : numOffs, sampledSize, _cs.estimationType);
                if(est > 10000)
                        est += est * 0.5;
-               if(nCol > 4) // Increase estimate if we get into many columns 
cocoding to be safe
+               if(nCol > 4 && est > 100) // Increase estimate if we get into 
many columns cocoding to be safe
                        est += ((double) est) * ((double) nCol) / 10;
                // Bound the estimate with the maxDistinct.
                return Math.max(Math.min(est, Math.min(maxDistinct, numOffs)), 
1);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 4fbf9b0ee4..6cc882cc2f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -220,6 +220,8 @@ public class CompressedSizeInfoColGroup {
 
        private static EnumMap<CompressionType, Double> 
calculateCompressionSizes(IColIndex cols, EstimationFactors fact,
                Set<CompressionType> validCompressionTypes) {
+               if(validCompressionTypes.size() > 10 )
+                       throw new DMLCompressionException("Invalid big number 
of compression types");
                EnumMap<CompressionType, Double> res = new 
EnumMap<>(CompressionType.class);
                for(CompressionType ct : validCompressionTypes) {
                        double compSize = getCompressionSize(cols, ct, fact);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java
index 130d0f77f8..904a228ae7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/EstimationFactors.java
@@ -87,16 +87,14 @@ public class EstimationFactors {
                this.tupleSparsity = tupleSparsity;
 
                if(overAllSparsity > 1 || overAllSparsity < 0)
-                       throw new DMLCompressionException("Invalid 
OverAllSparsity of: " + overAllSparsity);
+                       overAllSparsity = Math.max(0, Math.min(1, 
overAllSparsity));
                else if(tupleSparsity > 1 || tupleSparsity < 0)
-                       throw new DMLCompressionException("Invalid 
TupleSparsity of:" + tupleSparsity);
+                       tupleSparsity = Math.max(0, Math.min(1, tupleSparsity));
                else if(largestOff > numRows)
-                       throw new DMLCompressionException(
-                               "Invalid number of instance of most common 
element should be lower than number of rows. " + largestOff
-                                       + " > numRows: " + numRows);
+                       largestOff = numRows;
                else if(numVals > numOffs)
-                       throw new DMLCompressionException(
-                               "Num vals cannot be greater than num offs: 
vals: " + numVals + " offs: " + numOffs);
+                       numVals = numOffs;
+               
 
                if(CompressedMatrixBlock.debug && frequencies != null) {
                        for(int i = 0; i < frequencies.length; i++) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java 
b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java
index b1d310939d..d30ffa4753 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/utils/ACountHashMap.java
@@ -54,7 +54,7 @@ public abstract class ACountHashMap<T> implements Cloneable {
        }
 
        /**
-        * Increment and return the id of the incremeted index.
+        * Increment and return the id of the incremented index.
         * 
         * @param key The key to increment
         * @return The id of the incremented entry.
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java
 
b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java
index 1c9ef3082c..00fa67f6b6 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/utils/DoubleIntListHashMap.java
@@ -113,7 +113,7 @@ public class DoubleIntListHashMap {
                }
                else {
                        for(DIListEntry e = _data[ix]; e != null; e = e.next) {
-                               if(e.key == key) {
+                               if(Util.eq(e.key , key)) {
                                        IntArrayList lstPtr = e.value;
                                        lstPtr.appendValue(value);
                                        break;

Reply via email to