This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new d0597c0f62 [SYSTEMDS-3490] Compressed Transform Tests
d0597c0f62 is described below

commit d0597c0f62ca35dc6f99235bd7cfffa2421c6ab4
Author: baunsgaard <[email protected]>
AuthorDate: Wed May 17 10:24:38 2023 +0200

    [SYSTEMDS-3490] Compressed Transform Tests
    
    This commit update the compressed tests to 100% coverage and fixes
    some edge cases in binning and hashing.
    
    Closes #1826
---
 .../sysds/runtime/compress/lib/CLALibUtils.java    | 11 +---
 .../runtime/transform/encode/ColumnEncoderBin.java | 15 ++---
 .../runtime/transform/encode/CompressedEncode.java | 66 +++++++++-------------
 .../runtime/transform/encode/EncoderFactory.java   |  4 +-
 src/test/java/org/apache/sysds/test/TestUtils.java | 33 +----------
 .../transform/TransformCompressedTestMultiCol.java | 51 ++++++++++-------
 .../TransformCompressedTestSingleCol.java          | 30 +++++++---
 .../frame/transform/TransformCustomTest.java       | 22 ++++++++
 8 files changed, 112 insertions(+), 120 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
index 3e6837d490..1dfe2a0575 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibUtils.java
@@ -20,7 +20,6 @@
 package org.apache.sysds.runtime.compress.lib;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.HashSet;
 import java.util.List;
 import java.util.Set;
@@ -112,7 +111,7 @@ public final class CLALibUtils {
                        else
                                filteredGroups.add(g);
                }
-               return returnGroupIfFiniteNumbers(groups, filteredGroups, 
constV);
+               return filteredGroups;
        }
 
        protected static void filterGroupsAndSplitPreAgg(List<AColGroup> 
groups, double[] constV,
@@ -150,14 +149,6 @@ public final class CLALibUtils {
                }
        }
 
-       private static List<AColGroup> 
returnGroupIfFiniteNumbers(List<AColGroup> groups, List<AColGroup> 
filteredGroups,
-               double[] constV) {
-               for(double v : constV)
-                       if(!Double.isFinite(v))
-                               throw new NotImplementedException("Not handling 
if the values are not finite: " + Arrays.toString(constV));
-               return filteredGroups;
-       }
-
        private static AColGroup combineEmpty(List<AColGroup> e) {
                return new ColGroupEmpty(combineColIndexes(e));
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index 895141db07..b2c530a3e4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -43,10 +43,6 @@ public class ColumnEncoderBin extends ColumnEncoder {
        public static final String NBINS_PREFIX = "nbins";
        private static final long serialVersionUID = 1917445005206076078L;
 
-       public int getNumBin() {
-               return _numBin;
-       }
-
        protected int _numBin = -1;
        private BinMethod _binMethod = BinMethod.EQUI_WIDTH;
 
@@ -75,6 +71,10 @@ public class ColumnEncoderBin extends ColumnEncoder {
                _binMaxs = binMaxs;
        }
 
+       public int getNumBin() {
+               return _numBin;
+       }
+
        public double getColMins() {
                return _colMins;
        }
@@ -404,15 +404,8 @@ public class ColumnEncoderBin extends ColumnEncoder {
                sb.append(": ");
                sb.append(_colID);
                sb.append(" --- Method: " + _binMethod + " num Bin: " + 
_numBin);
-               // if(_binMethod == BinMethod.EQUI_WIDTH) {
                sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
                sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
-               // }
-               // else {
-               // // sb.append(" --- MinMax: "+ _colMins + " " + _colMaxs);
-               // sb.append("\n---- BinMin: " + Arrays.toString(_binMins));
-               // sb.append("\n---- BinMax: " + Arrays.toString(_binMaxs));
-               // }
                return sb.toString();
        }
 
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
index 63eb81e008..150133c469 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/CompressedEncode.java
@@ -33,11 +33,11 @@ import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.CompressedMatrixBlock;
 import org.apache.sysds.runtime.compress.colgroup.AColGroup;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupConst;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupDDC;
+import org.apache.sysds.runtime.compress.colgroup.ColGroupEmpty;
 import org.apache.sysds.runtime.compress.colgroup.ColGroupUncompressed;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary;
 import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary;
@@ -69,11 +69,12 @@ public class CompressedEncode {
                this.k = k;
        }
 
-       public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, 
int k) {
+       public static MatrixBlock encode(MultiColumnEncoder enc, FrameBlock in, 
int k)
+               throws InterruptedException, ExecutionException {
                return new CompressedEncode(enc, in, k).apply();
        }
 
-       private MatrixBlock apply() {
+       private MatrixBlock apply() throws InterruptedException, 
ExecutionException {
                final List<ColumnEncoderComposite> encoders = 
enc.getColumnEncoders();
                final List<AColGroup> groups = isParallel() ? 
multiThread(encoders) : singleThread(encoders);
                final int cols = shiftGroups(groups);
@@ -94,7 +95,8 @@ public class CompressedEncode {
                return groups;
        }
 
-       private List<AColGroup> multiThread(List<ColumnEncoderComposite> 
encoders) {
+       private List<AColGroup> multiThread(List<ColumnEncoderComposite> 
encoders)
+               throws InterruptedException, ExecutionException {
 
                final ExecutorService pool = CommonThreadPool.get(k);
                try {
@@ -106,13 +108,10 @@ public class CompressedEncode {
                        List<AColGroup> groups = new 
ArrayList<>(encoders.size());
                        for(Future<AColGroup> t : pool.invokeAll(tasks))
                                groups.add(t.get());
-
-                       pool.shutdown();
                        return groups;
                }
-               catch(InterruptedException | ExecutionException ex) {
+               finally {
                        pool.shutdown();
-                       throw new DMLRuntimeException("Failed parallel 
compressed transform encode", ex);
                }
        }
 
@@ -157,8 +156,10 @@ public class CompressedEncode {
                boolean containsNull = a.containsNull();
                HashMap<?, Long> map = a.getRecodeMap();
                int domain = map.size();
+               if(containsNull && domain == 0)
+                       return new ColGroupEmpty(ColIndexFactory.create(1));
                IColIndex colIndexes = ColIndexFactory.create(0, domain);
-               if(domain == 1)
+               if(domain == 1 && !containsNull)
                        return ColGroupConst.create(colIndexes, new double[] 
{1});
                ADictionary d = new IdentityDictionary(colIndexes.size(), 
containsNull);
                AMapToData m = createMappingAMapToData(a, map, containsNull);
@@ -180,12 +181,6 @@ public class CompressedEncode {
                AMapToData m = binEncode(a, b, containsNull);
 
                AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
-               try {
-                       ret.getNumberNonZeros(a.size());
-               }
-               catch(Exception e) {
-                       throw new DMLRuntimeException("Failed binning \n\n" + a 
+ "\n" + b + "\n" + d + "\n" + m, e);
-               }
                return ret;
        }
 
@@ -230,7 +225,6 @@ public class CompressedEncode {
                ADictionary d = new IdentityDictionary(colIndexes.size(), 
containsNull);
                AMapToData m = binEncode(a, b, containsNull);
                AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
-               ret.getNumberNonZeros(a.size());
                return ret;
        }
 
@@ -246,11 +240,11 @@ public class CompressedEncode {
                IColIndex colIndexes = ColIndexFactory.create(1);
                if(domain == 1)
                        return ColGroupConst.create(colIndexes, new double[] 
{1});
-               MatrixBlock incrementing = new MatrixBlock(domain + 
(containsNull ? 1 : 0) , 1, false);
+               MatrixBlock incrementing = new MatrixBlock(domain + 
(containsNull ? 1 : 0), 1, false);
                for(int i = 0; i < domain; i++)
                        incrementing.quickSetValue(i, 0, i + 1);
                if(containsNull)
-                       incrementing.quickSetValue(domain, 0 , Double.NaN);
+                       incrementing.quickSetValue(domain, 0, Double.NaN);
 
                ADictionary d = MatrixBlockDictionary.create(incrementing);
 
@@ -258,7 +252,6 @@ public class CompressedEncode {
 
                List<ColumnEncoder> r = c.getEncoders();
                r.set(0, new ColumnEncoderRecode(colId, (HashMap<Object, Long>) 
map));
-
                return ColGroupDDC.create(colIndexes, d, m, null);
 
        }
@@ -283,7 +276,7 @@ public class CompressedEncode {
                        if(containsNull)
                                vals[map.size()] = Double.NaN;
                        ValueType t = a.getValueType();
-                       map.forEach((k,v) -> vals[v.intValue()] = 
UtilFunctions.objectToDouble(t,k));
+                       map.forEach((k, v) -> vals[v.intValue()] = 
UtilFunctions.objectToDouble(t, k));
                        ADictionary d = Dictionary.create(vals);
                        AMapToData m = createMappingAMapToData(a, map, 
containsNull);
                        return ColGroupDDC.create(colIndexes, d, m, null);
@@ -295,17 +288,17 @@ public class CompressedEncode {
                final int si = map.size();
                AMapToData m = MapToFactory.create(in.getNumRows(), si + 
(containsNull ? 1 : 0));
                Array<?>.ArrayIterator it = a.getIterator();
-               if(containsNull){
+               if(containsNull) {
 
                        while(it.hasNext()) {
                                Object v = it.next();
                                if(v != null)
                                        m.set(it.getIndex(), 
map.get(v).intValue());
                                else
-                                       m.set(it.getIndex(),si);
+                                       m.set(it.getIndex(), si);
                        }
                }
-               else{
+               else {
                        while(it.hasNext()) {
                                Object v = it.next();
                                m.set(it.getIndex(), map.get(v).intValue());
@@ -340,25 +333,22 @@ public class CompressedEncode {
                int colId = c._colID;
                Array<?> a = in.getColumn(colId - 1);
                ColumnEncoderFeatureHash CEHash = (ColumnEncoderFeatureHash) 
c.getEncoders().get(0);
-
-               // HashMap<?, Long> map = a.getRecodeMap();
                int domain = (int) CEHash.getK();
                boolean nulls = a.containsNull();
                IColIndex colIndexes = ColIndexFactory.create(0, 1);
-               if(domain == 1)
+               if(domain == 1 && ! nulls)
                        return ColGroupConst.create(colIndexes, new double[] 
{1});
 
-                       MatrixBlock incrementing = new MatrixBlock(domain + 
(nulls ? 1 : 0), 1, false);
-                       for(int i = 0; i < domain; i++)
-                               incrementing.quickSetValue(i, 0, i + 1);
-                       if(nulls)
-                               incrementing.quickSetValue(domain, 0, 
Double.NaN);
-       
-                       ADictionary d = 
MatrixBlockDictionary.create(incrementing);
-
-               AMapToData m = createHashMappingAMapToData(a, domain , nulls);
-               AColGroup ret =  ColGroupDDC.create(colIndexes, d, m, null);
-               ret.getNumberNonZeros(a.size());
+               MatrixBlock incrementing = new MatrixBlock(domain + (nulls ? 1 
: 0), 1, false);
+               for(int i = 0; i < domain; i++)
+                       incrementing.quickSetValue(i, 0, i + 1);
+               if(nulls)
+                       incrementing.quickSetValue(domain, 0, Double.NaN);
+
+               ADictionary d = MatrixBlockDictionary.create(incrementing);
+
+               AMapToData m = createHashMappingAMapToData(a, domain, nulls);
+               AColGroup ret = ColGroupDDC.create(colIndexes, d, m, null);
                return ret;
        }
 
@@ -369,7 +359,7 @@ public class CompressedEncode {
                int domain = (int) CEHash.getK();
                boolean nulls = a.containsNull();
                IColIndex colIndexes = ColIndexFactory.create(0, domain);
-               if(domain == 1)
+               if(domain == 1 && !nulls)
                        return ColGroupConst.create(colIndexes, new double[] 
{1});
                ADictionary d = new IdentityDictionary(colIndexes.size(), 
nulls);
                AMapToData m = createHashMappingAMapToData(a, domain, nulls);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 41e16d6e6e..075b6fbdd4 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -44,8 +44,8 @@ import org.apache.sysds.utils.stats.TransformStatistics;
 import org.apache.wink.json4j.JSONArray;
 import org.apache.wink.json4j.JSONObject;
 
-public class EncoderFactory {
-       protected static final Log LOG = 
LogFactory.getLog(EncoderFactory.class.getName());
+public interface EncoderFactory {
+       final static Log LOG = 
LogFactory.getLog(EncoderFactory.class.getName());
 
        public static MultiColumnEncoder createEncoder(String spec, String[] 
colnames, int clen, FrameBlock meta) {
                return createEncoder(spec, colnames, 
UtilFunctions.nCopies(clen, ValueType.STRING), meta);
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index bc5d48b8a2..23da1a8fc3 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -91,9 +91,6 @@ import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.junit.Assert;
 
-//import jcuda.runtime.JCuda;
-
-
 /**
  * <p>
  * Provides methods to easily create tests. Implemented methods can be used for
@@ -106,8 +103,7 @@ import org.junit.Assert;
  * <li>clean up</li>
  * </ul>
  */
-public class TestUtils
-{
+public class TestUtils {
 
        private static final Log LOG = 
LogFactory.getLog(TestUtils.class.getName());
 
@@ -1604,16 +1600,6 @@ public class TestUtils
                return false;
        }
 
-
-       /**
-        *
-        * @param vt
-        * @param in1
-        * @param in2
-        * @param tolerance
-        *
-        * @return
-        */
        public static int compareTo(ValueType vt, Object in1, Object in2, 
double tolerance) {
                if(in1 == null && in2 == null) return 0;
                else if(in1 == null) return -1;
@@ -1659,12 +1645,6 @@ public class TestUtils
                }
        }
 
-       /**
-        * Converts a 2D array into a sparse hashmap matrix.
-        *
-        * @param matrix
-        * @return
-        */
        public static HashMap<CellIndex, Double> 
convert2DDoubleArrayToHashMap(double[][] matrix) {
                HashMap<CellIndex, Double> hmMatrix = new HashMap<>();
                for (int i = 0; i < matrix.length; i++) {
@@ -1677,11 +1657,6 @@ public class TestUtils
                return hmMatrix;
        }
 
-       /**
-        * Method to convert a hashmap of matrix entries into a double array
-        * @param matrix
-        * @return
-        */
        public static double[][] convertHashMapToDoubleArray(HashMap 
<CellIndex, Double> matrix) {
                int max_rows = -1, max_cols= -1;
                for(CellIndex ix : matrix.keySet()) {
@@ -1701,12 +1676,6 @@ public class TestUtils
                return ret_arr;
        }
 
-       /**
-        * Converts a 2D double array into a 1D double array.
-        *
-        * @param array
-        * @return
-        */
        public static double[] convert2Dto1DDoubleArray(double[][] array) {
                double[] ret = new double[array.length * array[0].length];
                int c = 0;
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
index 2cea03489d..a592fc7477 100644
--- 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestMultiCol.java
@@ -22,6 +22,7 @@ package org.apache.sysds.test.component.frame.transform;
 import static org.junit.Assert.fail;
 
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.Collection;
 
 import org.apache.commons.logging.Log;
@@ -55,24 +56,37 @@ public class TransformCompressedTestMultiCol {
                final int[] threads = new int[] {1, 4};
                try {
 
-                       FrameBlock data = 
TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4, 
ValueType.UINT8, ValueType.UINT4}, 231);
-                       data.setSchema(new ValueType[] {ValueType.INT32, 
ValueType.INT32, ValueType.INT32});
-                       for(int k : threads) {
-                               tests.add(new Object[] {data, k});
+                       ValueType[] kPlusCols = new ValueType[1002];
+
+                       Arrays.fill(kPlusCols, ValueType.BOOLEAN);
+
+                       FrameBlock[] blocks = new FrameBlock[] {//
+                               TestUtils.generateRandomFrameBlock(100, //
+                                       new ValueType[] {ValueType.UINT4, 
ValueType.UINT8, ValueType.UINT4}, 231), //
+                               TestUtils.generateRandomFrameBlock(100, //
+                                       new ValueType[] {ValueType.BOOLEAN, 
ValueType.UINT8, ValueType.UINT4}, 231), //
+                               new FrameBlock(new ValueType[] 
{ValueType.BOOLEAN, ValueType.INT32, ValueType.INT32}, 100), //
+                               TestUtils.generateRandomFrameBlock(100, //
+                                       new ValueType[] {ValueType.UINT4, 
ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2),
+                               TestUtils.generateRandomFrameBlock(432, //
+                                       new ValueType[] {ValueType.UINT4, 
ValueType.BOOLEAN, ValueType.FP32}, 231, 0.2),
+                               TestUtils.generateRandomFrameBlock(100, //
+                                       new ValueType[] {ValueType.UINT4, 
ValueType.BOOLEAN, ValueType.FP32}, 231, 0.9),
+                               TestUtils.generateRandomFrameBlock(100, //
+                                       new ValueType[] {ValueType.UINT4, 
ValueType.BOOLEAN, ValueType.FP32}, 231, 0.99),
+
+                               TestUtils.generateRandomFrameBlock(5, 
kPlusCols, 322),
+                               TestUtils.generateRandomFrameBlock(1020, 
kPlusCols, 322),
+
+                       };
+                       blocks[2].ensureAllocatedColumns(100);
+
+                       for(FrameBlock block : blocks) {
+                               for(int k : threads) {
+                                       tests.add(new Object[] {block, k});
+                               }
                        }
 
-                       FrameBlock data2 = 
TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.BOOLEAN, 
ValueType.UINT8, ValueType.UINT4}, 231);
-                       data2.setSchema(new ValueType[] {ValueType.BOOLEAN, 
ValueType.INT32, ValueType.INT32});
-                       for(int k : threads) {
-                               tests.add(new Object[] {data2, k});
-                       }
-
-                       FrameBlock data3 = new FrameBlock(
-                               new ValueType[] {ValueType.BOOLEAN, 
ValueType.INT32, ValueType.INT32}, 100) ;
-                       data3.ensureAllocatedColumns(100);
-                       for(int k : threads) 
-                               tests.add(new Object[] {data3, k});
-                       
                }
                catch(Exception e) {
                        e.printStackTrace();
@@ -114,12 +128,12 @@ public class TransformCompressedTestMultiCol {
        }
 
        @Test
-       public void testHash(){
+       public void testHash() {
                test("{ids:true, hash:[1,2,3], K:10}");
        }
 
        @Test
-       public void testHashToDummy(){
+       public void testHashToDummy() {
                test("{ids:true, hash:[1,2,3], K:10, dummycode:[1,2]}");
        }
 
@@ -137,7 +151,6 @@ public class TransformCompressedTestMultiCol {
                        MatrixBlock outNormal = encoderNormal.encode(data, k);
                        FrameBlock outNormalMD = 
encoderNormal.getMetaData(null);
 
-
                        TestUtils.compareMatrices(outNormal, outCompressed, 0, 
"Not Equal after apply");
                        TestUtils.compareFrames(outNormalMD, outCompressedMD, 
true);
                }
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
index a573783f6e..7b37ba1413 100644
--- 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCompressedTestSingleCol.java
@@ -54,14 +54,18 @@ public class TransformCompressedTestSingleCol {
                final ArrayList<Object[]> tests = new ArrayList<>();
                final int[] threads = new int[] {1, 4};
                try {
-
-                       FrameBlock data = 
TestUtils.generateRandomFrameBlock(100, new ValueType[] {ValueType.UINT4}, 231);
-                       for(int k : threads)
-                               tests.add(new Object[] {data, k});
-
-                       data = TestUtils.generateRandomFrameBlock(100, new 
ValueType[] {ValueType.UINT4}, 231, 0.2);
-                       for(int k : threads)
-                               tests.add(new Object[] {data, k});
+                       FrameBlock[] blocks = new FrameBlock[] {
+                               TestUtils.generateRandomFrameBlock(100, new 
ValueType[] {ValueType.UINT4}, 231),
+                               TestUtils.generateRandomFrameBlock(100, new 
ValueType[] {ValueType.UINT4}, 231, 0.2),
+                               TestUtils.generateRandomFrameBlock(100, new 
ValueType[] {ValueType.UINT4}, 231, 1.0),
+                               TestUtils.generateRandomFrameBlock(100, new 
ValueType[] {ValueType.UINT4}, 231, 1.0),
+                               // Above block size of number of unique elements
+                               TestUtils.generateRandomFrameBlock(1200, new 
ValueType[] {ValueType.FP32}, 231, 0.1),};
+
+                       blocks[3].set(40, 0, "14");
+                       for(FrameBlock block : blocks)
+                               for(int k : threads)
+                                       tests.add(new Object[] {block, k});
                }
                catch(Exception e) {
                        e.printStackTrace();
@@ -120,11 +124,21 @@ public class TransformCompressedTestSingleCol {
                test("{ids:true, hash:[1], K:10}");
        }
 
+       @Test
+       public void testHashDomain1() {
+               test("{ids:true, hash:[1], K:1}");
+       }
+
        @Test
        public void testHashToDummy() {
                test("{ids:true, hash:[1], K:10, dummycode:[1]}");
        }
 
+       @Test
+       public void testHashToDummyDomain1() {
+               test("{ids:true, hash:[1], K:1, dummycode:[1]}");
+       }
+
        public void test(String spec) {
                try {
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
index ce7c5d17d9..d1b2479375 100644
--- 
a/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/frame/transform/TransformCustomTest.java
@@ -21,11 +21,20 @@ package org.apache.sysds.test.component.frame.transform;
 
 import static org.junit.Assert.fail;
 
+import java.util.ArrayList;
+import java.util.List;
+
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderDummycode;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderPassThrough;
+import org.apache.sysds.runtime.transform.encode.CompressedEncode;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
 import org.apache.sysds.test.TestUtils;
@@ -71,6 +80,19 @@ public class TransformCustomTest {
                test("{ids:true, bin:[{id:1, method:equi-height, 
numbins:10}]}");
        }
 
+       @Test(expected = NotImplementedException.class)
+       public void testInvalidEncodeCompressed() throws Exception {
+               List<ColumnEncoderComposite> columnEncoders = new ArrayList<>();
+               List<ColumnEncoder> encoders = new ArrayList<>();
+               // create a nonsense sequence of encoders.
+               encoders.add(new ColumnEncoderDummycode());
+               encoders.add(new ColumnEncoderPassThrough());
+               encoders.add(new ColumnEncoderDummycode());
+               columnEncoders.add(new ColumnEncoderComposite(encoders));
+               MultiColumnEncoder enc = new MultiColumnEncoder(columnEncoders);
+               CompressedEncode.encode(enc, data, 1);
+       }
+
        public void test(String spec) {
                try {
 

Reply via email to