This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 56ac3a82f9 [SYSTEMDS-3436] CLA ArrayOutOfBounds in sample
56ac3a82f9 is described below

commit 56ac3a82f91281cfd540d555f92bfc6282184263
Author: baunsgaard <[email protected]>
AuthorDate: Wed Sep 14 18:45:37 2022 +0200

    [SYSTEMDS-3436] CLA ArrayOutOfBounds in sample
    
    More sparse specific tests and edge case fixes.
    
    Closes #1695
---
 .../compress/colgroup/offset/OffsetFactory.java    |  61 ++++----
 .../compress/estim/encoding/EncodingFactory.java   |  32 ++---
 .../runtime/compress/estim/encoding/IEncode.java   |   4 -
 .../compress/estim/encoding/SparseEncoding.java    | 157 +++++++++++++--------
 .../estim/encoding/EncodeNegativeTest.java         |   4 +-
 .../estim/encoding/EncodeSampleCustom.java         | 154 ++++++++++++++++++++
 .../estim/encoding/EncodeSampleUnbalancedTest.java |   1 +
 .../estim/encoding/EncodeSampleUniformTest.java    |  14 +-
 .../component/compress/offset/OffsetTests.java     |  13 +-
 9 files changed, 323 insertions(+), 117 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java
index 5e76e1c2a5..c1f555ecef 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/offset/OffsetFactory.java
@@ -21,6 +21,7 @@ package org.apache.sysds.runtime.compress.colgroup.offset;
 
 import java.io.DataInput;
 import java.io.IOException;
+import java.util.Arrays;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
@@ -77,30 +78,42 @@ public interface OffsetFactory {
         * @return A new Offset.
         */
        public static AOffset createOffset(int[] indexes, int apos, int alen) {
-
-               final int endLength = alen - apos - 1;
-               if(endLength < 0)
-                       throw new DMLCompressionException("Invalid empty offset 
to create");
-               else if(endLength == 0) // means size of 1 since we store the 
first offset outside the list
-                       return new OffsetSingle(indexes[apos]);
-               else if(endLength == 1)
-                       return new OffsetTwo(indexes[apos], indexes[apos + 1]);
-
-               final int minValue = indexes[apos];
-               final int maxValue = indexes[alen - 1];
-               final int range = maxValue - minValue;
-               // -1 because one index is skipped using a first idex allocated 
as a int.
-
-               final int correctionByte = correctionByte(range, endLength);
-               final int correctionChar = correctionChar(range, endLength);
-
-               final long byteSize = OffsetByte.estimateInMemorySize(endLength 
+ correctionByte);
-               final long charSize = OffsetChar.estimateInMemorySize(endLength 
+ correctionChar);
-
-               if(byteSize < charSize)
-                       return new OffsetByte(indexes, apos, alen);
-               else
-                       return new OffsetChar(indexes, apos, alen);
+               try {
+                       final int endLength = alen - apos - 1;
+                       if(endLength < 0)
+                               throw new DMLCompressionException("Invalid 
empty offset to create");
+                       else if(endLength == 0) // means size of 1 since we 
store the first offset outside the list
+                               return new OffsetSingle(indexes[apos]);
+                       else if(endLength == 1)
+                               return new OffsetTwo(indexes[apos], 
indexes[apos + 1]);
+
+                       final int minValue = indexes[apos];
+                       final int maxValue = indexes[alen - 1];
+                       final int range = maxValue - minValue;
+                       // -1 because one index is skipped using a first idex 
allocated as a int.
+
+                       final int correctionByte = correctionByte(range, 
endLength);
+                       final int correctionChar = correctionChar(range, 
endLength);
+
+                       final long byteSize = 
OffsetByte.estimateInMemorySize(endLength + correctionByte);
+                       final long charSize = 
OffsetChar.estimateInMemorySize(endLength + correctionChar);
+
+                       if(byteSize < charSize)
+                               return new OffsetByte(indexes, apos, alen);
+                       else
+                               return new OffsetChar(indexes, apos, alen);
+               }
+               catch(Exception e) {
+                       for(int i = apos+1; i < alen ; i++){
+                               if(indexes[i] <= indexes[i-1]){
+                                       String message = "Invalid input to 
create offset, all values should be continuously increasing.\n";
+                                       message += "Index " + (i-1) + " and 
Index " + i + " are wrong with values: " + indexes[i-1] + " and " + indexes[i]; 
+                                       throw new 
DMLCompressionException(message , e);
+                               }
+                       }
+                       throw new DMLCompressionException(
+                               "Failed to create offset with input:" + 
Arrays.toString(indexes) + " Apos: " + apos + " Alen: " + alen, e);
+               }
        }
 
        /**
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
index e246817582..2fe6d5e6e5 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/EncodingFactory.java
@@ -66,8 +66,9 @@ public interface EncodingFactory {
         * @return A delta encoded encoding.
         */
        public static IEncode createFromMatrixBlockDelta(MatrixBlock m, boolean 
transposed, int[] rowCols) {
-               final int sampleSize = transposed ? m.getNumColumns() : 
m.getNumRows();
-               return createFromMatrixBlockDelta(m, transposed, rowCols, 
sampleSize);
+               throw new NotImplementedException();
+               // final int sampleSize = transposed ? m.getNumColumns() : 
m.getNumRows();
+               // return createFromMatrixBlockDelta(m, transposed, rowCols, 
sampleSize);
        }
 
        /**
@@ -145,7 +146,7 @@ public interface EncodingFactory {
                        }
 
                        final AOffset o = OffsetFactory.createOffset(offsets);
-                       return new SparseEncoding(d, o, zeroCount, nCol);
+                       return new SparseEncoding(d, o, nCol);
                }
                else {
                        map.replaceWithUIDs();
@@ -203,8 +204,7 @@ public interface EncodingFactory {
                        // Iteration 3 of non zero indexes, make a Offset 
Encoding to know what cells are zero and not.
                        // not done yet
                        final AOffset o = OffsetFactory.createOffset(aix, apos, 
alen);
-                       final int zero = m.getNumColumns() - o.getSize();
-                       return new SparseEncoding(d, o, zero, 
m.getNumColumns());
+                       return new SparseEncoding(d, o, m.getNumColumns());
                }
        }
 
@@ -244,7 +244,7 @@ public interface EncodingFactory {
 
                        final AOffset o = OffsetFactory.createOffset(offsets);
 
-                       return new SparseEncoding(d, o, zeroCount, nRow);
+                       return new SparseEncoding(d, o, nRow);
                }
                else {
                        // Allocate counts, and iterate once to replace counts 
with u ids
@@ -300,10 +300,8 @@ public interface EncodingFactory {
                }
 
                // Iteration 3 of non zero indexes, make a Offset Encoding to 
know what cells are zero and not.
-               AOffset o = OffsetFactory.createOffset(offsets);
-
-               final int zero = m.getNumRows() - offsets.size();
-               return new SparseEncoding(d, o, zero, m.getNumRows());
+               final AOffset o = OffsetFactory.createOffset(offsets);
+               return new SparseEncoding(d, o, m.getNumRows());
        }
 
        private static IEncode createWithReader(MatrixBlock m, int[] rowCols, 
boolean transposed) {
@@ -326,11 +324,9 @@ public interface EncodingFactory {
                        return new ConstEncoding(nRows);
 
                map.replaceWithUIDs();
-               if(offsets.size() < nRows / 4) {
+               if(offsets.size() < nRows / 4)
                        // Output encoded sparse since there is very empty.
-                       final int zeros = nRows - offsets.size();
-                       return createWithReaderSparse(m, map, zeros, rowCols, 
offsets, nRows, transposed);
-               }
+                       return createWithReaderSparse(m, map, rowCols, offsets, 
nRows, transposed);
                else
                        return createWithReaderDense(m, map, rowCols, nRows, 
transposed, offsets.size() < nRows);
 
@@ -354,7 +350,7 @@ public interface EncodingFactory {
                return new DenseEncoding(d);
        }
 
-       private static IEncode createWithReaderSparse(MatrixBlock m, 
DblArrayCountHashMap map, int zeros, int[] rowCols,
+       private static IEncode createWithReaderSparse(MatrixBlock m, 
DblArrayCountHashMap map, int[] rowCols,
                IntArrayList offsets, int nRows, boolean transposed) {
                final ReaderColumnSelection reader2 = 
ReaderColumnSelection.createReader(m, rowCols, transposed);
                DblArray cellVals = reader2.nextRow();
@@ -370,6 +366,10 @@ public interface EncodingFactory {
 
                final AOffset o = OffsetFactory.createOffset(offsets);
 
-               return new SparseEncoding(d, o, zeros, nRows);
+               return new SparseEncoding(d, o, nRows);
+       }
+
+       public static SparseEncoding createSparse(AMapToData map, AOffset off, 
int nRows){
+               return new SparseEncoding(map, off, nRows);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
index b3daf3b5c3..5f15c147ac 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/IEncode.java
@@ -19,8 +19,6 @@
 
 package org.apache.sysds.runtime.compress.estim.encoding;
 
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.CompressionSettings;
 import org.apache.sysds.runtime.compress.estim.EstimationFactors;
 
@@ -29,8 +27,6 @@ import 
org.apache.sysds.runtime.compress.estim.EstimationFactors;
  * column groups.
  */
 public interface IEncode {
-       static final Log LOG = LogFactory.getLog(IEncode.class.getName());
-
        /**
         * Combine two encodings, note it should be guaranteed by the caller 
that the number of unique multiplied does not
         * overflow Integer.
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
index 3d6c4180e6..ae4eb1f40f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/estim/encoding/SparseEncoding.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.runtime.compress.estim.encoding;
 
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.CompressionSettings;
 import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
 import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
@@ -31,6 +33,8 @@ import org.apache.sysds.runtime.compress.utils.IntArrayList;
 /** Most common is zero encoding */
 public class SparseEncoding implements IEncode {
 
+       static final Log LOG = 
LogFactory.getLog(SparseEncoding.class.getName());
+
        /** A map to the distinct values contained */
        protected final AMapToData map;
 
@@ -40,13 +44,9 @@ public class SparseEncoding implements IEncode {
        /** Total number of rows encoded */
        protected final int nRows;
 
-       /** Count of Zero tuples in this encoding */
-       protected final int zeroCount;
-
-       protected SparseEncoding(AMapToData map, AOffset off, int zeroCount, 
int nRows) {
+       protected SparseEncoding(AMapToData map, AOffset off, int nRows) {
                this.map = map;
                this.off = off;
-               this.zeroCount = zeroCount;
                this.nRows = nRows;
        }
 
@@ -90,7 +90,7 @@ public class SparseEncoding implements IEncode {
                if(retOff.size() < nRows / 4) {
                        final AOffset o = OffsetFactory.createOffset(retOff);
                        final AMapToData retMap = 
MapToFactory.create(tmpVals.size(), tmpVals.extractValues(), unique - 1);
-                       return new SparseEncoding(retMap, o, nRows - 
retOff.size(), nRows);
+                       return new SparseEncoding(retMap, o, nRows);
                }
                else {
                        // there will always be a zero therefore unique is not 
subtracted one.
@@ -112,26 +112,6 @@ public class SparseEncoding implements IEncode {
                int il = itl.value();
                int ir = itr.value();
 
-               if(il == fl && ir == fr) { // easy both only have one value
-                       tmpVals.appendValue(0);
-                       if(fl == fr) { // both on same row
-                               retOff.appendValue(fl);
-                               return 2;
-                       }
-                       // Known two locations to add.
-                       tmpVals.appendValue(1);
-                       if(fl < fr) {// fl is first
-                               retOff.appendValue(fl);
-                               retOff.appendValue(fr);
-                               return 3;
-                       }
-                       else {// fl is last
-                               retOff.appendValue(fr);
-                               retOff.appendValue(fl);
-                               return 3;
-                       }
-               }
-
                while(il < fl && ir < fr) {
                        if(il == ir) {// Both sides have a value same row.
                                final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
@@ -164,61 +144,110 @@ public class SparseEncoding implements IEncode {
                int il = itl.value();
                int ir = itr.value();
 
-               if(il < fl) {
-                       while(il < fr && il < fl) {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
-                               newUID = addVal(nv, il, d, newUID, tmpVals, 
retOff);
-                               il = itl.next();
-                       }
+               if(il == fl && ir == fr) {
                        if(fl == fr) {
                                final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
                                return addVal(nv, il, d, newUID, tmpVals, 
retOff);
                        }
-                       else if(il == fr) {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+                       else if(fl < fr) {// fl is first
+                               int nv = lMap.getIndex(itl.getDataIndex()) + 
defR;
                                newUID = addVal(nv, il, d, newUID, tmpVals, 
retOff);
-                               il = itl.next();
-                       }
-                       else {
-                               final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
+                               nv = rMap.getIndex(itr.getDataIndex()) * nVl + 
defL;
                                newUID = addVal(nv, fr, d, newUID, tmpVals, 
retOff);
                        }
-                       while(il < fl) {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                       else {// fl is last
+                               int nv = rMap.getIndex(itr.getDataIndex()) * 
nVl + defL;
                                newUID = addVal(nv, il, d, newUID, tmpVals, 
retOff);
-                               il = itl.next();
+                               nv = lMap.getIndex(itl.getDataIndex()) + defR;
+                               newUID = addVal(nv, fr, d, newUID, tmpVals, 
retOff);
                        }
-                       final int nv = lMap.getIndex(itl.getDataIndex()) + defR;
-                       newUID = addVal(nv, il, d, newUID, tmpVals, retOff);
                }
-               else if(ir < fr) {
-                       while(ir < fl && ir < fr) {
-                               final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
-                               newUID = addVal(nv, ir, d, newUID, tmpVals, 
retOff);
-                               ir = itr.next();
+               else if(il < fl) {
+                       if(fl < fr) {
+                               while(il < fl) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                                       newUID = addVal(nv, il, d, newUID, 
tmpVals, retOff);
+                                       il = itl.next();
+                               }
+                               int nv = lMap.getIndex(itl.getDataIndex()) + 
defR;
+                               newUID = addVal(nv, il, d, newUID, tmpVals, 
retOff);
+                               nv = rMap.getIndex(itr.getDataIndex()) * nVl + 
defL;
+                               newUID = addVal(nv, fr, d, newUID, tmpVals, 
retOff);
+                               return newUID;
                        }
+                       else {
+                               while(il < fr) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                                       newUID = addVal(nv, il, d, newUID, 
tmpVals, retOff);
+                                       il = itl.next();
+                               }
+                               if(fl == fr) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+                                       return addVal(nv, il, d, newUID, 
tmpVals, retOff);
+                               }
+                               else if(il == fr) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+                                       newUID = addVal(nv, il, d, newUID, 
tmpVals, retOff);
+                                       il = itl.next();
+                               }
+                               else {
+                                       final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
+                                       newUID = addVal(nv, fr, d, newUID, 
tmpVals, retOff);
+                               }
+
+                               while(il < fl) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                                       newUID = addVal(nv, il, d, newUID, 
tmpVals, retOff);
+                                       il = itl.next();
+                               }
+                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                               newUID = addVal(nv, il, d, newUID, tmpVals, 
retOff);
 
-                       if(fr == fl) {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
-                               return addVal(nv, ir, d, newUID, tmpVals, 
retOff);
                        }
-                       else if(ir == fl) {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+               }
+               else { // if(ir < fr)
+                       if(fr < fl) {
+                               while(ir < fr) {
+                                       final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
+                                       newUID = addVal(nv, ir, d, newUID, 
tmpVals, retOff);
+                                       ir = itr.next();
+                               }
+                               int nv = rMap.getIndex(itr.getDataIndex()) * 
nVl + defL;
                                newUID = addVal(nv, ir, d, newUID, tmpVals, 
retOff);
-                               ir = itr.next();
-                       }
-                       else {
-                               final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                               nv = lMap.getIndex(itl.getDataIndex()) + defR;
                                newUID = addVal(nv, fl, d, newUID, tmpVals, 
retOff);
+                               return newUID;
                        }
-
-                       while(ir < fr) {
+                       else {
+                               while(ir < fl) {
+                                       final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
+                                       newUID = addVal(nv, ir, d, newUID, 
tmpVals, retOff);
+                                       ir = itr.next();
+                               }
+
+                               if(fr == fl) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+                                       return addVal(nv, ir, d, newUID, 
tmpVals, retOff);
+                               }
+                               else if(ir == fl) {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + rMap.getIndex(itr.getDataIndex()) * nVl;
+                                       newUID = addVal(nv, ir, d, newUID, 
tmpVals, retOff);
+                                       ir = itr.next();
+                               }
+                               else {
+                                       final int nv = 
lMap.getIndex(itl.getDataIndex()) + defR;
+                                       newUID = addVal(nv, fl, d, newUID, 
tmpVals, retOff);
+                               }
+
+                               while(ir < fr) {
+                                       final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
+                                       newUID = addVal(nv, ir, d, newUID, 
tmpVals, retOff);
+                                       ir = itr.next();
+                               }
                                final int nv = 
rMap.getIndex(itr.getDataIndex()) * nVl + defL;
                                newUID = addVal(nv, ir, d, newUID, tmpVals, 
retOff);
-                               ir = itr.next();
+
                        }
-                       final int nv = rMap.getIndex(itr.getDataIndex()) * nVl 
+ defL;
-                       newUID = addVal(nv, ir, d, newUID, tmpVals, retOff);
                }
 
                return newUID;
@@ -297,6 +326,10 @@ public class SparseEncoding implements IEncode {
                return false;
        }
 
+       public AOffset getOffsets() {
+               return off;
+       }
+
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeNegativeTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeNegativeTest.java
index 0edf02bf6e..a3e8a6c2d7 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeNegativeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeNegativeTest.java
@@ -44,12 +44,12 @@ public class EncodeNegativeTest {
                EncodingFactory.createFromMatrixBlock(mock, true, 3);
        }
 
-       @Test(expected = NullPointerException.class)
+       @Test(expected = NotImplementedException.class)
        public void testInvalidToCallWithNullDeltaTransposed() {
                EncodingFactory.createFromMatrixBlockDelta(null, true, null);
        }
 
-       @Test(expected = NullPointerException.class)
+       @Test(expected = NotImplementedException.class)
        public void testInvalidToCallWithNullDelta() {
                EncodingFactory.createFromMatrixBlockDelta(null, false, null);
        }
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
index 1e02d5ff57..b531e9dbf4 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleCustom.java
@@ -19,6 +19,7 @@
 
 package org.apache.sysds.test.component.compress.estim.encoding;
 
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.io.File;
@@ -32,7 +33,13 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData;
 import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory;
+import org.apache.sysds.runtime.compress.colgroup.offset.AOffset;
+import org.apache.sysds.runtime.compress.colgroup.offset.OffsetFactory;
 import org.apache.sysds.runtime.compress.estim.encoding.DenseEncoding;
+import org.apache.sysds.runtime.compress.estim.encoding.EncodingFactory;
+import org.apache.sysds.runtime.compress.estim.encoding.IEncode;
+import org.apache.sysds.runtime.compress.estim.encoding.SparseEncoding;
+import org.apache.sysds.test.component.compress.offset.OffsetTests;
 import org.junit.Test;
 
 import scala.NotImplementedError;
@@ -62,6 +69,153 @@ public class EncodeSampleCustom {
                }
        }
 
+       @Test
+       public void testSparse() {
+               // Custom combine from US Census Encoded dataset.
+               AMapToData Z0 = MapToFactory.create(77, 0);
+               AOffset O0 = OffsetFactory.createOffset(new int[] {4036, 4382, 
4390, 4764, 4831, 4929, 5013, 6964, 7018, 7642,
+                       8306, 8559, 8650, 9041, 9633, 9770, 11000, 11702, 
11851, 11890, 11912, 13048, 15859, 16164, 16191, 16212,
+                       17927, 18344, 19007, 19614, 19806, 20878, 21884, 21924, 
22245, 22454, 23185, 23825, 24128, 24829, 25835, 26130,
+                       26456, 26767, 27058, 28094, 28250, 28335, 28793, 30175, 
30868, 32526, 32638, 33464, 33536, 33993, 34096, 34146,
+                       34686, 35863, 36655, 37212, 37535, 37832, 38328, 38689, 
39802, 39810, 39835, 40065, 40554, 41221, 41420, 42133,
+                       42914, 43027, 43092});
+               AMapToData Z1 = MapToFactory.create(65, 0);
+               AOffset O1 = OffsetFactory.createOffset(new int[] {294, 855, 
1630, 1789, 1872, 1937, 2393, 2444, 3506, 4186, 5210,
+                       6048, 6073, 8645, 9147, 9804, 9895, 13759, 14041, 
14198, 16138, 16548, 16566, 17249, 18257, 18484, 18777,
+                       18881, 19138, 19513, 20127, 21443, 23264, 23432, 24050, 
24332, 24574, 24579, 25246, 25513, 25686, 27075, 31190,
+                       31305, 31429, 31520, 31729, 32073, 32670, 33529, 34453, 
34947, 36224, 37219, 38412, 39505, 39799, 40074, 40569,
+                       40610, 40745, 41755, 41761, 41875, 44394});
+               SparseEncoding a = EncodingFactory.createSparse(Z0, O0, 50000);
+               SparseEncoding b = EncodingFactory.createSparse(Z1, O1, 50000);
+
+               a.combine(b);
+       }
+
+       @Test
+       public void testSparse_2() {
+               // Custom combine from US Census Encoded dataset.
+               AMapToData Z0 = MapToFactory.create(8, 0);
+               AOffset O0 = OffsetFactory.createOffset(new int[] {40065, 
40554, 41221, 41420, 42133, 42914, 43027, 43092});
+               AMapToData Z1 = MapToFactory.create(7, 0);
+               AOffset O1 = OffsetFactory.createOffset(new int[] {40569, 
40610, 40745, 41755, 41761, 41875, 44394});
+               SparseEncoding a = EncodingFactory.createSparse(Z0, O0, 50000);
+               SparseEncoding b = EncodingFactory.createSparse(Z1, O1, 50000);
+
+               a.combine(b);
+       }
+
+       @Test
+       public void testSparse_3() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 7});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 7, 9};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_4() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 10});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 9, 10};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_5() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 10, 11, 12});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 9, 10, 11, 12};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_6() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9, 
12});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 10, 11, 12});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 9, 10, 11, 12};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_7() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9, 
11, 12});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 10, 11, 12});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 9, 10, 11, 12};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_8() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9, 
11, 12, 13, 14, 15, 16});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 5, 
6, 10, 11, 12});
+               int[] exp = new int[] {1, 2, 3, 5, 6, 9, 10, 11, 12, 13, 14, 
15, 16};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_9() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 9, 
11, 12, 13, 14, 15, 16});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 12, 
17});
+               int[] exp = new int[] {1, 2, 3, 9, 11, 12, 13, 14, 15, 16, 17};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_10() {
+               AOffset a = OffsetFactory.createOffset(new int[] {16});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 12, 
17});
+               int[] exp = new int[] {1, 2, 3, 12, 16, 17};
+               compareSparse(a, b, exp);
+       }
+
+       @Test
+       public void testSparse_11() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 16, 
18});
+               AOffset b = OffsetFactory.createOffset(new int[] {17});
+               int[] exp = new int[] {1, 2, 3, 16, 17, 18};
+               compareSparse(a, b, exp);
+       }
+
+       public void compareSparse(AOffset a, AOffset b, int[] exp) {
+               try {
+                       AMapToData Z0 = MapToFactory.create(a.getSize(), 0);
+                       AMapToData Z1 = MapToFactory.create(b.getSize(), 0);
+                       SparseEncoding aa = EncodingFactory.createSparse(Z0, a, 
50000);
+                       SparseEncoding bb = EncodingFactory.createSparse(Z1, b, 
50000);
+                       SparseEncoding c = (SparseEncoding) aa.combine(bb);
+                       OffsetTests.compare(c.getOffsets(), exp);
+               }
+               catch(Exception e) {
+                       e.printStackTrace();
+                       fail("Failed combining sparse correctly.\n" + a + "\n" 
+ b + "\nExpected:" + Arrays.toString(exp));
+               }
+       }
+
+       @Test
+       public void combineSimilarOffsetButNotMap() {
+
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 16, 
18});
+               AMapToData Z0 = MapToFactory.create(a.getSize(), 0);
+               AMapToData Z1 = MapToFactory.create(a.getSize(), 0);
+
+               SparseEncoding aa = EncodingFactory.createSparse(Z0, a, 50000);
+               SparseEncoding bb = EncodingFactory.createSparse(Z1, a, 50000);
+               IEncode c = aa.combine(bb);
+               assertTrue(c != aa);
+       }
+
+       @Test
+       public void combineSimilarMapButNotOffsets() {
+               AOffset a = OffsetFactory.createOffset(new int[] {1, 2, 3, 16, 
18});
+               AOffset b = OffsetFactory.createOffset(new int[] {1, 2, 3, 17, 
18});
+               AMapToData Z0 = MapToFactory.create(a.getSize(), 0);
+
+               SparseEncoding aa = EncodingFactory.createSparse(Z0, a, 50000);
+               SparseEncoding bb = EncodingFactory.createSparse(Z0, b, 50000);
+               IEncode c = aa.combine(bb);
+               assertTrue(c != aa);
+       }
+
        private static int[] readData(String path) {
                try {
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java
index 6f45882c7b..c499ac7f7f 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUnbalancedTest.java
@@ -69,6 +69,7 @@ public class EncodeSampleUnbalancedTest extends 
EncodeSampleMultiColTest {
                for(int i = 0; i < 10; i++) {
 
                        tests.add(createTSparse(1, .01, 2, 1, .01, 2, 100, i * 
231, true, true));
+                       tests.add(createTSparse(1, .1, 3, 1, .2, 3, 100, i * 
231, true, true));
                }
 
                // big sparse
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUniformTest.java
 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUniformTest.java
index 534d0d8e23..37c5ed18c5 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUniformTest.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/estim/encoding/EncodeSampleUniformTest.java
@@ -54,12 +54,14 @@ public class EncodeSampleUniformTest extends 
EncodeSampleMultiColTest {
                tests.add(create(30, 10, 1.0, false, 2, 7654));
 
                // row sparse
-               tests.add(create(2, 300, 0.1, true, 2, 1251));
-               tests.add(create(2, 300, 0.1, true, 2, 11));
-               tests.add(create(2, 300, 0.2, true, 2, 65));
-               tests.add(create(2, 300, 0.24, true, 2, 245));
-               tests.add(create(2, 300, 0.24, true, 4, 16));
-               tests.add(create(2, 300, 0.23, true, 4, 15));
+               for(int i = 0; i < 5; i++) {
+                       tests.add(create(2, 300, 0.1, true, 2 , 1251 * i));
+                       tests.add(create(2, 300, 0.1, true, 2 , 11 * i));
+                       tests.add(create(2, 300, 0.2, true, 2 , 65 * i));
+                       tests.add(create(2, 300, 0.24, true, 2 , 245 * i));
+                       tests.add(create(2, 300, 0.24, true, 3 , 16 * i));
+                       tests.add(create(2, 300, 0.23, true, 3 , 15 * i));
+               }
 
                // ultra sparse
                tests.add(create(2, 10000, 0.001, true, 3, 215));
diff --git 
a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java
 
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java
index 157d9a642d..7bedc7768a 100644
--- 
a/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java
+++ 
b/src/test/java/org/apache/sysds/test/component/compress/offset/OffsetTests.java
@@ -447,7 +447,7 @@ public class OffsetTests {
        }
 
        @Test
-       public void testIteratorToString(){
+       public void testIteratorToString() {
                AOffsetIterator a = o.getOffsetIterator();
                a.toString();
 
@@ -455,8 +455,12 @@ public class OffsetTests {
                b.toString();
        }
 
-       protected static void compare(AOffset o, int[] v) {
+       public static void compare(AOffset o, int[] v) {
                AIterator i = o.getIterator();
+
+               if(o.getSize() != v.length) {
+                       fail("Incorrect result sizes : " + o + " " + 
Arrays.toString(v));
+               }
                if(v[0] != i.value())
                        fail("incorrect result using : " + 
o.getClass().getSimpleName() + " expected: " + Arrays.toString(v)
                                + " but was :" + o.toString());
@@ -471,7 +475,10 @@ public class OffsetTests {
                                + o.getOffsetsLength() + "\n" + 
Arrays.toString(v));
        }
 
-       protected static void compareOffsetIterator(AOffset o, int[] v) {
+       public static void compareOffsetIterator(AOffset o, int[] v) {
+               if(o.getSize() != v.length) {
+                       fail("Incorrect result sizes : " + o + " " + 
Arrays.toString(v));
+               }
                AOffsetIterator i = o.getOffsetIterator();
                if(v[0] != i.value())
                        fail("incorrect result using : " + 
o.getClass().getSimpleName() + " expected: " + Arrays.toString(v)

Reply via email to