This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new cc8eb95135 [SYSTEMDS-2830] Functional Compression
cc8eb95135 is described below

commit cc8eb951358320a984ea6798d3d252b4eed5c80c
Author: wedenigt <[email protected]>
AuthorDate: Fri Jun 24 17:42:24 2022 +0200

    [SYSTEMDS-2830] Functional Compression
    
    This commit adds a new column group class for functional compression.
    Initial implementation covers a linear compression scheme.
    The new colgroup supports construction from matrix,
    most of the operations and tests.
    
    Closes #1634
    Closes #1645
---
 .../sysds/runtime/compress/colgroup/AColGroup.java |   5 +-
 .../runtime/compress/colgroup/ColGroupFactory.java |  26 +-
 .../colgroup/ColGroupLinearFunctional.java         | 665 +++++++++++++++++++++
 .../runtime/compress/colgroup/ColGroupSizes.java   |   9 +
 .../compress/colgroup/ColGroupUncompressed.java    |  30 +-
 .../runtime/compress/colgroup/ColGroupUtils.java   |  39 ++
 .../colgroup/functional/LinearRegression.java      |  75 +++
 .../compress/estim/CompressedSizeInfoColGroup.java |   2 +
 .../colgroup/ColGroupLinearFunctionalBase.java     | 252 ++++++++
 .../colgroup/ColGroupLinearFunctionalTest.java     | 346 +++++++++++
 .../compress/functional/LinearRegressionTests.java | 115 ++++
 11 files changed, 1526 insertions(+), 38 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
index 45a0d62df7..557c0269b3 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AColGroup.java
@@ -52,7 +52,7 @@ public abstract class AColGroup implements Serializable {
 
        /** Public super types of compression ColGroups supported */
        public static enum CompressionType {
-               UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR, 
DeltaDDC
+               UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCFOR, DDCFOR, 
DeltaDDC, LinearFunctional;
        }
 
        /**
@@ -61,7 +61,8 @@ public abstract class AColGroup implements Serializable {
         * Protected such that outside the ColGroup package it should be 
unknown which specific subtype is used.
         */
        protected static enum ColGroupType {
-               UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCSingle, 
SDCSingleZeros, SDCZeros, SDCFOR, DDCFOR, DeltaDDC;
+               UNCOMPRESSED, RLE, OLE, DDC, CONST, EMPTY, SDC, SDCSingle, 
SDCSingleZeros, SDCZeros, SDCFOR, DDCFOR, DeltaDDC,
+               LinearFunctional;
        }
 
        /** The ColGroup Indexes contained in the ColGroup */
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
index 8a73ec1b54..c9a8e894c7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupFactory.java
@@ -43,6 +43,7 @@ import 
org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.DictionaryFactory;
 import 
org.apache.sysds.runtime.compress.colgroup.dictionary.MatrixBlockDictionary;
+import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
 import 
org.apache.sysds.runtime.compress.colgroup.insertionsort.AInsertionSorter;
 import 
org.apache.sysds.runtime.compress.colgroup.insertionsort.InsertionSorterFactory;
 import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
@@ -129,13 +130,13 @@ public class ColGroupFactory {
        }
 
        private List<AColGroup> compress() {
-               try{
+               try {
                        List<AColGroup> ret = compressExecute();
                        if(pool != null)
                                pool.shutdown();
-                               return ret;
+                       return ret;
                }
-               catch(Exception e ){
+               catch(Exception e) {
                        if(pool != null)
                                pool.shutdown();
                        throw new DMLCompressionException("Compression Failed", 
e);
@@ -359,6 +360,8 @@ public class ColGroupFactory {
                        return compressSDCFromSparseTransposedBlock(colIndexes, 
nrUniqueEstimate, cg.getTupleSparsity());
                else if(ct == CompressionType.DDC)
                        return directCompressDDC(colIndexes, cg);
+               else if(ct == CompressionType.LinearFunctional)
+                       return compressLinearFunctional(colIndexes, in, cs);
                else {
                        LOG.debug("Default slow path: " + ct + "  " + 
cs.transposed + " " + Arrays.toString(colIndexes));
                        final int numRows = cs.transposed ? in.getNumColumns() 
: in.getNumRows();
@@ -445,19 +448,20 @@ public class ColGroupFactory {
                if(dict == null)
                        // Again highly unlikely but possible.
                        return new ColGroupEmpty(colIndexes);
-               try{
+               try {
                        if(extra)
                                d.replace(fill, map.size());
-       
+
                        final int nUnique = map.size() + (extra ? 1 : 0);
-       
+
                        final AMapToData resData = MapToFactory.resize(d, 
nUnique);
                        return ColGroupDDC.create(colIndexes, nRow, dict, 
resData, null);
 
                }
-               catch(Exception e ){
+               catch(Exception e) {
                        ReaderColumnSelection reader = 
ReaderColumnSelection.createReader(in, colIndexes, cs.transposed, 0, nRow);
-                       throw new DMLCompressionException("direct compress DDC 
Multi col failed extra:" + extra + " with reader type:" + 
reader.getClass().getSimpleName(), e);
+                       throw new DMLCompressionException("direct compress DDC 
Multi col failed extra:" + extra + " with reader type:"
+                               + reader.getClass().getSimpleName(), e);
                }
        }
 
@@ -653,6 +657,12 @@ public class ColGroupFactory {
                return ColGroupSDCSingle.create(colIndexes, rlen, dict, 
defaultTuple, off, null);
        }
 
+       private static AColGroup compressLinearFunctional(int[] colIndexes, 
MatrixBlock in, CompressionSettings cs) {
+               double[] coefficients = LinearRegression.regressMatrixBlock(in, 
colIndexes, cs.transposed);
+               int numRows = cs.transposed ? in.getNumColumns() : 
in.getNumRows();
+               return ColGroupLinearFunctional.create(colIndexes, 
coefficients, numRows);
+       }
+
        private static AColGroup compressDDC(int[] colIndexes, int rlen, 
ABitmap ubm, CompressionSettings cs,
                double tupleSparsity) {
                boolean zeros = ubm.getNumOffsets() < rlen;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
new file mode 100644
index 0000000000..ecb516724a
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupLinearFunctional.java
@@ -0,0 +1,665 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+import java.util.Arrays;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.cost.ComputationCostEstimator;
+import org.apache.sysds.runtime.compress.utils.Util;
+import 
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.data.DenseBlock;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.functionobjects.Builtin;
+import org.apache.sysds.runtime.functionobjects.Divide;
+import org.apache.sysds.runtime.functionobjects.Minus;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
+import org.apache.sysds.runtime.matrix.data.LibMatrixMult;
+import org.apache.sysds.runtime.matrix.data.LibMatrixReorg;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.CMOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysds.utils.MemoryEstimates;
+
+public class ColGroupLinearFunctional extends AColGroupCompressed {
+
+       private static final long serialVersionUID = -2811822570758221975L;
+
+       // Needed for numerical robustness when checking if a value is 
contained in a column
+       private final static double CONTAINS_VALUE_THRESHOLD = 1e-6;
+
+       protected double[] _coefficents;
+
+       protected int _numRows;
+
+       /** Constructor for serialization */
+       protected ColGroupLinearFunctional() {
+               super();
+       }
+
+       /**
+        * Constructs a Linear Functional Column Group that compresses its 
content using a linear functional.
+        *
+        * @param colIndices  The Column indexes for the column group.
+        * @param coefficents Array where the first `colIndices.length` entries 
are the intercepts and the next
+        *                    `colIndices.length` entries are the slopes
+        * @param numRows     Number of rows encoded within this column group.
+        */
+       private ColGroupLinearFunctional(int[] colIndices, double[] 
coefficents, int numRows) {
+               super(colIndices);
+               this._coefficents = coefficents;
+               this._numRows = numRows;
+       }
+
+       /**
+        * Generate a linear functional column group.
+        *
+        * @param colIndices  The specific column indexes that is contained in 
this column group.
+        * @param coefficents Array where the first `colIndices.length` entries 
are the intercepts and the next
+        *                    `colIndices.length` entries are the slopes
+        * @param numRows     Number of rows encoded within this column group.
+        * @return A LinearFunctional column group.
+        */
+       public static AColGroup create(int[] colIndices, double[] coefficents, 
int numRows) {
+               if(coefficents.length != 2 * colIndices.length)
+                       throw new DMLCompressionException("Invalid size of 
values compared to columns");
+
+               boolean allSlopesConstant = true;
+               for(int j = 0; j < colIndices.length; j++) {
+                       if(coefficents[colIndices.length + j] != 0) {
+                               allSlopesConstant = false;
+                               break;
+                       }
+               }
+
+               if(allSlopesConstant) {
+                       boolean allInterceptsZero = true;
+                       for(int j = 0; j < colIndices.length; j++) {
+                               if(coefficents[j] != 0) {
+                                       allInterceptsZero = false;
+                                       break;
+                               }
+                       }
+
+                       if(allInterceptsZero)
+                               return new ColGroupEmpty(colIndices);
+                       else {
+                               double[] intercepts = new 
double[colIndices.length];
+                               System.arraycopy(coefficents, 0, intercepts, 0, 
colIndices.length);
+                               return ColGroupConst.create(colIndices, 
intercepts);
+                       }
+               }
+               else
+                       return new ColGroupLinearFunctional(colIndices, 
coefficents, numRows);
+       }
+
+       public double getInterceptForColumn(int colIdx) {
+               return this._coefficents[colIdx];
+       }
+
+       public double getSlopeForColumn(int colIdx) {
+               return this._coefficents[this._colIndexes.length + colIdx];
+       }
+
+       public int getNumRows() {
+               return _numRows;
+       }
+
+       @Override
+       protected void computeRowMxx(double[] c, Builtin builtin, int rl, int 
ru, double[] preAgg) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public CompressionType getCompType() {
+               return CompressionType.LinearFunctional;
+       }
+
+       @Override
+       public ColGroupType getColGroupType() {
+               return ColGroupType.LinearFunctional;
+       }
+
+       @Override
+       public double getMin() {
+               double min = Double.POSITIVE_INFINITY;
+
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       if(slope >= 0 && (intercept + slope) < min) {
+                               min = intercept + slope;
+                       }
+                       else if(slope < 0 && (intercept + _numRows * slope) < 
min) {
+                               min = intercept + _numRows * slope;
+                       }
+               }
+
+               return min;
+       }
+
+       @Override
+       public double getMax() {
+               double max = Double.NEGATIVE_INFINITY;
+
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       if(slope >= 0 && (intercept + _numRows * slope) > max) {
+                               max = intercept + _numRows * slope;
+                       }
+                       else if(slope < 0 && (intercept + slope) > max) {
+                               max = intercept + slope;
+                       }
+               }
+
+               return max;
+       }
+
+       @Override
+       public void decompressToDenseBlock(DenseBlock db, int rl, int ru, int 
offR, int offC) {
+               final int nCol = getNumCols();
+               final double[] accumulators = new double[nCol];
+
+               // copy intercepts into accumulators array
+               System.arraycopy(_coefficents, 0, accumulators, 0, nCol);
+
+               int offT = rl + offR;
+               for(int row = rl; row < ru; row++, offT++) {
+                       final double[] c = db.values(offT);
+                       final int off = db.pos(offT) + offC;
+
+                       for(int j = 0; j < nCol; j++) {
+                               accumulators[j] += getSlopeForColumn(j);
+                               c[off + _colIndexes[j]] += accumulators[j];
+                       }
+               }
+       }
+
+       @Override
+       public void decompressToSparseBlock(SparseBlock ret, int rl, int ru, 
int offR, int offC) {
+               final int nCol = _colIndexes.length;
+               for(int i = rl, offT = rl + offR; i < ru; i++, offT++) {
+                       for(int j = 0; j < nCol; j++)
+                               ret.append(offT, _colIndexes[j] + offC, 
getIdx(i, j));
+               }
+       }
+
+       @Override
+       public double getIdx(int r, int colIdx) {
+               return getInterceptForColumn(colIdx) + 
getSlopeForColumn(colIdx) * (r + 1);
+       }
+
+       @Override
+       public AColGroup scalarOperation(ScalarOperator op) {
+               double[] coefficients_new = new double[_coefficents.length];
+
+               if(op.fn instanceof Plus || op.fn instanceof Minus) {
+                       // copy slopes into new array, since they do not change 
if we add/subtract a scalar
+                       System.arraycopy(_coefficents, 0, coefficients_new, 
getNumCols(), getNumCols());
+                       // absorb plus/minus into intercept
+                       for(int col = 0; col < getNumCols(); col++)
+                               coefficients_new[col] = 
op.executeScalar(_coefficents[col]);
+
+                       return create(_colIndexes, coefficients_new, _numRows);
+               }
+               else if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+                       // multiply/divide changes intercepts & slopes
+                       for(int j = 0; j < _coefficents.length; j++)
+                               coefficients_new[j] = 
op.executeScalar(_coefficents[j]);
+
+                       return create(_colIndexes, coefficients_new, _numRows);
+               }
+               else {
+                       throw new NotImplementedException();
+               }
+
+       }
+
+       @Override
+       public AColGroup unaryOperation(UnaryOperator op) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public AColGroup binaryRowOpLeft(BinaryOperator op, double[] v, boolean 
isRowSafe) {
+               return binaryRowOp(op, v, isRowSafe, true);
+       }
+
+       @Override
+       public AColGroup binaryRowOpRight(BinaryOperator op, double[] v, 
boolean isRowSafe) {
+               return binaryRowOp(op, v, isRowSafe, false);
+       }
+
+       private AColGroup binaryRowOp(BinaryOperator op, double[] v, boolean 
isRowSafe, boolean left) {
+               double[] coefficients_new = new double[_coefficents.length];
+
+               if(op.fn instanceof Plus || op.fn instanceof Minus) {
+                       // copy slopes into new array, since they do not change 
if we add/subtract a scalar
+                       System.arraycopy(_coefficents, 0, coefficients_new, 
getNumCols(), getNumCols());
+
+                       // absorb plus/minus into intercept
+                       if(left) {
+                               for(int col = 0; col < getNumCols(); col++)
+                                       coefficients_new[col] = 
op.fn.execute(v[_colIndexes[col]], _coefficents[col]);
+                       }
+                       else {
+                               for(int col = 0; col < getNumCols(); col++)
+                                       coefficients_new[col] = 
op.fn.execute(_coefficents[col], v[_colIndexes[col]]);
+                       }
+
+                       return create(_colIndexes, coefficients_new, _numRows);
+               }
+               else if(op.fn instanceof Multiply || op.fn instanceof Divide) {
+                       // multiply/divide changes intercepts & slopes
+                       if(left) {
+                               for(int col = 0; col < getNumCols(); col++) {
+                                       // update intercept
+                                       coefficients_new[col] = 
op.fn.execute(v[_colIndexes[col]], _coefficents[col]);
+                                       // update slope
+                                       coefficients_new[col + getNumCols()] = 
op.fn.execute(v[_colIndexes[col]],
+                                               _coefficents[col + 
getNumCols()]);
+                               }
+                       }
+                       else {
+                               for(int col = 0; col < getNumCols(); col++) {
+                                       // update intercept
+                                       coefficients_new[col] = 
op.fn.execute(_coefficents[col], v[_colIndexes[col]]);
+                                       // update slope
+                                       coefficients_new[col + getNumCols()] = 
op.fn.execute(_coefficents[col + getNumCols()],
+                                               v[_colIndexes[col]]);
+                               }
+                       }
+
+                       return create(_colIndexes, coefficients_new, _numRows);
+               }
+               else {
+                       throw new NotImplementedException();
+               }
+       }
+
+       @Override
+       protected double computeMxx(double c, Builtin builtin) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected void computeColMxx(double[] c, Builtin builtin) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected void computeSum(double[] c, int nRows) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       c[0] += nRows * (intercept + (nRows + 1) * slope / 2);
+               }
+       }
+
+       @Override
+       public void computeColSums(double[] c, int nRows) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       c[_colIndexes[col]] += nRows * (intercept + (nRows + 1) 
* slope / 2);
+               }
+       }
+
+       @Override
+       protected void computeSumSq(double[] c, int nRows) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       // Given the intercept and slope of a column, the sum 
of the squared components of the column reads
+                       // \sum_{i=1}^n (intercept + slope * i)^2
+                       // We get a closed form expression by expanding the 
binomial and using the fact that
+                       // \sum_{i=1}^n i = n(n+1)/2 and \sum_{i=1}^n i^2 = 
n(n+1)(2n+1)/6
+
+                       c[0] += nRows * (Math.pow(intercept, 2) + (nRows + 1) * 
slope * intercept +
+                               (nRows + 1) * (2 * nRows + 1) * Math.pow(slope, 
2) / 6);
+               }
+       }
+
+       @Override
+       protected void computeColSumsSq(double[] c, int nRows) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+                       c[_colIndexes[col]] += nRows * (Math.pow(intercept, 2) 
+ (nRows + 1) * slope * intercept +
+                               (nRows + 1) * (2 * nRows + 1) * Math.pow(slope, 
2) / 6);
+               }
+       }
+
+       @Override
+       protected void computeRowSums(double[] c, int rl, int ru, double[] 
preAgg) {
+               double intercept_sum = preAgg[0];
+               double slope_sum = preAgg[1];
+
+               for(int rix = rl; rix < ru; rix++)
+                       c[rix] += intercept_sum + slope_sum * (rix + 1);
+       }
+
+       @Override
+       public int getNumValues() {
+               return 0;
+       }
+
+       @Override
+       public AColGroup rightMultByMatrix(MatrixBlock right) {
+               final int nColR = right.getNumColumns();
+               final int[] outputCols = Util.genColsIndices(nColR);
+
+               // TODO: add specialization for sparse/dense matrix blocks
+               MatrixBlock result = new MatrixBlock(_numRows, nColR, false);
+               for(int j = 0; j < nColR; j++) {
+                       double bias_accum = 0.0;
+                       double slope_accum = 0.0;
+
+                       for(int c = 0; c < _colIndexes.length; c++) {
+                               bias_accum += right.getValue(_colIndexes[c], j) 
* getInterceptForColumn(c);
+                               slope_accum += right.getValue(_colIndexes[c], 
j) * getSlopeForColumn(c);
+                       }
+
+                       for(int r = 0; r < _numRows; r++) {
+                               result.setValue(r, j, bias_accum + (r + 1) * 
slope_accum);
+                       }
+               }
+
+               // returns an uncompressed ColGroup
+               return ColGroupUncompressed.create(result, outputCols);
+       }
+
+       @Override
+       public void tsmm(double[] ret, int numColumns, int nRows) {
+               // runs in O(tCol^2) since dot-products take O(1) time to 
compute when both vectors are linearly compressed
+               final int tCol = _colIndexes.length;
+
+               final double sumIndices = nRows * (nRows + 1) / 2.0;
+               final double sumSquaredIndices = nRows * (nRows + 1) * (2 * 
nRows + 1) / 6.0;
+               for(int row = 0; row < tCol; row++) {
+                       final double alpha1 = nRows * 
getInterceptForColumn(row) + sumIndices * getSlopeForColumn(row);
+                       final double alpha2 = sumIndices * 
getInterceptForColumn(row) + sumSquaredIndices * getSlopeForColumn(row);
+                       final int offRet = _colIndexes[row] * numColumns;
+                       for(int col = row; col < tCol; col++) {
+                               ret[offRet + _colIndexes[col]] += alpha1 * 
getInterceptForColumn(col) + alpha2 * getSlopeForColumn(col);
+                       }
+               }
+       }
+
+       @Override
+       public void leftMultByMatrixNoPreAgg(MatrixBlock matrix, MatrixBlock 
result, int rl, int ru, int cl, int cu) {
+               throw new DMLCompressionException("This method should never be 
called");
+       }
+
+       @Override
+       public void leftMultByAColGroup(AColGroup lhs, MatrixBlock result) {
+               if(lhs instanceof ColGroupEmpty) {
+                       return;
+               }
+
+               MatrixBlock tmpRet = new MatrixBlock(lhs.getNumCols(), 
_colIndexes.length, 0);
+
+               if(lhs instanceof ColGroupUncompressed) {
+                       ColGroupUncompressed lhsUC = (ColGroupUncompressed) lhs;
+                       int numRowsLeft = lhsUC.getData().getNumRows();
+
+                       double[] colSumsAndWeightedColSums = new double[2 * 
lhs.getNumCols()];
+                       for(int j = 0, offTmp = 0; j < lhs.getNumCols(); j++, 
offTmp += 2) {
+                               for(int i = 0; i < numRowsLeft; i++) {
+                                       colSumsAndWeightedColSums[offTmp] += 
lhs.getIdx(i, j);
+                                       colSumsAndWeightedColSums[offTmp + 1] 
+= (i + 1) * lhs.getIdx(i, j);
+                               }
+                       }
+
+                       MatrixBlock sumMatrix = new 
MatrixBlock(lhs.getNumCols(), 2, colSumsAndWeightedColSums);
+                       MatrixBlock coefficientMatrix = new MatrixBlock(2, 
_colIndexes.length, _coefficents);
+
+                       LibMatrixMult.matrixMult(sumMatrix, coefficientMatrix, 
tmpRet);
+               }
+               else if(lhs instanceof ColGroupLinearFunctional) {
+                       ColGroupLinearFunctional lhsLF = 
(ColGroupLinearFunctional) lhs;
+
+                       final double sumIndices = _numRows * (_numRows + 1) / 
2.0;
+                       final double sumSquaredIndices = _numRows * (_numRows + 
1) * (2 * _numRows + 1) / 6.0;
+
+                       MatrixBlock weightMatrix = new MatrixBlock(2, 2,
+                               new double[] {_numRows, sumIndices, sumIndices, 
sumSquaredIndices});
+                       MatrixBlock coefficientMatrixLhs = new MatrixBlock(2, 
lhsLF._colIndexes.length, lhsLF._coefficents);
+                       MatrixBlock coefficientMatrixRhs = new MatrixBlock(2, 
_colIndexes.length, _coefficents);
+
+                       coefficientMatrixLhs = 
LibMatrixReorg.transposeInPlace(coefficientMatrixLhs,
+                               InfrastructureAnalyzer.getLocalParallelism());
+
+                       // We simply compute a matrix multiplication chain in 
coefficient space, i.e.,
+                       // t(L) %*% R = t(coeff(L)) %*% W %*% coeff(R)
+                       // where W is a weight matrix capturing the size of the 
shared dimension (weightMatrix above)
+                       // and coeff(X) denotes the 2 x n matrix of the m x n 
matrix X.
+                       MatrixBlock tmp = new MatrixBlock(lhs.getNumCols(), 2, 
false);
+                       LibMatrixMult.matrixMult(coefficientMatrixLhs, 
weightMatrix, tmp);
+                       LibMatrixMult.matrixMult(tmp, coefficientMatrixRhs, 
tmpRet);
+               }
+               else if(lhs instanceof APreAgg) {
+                       // TODO: implement
+                       throw new NotImplementedException();
+               }
+               else {
+                       throw new NotImplementedException();
+               }
+
+               ColGroupUtils.copyValuesColGroupMatrixBlocks(lhs, this, tmpRet, 
result);
+       }
+
+       @Override
+       public void tsmmAColGroup(AColGroup other, MatrixBlock result) {
+               throw new DMLCompressionException("Should not be called");
+       }
+
+       @Override
+       protected AColGroup sliceSingleColumn(int idx) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       protected AColGroup sliceMultiColumns(int idStart, int idEnd, int[] 
outputCols) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public AColGroup copy() {
+               return this;
+       }
+
+       @Override
+       public boolean containsValue(double pattern) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       if(colContainsValue(col, pattern))
+                               return true;
+               }
+
+               return false;
+       }
+
+       public boolean colContainsValue(int col, double pattern) {
+               if(pattern == getInterceptForColumn(col))
+                       return Math.abs(getSlopeForColumn(col)) < 
CONTAINS_VALUE_THRESHOLD;
+
+               double div = (pattern - getInterceptForColumn(col)) / 
getSlopeForColumn(col);
+               double diffToNextInt = Math.min(Math.ceil(div) - div, div - 
Math.floor(div));
+
+               return Math.abs(diffToNextInt) < CONTAINS_VALUE_THRESHOLD;
+       }
+
+       @Override
+       public long getNumberNonZeros(int nRows) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public AColGroup replace(double pattern, double replace) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void readFields(DataInput in) throws IOException {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public void write(DataOutput out) throws IOException {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public long getExactSizeOnDisk() {
+               long ret = super.getExactSizeOnDisk();
+               ret += MemoryEstimates.doubleArrayCost(_coefficents.length);
+               ret += 4L; // _numRows
+               return ret;
+       }
+
+       @Override
+       protected void computeProduct(double[] c, int nRows) {
+               if(containsValue(0)) {
+                       c[0] = 0;
+                       return;
+               }
+
+               for(int col = 0; col < getNumCols(); col++) {
+                       double intercept = getInterceptForColumn(col);
+                       double slope = getSlopeForColumn(col);
+
+                       for(int i = 0; i < nRows; i++) {
+                               c[0] *= intercept + slope * (i + 1);
+                       }
+               }
+       }
+
+       @Override
+       protected void computeRowProduct(double[] c, int rl, int ru, double[] 
preAgg) {
+               for(int rix = rl; rix < ru; rix++) {
+                       for(int col = 0; col < getNumCols(); col++) {
+                               double intercept = getInterceptForColumn(col);
+                               double slope = getSlopeForColumn(col);
+                               c[rix] *= intercept + slope * (rix + 1);
+                       }
+               }
+       }
+
+       @Override
+       protected void computeColProduct(double[] c, int nRows) {
+               for(int col = 0; col < getNumCols(); col++) {
+                       if(colContainsValue(col, 0)) {
+                               c[_colIndexes[col]] = 0;
+                       }
+                       else {
+                               double intercept = getInterceptForColumn(col);
+                               double slope = getSlopeForColumn(col);
+                               for(int i = 0; i < nRows; i++) {
+                                       c[_colIndexes[col]] *= intercept + 
slope * (i + 1);
+                               }
+                       }
+               }
+       }
+
+       @Override
+       protected double[] preAggSumRows() {
+               double intercept_sum = 0;
+               for(int col = 0; col < getNumCols(); col++)
+                       intercept_sum += getInterceptForColumn(col);
+
+               double slope_sum = 0;
+               for(int col = 0; col < getNumCols(); col++)
+                       slope_sum += getSlopeForColumn(col);
+
+               return new double[] {intercept_sum, slope_sum};
+       }
+
+       @Override
+       protected double[] preAggSumSqRows() {
+               return null;
+       }
+
+       @Override
+       protected double[] preAggProductRows() {
+               return null;
+       }
+
+       @Override
+       protected double[] preAggBuiltinRows(Builtin builtin) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public long estimateInMemorySize() {
+               return 
ColGroupSizes.estimateInMemorySizeLinearFunctional(getNumCols());
+       }
+
+       @Override
+       public CM_COV_Object centralMoment(CMOperator op, int nRows) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public AColGroup rexpandCols(int max, boolean ignore, boolean cast, int 
nRows) {
+               throw new NotImplementedException();
+       }
+
+       @Override
+       public double getCost(ComputationCostEstimator e, int nRows) {
+               LOG.warn("Cost calculation for LinearFunctional ColGroup is not 
precise");
+               final int nCols = getNumCols();
+               // We store 2 tuples in this column group, namely intercepts 
and slopes
+               return e.getCost(nRows, nRows, nCols, 2, 1.0);
+       }
+
+       @Override
+       public String toString() {
+               StringBuilder sb = new StringBuilder();
+               sb.append(super.toString());
+               sb.append(String.format("\n%15s", " Intercepts: " + 
Arrays.toString(getIntercepts())));
+               sb.append(String.format("\n%15s", " Slopes: " + 
Arrays.toString(getSlopes())));
+               return sb.toString();
+       }
+
+       public double[] getIntercepts() {
+               double[] intercepts = new double[getNumCols()];
+               for(int col = 0; col < getNumCols(); col++)
+                       intercepts[col] = getInterceptForColumn(col);
+
+               return intercepts;
+       }
+
+       public double[] getSlopes() {
+               double[] slopes = new double[getNumCols()];
+               for(int col = 0; col < getNumCols(); col++)
+                       slopes[col] = getSlopeForColumn(col);
+
+               return slopes;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
index dc69ae7e59..5f273e9242 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSizes.java
@@ -107,4 +107,13 @@ public final class ColGroupSizes {
                size += MatrixBlock.estimateSizeInMemory(nrRows, nrColumns, 
(nrColumns > 1) ? sparsity : 1);
                return size;
        }
+
+       public static long estimateInMemorySizeLinearFunctional(int nrColumns) {
+               long size = 0;
+               // Since the Object is a col group the overhead from the Memory 
Size group is added
+               size += estimateInMemorySizeGroup(nrColumns);
+               size += MemoryEstimates.doubleArrayCost(2L * nrColumns); // 
coefficients; per column, we store 2 doubles (slope & intercept)
+               size += 4; // _numRows
+               return size;
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
index 862ee09bb1..0d78b95dc1 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUncompressed.java
@@ -351,7 +351,7 @@ public class ColGroupUncompressed extends AColGroup {
                }
        }
 
-       // @Override
+//      @Override
        public void leftMultByMatrix(MatrixBlock matrix, MatrixBlock result, 
int rl, int ru) {
 
                final MatrixBlock tmpRet = new MatrixBlock(ru - rl, 
_data.getNumColumns(), false);
@@ -560,33 +560,7 @@ public class ColGroupUncompressed extends AColGroup {
                                LibMatrixMult.matrixMult(transposed, 
this._data, tmpRet);
                        }
 
-                       final double[] resV = result.getDenseBlockValues();
-                       if(tmpRet.isEmpty())
-                               return;
-                       else if(tmpRet.isInSparseFormat()) {
-                               SparseBlock sb = tmpRet.getSparseBlock();
-                               for(int row = 0; row < lhs._colIndexes.length; 
row++) {
-                                       if(sb.isEmpty(row))
-                                               continue;
-                                       final int apos = sb.pos(row);
-                                       final int alen = sb.size(row) + apos;
-                                       final int[] aix = sb.indexes(row);
-                                       final double[] avals = sb.values(row);
-                                       final int offRes = lhs._colIndexes[row] 
* result.getNumColumns();
-                                       for(int col = apos; col < alen; col++)
-                                               resV[offRes + 
_colIndexes[aix[col]]] += avals[col];
-                               }
-                       }
-                       else {
-                               double[] tmpRetV = tmpRet.getDenseBlockValues();
-                               for(int row = 0; row < lhs._colIndexes.length; 
row++) {
-                                       final int offRes = lhs._colIndexes[row] 
* result.getNumColumns();
-                                       final int offTmp = 
lhs._colIndexes.length * row;
-                                       for(int col = 0; col < 
_colIndexes.length; col++) {
-                                               resV[offRes + _colIndexes[col]] 
+= tmpRetV[offTmp + col];
-                                       }
-                               }
-                       }
+                       ColGroupUtils.copyValuesColGroupMatrixBlocks(lhs, this, 
tmpRet, result);
                }
                else if(lhs instanceof APreAgg) {
                        // throw new NotImplementedException();
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
index f33d2dee29..55b7be3243 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupUtils.java
@@ -19,7 +19,9 @@
 
 package org.apache.sysds.runtime.compress.colgroup;
 
+import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 
 public class ColGroupUtils {
@@ -62,4 +64,41 @@ public class ColGroupUtils {
                return ret;
        }
 
+       /**
+        * Copy values from tmpResult into correct positions of result 
(according to colIndexes in lhs and rhs)
+        *
+        * @param lhs        Left ColumnGroup
+        * @param rhs        Right ColumnGroup
+        * @param tmpResult  The matrix block to move values from
+        * @param result     The result matrix block to move values to
+        */
+       protected final static void copyValuesColGroupMatrixBlocks(AColGroup 
lhs, AColGroup rhs, MatrixBlock tmpResult, MatrixBlock result) {
+               final double[] resV = result.getDenseBlockValues();
+               if(tmpResult.isEmpty())
+                       return;
+               else if(tmpResult.isInSparseFormat()) {
+                       SparseBlock sb = tmpResult.getSparseBlock();
+                       for(int row = 0; row < lhs._colIndexes.length; row++) {
+                               if(sb.isEmpty(row))
+                                       continue;
+                               final int apos = sb.pos(row);
+                               final int alen = sb.size(row) + apos;
+                               final int[] aix = sb.indexes(row);
+                               final double[] avals = sb.values(row);
+                               final int offRes = lhs._colIndexes[row] * 
result.getNumColumns();
+                               for(int col = apos; col < alen; col++)
+                                       resV[offRes + 
rhs._colIndexes[aix[col]]] += avals[col];
+                       }
+               }
+               else {
+                       double[] tmpRetV = tmpResult.getDenseBlockValues();
+                       for(int row = 0; row < lhs.getNumCols(); row++) {
+                               final int offRes = lhs._colIndexes[row] * 
result.getNumColumns();
+                               final int offTmp = row * rhs.getNumCols();
+                               for(int col = 0; col < rhs.getNumCols(); col++) 
{
+                                       resV[offRes + rhs._colIndexes[col]] += 
tmpRetV[offTmp + col];
+                               }
+                       }
+               }
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
new file mode 100644
index 0000000000..42a7cb7a51
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/functional/LinearRegression.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.functional;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.readers.ReaderColumnSelection;
+import org.apache.sysds.runtime.compress.utils.DblArray;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+public class LinearRegression {
+
+       public static double[] regressMatrixBlock(MatrixBlock rawBlock, int[] 
colIndexes, boolean transposed) {
+               final int nRows = transposed ? rawBlock.getNumColumns() : 
rawBlock.getNumRows();
+
+               if(nRows <= 1)
+                       throw new DMLCompressionException("At least 2 data 
points are required to fit a linear function.");
+               else if(colIndexes.length < 1)
+                       throw new DMLCompressionException("At least 1 column 
must be specified for compression.");
+
+               // the first `colIndexes.length` entries represent the 
intercepts (beta0)
+               // the second `colIndexes.length` entries represent the slopes 
(beta1)
+               double[] beta0_beta1 = new double[2 * colIndexes.length];
+
+               double s_xx = (Math.pow(nRows, 3) - nRows) / 12;
+               double x_bar = (double) (nRows + 1) / 2;
+
+               double[] colSums = new double[colIndexes.length];
+               double[] weightedColSums = new double[colIndexes.length];
+
+               if(colIndexes.length == 1) {
+                       for (int rowIdx = 0; rowIdx < nRows; rowIdx++) {
+                               double value = transposed ? 
rawBlock.getValue(colIndexes[0], rowIdx) : rawBlock.getValue(rowIdx, 
colIndexes[0]);
+                               colSums[0] += value;
+                               weightedColSums[0] += (rowIdx + 1) * value;
+                       }
+               } else {
+                       ReaderColumnSelection reader = 
ReaderColumnSelection.createReader(rawBlock, colIndexes, transposed);
+
+                       DblArray cellVals;
+                       while((cellVals = reader.nextRow()) != null) {
+                               int rowIdx = reader.getCurrentRowIndex() + 1;
+                               double[] row = cellVals.getData();
+
+                               for(int i = 0; i < colIndexes.length; i++) {
+                                       colSums[i] += row[i];
+                                       weightedColSums[i] += rowIdx * row[i];
+                               }
+                       }
+               }
+
+               for(int i = 0; i < colIndexes.length; i++) {
+                       beta0_beta1[colIndexes.length + i] = (-x_bar * 
colSums[i] + weightedColSums[i]) / s_xx;
+                       beta0_beta1[i] = (colSums[i] / nRows) - 
beta0_beta1[colIndexes.length + i] * x_bar;
+               }
+
+               return beta0_beta1;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
index 7c330c1a1e..3b0667bd19 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/CompressedSizeInfoColGroup.java
@@ -189,6 +189,8 @@ public class CompressedSizeInfoColGroup {
        private static long getCompressionSize(int numCols, CompressionType ct, 
EstimationFactors fact) {
                int nv;
                switch(ct) {
+                       case LinearFunctional:
+                               return 
ColGroupSizes.estimateInMemorySizeLinearFunctional(numCols);
                        case DeltaDDC: // TODO add proper extraction
                        case DDC:
                                nv = fact.numVals + (fact.numOffs < 
fact.numRows ? 1 : 0);
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
new file mode 100644
index 0000000000..96311c9111
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalBase.java
@@ -0,0 +1,252 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.colgroup;
+
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.EnumSet;
+import java.util.Random;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.CompressionSettings;
+import org.apache.sysds.runtime.compress.CompressionSettingsBuilder;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupFactory;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupLinearFunctional;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeEstimatorExact;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfo;
+import org.apache.sysds.runtime.compress.estim.CompressedSizeInfoColGroup;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+import org.junit.runners.Parameterized.Parameters;
+
+@RunWith(value = Parameterized.class)
+public abstract class ColGroupLinearFunctionalBase {
+
+       protected static final Log LOG = 
LogFactory.getLog(ColGroupLinearFunctionalBase.class.getName());
+       private final static Random random = new Random();
+       protected final AColGroup base;
+       protected final ColGroupLinearFunctional lin;
+       protected final AColGroup baseLeft;
+       protected final int nRowLeft;
+       protected final int nColLeft;
+
+       protected final int nRowRight;
+       protected final int nColRight;
+
+       protected final AColGroup cgLeft;
+       protected final ColGroupUncompressed cgRight;
+       protected final int nRow;
+       protected final double tolerance;
+
+       @Parameters
+       public static Collection<Object[]> data() {
+               ArrayList<Object[]> tests = new ArrayList<>();
+
+               try {
+                       addLinCases(tests);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("failed constructing tests");
+               }
+
+               return tests;
+       }
+
+       public ColGroupLinearFunctionalBase(AColGroup base, 
ColGroupLinearFunctional lin, AColGroup baseLeft,
+               AColGroup cgLeft, int nRowLeft, int nColLeft, int nRowRight, 
int nColRight, ColGroupUncompressed cgRight,
+               double tolerance) {
+               if(lin.getNumCols() != base.getNumCols())
+                       fail("Linearly compressed ColGroup and Base ColGroup 
must have same number of columns");
+
+               if(nRowLeft != lin.getNumRows())
+                       fail("Transposed left ColGroup and center ColGroup 
(`lin`) must have compatible dimensions");
+
+               int[] colIndices = lin.getColIndices();
+               if(colIndices[colIndices.length - 1] > nRowRight)
+                       fail("Right ColGroup must have at least as many rows as 
the largest column index of center ColGroup (`lin`)");
+
+               this.base = base;
+               this.lin = lin;
+               this.baseLeft = baseLeft;
+               this.nRowLeft = nRowLeft;
+               this.nColLeft = nColLeft;
+               this.nRowRight = nRowRight;
+               this.nColRight = nColRight;
+               this.cgLeft = cgLeft;
+               this.cgRight = cgRight;
+               this.tolerance = tolerance;
+               this.nRow = lin.getNumRows();
+       }
+
+       protected static void addLinCases(ArrayList<Object[]> tests) {
+               double[][] data = new double[][] {{1, 2, 3, 4, 5}, {-4, 2, 8, 
14, 20}};
+               double[][] dataRight = new double[][] {{1, -2, 23, 7}, {4, 11, 
-10, -2}};
+               double[][] dataLeft = new double[][] {{8, 3, 7, 12, -3}, {-1, 
8, 4, -2, -2}, {3, 4, 2, 0, -1}};
+               int[] colIndexesLeft = new int[] {0, 2};
+
+               double[][] dataLeftCompressed = new double[][] {{8, 4, 0, -4, 
-8}, {-1, 0, 1, 2, 3}};
+               int[] colIndexesLeftCompressed = new int[] {0};
+
+               tests
+                       .add(createInitParams(data, true, null, dataLeft, true, 
colIndexesLeft, false, dataRight, true, null, 0.001));
+
+               tests.add(createInitParams(data, true, null, 
dataLeftCompressed, true, colIndexesLeftCompressed, true, dataRight,
+                       true, null, 0.001));
+
+               tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}}, 
true, null, null, true, null, true, dataRight, true,
+                       null, 0.001));
+
+               tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}}, 
true, null, null, true, null, true, dataRight, true,
+                       null, 0.001));
+
+               tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}, {1, 
1, 1, 1, 1}, {4, 2, 4, 2, 4}}, true,
+                       new int[] {0, 1}, null, true, null, true, dataRight, 
true, null, 0.001));
+
+               tests.add(createInitParams(new double[][] {{1, 2, 3, 4, 5}, 
{-1, -2, -3, -4, -5}}, true, null, null, true, null,
+                       true, dataRight, true, null, 0.001));
+
+               double[][] randomData = generateTestMatrixLinear(80, 100, -100, 
100, -1, 1, 42);
+               double[][] randomDataLeft = generateTestMatrixLinear(80, 50, 
-100, 100, -1, 1, 43);
+               double[][] randomDataRight = generateTestMatrixLinear(100, 500, 
-100, 100, -1, 1, 44);
+
+               tests.add(createInitParams(randomData, false, null, 
randomDataLeft, false, null, true, randomDataRight, true,
+                       null, 0.001));
+       }
+
+       protected static Object[] createInitParams(double[][] data, boolean 
isTransposed, int[] colIndexes,
+               double[][] dataLeft, boolean transposedLeft, int[] 
colIndexesLeft, boolean linCompressLeft, double[][] dataRight,
+               boolean transposedRight, int[] colIndexesRight, double 
tolerance) {
+               if(dataLeft == null)
+                       dataLeft = data;
+
+               // int nRow = isTransposed ? data[0].length : data.length;
+               int nCol = isTransposed ? data.length : data[0].length;
+               int nRowLeft = transposedLeft ? dataLeft[0].length : 
dataLeft.length;
+               int nColLeft = transposedLeft ? dataLeft.length : 
dataLeft[0].length;
+               int nRowRight = transposedRight ? dataRight[0].length : 
dataRight.length;
+               int nColRight = transposedRight ? dataRight.length : 
dataRight[0].length;
+
+               if(colIndexes == null)
+                       colIndexes = Util.genColsIndices(nCol);
+
+               if(colIndexesLeft == null)
+                       colIndexesLeft = Util.genColsIndices(nColLeft);
+
+               if(colIndexesRight == null)
+                       colIndexesRight = Util.genColsIndices(nColRight);
+
+               return new Object[] {cgUncompressed(data, colIndexes, 
isTransposed),
+                       cgLinCompressed(data, colIndexes, isTransposed), 
cgUncompressed(dataLeft, colIndexesLeft, transposedLeft),
+                       linCompressLeft ? cgLinCompressed(dataLeft, 
colIndexesLeft, transposedLeft) : cgUncompressed(dataLeft,
+                               colIndexesLeft, transposedLeft),
+                       nRowLeft, nColLeft, nRowRight, nColRight, 
cgUncompressed(dataRight, colIndexesRight, transposedRight),
+                       tolerance};
+       }
+
+       protected static AColGroup cgUncompressed(double[][] data, int[] 
colIndexes, boolean isTransposed) {
+               MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+               return createColGroup(mbt, colIndexes, isTransposed, 
AColGroup.CompressionType.UNCOMPRESSED);
+       }
+
+       protected static AColGroup cgLinCompressed(double[][] data, boolean 
isTransposed) {
+               final int numCols = isTransposed ? data.length : data[0].length;
+               return cgLinCompressed(data, Util.genColsIndices(numCols), 
isTransposed);
+       }
+
+       protected static AColGroup cgLinCompressed(double[][] data, int[] 
colIndexes, boolean isTransposed) {
+               MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+               return createColGroup(mbt, colIndexes, isTransposed, 
AColGroup.CompressionType.LinearFunctional);
+       }
+
+       public static AColGroup createColGroup(MatrixBlock mbt, int[] 
colIndexes, boolean isTransposed,
+               AColGroup.CompressionType cgType) {
+               CompressionSettings cs = new 
CompressionSettingsBuilder().setSamplingRatio(1.0)
+                       .setValidCompressions(EnumSet.of(cgType)).create();
+               cs.transposed = isTransposed;
+
+               final CompressedSizeInfoColGroup cgi = new 
CompressedSizeEstimatorExact(mbt, cs).getColGroupInfo(colIndexes);
+               CompressedSizeInfo csi = new CompressedSizeInfo(cgi);
+               return ColGroupFactory.compressColGroups(mbt, csi, cs, 
1).get(0);
+       }
+
+       public static double[] generateLinearColumn(double intercept, double 
slope, int length) {
+               double[] result = new double[length];
+               for(int i = 0; i < length; i++) {
+                       result[i] = intercept + slope * (i + 1);
+               }
+
+               return result;
+       }
+
+       public static double[][] generateTestMatrixLinear(int rows, int cols, 
double minIntercept, double maxIntercept,
+               double minSlope, double maxSlope, long seed) {
+               double[][] coefficients = generateRandomInterceptsSlopes(cols, 
minIntercept, maxIntercept, minSlope, maxSlope,
+                       seed);
+               return generateTestMatrixLinearColumns(rows, cols, 
coefficients[0], coefficients[1]);
+       }
+
+       public static double[][] generateRandomInterceptsSlopes(int cols, 
double minIntercept, double maxIntercept,
+               double minSlope, double maxSlope, long seed) {
+
+               double[] intercepts = new double[cols];
+               double[] slopes = new double[cols];
+
+               random.setSeed(seed);
+               for(int j = 0; j < cols; j++) {
+                       intercepts[j] = minIntercept + random.nextDouble() * 
(maxIntercept - minIntercept);
+                       slopes[j] = minSlope + random.nextDouble() * (maxSlope 
- minSlope);
+               }
+
+               return new double[][] {intercepts, slopes};
+       }
+
+       public static double[][] generateTestMatrixLinearColumns(int rows, int 
cols, double[] intercepts, double[] slopes) {
+               if(intercepts.length != slopes.length || intercepts.length != 
cols)
+                       fail("Intercepts and slopes array must both have length 
`cols`");
+
+               double[][] data = new double[rows][cols];
+
+               for(int j = 0; j < cols; j++) {
+                       double[] linCol = generateLinearColumn(intercepts[j], 
slopes[j], rows);
+                       for(int i = 0; i < rows; i++) {
+                               data[i][j] = linCol[i];
+                       }
+               }
+
+               return data;
+       }
+
+       protected double[] getValues(AColGroup cg) {
+               MatrixBlock mb = new MatrixBlock(nRow, cg.getNumCols(), false);
+               mb.allocateDenseBlock();
+               cg.decompressToDenseBlock(mb.getDenseBlock(), 0, nRow);
+               return mb.getDenseBlockValues();
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
new file mode 100644
index 0000000000..9a37030d13
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/colgroup/ColGroupLinearFunctionalTest.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.colgroup;
+
+import static org.junit.Assert.fail;
+
+import java.util.Arrays;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.compress.colgroup.AColGroup;
+import org.apache.sysds.runtime.compress.colgroup.AColGroupCompressed;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupLinearFunctional;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
+import org.apache.sysds.runtime.functionobjects.KahanPlusSq;
+import org.apache.sysds.runtime.functionobjects.Multiply;
+import org.apache.sysds.runtime.functionobjects.Plus;
+import org.apache.sysds.runtime.functionobjects.ReduceAll;
+import org.apache.sysds.runtime.functionobjects.ReduceCol;
+import org.apache.sysds.runtime.functionobjects.ReduceRow;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.junit.Assert;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+public class ColGroupLinearFunctionalTest extends ColGroupLinearFunctionalBase 
{
+       protected static final Log LOG = 
LogFactory.getLog(ColGroupLinearFunctionalTest.class.getName());
+
+       public ColGroupLinearFunctionalTest(AColGroup base, 
ColGroupLinearFunctional lin, AColGroup baseLeft,
+               AColGroup cgLeft, int nRowLeft, int nColLeft, int nRowRight, 
int nColRight, ColGroupUncompressed cgRight,
+               double tolerance) {
+               super(base, lin, baseLeft, cgLeft, nRowLeft, nColLeft, 
nRowRight, nColRight, cgRight, tolerance);
+       }
+
+       @Test
+       public void testContainsValue() {
+               double[] linValues = getValues(lin);
+               double[] baseValues = getValues(base);
+
+               for(int i = 0; i < linValues.length; i++) {
+                       Assert.assertEquals("Base ColGroup and linear ColGroup 
must be initialized with the same values", linValues[i],
+                               baseValues[i], tolerance);
+                       if(!lin.containsValue(baseValues[i])) {
+                               // debug
+                               System.out.println(baseValues[i]);
+                               System.out.println(i);
+                               
Assert.assertTrue(base.containsValue(baseValues[i]) && 
lin.containsValue(baseValues[i]));
+
+                       }
+                       Assert.assertTrue(base.containsValue(baseValues[i]) && 
lin.containsValue(baseValues[i]));
+               }
+       }
+
+       @Test
+       public void testTsmm() {
+               int nCol = lin.getNumCols();
+
+               final MatrixBlock resultUncompressed = new 
MatrixBlock(lin.getNumCols(), nCol, false);
+               resultUncompressed.allocateDenseBlock();
+               base.tsmm(resultUncompressed, nRow);
+
+               final MatrixBlock resultCompressed = new MatrixBlock(nCol, 
nCol, false);
+               resultCompressed.allocateDenseBlock();
+               lin.tsmm(resultCompressed, nRow);
+
+               
Assert.assertArrayEquals(resultUncompressed.getDenseBlockValues(), 
resultCompressed.getDenseBlockValues(),
+                       tolerance);
+       }
+
+       @Test
+       public void testRightMultByMatrix() {
+               MatrixBlock mbtRight = cgRight.getData();
+
+               AColGroup colGroupResultExpected = 
base.rightMultByMatrix(mbtRight);
+               MatrixBlock resultExpected = ((ColGroupUncompressed) 
colGroupResultExpected).getData();
+               AColGroup colGroupResult = lin.rightMultByMatrix(mbtRight);
+               MatrixBlock result = ((ColGroupUncompressed) 
colGroupResult).getData();
+
+               Assert.assertArrayEquals(resultExpected.getDenseBlockValues(), 
result.getDenseBlockValues(), tolerance);
+       }
+
+       @Test
+       public void testLeftMultByAColGroup() {
+               if(cgLeft.getCompType() == 
AColGroup.CompressionType.LinearFunctional)
+                       leftMultByAColGroup(true);
+               else if(cgLeft.getCompType() == 
AColGroup.CompressionType.UNCOMPRESSED)
+                       leftMultByAColGroup(false);
+               else
+                       fail("CompressionType not supported for 
leftMultByAColGrup");
+       }
+
+       public void leftMultByAColGroup(boolean compressedLeft) {
+               final MatrixBlock result = new MatrixBlock(nRowLeft, nColRight, 
false);
+               final MatrixBlock resultExpected = new MatrixBlock(nRowLeft, 
nColRight, false);
+               result.allocateDenseBlock();
+               resultExpected.allocateDenseBlock();
+
+               base.leftMultByAColGroup(baseLeft, resultExpected);
+               lin.leftMultByAColGroup(cgLeft, result);
+
+               Assert.assertArrayEquals(resultExpected.getDenseBlockValues(), 
result.getDenseBlockValues(), tolerance);
+       }
+
+       @Test
+       public void testColSumsSq() {
+               double[] colSumsExpected = new double[base.getNumCols()];
+               AggregateOperator aop = new AggregateOperator(0, 
KahanPlusSq.getKahanPlusSqFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceRow.getReduceRowFnObject());
+
+               if(base instanceof AColGroupCompressed) {
+                       AColGroupCompressed baseComp = (AColGroupCompressed) 
base;
+                       baseComp.unaryAggregateOperations(auop, 
colSumsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+               }
+               else if(base instanceof ColGroupUncompressed) {
+                       MatrixBlock mb = ((ColGroupUncompressed) 
base).getData();
+
+                       for(int j = 0; j < base.getNumCols(); j++) {
+                               double colSum = 0;
+                               for(int i = 0; i < nRow; i++) {
+                                       colSum += Math.pow(mb.getDouble(i, j), 
2);
+                               }
+                               colSumsExpected[j] = colSum;
+                       }
+               }
+               else {
+                       fail("Base ColGroup type does not support colSumSq.");
+               }
+
+               double[] colSums = new double[lin.getNumCols()];
+               lin.unaryAggregateOperations(auop, colSums, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               Assert.assertArrayEquals(colSumsExpected, colSums, tolerance);
+       }
+
+       @Test
+       public void testProduct() {
+               double[] productExpected = new double[] {1};
+
+               AggregateOperator aop = new AggregateOperator(0, 
Multiply.getMultiplyFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceAll.getReduceAllFnObject());
+
+               if(base instanceof AColGroupCompressed) {
+                       AColGroupCompressed baseComp = (AColGroupCompressed) 
base;
+                       baseComp.unaryAggregateOperations(auop, 
productExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+               }
+               else if(base instanceof ColGroupUncompressed) {
+                       MatrixBlock mb = ((ColGroupUncompressed) 
base).getData();
+
+                       for(int j = 0; j < base.getNumCols(); j++) {
+                               for(int i = 0; i < nRow; i++) {
+                                       productExpected[0] *= mb.getDouble(i, 
j);
+                               }
+                       }
+               }
+               else {
+                       fail("Base ColGroup type does not support colProduct.");
+               }
+
+               double[] product = new double[] {1};
+               lin.unaryAggregateOperations(auop, product, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               // use relative tolerance since products can get very large
+               double relTolerance = tolerance * Math.abs(productExpected[0]);
+               Assert.assertEquals(productExpected[0], product[0], 
relTolerance);
+       }
+
+       @Test
+       public void testMax() {
+               Assert.assertEquals(base.getMax(), lin.getMax(), tolerance);
+       }
+
+       @Test
+       public void testMin() {
+               Assert.assertEquals(base.getMin(), lin.getMin(), tolerance);
+       }
+
+       @Test
+       public void testColProducts() {
+               double[] colProductsExpected = new double[base.getNumCols()];
+
+               AggregateOperator aop = new AggregateOperator(0, 
Multiply.getMultiplyFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceRow.getReduceRowFnObject());
+
+               if(base instanceof AColGroupCompressed) {
+                       AColGroupCompressed baseComp = (AColGroupCompressed) 
base;
+                       baseComp.unaryAggregateOperations(auop, 
colProductsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+               }
+               else if(base instanceof ColGroupUncompressed) {
+                       MatrixBlock mb = ((ColGroupUncompressed) 
base).getData();
+
+                       for(int j = 0; j < base.getNumCols(); j++) {
+                               double colProduct = 1;
+                               for(int i = 0; i < nRow; i++) {
+                                       colProduct *= mb.getDouble(i, j);
+                               }
+                               colProductsExpected[j] = colProduct;
+                       }
+               }
+               else {
+                       fail("Base ColGroup type does not support colProduct.");
+               }
+
+               double[] colProducts = new double[base.getNumCols()];
+               for(int j = 0; j < base.getNumCols(); j++) {
+                       colProducts[j] = 1;
+               }
+
+               lin.unaryAggregateOperations(auop, colProducts, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               // use relative tolerance since column products can get very 
large
+               double relTolerance = tolerance * 
Math.abs(Arrays.stream(colProductsExpected).max().orElse(0));
+               Assert.assertArrayEquals(colProductsExpected, colProducts, 
relTolerance);
+       }
+
+       @Test
+       public void testSumSq() {
+               double[] sumSqExpected = new double[] {0};
+
+               AggregateOperator aop = new AggregateOperator(0, 
KahanPlusSq.getKahanPlusSqFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceAll.getReduceAllFnObject());
+
+               if(base instanceof AColGroupCompressed) {
+                       AColGroupCompressed baseComp = (AColGroupCompressed) 
base;
+                       baseComp.unaryAggregateOperations(auop, sumSqExpected, 
nRow, 0, nRow, baseComp.preAggRows(auop));
+               }
+               else if(base instanceof ColGroupUncompressed) {
+                       MatrixBlock mb = ((ColGroupUncompressed) 
base).getData();
+
+                       for(int j = 0; j < base.getNumCols(); j++) {
+                               for(int i = 0; i < nRow; i++) {
+                                       sumSqExpected[0] += 
Math.pow(mb.getDouble(i, j), 2);
+                               }
+                       }
+               }
+               else {
+                       fail("Base ColGroup type does not support sumSq.");
+               }
+
+               double[] sumSq = new double[] {0};
+               lin.unaryAggregateOperations(auop, sumSq, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               Assert.assertEquals(sumSqExpected[0], sumSq[0], tolerance);
+       }
+
+       @Test
+       public void testSum() {
+               double[] colSums = new double[base.getNumCols()];
+               base.computeColSums(colSums, nRow);
+               double sumExpected = Arrays.stream(colSums).sum();
+
+               double[] sum = new double[1];
+               AggregateOperator aop = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceAll.getReduceAllFnObject());
+               lin.unaryAggregateOperations(auop, sum, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               Assert.assertEquals(sumExpected, sum[0], tolerance);
+       }
+
+       @Test
+       public void testRowSums() {
+               double[] rowSumsExpected = new double[nRow];
+
+               AggregateOperator aop = new AggregateOperator(0, 
Plus.getPlusFnObject());
+               AggregateUnaryOperator auop = new AggregateUnaryOperator(aop, 
ReduceCol.getReduceColFnObject());
+
+               if(base instanceof AColGroupCompressed) {
+                       AColGroupCompressed baseComp = (AColGroupCompressed) 
base;
+                       baseComp.unaryAggregateOperations(auop, 
rowSumsExpected, nRow, 0, nRow, baseComp.preAggRows(auop));
+               }
+               else if(base instanceof ColGroupUncompressed) {
+                       MatrixBlock mb = ((ColGroupUncompressed) 
base).getData();
+
+                       for(int i = 0; i < nRow; i++) {
+                               double rowSum = 0;
+                               for(int j = 0; j < base.getNumCols(); j++) {
+                                       rowSum += mb.getDouble(i, j);
+                               }
+                               rowSumsExpected[i] = rowSum;
+                       }
+               }
+               else {
+                       fail("Base ColGroup type does not support rowSum.");
+               }
+
+               double[] rowSums = new double[nRow];
+               lin.unaryAggregateOperations(auop, rowSums, nRow, 0, nRow, 
lin.preAggRows(auop));
+
+               Assert.assertArrayEquals(rowSumsExpected, rowSums, tolerance);
+       }
+
+       @Test
+       public void testColSums() {
+               double[] colSumsExpected = new double[base.getNumCols()];
+               double[] colSums = new double[base.getNumCols()];
+               base.computeColSums(colSumsExpected, nRow);
+               lin.computeColSums(colSums, nRow);
+
+               Assert.assertArrayEquals(colSumsExpected, colSums, tolerance);
+       }
+
+       @Test
+       public void testColumnGroupConstruction() {
+               double[][] constColumn = new double[][] {{1, 1, 1, 1, 1}};
+               AColGroup cgConst = cgLinCompressed(constColumn, true);
+               Assert.assertSame(AColGroup.CompressionType.CONST, 
cgConst.getCompType());
+
+               double[][] zeroColumn = new double[][] {{0, 0, 0, 0, 0}};
+               AColGroup cgEmpty = cgLinCompressed(zeroColumn, true);
+               Assert.assertSame(AColGroup.CompressionType.EMPTY, 
cgEmpty.getCompType());
+       }
+
+       @Test
+       public void testDecompressToDenseBlock() {
+               MatrixBlock ret = new MatrixBlock(nRow, lin.getNumCols(), 
false);
+               ret.allocateDenseBlock();
+               lin.decompressToDenseBlock(ret.getDenseBlock(), 0, nRow);
+
+               MatrixBlock expected = new MatrixBlock(nRow, lin.getNumCols(), 
false);
+               expected.allocateDenseBlock();
+               base.decompressToDenseBlock(expected.getDenseBlock(), 0, nRow);
+
+               Assert.assertArrayEquals(expected.getDenseBlockValues(), 
ret.getDenseBlockValues(), tolerance);
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
 
b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
new file mode 100644
index 0000000000..b3b9a12fca
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/functional/LinearRegressionTests.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.compress.functional;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.stream.DoubleStream;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+import org.apache.sysds.runtime.compress.colgroup.functional.LinearRegression;
+import org.apache.sysds.runtime.compress.utils.Util;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.util.DataConverter;
+import 
org.apache.sysds.test.component.compress.colgroup.ColGroupLinearFunctionalBase;
+import org.junit.Test;
+import org.junit.runner.RunWith;
+import org.junit.runners.Parameterized;
+
+@RunWith(value = Parameterized.class)
+public class LinearRegressionTests {
+       protected final double[][] data;
+       protected final int[] colIndexes;
+       protected final boolean isTransposed;
+       protected final double[] expectedCoefficients;
+       protected final Exception expectedException;
+
+       protected final double EQUALITY_TOLERANCE = 1e-4;
+
+       @Parameterized.Parameters
+       public static Collection<Object[]> data() {
+               ArrayList<Object[]> tests = new ArrayList<>();
+               try {
+                       addCases(tests);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("failed constructing tests");
+               }
+
+               return tests;
+       }
+
+       public LinearRegressionTests(double[][] data, int[] colIndexes, boolean 
isTransposed, double[] expectedCoefficients,
+               Exception expectedException) {
+               this.data = data;
+               this.colIndexes = colIndexes;
+               this.isTransposed = isTransposed;
+               this.expectedCoefficients = expectedCoefficients;
+               this.expectedException = expectedException;
+       }
+
+       protected static void addCases(ArrayList<Object[]> tests) {
+               double[][] data = new double[][] {{1, 1, -3, 4, 5}, {2, 2, 3, 
4, 5}, {3, 3, 3, 4, 5}};
+               int[] colIndexes = new int[] {0, 1, 3, 4};
+               double[] trueCoefficients = new double[] {0, 0, 4, 5, 1, 1, 0, 
0};
+               tests.add(new Object[] {data, colIndexes, false, 
trueCoefficients, null});
+
+               // expect exception if passing columns with single data points 
each
+               tests.add(new Object[] {new double[][] {{1, 2, 3}}, 
Util.genColsIndices(1), false, null,
+                       new DMLCompressionException("At least 2 data points are 
required to fit a linear function.")});
+
+               // expect exception if passing no colIndexes
+               tests.add(new Object[] {new double[][] {{1, 2, 3}, {2, 3, 4}}, 
Util.genColsIndices(0), false, null,
+                       new DMLCompressionException("At least 1 column must be 
specified for compression.")});
+
+               // random matrix
+               int rows = 100;
+               int cols = 200;
+               // TODO: move generateRandomInterceptsSlopes in an appropriate 
Util class
+               double[][] randomCoefficients = 
ColGroupLinearFunctionalBase.generateRandomInterceptsSlopes(cols, -1000, 1000,
+                       -20, 20, 42);
+               // TODO: move generateTestMatrixLinearColumns in an appropriate 
Util class
+               double[][] testData = 
ColGroupLinearFunctionalBase.generateTestMatrixLinearColumns(rows, cols,
+                       randomCoefficients[0], randomCoefficients[1]);
+               tests.add(new Object[] {testData, Util.genColsIndices(cols), 
false,
+                       
DoubleStream.concat(Arrays.stream(randomCoefficients[0]), 
Arrays.stream(randomCoefficients[1])).toArray(),
+                       null});
+       }
+
+       @Test
+       public void testLinearRegression() {
+               MatrixBlock mbt = DataConverter.convertToMatrixBlock(data);
+               try {
+                       double[] coefficients = 
LinearRegression.regressMatrixBlock(mbt, colIndexes, isTransposed);
+                       assertArrayEquals(expectedCoefficients, coefficients, 
EQUALITY_TOLERANCE);
+               }
+               catch(Exception e) {
+                       assertEquals(expectedException.getClass(), 
e.getClass());
+                       assertEquals(expectedException.getMessage(), 
e.getMessage());
+               }
+       }
+
+}

Reply via email to