This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 0ba2aa994f [SYSTEMDS-3643] Fused Scaling Compressed Multiplication
0ba2aa994f is described below

commit 0ba2aa994f8f3006a2a660c8cad4fdd8e78ac94f
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Mon Oct 30 13:30:15 2023 +0100

    [SYSTEMDS-3643] Fused Scaling Compressed Multiplication
    
    This commit contains the code to fuse the scaling part into the
    Matrix Multiplication kernels of CLA. This is used to not allocate
    new Dictionaries, when the two column group sides have identical
    index structures.
    
    The change improve instructions such as MMChain and TSMM. The improvements
    are biggest if there are few column groups.
    
    Closes #1936
---
 .../sysds/runtime/compress/colgroup/APreAgg.java   |   5 +-
 .../colgroup/dictionary/DictLibMatrixMult.java     | 127 +++++++++--
 .../compress/colgroup/dictionary/Dictionary.java   |  48 ++++-
 .../compress/colgroup/dictionary/IDictionary.java  |  94 ++++++---
 .../colgroup/dictionary/IdentityDictionary.java    | 168 +++++++++++++--
 .../dictionary/IdentityDictionarySlice.java        |  23 +-
 .../colgroup/dictionary/MatrixBlockDictionary.java |  71 ++++++-
 .../colgroup/dictionary/PlaceHolderDict.java       |  18 ++
 .../compress/colgroup/dictionary/QDictionary.java  |  18 ++
 .../sysds/runtime/data/SparseBlockFactory.java     |  45 +++-
 src/test/java/org/apache/sysds/test/TestUtils.java |  11 +
 .../compress/dictionary/DictionaryTests.java       | 232 ++++++++++++++++++++-
 .../sysds/test/component/matrix/SparseFactory.java |  42 ++++
 13 files changed, 821 insertions(+), 81 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
index 8b8a7b7df0..7f585f2d7a 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/APreAgg.java
@@ -85,9 +85,12 @@ public abstract class APreAgg extends AColGroupValue {
         * @return A aggregate dictionary
         */
        public final IDictionary preAggregateThatIndexStructure(APreAgg that) {
-               long outputLength = (long)that._colIndexes.size() * 
this.getNumValues();
+               final long outputLength = (long)that._colIndexes.size() * 
this.getNumValues();
                if(outputLength > Integer.MAX_VALUE)
                        throw new NotImplementedException("Not supported pre 
aggregate of above integer length");
+               if(outputLength <= 0) // if the pre aggregate output is empty 
or nothing, return null
+                       return null;
+               
                // create empty Dictionary that we slowly fill, hence the 
dictionary is empty and no check
                final Dictionary ret = Dictionary.createNoCheck(new 
double[(int)outputLength]);
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java
index 240e57cc12..9aba711a30 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/DictLibMatrixMult.java
@@ -65,11 +65,7 @@ public class DictLibMatrixMult {
         */
        public static void MMDictsWithScaling(IDictionary left, IDictionary 
right, IColIndex leftRows,
                IColIndex rightColumns, MatrixBlock result, int[] counts) {
-               LOG.warn("Inefficient double allocation of dictionary");
-               final boolean modifyRight = right.getInMemorySize() > 
left.getInMemorySize();
-               final IDictionary rightM = modifyRight ? 
right.scaleTuples(counts, rightColumns.size()) : right;
-               final IDictionary leftM = modifyRight ? left : 
left.scaleTuples(counts, leftRows.size());
-               MMDicts(leftM, rightM, leftRows, rightColumns, result);
+               left.MMDictScaling(right, leftRows, rightColumns, result, 
counts);
        }
 
        /**
@@ -198,17 +194,43 @@ public class DictLibMatrixMult {
 
        protected static void MMDictsDenseDense(double[] left, double[] right, 
IColIndex rowsLeft, IColIndex colsRight,
                MatrixBlock result) {
-               final int commonDim = Math.min(left.length / rowsLeft.size(), 
right.length / colsRight.size());
+               final int leftSide = rowsLeft.size();
+               final int rightSide = colsRight.size();
+               final int commonDim = Math.min(left.length / leftSide, 
right.length / rightSide);
                final int resCols = result.getNumColumns();
                final double[] resV = result.getDenseBlockValues();
+
                for(int k = 0; k < commonDim; k++) {
-                       final int offL = k * rowsLeft.size();
-                       final int offR = k * colsRight.size();
-                       for(int i = 0; i < rowsLeft.size(); i++) {
+                       final int offL = k * leftSide;
+                       final int offR = k * rightSide;
+                       for(int i = 0; i < leftSide; i++) {
                                final int offOut = rowsLeft.get(i) * resCols;
                                final double vl = left[offL + i];
                                if(vl != 0) {
-                                       for(int j = 0; j < colsRight.size(); 
j++)
+                                       for(int j = 0; j < rightSide; j++)
+                                               resV[offOut + colsRight.get(j)] 
+= vl * right[offR + j];
+                               }
+                       }
+               }
+       }
+
+       protected static void MMDictsScalingDenseDense(double[] left, double[] 
right, IColIndex rowsLeft,
+               IColIndex colsRight, MatrixBlock result, int[] scaling) {
+               final int leftSide = rowsLeft.size();
+               final int rightSide = colsRight.size();
+               final int commonDim = Math.min(left.length / leftSide, 
right.length / rightSide);
+               final int resCols = result.getNumColumns();
+               final double[] resV = result.getDenseBlockValues();
+
+               for(int k = 0; k < commonDim; k++) {
+                       final int offL = k * leftSide;
+                       final int offR = k * rightSide;
+                       final int s = scaling[k];
+                       for(int i = 0; i < leftSide; i++) {
+                               final int offOut = rowsLeft.get(i) * resCols;
+                               final double vl = left[offL + i] * s;
+                               if(vl != 0) {
+                                       for(int j = 0; j < rightSide; j++)
                                                resV[offOut + colsRight.get(j)] 
+= vl * right[offR + j];
                                }
                        }
@@ -236,10 +258,34 @@ public class DictLibMatrixMult {
                }
        }
 
+       protected static void MMDictsScalingSparseDense(SparseBlock left, 
double[] right, IColIndex rowsLeft,
+               IColIndex colsRight, MatrixBlock result, int[] scaling) {
+               final double[] resV = result.getDenseBlockValues();
+               final int commonDim = Math.min(left.numRows(), right.length / 
colsRight.size());
+               for(int i = 0; i < commonDim; i++) {
+                       if(left.isEmpty(i))
+                               continue;
+                       final int apos = left.pos(i);
+                       final int alen = left.size(i) + apos;
+                       final int[] aix = left.indexes(i);
+                       final double[] leftVals = left.values(i);
+                       final int offRight = i * colsRight.size();
+                       final int s = scaling[i];
+                       for(int k = apos; k < alen; k++) {
+                               final int offOut = rowsLeft.get(aix[k]) * 
result.getNumColumns();
+                               final double v = leftVals[k] * s;
+                               for(int j = 0; j < colsRight.size(); j++)
+                                       resV[offOut + colsRight.get(j)] += v * 
right[offRight + j];
+                       }
+               }
+       }
+
        protected static void MMDictsDenseSparse(double[] left, SparseBlock 
right, IColIndex rowsLeft, IColIndex colsRight,
                MatrixBlock result) {
                final double[] resV = result.getDenseBlockValues();
-               final int commonDim = Math.min(left.length / rowsLeft.size(), 
right.numRows());
+               final int leftSize = rowsLeft.size();
+               final int commonDim = Math.min(left.length / leftSize, 
right.numRows());
+
                for(int i = 0; i < commonDim; i++) {
                        if(right.isEmpty(i))
                                continue;
@@ -247,8 +293,8 @@ public class DictLibMatrixMult {
                        final int alen = right.size(i) + apos;
                        final int[] aix = right.indexes(i);
                        final double[] rightVals = right.values(i);
-                       final int offLeft = i * rowsLeft.size();
-                       for(int j = 0; j < rowsLeft.size(); j++) {
+                       final int offLeft = i * leftSize;
+                       for(int j = 0; j < leftSize; j++) {
                                final int offOut = rowsLeft.get(j) * 
result.getNumColumns();
                                final double v = left[offLeft + j];
                                if(v != 0) {
@@ -259,6 +305,32 @@ public class DictLibMatrixMult {
                }
        }
 
+               protected static void MMDictsScalingDenseSparse(double[] left, 
SparseBlock right, IColIndex rowsLeft, IColIndex colsRight,
+               MatrixBlock result, int[] scaling) {
+               final double[] resV = result.getDenseBlockValues();
+               final int leftSize = rowsLeft.size();
+               final int commonDim = Math.min(left.length / leftSize, 
right.numRows());
+
+               for(int i = 0; i < commonDim; i++) {
+                       if(right.isEmpty(i))
+                               continue;
+                       final int apos = right.pos(i);
+                       final int alen = right.size(i) + apos;
+                       final int[] aix = right.indexes(i);
+                       final double[] rightVals = right.values(i);
+                       final int offLeft = i * leftSize;
+                       final int s = scaling[i];
+                       for(int j = 0; j < leftSize; j++) {
+                               final int offOut = rowsLeft.get(j) * 
result.getNumColumns();
+                               final double v = left[offLeft + j] * s;
+                               if(v != 0) {
+                                       for(int k = apos; k < alen; k++)
+                                               resV[offOut + 
colsRight.get(aix[k])] += v * rightVals[k];
+                               }
+                       }
+               }
+       }
+
        protected static void MMDictsSparseSparse(SparseBlock left, SparseBlock 
right, IColIndex rowsLeft,
                IColIndex colsRight, MatrixBlock result) {
                final int commonDim = Math.min(left.numRows(), right.numRows());
@@ -286,6 +358,35 @@ public class DictLibMatrixMult {
                }
        }
 
+       protected static void MMDictsScalingSparseSparse(SparseBlock left, 
SparseBlock right, IColIndex rowsLeft,
+               IColIndex colsRight, MatrixBlock result, int[] scaling) {
+               final int commonDim = Math.min(left.numRows(), right.numRows());
+               final double[] resV = result.getDenseBlockValues();
+               final int resCols = result.getNumColumns();
+               // remember that the left side is transposed...
+               for(int i = 0; i < commonDim; i++) {
+                       if(left.isEmpty(i) || right.isEmpty(i))
+                               continue;
+                       final int leftAPos = left.pos(i);
+                       final int leftAlen = left.size(i) + leftAPos;
+                       final int[] leftAix = left.indexes(i);
+                       final double[] leftVals = left.values(i);
+                       final int rightAPos = right.pos(i);
+                       final int rightAlen = right.size(i) + rightAPos;
+                       final int[] rightAix = right.indexes(i);
+                       final double[] rightVals = right.values(i);
+
+                       final int s = scaling[i];
+
+                       for(int k = leftAPos; k < leftAlen; k++) {
+                               final int offOut = rowsLeft.get(leftAix[k]) * 
resCols;
+                               final double v = leftVals[k] * s;
+                               for(int j = rightAPos; j < rightAlen; j++)
+                                       resV[offOut + 
colsRight.get(rightAix[j])] += v * rightVals[j];
+                       }
+               }
+       }
+
        protected static void MMToUpperTriangleSparseSparse(SparseBlock left, 
SparseBlock right, IColIndex rowsLeft,
                IColIndex colsRight, MatrixBlock result) {
                final int commonDim = Math.min(left.numRows(), right.numRows());
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
index 983dc84b50..4f0bbfbee1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/Dictionary.java
@@ -22,13 +22,16 @@ package 
org.apache.sysds.runtime.compress.colgroup.dictionary;
 import java.io.DataInput;
 import java.io.DataOutput;
 import java.io.IOException;
+import java.lang.ref.SoftReference;
 import java.math.BigDecimal;
 import java.math.MathContext;
 import java.util.Arrays;
 
+import org.apache.commons.lang3.NotImplementedException;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
+import org.apache.sysds.runtime.compress.utils.Util;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Plus;
@@ -51,6 +54,8 @@ public class Dictionary extends ADictionary {
        private static final long serialVersionUID = -6517136537249507753L;
 
        protected final double[] _values;
+       /** A Cache to contain a MatrixBlock version of the dictionary. */
+       protected volatile SoftReference<MatrixBlockDictionary> cache = null;
 
        protected Dictionary(double[] values) {
                _values = values;
@@ -799,7 +804,14 @@ public class Dictionary extends ADictionary {
 
        @Override
        public MatrixBlockDictionary getMBDict(int nCol) {
-               return MatrixBlockDictionary.createDictionary(_values, nCol, 
true);
+               if(cache != null) {
+                       MatrixBlockDictionary r = cache.get();
+                       if(r != null)
+                               return r;
+               }
+               MatrixBlockDictionary ret = 
MatrixBlockDictionary.createDictionary(_values, nCol, true);
+               cache = new SoftReference<>(ret);
+               return ret;
        }
 
        @Override
@@ -843,13 +855,15 @@ public class Dictionary extends ADictionary {
        @Override
        public Dictionary preaggValuesFromDense(int numVals, IColIndex 
colIndexes, IColIndex aggregateColumns, double[] b,
                int cut) {
-               double[] ret = new double[numVals * aggregateColumns.size()];
-               for(int k = 0, off = 0; k < numVals * colIndexes.size(); k += 
colIndexes.size(), off += aggregateColumns.size()) {
-                       for(int h = 0; h < colIndexes.size(); h++) {
-                               int idb = colIndexes.get(h) * cut;
+               final int cz = colIndexes.size();
+               final int az = aggregateColumns.size();
+               final double[] ret = new double[numVals * az];
+               for(int k = 0, off = 0; k < numVals * cz; k += cz, off += az) {
+                       for(int h = 0; h < cz; h++) {
+                               final int idb = colIndexes.get(h) * cut;
                                double v = _values[k + h];
                                if(v != 0)
-                                       for(int i = 0; i < 
aggregateColumns.size(); i++)
+                                       for(int i = 0; i < az; i++)
                                                ret[off + i] += v * b[idb + 
aggregateColumns.get(i)];
                        }
                }
@@ -861,13 +875,15 @@ public class Dictionary extends ADictionary {
                double[] retV = new double[_values.length];
                for(int i = 0; i < _values.length; i++) {
                        final double v = _values[i];
-                       retV[i] = v == pattern ? replace : v;
+                       retV[i] = Util.eq(v, pattern) ? replace : v;
                }
                return create(retV);
        }
 
        @Override
        public IDictionary replaceWithReference(double pattern, double replace, 
double[] reference) {
+               if(Util.eq(pattern, Double.NaN))
+                       throw new NotImplementedException();
                final double[] retV = new double[_values.length];
                final int nCol = reference.length;
                final int nRow = _values.length / nCol;
@@ -1040,16 +1056,34 @@ public class Dictionary extends ADictionary {
                right.MMDictDense(_values, rowsLeft, colsRight, result);
        }
 
+       @Override
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               right.MMDictScalingDense(_values, rowsLeft, colsRight, result, 
scaling);
+       }
+
        @Override
        public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result) {
                DictLibMatrixMult.MMDictsDenseDense(left, _values, rowsLeft, 
colsRight, result);
        }
 
+       @Override
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               DictLibMatrixMult.MMDictsScalingDenseDense(left, _values, 
rowsLeft, colsRight, result, scaling);
+       }
+
        @Override
        public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result) {
                DictLibMatrixMult.MMDictsSparseDense(left, _values, rowsLeft, 
colsRight, result);
        }
 
+       @Override
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               DictLibMatrixMult.MMDictsScalingSparseDense(left, _values, 
rowsLeft, colsRight, result, scaling);
+       }
+
        @Override
        public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result) {
                right.TSMMToUpperTriangleDense(_values, rowsLeft, colsRight, 
result);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
index 2f3d435673..1047692f50 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IDictionary.java
@@ -524,46 +524,46 @@ public interface IDictionary {
        public long getNumberNonZerosWithReference(int[] counts, double[] 
reference, int nRows);
 
        /**
-        * Copies and adds the dictionary entry from this dictionary to the d 
dictionary
+        * Adds the dictionary entry from this dictionary to the d dictionary
         * 
-        * @param v    the target dictionary (dense double array)
-        * @param fr   the from index
-        * @param to   the to index
-        * @param nCol the number of columns
+        * @param v    The target dictionary (dense double array)
+        * @param fr   The from index is the tuple index to copy from.
+        * @param to   The to index is the row index to copy into.
+        * @param nCol The number of columns in both cases
         */
        public void addToEntry(double[] v, int fr, int to, int nCol);
 
        /**
-        * copies and adds the dictonary entry from this dictionary yo the d 
dictionary rep times.
+        * Adds the dictionary entry from this dictionary to the v dictionary 
rep times.
         * 
-        * @param v    the target dictionary (dense double array)
-        * @param fr   the from index
-        * @param to   the to index
-        * @param nCol the number of columns
-        * @param rep  the number of repetitions to apply (simply multiply do 
not loop)
+        * @param v    The target dictionary (dense double array)
+        * @param fr   The from index is the tuple index to copy from.
+        * @param to   The to index is the row index to copy into.
+        * @param nCol The number of columns in both cases
+        * @param rep  The number of repetitions to apply (simply multiply do 
not loop)
         */
        public void addToEntry(double[] v, int fr, int to, int nCol, int rep);
 
        /**
         * Vectorized add to entry, this call helps with a bit of locality for 
the cache.
         * 
-        * @param v    THe target dictionary (dense double array)
-        * @param f1   from index 1
-        * @param f2   from index 2
-        * @param f3   from index 3
-        * @param f4   from index 4
-        * @param f5   from index 5
-        * @param f6   from index 6
-        * @param f7   from index 7
-        * @param f8   from index 8
-        * @param t1   to index 1
-        * @param t2   to index 2
-        * @param t3   to index 3
-        * @param t4   to index 4
-        * @param t5   to index 5
-        * @param t6   to index 6
-        * @param t7   to index 7
-        * @param t8   to index 8
+        * @param v    The target dictionary (dense double array)
+        * @param f1   From index 1
+        * @param f2   From index 2
+        * @param f3   From index 3
+        * @param f4   From index 4
+        * @param f5   From index 5
+        * @param f6   From index 6
+        * @param f7   From index 7
+        * @param f8   From index 8
+        * @param t1   To index 1
+        * @param t2   To index 2
+        * @param t3   To index 3
+        * @param t4   To index 4
+        * @param t5   To index 5
+        * @param t6   To index 6
+        * @param t7   To index 7
+        * @param t8   To index 8
         * @param nCol Number of columns in the dictionary
         */
        public void addToEntryVectorized(double[] v, int f1, int f2, int f3, 
int f4, int f5, int f6, int f7, int f8, int t1,
@@ -820,6 +820,20 @@ public interface IDictionary {
         */
        public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result);
 
+       /**
+        * Matrix multiplication of dictionaries
+        * 
+        * Note the left is this, and it is transposed
+        * 
+        * @param right     Right hand side of multiplication
+        * @param rowsLeft  Offset rows on the left
+        * @param colsRight Offset cols on the right
+        * @param result    The output matrix block
+        * @param scaling   The scaling
+        */
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling);
+
        /**
         * Matrix multiplication of dictionaries left side dense and transposed 
right side is this.
         * 
@@ -830,6 +844,18 @@ public interface IDictionary {
         */
        public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result);
 
+       /**
+        * Matrix multiplication of dictionaries left side dense and transposed 
right side is this.
+        * 
+        * @param left      Dense left side
+        * @param rowsLeft  Offset rows on the left
+        * @param colsRight Offset cols on the right
+        * @param result    The output matrix block
+        * @param scaling   The scaling
+        */
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling);
+
        /**
         * Matrix multiplication of dictionaries left side sparse and 
transposed right side is this.
         * 
@@ -839,6 +865,18 @@ public interface IDictionary {
         * @param result    The output matrix block
         */
        public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result);
+       
+/**
+        * Matrix multiplication of dictionaries left side sparse and 
transposed right side is this.
+        * 
+        * @param left      Sparse left side
+        * @param rowsLeft  Offset rows on the left
+        * @param colsRight Offset cols on the right
+        * @param result    The output matrix block
+        * @param scaling   The scaling
+        */
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling);
 
        /**
         * Matrix multiplication but allocate output in upper triangle and 
twice if on diagonal, note this is left
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
index 39712155e6..74f5e5b099 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionary.java
@@ -32,6 +32,9 @@ import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockFactory;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -51,7 +54,7 @@ public class IdentityDictionary extends ADictionary {
        /** Specify if the Identity matrix should contain an empty row in the 
end. */
        protected final boolean withEmpty;
        /** A Cache to contain a materialized version of the identity matrix. */
-       protected SoftReference<MatrixBlockDictionary> cache = null;
+       protected volatile SoftReference<MatrixBlockDictionary> cache = null;
 
        /**
         * Create an identity matrix dictionary. It behaves as if allocated a 
Sparse Matrix block but exploits that the
@@ -212,7 +215,29 @@ public class IdentityDictionary extends ADictionary {
 
        @Override
        public IDictionary binOpRight(BinaryOperator op, double[] v, IColIndex 
colIndexes) {
-               return getMBDict().binOpRight(op, v, colIndexes);
+               boolean same = false;
+               if(op.fn instanceof Plus || op.fn instanceof Minus) {
+                       same = true;
+                       for(int i = 0; i < colIndexes.size(); i++) {
+                               if(v[colIndexes.get(i)] != 0.0) {
+                                       same = false;
+                                       break;
+                               }
+                       }
+               }
+               if(op.fn instanceof Divide) {
+                       same = true;
+                       for(int i = 0; i < colIndexes.size(); i++) {
+                               if(v[colIndexes.get(i)] != 1.0) {
+                                       same = false;
+                                       break;
+                               }
+                       }
+               }
+               if(same)
+                       return this;
+               MatrixBlockDictionary mb = getMBDict();
+               return mb.binOpRight(op, v, colIndexes);
        }
 
        @Override
@@ -243,22 +268,33 @@ public class IdentityDictionary extends ADictionary {
 
        @Override
        public int getNumberOfValues(int ncol) {
+               if(ncol != nRowCol)
+                       throw new DMLCompressionException("Invalid call to get 
Number of values assuming wrong number of columns");
                return nRowCol + (withEmpty ? 1 : 0);
        }
 
        @Override
        public double[] sumAllRowsToDouble(int nrColumns) {
-               double[] ret = new double[nRowCol];
-               Arrays.fill(ret, 1);
-               return ret;
+               if(withEmpty) {
+                       double[] ret = new double[nRowCol + 1];
+                       Arrays.fill(ret, 1);
+                       ret[ret.length - 1] = 0;
+                       return ret;
+               }
+               else {
+                       double[] ret = new double[nRowCol];
+                       Arrays.fill(ret, 1);
+                       return ret;
+               }
        }
 
        @Override
        public double[] sumAllRowsToDoubleWithDefault(double[] defaultTuple) {
-               double[] ret = new double[nRowCol];
-               Arrays.fill(ret, 1);
+               double[] ret = new double[defaultTuple.length];
                for(int i = 0; i < defaultTuple.length; i++)
-                       ret[i] += defaultTuple[i];
+                       ret[i] += 1 + defaultTuple[i];
+               if(withEmpty)
+                       ret[ret.length - 1] += -1;
                return ret;
        }
 
@@ -341,6 +377,8 @@ public class IdentityDictionary extends ADictionary {
                double s = 0.0;
                for(int v : counts)
                        s += v;
+               if(withEmpty)
+                       s -= counts[counts.length - 1];
                return s;
        }
 
@@ -389,13 +427,54 @@ public class IdentityDictionary extends ADictionary {
 
        @Override
        public void addToEntry(final double[] v, final int fr, final int to, 
final int nCol, int rep) {
-               getMBDict().addToEntry(v, fr, to, nCol, rep);
+               if(withEmpty) {
+                       if(fr < nRowCol)
+                               v[to * nCol + fr] += rep;
+               }
+               else {
+                       v[to * nCol + fr] += rep;
+               }
        }
 
        @Override
        public void addToEntryVectorized(double[] v, int f1, int f2, int f3, 
int f4, int f5, int f6, int f7, int f8, int t1,
                int t2, int t3, int t4, int t5, int t6, int t7, int t8, int 
nCol) {
-               getMBDict().addToEntryVectorized(v, f1, f2, f3, f4, f5, f6, f7, 
f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
+               if(withEmpty)
+                       addToEntryVectorizedWithEmpty(v, f1, f2, f3, f4, f5, 
f6, f7, f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
+               else
+                       addToEntryVectorizedNorm(v, f1, f2, f3, f4, f5, f6, f7, 
f8, t1, t2, t3, t4, t5, t6, t7, t8, nCol);
+       }
+
+       private void addToEntryVectorizedWithEmpty(double[] v, int f1, int f2, 
int f3, int f4, int f5, int f6, int f7,
+               int f8, int t1, int t2, int t3, int t4, int t5, int t6, int t7, 
int t8, int nCol) {
+               if(f1 < nRowCol)
+                       v[t1 * nCol + f1] += 1;
+               if(f2 < nRowCol)
+                       v[t2 * nCol + f2] += 1;
+               if(f3 < nRowCol)
+                       v[t3 * nCol + f3] += 1;
+               if(f4 < nRowCol)
+                       v[t4 * nCol + f4] += 1;
+               if(f5 < nRowCol)
+                       v[t5 * nCol + f5] += 1;
+               if(f6 < nRowCol)
+                       v[t6 * nCol + f6] += 1;
+               if(f7 < nRowCol)
+                       v[t7 * nCol + f7] += 1;
+               if(f8 < nRowCol)
+                       v[t8 * nCol + f8] += 1;
+       }
+
+       private void addToEntryVectorizedNorm(double[] v, int f1, int f2, int 
f3, int f4, int f5, int f6, int f7, int f8,
+               int t1, int t2, int t3, int t4, int t5, int t6, int t7, int t8, 
int nCol) {
+               v[t1 * nCol + f1] += 1;
+               v[t2 * nCol + f2] += 1;
+               v[t3 * nCol + f3] += 1;
+               v[t4 * nCol + f4] += 1;
+               v[t5 * nCol + f5] += 1;
+               v[t6 * nCol + f6] += 1;
+               v[t7 * nCol + f7] += 1;
+               v[t8 * nCol + f8] += 1;
        }
 
        @Override
@@ -466,7 +545,28 @@ public class IdentityDictionary extends ADictionary {
        @Override
        public IDictionary preaggValuesFromDense(final int numVals, final 
IColIndex colIndexes,
                final IColIndex aggregateColumns, final double[] b, final int 
cut) {
-               return getMBDict().preaggValuesFromDense(numVals, colIndexes, 
aggregateColumns, b, cut);
+               /**
+                * This operations is Essentially a Identity matrix 
multiplication with a right hand side dense matrix, but we
+                * need to slice out the right hand side from the input.
+                * 
+                * ColIndexes specify the rows to slice out of the right matrix.
+                * 
+                * aggregate columns specify the columns to slice out from the 
right.
+                */
+               final int cs = colIndexes.size();
+               final int s = aggregateColumns.size();
+
+               double[] ret = new double[s * numVals];
+               int off = 0;
+               for(int i = 0; i < cs; i++) {// rows on right
+                       final int offB = colIndexes.get(i) * cut;
+                       for(int j = 0; j < s; j++) {
+                               ret[off++] = b[offB + aggregateColumns.get(j)];
+                       }
+               }
+
+               MatrixBlock db = new MatrixBlock(numVals, s, ret);
+               return new MatrixBlockDictionary(db);
        }
 
        @Override
@@ -529,7 +629,10 @@ public class IdentityDictionary extends ADictionary {
 
        @Override
        public double getSparsity() {
-               return 1d / nRowCol;
+               if(withEmpty)
+                       return 1d / (nRowCol + 1);
+               else
+                       return 1d / nRowCol;
        }
 
        @Override
@@ -545,13 +648,44 @@ public class IdentityDictionary extends ADictionary {
        @Override
        public void MMDict(IDictionary right, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result) {
                getMBDict().MMDict(right, rowsLeft, colsRight, result);
-               // should replace with add to right to output cells.
+       }
+
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               getMBDict().MMDictScaling(right, rowsLeft, colsRight, result, 
scaling);
        }
 
        @Override
        public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result) {
-               getMBDict().MMDictDense(left, rowsLeft, colsRight, result);
+               // getMBDict().MMDictDense(left, rowsLeft, colsRight, result);
                // should replace with add to right to output cells.
+               final int leftSide = rowsLeft.size();
+               final int resCols = result.getNumColumns();
+               final int commonDim = Math.min(left.length / leftSide, nRowCol);
+               final double[] resV = result.getDenseBlockValues();
+               for(int i = 0; i < leftSide; i++) {// rows in left side
+                       final int offOut = rowsLeft.get(i) * resCols;
+                       final int leftOff = i * leftSide;
+                       for(int j = 0; j < commonDim; j++) { // cols in left 
side skipping empty from identity
+                               resV[offOut + colsRight.get(j)] += left[leftOff 
+ j];
+                       }
+               }
+       }
+
+       @Override
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               final int leftSide = rowsLeft.size();
+               final int resCols = result.getNumColumns();
+               final int commonDim = Math.min(left.length / leftSide, nRowCol);
+               final double[] resV = result.getDenseBlockValues();
+               for(int i = 0; i < leftSide; i++) {// rows in left side
+                       final int offOut = rowsLeft.get(i) * resCols;
+                       final int leftOff = i * leftSide;
+                       for(int j = 0; j < commonDim; j++) { // cols in left 
side skipping empty from identity
+                               resV[offOut + colsRight.get(j)] += left[leftOff 
+ j] * scaling[j];
+                       }
+               }
        }
 
        @Override
@@ -559,6 +693,12 @@ public class IdentityDictionary extends ADictionary {
                getMBDict().MMDictSparse(left, rowsLeft, colsRight, result);
        }
 
+       @Override
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               getMBDict().MMDictScalingSparse(left, rowsLeft, colsRight, 
result, scaling);
+       }
+
        @Override
        public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result) {
                getMBDict().TSMMToUpperTriangle(right, rowsLeft, colsRight, 
result);
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
index 6a282e8b26..167328871b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/IdentityDictionarySlice.java
@@ -69,13 +69,13 @@ public class IdentityDictionarySlice extends 
IdentityDictionary {
        @Override
        public double getValue(int i) {
                throw new NotImplementedException();
-
        }
 
        @Override
        public final double getValue(int r, int c, int nCol) {
-               throw new NotImplementedException();
-
+               if(r < l || r > u)
+                       return 0;
+               return super.getValue(r - l, c, nCol);
        }
 
        @Override
@@ -278,6 +278,23 @@ public class IdentityDictionarySlice extends 
IdentityDictionary {
                return 1d / nRowCol;
        }
 
+       @Override
+       public IDictionary preaggValuesFromDense(final int numVals, final 
IColIndex colIndexes,
+               final IColIndex aggregateColumns, final double[] b, final int 
cut) {
+               return getMBDict().preaggValuesFromDense(numVals, colIndexes, 
aggregateColumns, b, cut);
+       }
+
+       @Override
+       public void addToEntryVectorized(double[] v, int f1, int f2, int f3, 
int f4, int f5, int f6, int f7, int f8, int t1,
+               int t2, int t3, int t4, int t5, int t6, int t7, int t8, int 
nCol) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void addToEntry(final double[] v, final int fr, final int to, 
final int nCol, int rep) {
+               throw new NotImplementedException();
+       }
+
        @Override
        public boolean equals(IDictionary o) {
                if(o instanceof IdentityDictionarySlice) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
index 3995fc4e36..2a800837c7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/MatrixBlockDictionary.java
@@ -88,8 +88,27 @@ public class MatrixBlockDictionary extends ADictionary {
        }
 
        public static MatrixBlockDictionary createDictionary(double[] values, 
int nCol, boolean check) {
-               final MatrixBlock mb = Util.matrixBlockFromDenseArray(values, 
nCol, check);
-               return create(mb, check);
+               if(nCol <= 1) {
+                       final MatrixBlock mb = 
Util.matrixBlockFromDenseArray(values, nCol, check);
+                       return create(mb, check);
+               }
+               else {
+                       final int nnz = checkNNz(values);
+                       if((double) nnz / values.length < 0.4D) {
+                               SparseBlock sb = 
SparseBlockFactory.createFromArray(values, nCol, nnz);
+                               MatrixBlock mb = new MatrixBlock(values.length 
/ nCol, nCol, nnz, sb);
+                               return create(mb, false);
+                       }
+                       else
+                               return 
create(Util.matrixBlockFromDenseArray(values, nCol, check), false);
+               }
+       }
+
+       private static int checkNNz(double[] values) {
+               int nnz = 0;
+               for(int i = 0; i < values.length; i++)
+                       nnz += values[i] == 0 ? 0 : 1;
+               return nnz;
        }
 
        public MatrixBlock getMatrixBlock() {
@@ -837,6 +856,9 @@ public class MatrixBlockDictionary extends ADictionary {
 
        @Override
        public int getNumberOfValues(int ncol) {
+
+               if(ncol != _data.getNumColumns())
+                       throw new DMLCompressionException("Invalid call to get 
Number of values assuming wrong number of columns");
                return _data.getNumRows();
        }
 
@@ -1771,15 +1793,15 @@ public class MatrixBlockDictionary extends ADictionary {
                        }
                }
                else {
-                       double[] values = _data.getDenseBlockValues();
-                       for(int k = 0, off = 0;
-                               k < numVals * colIndexes.size();
-                               k += colIndexes.size(), off += 
aggregateColumns.size()) {
-                               for(int h = 0; h < colIndexes.size(); h++) {
-                                       int idb = colIndexes.get(h) * cut;
+                       final int cz = colIndexes.size();
+                       final int az = aggregateColumns.size();
+                       final double[] values = _data.getDenseBlockValues();
+                       for(int k = 0, off = 0; k < numVals * cz; k += cz, off 
+= az) {
+                               for(int h = 0; h < cz; h++) {
+                                       final int idb = colIndexes.get(h) * cut;
                                        double v = values[k + h];
                                        if(v != 0)
-                                               for(int i = 0; i < 
aggregateColumns.size(); i++)
+                                               for(int i = 0; i < az; i++)
                                                        ret[off + i] += v * 
b[idb + aggregateColumns.get(i)];
                                }
                        }
@@ -1801,10 +1823,14 @@ public class MatrixBlockDictionary extends ADictionary {
 
        @Override
        public IDictionary replaceWithReference(double pattern, double replace, 
double[] reference) {
+               if(Util.eq(pattern, Double.NaN))
+                       throw new NotImplementedException();
+
                final int nRow = _data.getNumRows();
                final int nCol = _data.getNumColumns();
                final MatrixBlock ret = new MatrixBlock(nRow, nCol, false);
                ret.allocateDenseBlock();
+
                final double[] retV = ret.getDenseBlockValues();
                int off = 0;
                if(_data.isInSparseFormat()) {
@@ -2030,6 +2056,15 @@ public class MatrixBlockDictionary extends ADictionary {
                        right.MMDictDense(_data.getDenseBlockValues(), 
rowsLeft, colsRight, result);
        }
 
+       @Override
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               if(_data.isInSparseFormat())
+                       right.MMDictScalingSparse(_data.getSparseBlock(), 
rowsLeft, colsRight, result, scaling);
+               else
+                       right.MMDictScalingDense(_data.getDenseBlockValues(), 
rowsLeft, colsRight, result, scaling);
+       }
+
        @Override
        public void MMDictDense(double[] left, IColIndex rowsLeft, IColIndex 
colsRight, MatrixBlock result) {
                if(_data.isInSparseFormat())
@@ -2038,6 +2073,15 @@ public class MatrixBlockDictionary extends ADictionary {
                        DictLibMatrixMult.MMDictsDenseDense(left, 
_data.getDenseBlockValues(), rowsLeft, colsRight, result);
        }
 
+       @Override
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               if(_data.isInSparseFormat())
+                       DictLibMatrixMult.MMDictsScalingDenseSparse(left, 
_data.getSparseBlock(), rowsLeft, colsRight, result, scaling);
+               else
+                       DictLibMatrixMult.MMDictsScalingDenseDense(left, 
_data.getDenseBlockValues(), rowsLeft, colsRight, result,scaling);
+       }
+
        @Override
        public void MMDictSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result) {
 
@@ -2047,6 +2091,15 @@ public class MatrixBlockDictionary extends ADictionary {
                        DictLibMatrixMult.MMDictsSparseDense(left, 
_data.getDenseBlockValues(), rowsLeft, colsRight, result);
        }
 
+       @Override
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               if(_data.isInSparseFormat())
+                       DictLibMatrixMult.MMDictsScalingSparseSparse(left, 
_data.getSparseBlock(), rowsLeft, colsRight, result, scaling);
+               else
+                       DictLibMatrixMult.MMDictsScalingSparseDense(left, 
_data.getDenseBlockValues(), rowsLeft, colsRight, result, scaling);
+       }
+
        @Override
        public void TSMMToUpperTriangle(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result) {
                if(_data.isInSparseFormat())
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
index 94fa9ef528..51c41ffeec 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/PlaceHolderDict.java
@@ -507,4 +507,22 @@ public class PlaceHolderDict implements IDictionary, 
Serializable {
                return new PlaceHolderDict(nVal);
        }
 
+       @Override
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new RuntimeException(errMessage);
+       }
+
+       @Override
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new RuntimeException(errMessage);
+       }
+
+       @Override
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new RuntimeException(errMessage);
+       }
+
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
index b55a291ae3..ae833dd7a9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/dictionary/QDictionary.java
@@ -613,4 +613,22 @@ public class QDictionary extends ADictionary {
        public IDictionary reorder(int[] reorder) {
                throw new NotImplementedException();
        }
+
+       @Override
+       public void MMDictScaling(IDictionary right, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void MMDictScalingDense(double[] left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void MMDictScalingSparse(SparseBlock left, IColIndex rowsLeft, 
IColIndex colsRight, MatrixBlock result,
+               int[] scaling) {
+               throw new NotImplementedException();
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java 
b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
index 66f07ab6ad..6b04cf6d71 100644
--- a/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/SparseBlockFactory.java
@@ -19,10 +19,16 @@
 
 package org.apache.sysds.runtime.data;
 
+import java.util.Arrays;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
-public abstract class SparseBlockFactory
-{
+public abstract class SparseBlockFactory{
+               protected static final Log LOG = 
LogFactory.getLog(SparseBlockFactory.class.getName());
+
 
        public static SparseBlock createSparseBlock(int rlen) {
                return createSparseBlock(MatrixBlock.DEFAULT_SPARSEBLOCK, rlen);
@@ -117,4 +123,39 @@ public abstract class SparseBlockFactory
                rowPtr[nRowCol+1] = nRowCol;
                return new SparseBlockCSR(rowPtr, colIdx, vals, nnz);
        }
+
+       /**
+        * Create a sparse block from an array. Note that the nnz count should 
be absolutely correct for this call to work.
+        * 
+        * @param valsDense a double array of values linearized.
+        * @param nCol The number of columns in reach row.
+        * @param nnz  The number of non zero values.
+        * @return A sparse block.
+        */
+       public static SparseBlock createFromArray(final double[] valsDense, 
final  int nCol, final int nnz) {
+               final int nRow = valsDense.length / nCol;
+               if(nnz > 0) {
+
+                       final int[] rowPtr = new int[nRow + 1];
+                       final int[] colIdx = new int[nnz];
+                       final double[] valsSparse = new double[nnz];
+                       int off = 0;
+                       for(int i = 0; i < valsDense.length; i++) {
+                               final int mod = i % nCol;
+                               if(mod == 0)
+                                       rowPtr[i / nCol] = off;
+                               if(valsDense[i] != 0) {
+                                       valsSparse[off] = valsDense[i];
+                                       colIdx[off] = mod;
+                                       off++;
+                               }
+                       }
+                       rowPtr[rowPtr.length -1] = off;
+
+                       return new SparseBlockCSR(rowPtr, colIdx, valsSparse, 
nnz);
+               }
+               else {
+                       return new SparseBlockMCSR(nRow); // empty MCSR block
+               }
+       }
 }
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index acda5eaf83..e090912e86 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -2044,6 +2044,17 @@ public class TestUtils {
                return matrix;
        }
 
+       public static double[] generateTestVector(int cols, double min, double 
max, double sparsity, long seed) {
+               double[] vector = new double[cols];
+               Random random = (seed == -1) ? TestUtils.random : new 
Random(seed);
+               for(int j = 0; j < cols; j++) {
+                       if(random.nextDouble() > sparsity)
+                               continue;
+                       vector[j] = (random.nextDouble() * (max - min) + min);
+               }
+               return vector;
+       }
+
        /**
         *
         * Generates a test matrix with the specified parameters as a 
MatrixBlock.
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
index 91707565f3..9307930f1d 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/dictionary/DictionaryTests.java
@@ -21,21 +21,29 @@ package org.apache.sysds.test.component.compress.dictionary;
 
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotEquals;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 import java.util.List;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
-import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
+import org.apache.sysds.runtime.compress.colgroup.dictionary.IDictionary;
+import 
org.apache.sysds.runtime.compress.colgroup.dictionary.IdentityDictionary;
 import 
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.compress.colgroup.indexes.ColIndexFactory;
+import org.apache.sysds.runtime.compress.colgroup.indexes.IColIndex;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Minus;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.test.TestUtils;
 import org.junit.Test;
 import org.junit.runner.RunWith;
@@ -72,6 +80,30 @@ public class DictionaryTests {
                        addAll(tests, new double[] {1, 2, 3, 4, 5, 6}, 2);
                        addAll(tests, new double[] {1, 2.2, 3.3, 4.4, 5.5, 
6.6}, 3);
 
+                       tests.add(new Object[] {new IdentityDictionary(2), 
Dictionary.create(new double[] {1, 0, 0, 1}), 2, 2});
+                       tests.add(new Object[] {new IdentityDictionary(2, 
true), //
+                               Dictionary.create(new double[] {1, 0, 0, 1, 0, 
0}), 3, 2});
+                       tests.add(new Object[] {new IdentityDictionary(3), //
+                               Dictionary.create(new double[] {1, 0, 0, 0, 1, 
0, 0, 0, 1}), 3, 3});
+                       tests.add(new Object[] {new IdentityDictionary(3, 
true), //
+                               Dictionary.create(new double[] {1, 0, 0, 0, 1, 
0, 0, 0, 1, 0, 0, 0}), 4, 3});
+
+                       tests.add(new Object[] {new IdentityDictionary(4), //
+                               Dictionary.create(new double[] {//
+                                       1, 0, 0, 0, //
+                                       0, 1, 0, 0, //
+                                       0, 0, 1, 0, //
+                                       0, 0, 0, 1,//
+                               }), 4, 4});
+                       tests.add(new Object[] {new IdentityDictionary(4, 
true), //
+                               Dictionary.create(new double[] {//
+                                       1, 0, 0, 0, //
+                                       0, 1, 0, 0, //
+                                       0, 0, 1, 0, //
+                                       0, 0, 0, 1, //
+                                       0, 0, 0, 0}),
+                               5, 4});
+
                        create(tests, 30, 300, 0.2);
                }
                catch(Exception e) {
@@ -405,6 +437,170 @@ public class DictionaryTests {
                containsValueWithReference(1.0, getReference(nCol, 3241, -1.0, 
-1.0));
        }
 
+       @Test
+       public void equalsEl() {
+               assertEquals(a, b);
+       }
+
+       @Test
+       public void opRightMinus() {
+               BinaryOperator op = new 
BinaryOperator(Minus.getMinusFnObject());
+               double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 
132L);
+               opRight(op, vals, ColIndexFactory.create(0, nCol));
+       }
+
+       @Test
+       public void opRightMinusNoCol() {
+               BinaryOperator op = new 
BinaryOperator(Minus.getMinusFnObject());
+               double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 
132L);
+               opRight(op, vals);
+       }
+
+       @Test
+       public void opRightMinusZero() {
+               BinaryOperator op = new 
BinaryOperator(Minus.getMinusFnObject());
+               double[] vals = new double[nCol];
+               opRight(op, vals, ColIndexFactory.create(0, nCol));
+       }
+
+       @Test
+       public void opRightDivOne() {
+               BinaryOperator op = new 
BinaryOperator(Divide.getDivideFnObject());
+               double[] vals = new double[nCol];
+               Arrays.fill(vals, 1);
+               opRight(op, vals, ColIndexFactory.create(0, nCol));
+       }
+
+       @Test
+       public void opRightDiv() {
+               BinaryOperator op = new 
BinaryOperator(Divide.getDivideFnObject());
+               double[] vals = TestUtils.generateTestVector(nCol, -1, 1, 1.0, 
232L);
+               opRight(op, vals, ColIndexFactory.create(0, nCol));
+       }
+
+       private void opRight(BinaryOperator op, double[] vals, IColIndex cols) {
+               IDictionary aa = a.binOpRight(op, vals, cols);
+               IDictionary bb = b.binOpRight(op, vals, cols);
+               compare(aa, bb, nRow, nCol);
+       }
+
+       private void opRight(BinaryOperator op, double[] vals) {
+               IDictionary aa = a.binOpRight(op, vals);
+               IDictionary bb = b.binOpRight(op, vals);
+               compare(aa, bb, nRow, nCol);
+       }
+
+       @Test
+       public void testAddToEntry1() {
+               double[] ret1 = new double[nCol];
+               a.addToEntry(ret1, 0, 0, nCol);
+               double[] ret2 = new double[nCol];
+               b.addToEntry(ret2, 0, 0, nCol);
+               assertTrue(Arrays.equals(ret1, ret2));
+       }
+
+       @Test
+       public void testAddToEntry2() {
+               double[] ret1 = new double[nCol * 2];
+               a.addToEntry(ret1, 0, 1, nCol);
+               double[] ret2 = new double[nCol * 2];
+               b.addToEntry(ret2, 0, 1, nCol);
+               assertTrue(Arrays.equals(ret1, ret2));
+       }
+
+       @Test
+       public void testAddToEntry3() {
+               double[] ret1 = new double[nCol * 3];
+               a.addToEntry(ret1, 0, 2, nCol);
+               double[] ret2 = new double[nCol * 3];
+               b.addToEntry(ret2, 0, 2, nCol);
+               assertTrue(Arrays.equals(ret1, ret2));
+       }
+
+       @Test
+       public void testAddToEntry4() {
+               if(a.getNumberOfValues(nCol) > 2) {
+
+                       double[] ret1 = new double[nCol * 3];
+                       a.addToEntry(ret1, 2, 2, nCol);
+                       double[] ret2 = new double[nCol * 3];
+                       b.addToEntry(ret2, 2, 2, nCol);
+                       assertTrue(Arrays.equals(ret1, ret2));
+               }
+       }
+
+       @Test
+       public void testAddToEntryVectorized1() {
+               try {
+                       double[] ret1 = new double[nCol * 3];
+                       a.addToEntryVectorized(ret1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
1, 2, 0, 1, 2, 0, 1, nCol);
+                       double[] ret2 = new double[nCol * 3];
+                       b.addToEntryVectorized(ret2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 
1, 2, 0, 1, 2, 0, 1, nCol);
+                       assertTrue(Arrays.equals(ret1, ret2));
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testAddToEntryVectorized2() {
+               try {
+
+                       if(a.getNumberOfValues(nCol) > 1) {
+                               double[] ret1 = new double[nCol * 3];
+                               a.addToEntryVectorized(ret1, 1, 0, 1, 0, 1, 0, 
1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol);
+                               double[] ret2 = new double[nCol * 3];
+                               b.addToEntryVectorized(ret2, 1, 0, 1, 0, 1, 0, 
1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol);
+                               assertTrue("Error: " + 
a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(),
+                                       Arrays.equals(ret1, ret2));
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testAddToEntryVectorized3() {
+               try {
+
+                       if(a.getNumberOfValues(nCol) > 2) {
+                               double[] ret1 = new double[nCol * 3];
+                               a.addToEntryVectorized(ret1, 1, 2, 1, 2, 1, 0, 
1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol);
+                               double[] ret2 = new double[nCol * 3];
+                               b.addToEntryVectorized(ret2, 1, 2, 1, 2, 1, 0, 
1, 0, 0, 1, 2, 0, 1, 2, 0, 1, nCol);
+                               assertTrue("Error: " + 
a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(),
+                                       Arrays.equals(ret1, ret2));
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void testAddToEntryVectorized4() {
+               try {
+
+                       if(a.getNumberOfValues(nCol) > 3) {
+                               double[] ret1 = new double[nCol * 57];
+                               a.addToEntryVectorized(ret1, 3, 3, 0, 3, 0, 2, 
0, 3, 20, 1, 12, 2, 10, 3, 6, 56, nCol);
+                               double[] ret2 = new double[nCol * 57];
+                               b.addToEntryVectorized(ret2, 3, 3, 0, 3, 0, 2, 
0, 3, 20, 1, 12, 2, 10, 3, 6, 56, nCol);
+                               assertTrue("Error: " + 
a.getClass().getSimpleName() + " " + b.getClass().getSimpleName(),
+                                       Arrays.equals(ret1, ret2));
+                       }
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
        public void containsValueWithReference(double value, double[] 
reference) {
                assertEquals(//
                        a.containsValueWithReference(value, reference), //
@@ -412,9 +608,37 @@ public class DictionaryTests {
        }
 
        private static void compare(IDictionary a, IDictionary b, int nRow, int 
nCol) {
-               for(int i = 0; i < nRow; i++)
-                       for(int j = 0; j < nCol; j++)
-                               assertEquals(a.getValue(i, j, nCol), 
b.getValue(i, j, nCol), 0.0001);
+               try {
+
+                       String errorM = a.getClass().getSimpleName() + " " + 
b.getClass().getSimpleName();
+                       for(int i = 0; i < nRow; i++)
+                               for(int j = 0; j < nCol; j++)
+                                       assertEquals(errorM, a.getValue(i, j, 
nCol), b.getValue(i, j, nCol), 0.0001);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
+       }
+
+       @Test
+       public void preaggValuesFromDense() {
+               try {
+
+                       final int nv = a.getNumberOfValues(nCol);
+                       IColIndex idc = ColIndexFactory.create(0, nCol);
+
+                       double[] bv = TestUtils.generateTestVector(nCol * nCol, 
-1, 1, 1.0, 321521L);
+
+                       IDictionary aa = a.preaggValuesFromDense(nv, idc, idc, 
bv, nCol);
+                       IDictionary bb = b.preaggValuesFromDense(nv, idc, idc, 
bv, nCol);
+
+                       compare(aa, bb, aa.getNumberOfValues(nCol), nCol);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail(e.getMessage());
+               }
        }
 
        public void productWithDefault(double retV, double[] def) {
diff --git 
a/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java 
b/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java
new file mode 100644
index 0000000000..6f80eb4f41
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/matrix/SparseFactory.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.matrix;
+
+import static org.junit.Assert.assertEquals;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockFactory;
+import org.junit.Test;
+
+public class SparseFactory {
+       protected static final Log LOG = 
LogFactory.getLog(SparseFactory.class.getName());
+
+       @Test
+       public void testCreateFromArray() {
+               double[] dense = new double[] {0, 0, 0, 1, 1, 1, 0, 0, 0};
+               SparseBlock sb = SparseBlockFactory.createFromArray(dense, 3, 
3);
+               
+               assertEquals(0, sb.get(0, 0), 0.0);
+               assertEquals(0, sb.get(1, 1), 1.0);
+               assertEquals(0, sb.get(2, 2), 0.0);
+       }
+}

Reply via email to