This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit ead3c5e9fbbfa6884d0a226abf029b2ab7157424 Author: baunsgaard <[email protected]> AuthorDate: Wed Oct 19 18:23:08 2022 +0200 [SYSTEMDS-3445] N Way combining Mapping This commit adds support for n way combining / appending AMapToData objects. This enables for instance DDC to do n-way-combining as well. --- .../runtime/compress/colgroup/AMapToDataGroup.java | 26 ++++++ .../runtime/compress/colgroup/ColGroupDDC.java | 31 ++++++- .../runtime/compress/colgroup/ColGroupDDCFOR.java | 2 +- .../runtime/compress/colgroup/ColGroupOLE.java | 2 +- .../runtime/compress/colgroup/ColGroupRLE.java | 2 +- .../runtime/compress/colgroup/ColGroupSDC.java | 2 +- .../runtime/compress/colgroup/ColGroupSDCFOR.java | 2 +- .../compress/colgroup/ColGroupSDCSingle.java | 2 +- .../compress/colgroup/ColGroupSDCSingleZeros.java | 2 +- .../compress/colgroup/ColGroupSDCZeros.java | 2 +- .../compress/colgroup/mapping/AMapToData.java | 3 + .../compress/colgroup/mapping/MapToBit.java | 60 +++++++++++++- .../compress/colgroup/mapping/MapToByte.java | 28 ++++++- .../compress/colgroup/mapping/MapToChar.java | 19 +++++ .../compress/colgroup/mapping/MapToCharPByte.java | 6 ++ .../compress/colgroup/mapping/MapToInt.java | 19 +++++ .../compress/colgroup/mapping/MapToZero.java | 13 ++- .../component/compress/combine/CombineTest.java | 16 ++-- .../component/compress/mapping/MappingTests.java | 96 +++++++++++++++++++++- 19 files changed, 302 insertions(+), 31 deletions(-) diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMapToDataGroup.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMapToDataGroup.java new file mode 100644 index 0000000000..50b65898a1 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/AMapToDataGroup.java @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.colgroup; + +import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; + +public interface AMapToDataGroup { + public AMapToData getMapToData(); +} diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java index 9ef46554d0..ba642e7df9 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDC.java @@ -46,7 +46,7 @@ import org.apache.sysds.runtime.matrix.operators.UnaryOperator; /** * Class to encapsulate information about a column group that is encoded with dense dictionary encoding (DDC). */ -public class ColGroupDDC extends APreAgg { +public class ColGroupDDC extends APreAgg implements AMapToDataGroup { private static final long serialVersionUID = -5769772089913918987L; protected final AMapToData _data; @@ -122,6 +122,11 @@ public class ColGroupDDC extends APreAgg { } + @Override + public AMapToData getMapToData(){ + return _data; + } + private void decompressToDenseBlockDenseDictSingleColOutContiguous(DenseBlock db, int rl, int ru, int offR, int offC, double[] values) { final double[] c = db.values(0); @@ -481,7 +486,7 @@ public class ColGroupDDC extends APreAgg { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, newDictionary, _data, getCounts()); + return create(colIndexes, newDictionary, _data, getCachedCounts()); } @Override @@ -506,10 +511,28 @@ public class ColGroupDDC extends APreAgg { return null; } - @Override public AColGroup appendNInternal(AColGroup[] g) { - return null; + for(int i = 1; i < g.length; i++) { + if(!Arrays.equals(_colIndexes, g[i]._colIndexes)) { + LOG.warn("Not same columns therefore not appending DDC\n" + Arrays.toString(_colIndexes) + "\n\n" + + Arrays.toString(g[i]._colIndexes)); + return null; + } + + if(!(g[i] instanceof ColGroupDDC)) { + LOG.warn("Not DDC but " + g[i].getClass().getSimpleName() + ", therefore not appending DDC"); + return null; + } + + final ColGroupDDC gDDC = (ColGroupDDC) g[i]; + if(!gDDC._dict.eq(_dict)) { + LOG.warn("Not same Dictionaries therefore not appending DDC\n" + _dict + "\n\n" + gDDC._dict); + return null; + } + } + AMapToData nd = _data.appendN(Arrays.copyOf(g, g.length, AMapToDataGroup[].class)); + return create(_colIndexes, _dict, nd, null); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java index effc025de1..ca0b2b46f8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupDDCFOR.java @@ -430,7 +430,7 @@ public class ColGroupDDCFOR extends AMorphingMMColGroup { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, newDictionary, _data, getCounts(), _reference); + return create(colIndexes, newDictionary, _data, getCachedCounts(), _reference); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java index d92fda137a..c966f585fe 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupOLE.java @@ -652,7 +652,7 @@ public class ColGroupOLE extends AColGroupOffset { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, _zeros, newDictionary, _data, _ptr, getCounts()); + return create(colIndexes, _numRows, _zeros, newDictionary, _data, _ptr, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java index 426ad96493..b923ba7c2c 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupRLE.java @@ -964,7 +964,7 @@ public class ColGroupRLE extends AColGroupOffset { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, _zeros, newDictionary, _data, _ptr, getCounts()); + return create(colIndexes, _numRows, _zeros, newDictionary, _data, _ptr, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java index d36cd0289b..d055ade891 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDC.java @@ -566,7 +566,7 @@ public class ColGroupSDC extends ASDC { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, newDictionary, _defaultTuple, _indexes, _data, getCounts()); + return create(colIndexes, _numRows, newDictionary, _defaultTuple, _indexes, _data, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java index 30f0e7568e..5a1b9212df 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCFOR.java @@ -436,7 +436,7 @@ public class ColGroupSDCFOR extends ASDC { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, newDictionary, _indexes, _data, getCounts(), _reference); + return create(colIndexes, _numRows, newDictionary, _indexes, _data, getCachedCounts(), _reference); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java index 6a2851e753..55372423cc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingle.java @@ -573,7 +573,7 @@ public class ColGroupSDCSingle extends ASDC { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, newDictionary, _defaultTuple, _indexes, getCounts()); + return create(colIndexes, _numRows, newDictionary, _defaultTuple, _indexes, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java index e044446b33..1731c7799b 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCSingleZeros.java @@ -806,7 +806,7 @@ public class ColGroupSDCSingleZeros extends ASDCZero { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, newDictionary, _indexes, getCounts()); + return create(colIndexes, _numRows, newDictionary, _indexes, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java index 2d4fc037e8..961e63570e 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/ColGroupSDCZeros.java @@ -717,7 +717,7 @@ public class ColGroupSDCZeros extends ASDCZero { @Override protected AColGroup copyAndSet(int[] colIndexes, ADictionary newDictionary) { - return create(colIndexes, _numRows, newDictionary, _indexes, _data, getCounts()); + return create(colIndexes, _numRows, newDictionary, _indexes, _data, getCachedCounts()); } @Override diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java index b9bec0b742..4e781cf5bc 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/AMapToData.java @@ -27,6 +27,7 @@ import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.dictionary.Dictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; @@ -818,6 +819,8 @@ public abstract class AMapToData implements Serializable { public abstract AMapToData append(AMapToData t); + public abstract AMapToData appendN(AMapToDataGroup[] d); + @Override public String toString() { final int sz = size(); diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java index d2f9f49cae..4d06e1efd6 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToBit.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -323,12 +324,12 @@ public class MapToBit extends AMapToData { @Override public AMapToData slice(int l, int u) { - return new MapToBit(getUnique(), _data.get(l,u), u - l); + return new MapToBit(getUnique(), _data.get(l, u), u - l); } @Override public AMapToData append(AMapToData t) { - if(t instanceof MapToBit){ + if(t instanceof MapToBit) { MapToBit tb = (MapToBit) t; BitSet tbb = tb._data; final int newSize = _size + t.size(); @@ -338,9 +339,62 @@ public class MapToBit extends AMapToData { tbb.stream().forEach(x -> ret.set(x + _size, true)); return new MapToBit(2, ret, newSize); } - else{ + else { throw new NotImplementedException("Not implemented append on Bit map different type"); } } + + @Override + public AMapToData appendN(AMapToDataGroup[] d) { + int p = 0; // pointer + for(AMapToDataGroup gd : d) + p += gd.getMapToData().size(); + final long[] ret = new long[(p - 1) / 64 + 1]; + long[] or = _data.toLongArray(); + System.arraycopy(or, 0, ret, 0, or.length); + + p = size(); + for(int i = 1; i < d.length; i++) { + final MapToBit mm = (MapToBit) d[i].getMapToData(); + final int ms = mm.size(); + or = mm._data.toLongArray(); + final int remainder = p % 64; + int retLp = p / 64; + if(remainder == 0)// Easy lining up + System.arraycopy(or, 0, ret, retLp, or.length); + else { // Not Lining up + // all but last + for(int j = 0; j < or.length - 1; j++) { + long v = or[j]; + ret[retLp] = ret[retLp] ^ (v << remainder); + retLp++; + ret[retLp] = v >>> (64 - remainder); + } + // last + long v = or[or.length - 1]; + ret[retLp] = ret[retLp] ^ (v << remainder); + retLp++; + if(retLp < ret.length) + ret[retLp] = v >>> (64 - remainder); + } + p += ms; + } + + BitSet retBS = BitSet.valueOf(ret); + return new MapToBit(getUnique(), retBS, p); + + } + + private static String bl(long l) { + int lead = Long.numberOfLeadingZeros(l); + if(lead == 64) + return "0000000000000000000000000000000000000000000000000000000000000000"; + StringBuilder sb = new StringBuilder(64); + for(int i = 0; i < lead; i++) { + sb.append('0'); + } + sb.append(Long.toBinaryString(l)); + return sb.toString(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java index e0e13bf1e7..81eb974490 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToByte.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -214,13 +215,13 @@ public class MapToByte extends AMapToData { @Override public AMapToData append(AMapToData t) { if(t instanceof MapToByte) { - MapToByte tb = (MapToByte) t; - byte[] tbb = tb._data; + final MapToByte tb = (MapToByte) t; + final byte[] tbb = tb._data; final int newSize = _data.length + t.size(); final int newDistinct = Math.max(getUnique(), t.getUnique()); // copy - byte[] ret = Arrays.copyOf(_data, newSize); + final byte[] ret = Arrays.copyOf(_data, newSize); System.arraycopy(tbb, 0, ret, _data.length, t.size()); // return @@ -233,4 +234,25 @@ public class MapToByte extends AMapToData { throw new NotImplementedException("Not implemented append on Bit map different type"); } } + + @Override + public AMapToData appendN(AMapToDataGroup[] d) { + int p = 0; // pointer + for(AMapToDataGroup gd : d) + p += gd.getMapToData().size(); + final byte[] ret = Arrays.copyOf(_data, p); + + p = size(); + for(int i = 1; i < d.length; i++) { + final MapToByte mm = (MapToByte) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } + + if(getUnique() < 127) + return new MapToUByte(getUnique(), ret); + else + return new MapToByte(getUnique(), ret); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java index b309861ee2..f24c8a72e8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToChar.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -251,4 +252,22 @@ public class MapToChar extends AMapToData { throw new NotImplementedException("Not implemented append on Bit map different type"); } } + + @Override + public AMapToData appendN(AMapToDataGroup[] d) { + int p = 0; // pointer + for(AMapToDataGroup gd : d) + p += gd.getMapToData().size(); + final char[] ret = Arrays.copyOf(_data, p); + + p = size(); + for(int i = 1; i < d.length; i++) { + final MapToChar mm = (MapToChar) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } + + return new MapToChar(getUnique(), ret); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java index 88b2e435e1..f1483a2bcd 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToCharPByte.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -241,4 +242,9 @@ public class MapToCharPByte extends AMapToData { throw new NotImplementedException("Not implemented append on Bit map different type"); } } + + @Override + public AMapToData appendN(AMapToDataGroup[] d){ + throw new NotImplementedException(); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java index b4117cd5b8..2324b2bac0 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToInt.java @@ -26,6 +26,7 @@ import java.util.Arrays; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; import org.apache.sysds.utils.MemoryEstimates; @@ -252,4 +253,22 @@ public class MapToInt extends AMapToData { throw new NotImplementedException("Not implemented append on Bit map different type"); } } + + @Override + public AMapToData appendN(AMapToDataGroup[] d) { + int p = 0; // pointer + for(AMapToDataGroup gd : d) + p += gd.getMapToData().size(); + final int[] ret = Arrays.copyOf(_data, p); + + p = size(); + for(int i = 1; i < d.length; i++) { + final MapToInt mm = (MapToInt) d[i].getMapToData(); + final int ms = mm.size(); + System.arraycopy(mm._data, 0, ret, p, ms); + p += ms; + } + + return new MapToInt(getUnique(), ret); + } } diff --git a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java index c3cd14afc4..8e2115e5b8 100644 --- a/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java +++ b/src/main/java/org/apache/sysds/runtime/compress/colgroup/mapping/MapToZero.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.BitSet; import org.apache.commons.lang.NotImplementedException; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.dictionary.ADictionary; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory.MAP_TYPE; @@ -153,9 +154,17 @@ public class MapToZero extends AMapToData { @Override public AMapToData append(AMapToData t) { - if(t instanceof MapToZero) + if(t instanceof MapToZero) return new MapToZero(_size + t.size()); - else + else throw new NotImplementedException("Not implemented append on Bit map different type"); } + + @Override + public AMapToData appendN(AMapToDataGroup[] d) { + int p = 0; // pointer + for(AMapToDataGroup gd : d) + p += gd.getMapToData().size(); + return new MapToZero(p); + } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineTest.java b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineTest.java index 8db991ef9f..4223e5c773 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/combine/CombineTest.java +++ b/src/test/java/org/apache/sysds/test/component/compress/combine/CombineTest.java @@ -87,13 +87,7 @@ public class CombineTest { @Test public void combineDDC() { - MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(165, 2, 1, 3, 1.0, 2514)); - CompressedMatrixBlock csb = (CompressedMatrixBlock) CompressedMatrixBlockFactory - .compress(mb, - new CompressionSettingsBuilder().clearValidCompression().addValidCompression(CompressionType.DDC)) - .getLeft(); - - AColGroup g = csb.getColGroups().get(0); + AColGroup g = getDDC(); double sum = g.getSum(165); AColGroup ret = g.append(g); double sum2 = ret.getSum(165 * 2); @@ -101,7 +95,15 @@ public class CombineTest { AColGroup ret2 = ret.append(g); double sum3 = ret2.getSum(165 * 3); assertEquals(sum * 3, sum3, 0.001); + } + private static AColGroup getDDC(){ + MatrixBlock mb = TestUtils.ceil(TestUtils.generateTestMatrixBlock(165, 2, 1, 3, 1.0, 2514)); + CompressedMatrixBlock csb = (CompressedMatrixBlock) CompressedMatrixBlockFactory + .compress(mb, + new CompressionSettingsBuilder().clearValidCompression().addValidCompression(CompressionType.DDC)) + .getLeft(); + return csb.getColGroups().get(0); } } diff --git a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java index 4d9712efcb..c1e9046830 100644 --- a/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java +++ b/src/test/java/org/apache/sysds/test/component/compress/mapping/MappingTests.java @@ -32,8 +32,10 @@ import java.util.Arrays; import java.util.Collection; import java.util.Random; +import org.apache.commons.lang.NotImplementedException; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.colgroup.AMapToDataGroup; import org.apache.sysds.runtime.compress.colgroup.mapping.AMapToData; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToCharPByte; import org.apache.sysds.runtime.compress.colgroup.mapping.MapToFactory; @@ -68,8 +70,14 @@ public class MappingTests { tests.add(new Object[] {6, t, 63, false}); tests.add(new Object[] {4, t, 63, false}); tests.add(new Object[] {3, t, 64, false}); + tests.add(new Object[] {4, t, 64, false}); tests.add(new Object[] {3, t, 65, false}); tests.add(new Object[] {5, t, 64 + 63, false}); + tests.add(new Object[] {5, t, 127, false}); + tests.add(new Object[] {5, t, 128, false}); + tests.add(new Object[] {5, t, 129, false}); + tests.add(new Object[] {7, t, 255, false}); + tests.add(new Object[] {8, t, 256, false}); tests.add(new Object[] {5, t, 1234, false}); tests.add(new Object[] {5, t, 13, true}); } @@ -80,7 +88,7 @@ public class MappingTests { this.seed = seed; this.type = type; this.size = size; - this.max =Math.min(Math.min(MappingTestUtil.getUpperBoundValue(type), fictiveMax) + 1, size); + this.max = Math.min(Math.min(MappingTestUtil.getUpperBoundValue(type), fictiveMax) + 1, size); expected = new int[size]; m = genMap(MapToFactory.create(size, max), expected, max, fill, seed); } @@ -113,7 +121,7 @@ public class MappingTests { } // to make sure that the bit set is actually filled. - for(int i = 0; i < max; i++){ + for(int i = 0; i < max; i++) { m.set(i, i); expected[i] = i; @@ -126,7 +134,6 @@ public class MappingTests { for(int i = 0; i < size; i++) if(expected[i] != m.getIndex(i)) fail("Expected equals " + Arrays.toString(expected) + "\nbut got: " + m); - } @Test @@ -214,7 +221,7 @@ public class MappingTests { @Test public void replaceMax() { - m.replace(max-1, 0); + m.replace(max - 1, 0); for(int i = 0; i < size; i++) { expected[i] = expected[i] == max - 1 ? 0 : expected[i]; @@ -271,4 +278,85 @@ public class MappingTests { + m.getType() + " " + type + " " + max + " " + m); } + @Test + public void testAppend() { + int nVal = m.getUnique(); + int[] counts = m.getCounts(new int[nVal]); + + AMapToData m2 = m.append(m); + assertEquals(m.size() * 2, m2.size()); + assertEquals(m.getUnique(), m2.getUnique()); + int[] counts2 = m2.getCounts(new int[nVal]); + + for(int i = 0; i < nVal; i++) + assertEquals(counts[i] * 2, counts2[i]); + } + + @Test + public void testAppendN() { + int nVal = m.getUnique(); + int[] counts = m.getCounts(new int[nVal]); + + try { + + AMapToData m2 = m.appendN(new AMapToDataGroup[] {// + new Holder(m), new Holder(m), new Holder(m)}); + try { + assertEquals(m.size() * 3, m2.size()); + assertEquals(m.getUnique(), m2.getUnique()); + int[] counts2 = m2.getCounts(new int[nVal]); + + for(int i = 0; i < nVal; i++) + assertEquals(counts[i] * 3, counts2[i]); + } + catch(AssertionError e) { + fail(e.getMessage() + "\nFailed appendN with in: \n" + m + "\ncomp:\n" + m2); + } + } + catch(Exception e) { + e.printStackTrace(); + fail("Failed " + e.getMessage()); + } + + } + + @Test + public void testAppendVSAppendN() { + final AMapToData m2 = m.append(m).append(m); + final AMapToData m3 = m.appendN(new AMapToDataGroup[] {// + new Holder(m), new Holder(m), new Holder(m)}); + compare(m2, m3); + } + + @Test(expected = NotImplementedException.class) + public void testAppendNotSame() { + AMapToData mm; + switch(type) { + case INT: + mm = MapToFactory.create(size, MAP_TYPE.CHAR_BYTE); + mm.copy(m); + m.append(mm); + break; + default: + mm = MapToFactory.create(size, MAP_TYPE.INT); + mm.copy(m); + m.append(mm); + } + LOG.error("Did not throw exception with: " + m); + } + + private static class Holder implements AMapToDataGroup { + + AMapToData d; + + protected Holder(AMapToData d) { + this.d = d; + } + + @Override + public AMapToData getMapToData() { + return d; + } + + } }
