This is an automated email from the ASF dual-hosted git repository. baunsgaard pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 5173aa072a [MINOR] Add a custom LongInt hashmap 5173aa072a is described below commit 5173aa072a7fd2ebae6ef3ba1260801140da264c Author: Sebastian Baunsgaard <baunsga...@apache.org> AuthorDate: Tue Apr 16 13:49:08 2024 +0200 [MINOR] Add a custom LongInt hashmap This commit adds a new longint hash map for efficient combining of column groups. The commit does not enable the HashMap, but separate it into a smaller self standing and tested commit. Closes #2020 --- .../runtime/compress/utils/HashMapLongInt.java | 221 +++++++++++++++++++++ .../compress/util/HashMapLongIntTest.java | 84 ++++++++ 2 files changed, 305 insertions(+) diff --git a/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java new file mode 100644 index 0000000000..8379a06698 --- /dev/null +++ b/src/main/java/org/apache/sysds/runtime/compress/utils/HashMapLongInt.java @@ -0,0 +1,221 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.runtime.compress.utils; + +import java.util.Arrays; +import java.util.Iterator; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; +import org.apache.sysds.runtime.compress.utils.HashMapLongInt.KV; + +public class HashMapLongInt implements Iterable<KV> { + protected static final Log LOG = LogFactory.getLog(HashMapLongInt.class.getName()); + + protected long[][] keys; + protected int[][] values; + protected int size = 0; + + public HashMapLongInt(int arrSize) { + keys = createKeys(arrSize); + values = createValues(arrSize); + } + + public int size() { + return size; + } + + /** + * return -1 if there was no such key. + * + * @param key the key to add + * @param value The value for that key. + * @return -1 if there was no such key, otherwise the value + */ + public int putIfAbsent(long key, int value) { + final int ix = hash(key); + if(keys[ix] == null) + return createBucket(ix, key, value); + else + return addToBucket(ix, key, value); + } + + public int get(long key) { + final int ix = hash(key); + final long[] bucketKeys = keys[ix]; + if(bucketKeys != null) { + for(int i = 0; i < bucketKeys.length; i++) { + if(bucketKeys[i] == key) + return values[ix][i]; + } + } + return -1; + } + + private int addToBucket(int ix, long key, int value) { + final long[] bucketKeys = keys[ix]; + for(int i = 0; i < bucketKeys.length; i++) { + if(bucketKeys[i] == key) + return values[ix][i]; + else if(bucketKeys[i] == -1) { + bucketKeys[i] = key; + values[ix][i] = value; + size++; + return -1; + } + } + return reallocateBucket(ix, key, value); + } + + private int reallocateBucket(int ix, long key, int value) { + final long[] bucketKeys = keys[ix]; + final int len = bucketKeys.length; + + // there was no match in the bucket + // reallocate bucket. + long[] newBucketKeys = new long[len * 2]; + int[] newBucketValues = new int[len * 2]; + System.arraycopy(bucketKeys, 0, newBucketKeys, 0, len); + System.arraycopy(values[ix], 0, newBucketValues, 0, len); + Arrays.fill(newBucketKeys, len + 1, newBucketKeys.length, -1L); + newBucketKeys[len] = key; + newBucketValues[len] = value; + + keys[ix] = newBucketKeys; + values[ix] = newBucketValues; + + size++; + return -1; + } + + private int createBucket(int ix, long key, int value) { + keys[ix] = new long[4]; + values[ix] = new int[4]; + keys[ix][0] = key; + values[ix][0] = value; + keys[ix][1] = -1; + keys[ix][2] = -1; + keys[ix][3] = -1; + size++; + return -1; + } + + protected long[][] createKeys(int size) { + return new long[size][]; + } + + protected int[][] createValues(int size) { + return new int[size][]; + } + + protected int hash(long key) { + return Long.hashCode(key) % keys.length; + } + + @Override + public Iterator<KV> iterator() { + return new Itt(); + } + + private class Itt implements Iterator<KV> { + + private final int lastBucket; + private final int lastCell; + private int bucketId = 0; + private int bucketCell = 0; + private KV tmp = new KV(-1, -1); + + protected Itt() { + if(size == 0) { + lastBucket = 0; + lastCell = 0; + } + else { + int tmpLastBucket = keys.length - 1; + long[] bucket = keys[tmpLastBucket]; + while((bucket = keys[tmpLastBucket]) == null) { + tmpLastBucket--; + } + int tmpLastCell = bucket.length - 1; + while(bucket[tmpLastCell] == -1) { + tmpLastCell--; + } + lastBucket = tmpLastBucket; + lastCell = tmpLastCell; + } + } + + @Override + public boolean hasNext() { + return bucketId < lastBucket || (bucketId == lastBucket && bucketCell <= lastCell); + } + + @Override + public KV next() { + long[] bucket = keys[bucketId]; + if(bucket != null && (bucketCell >= bucket.length || bucket[bucketCell] == -1)) { + bucketId++; + bucketCell = 0; + } + while((bucket = keys[bucketId]) == null) { + bucket = keys[bucketId++]; + } + + tmp.set(bucket[bucketCell], values[bucketId][bucketCell]); + bucketCell++; + return tmp; + } + + } + + public class KV { + public long k; + public int v; + + private KV(long k, int v) { + this.k = k; + this.v = v; + } + + protected KV set(long k, int v) { + this.k = k; + this.v = v; + return this; + } + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(this.getClass().getSimpleName()); + sb.append(" "); + for(int i = 0; i < keys.length; i++) { + if(keys[i] != null) { + sb.append(String.format("\nB:%d: ", i)); + for(int j = 0; j < keys[i].length; j++) { + if(keys[i][j] != -1) + sb.append(String.format("%d->%d, ", keys[i][j], values[i][j])); + } + } + } + return sb.delete(sb.length() - 2, sb.length()).toString(); + } + +} diff --git a/src/test/java/org/apache/sysds/test/component/compress/util/HashMapLongIntTest.java b/src/test/java/org/apache/sysds/test/component/compress/util/HashMapLongIntTest.java new file mode 100644 index 0000000000..404380816a --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/compress/util/HashMapLongIntTest.java @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.compress.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +import java.util.HashSet; +import java.util.Set; + +import org.apache.sysds.runtime.compress.utils.HashMapLongInt; +import org.apache.sysds.runtime.compress.utils.HashMapLongInt.KV; +import org.junit.Test; + +public class HashMapLongIntTest { + + @Test + public void add1() { + addSize(new HashMapLongInt(1)); + } + + @Test + public void add2() { + addSize(new HashMapLongInt(2)); + } + + @Test + public void add4() { + addSize(new HashMapLongInt(4)); + } + + @Test + public void add10() { + addSize(new HashMapLongInt(10)); + } + + @Test + public void add100() { + addSize(new HashMapLongInt(100)); + } + + public void addSize(HashMapLongInt a) { + int r = a.putIfAbsent(1, 1); + assertEquals(-1, r); + int r2 = a.putIfAbsent(1, 1); + assertEquals(1, r2); + for(int i = 2; i < 10; i++) { + + a.putIfAbsent(i, i); + } + assertEquals(9, a.size()); + assertEquals(9, a.putIfAbsent(9, 9)); + Set<Long> s = new HashSet<>(); + Set<Integer> v = new HashSet<>(); + for(KV k : a) { + s.add(k.k); + v.add(k.v); + } + for(int i = 1; i < 10; i++) { + assertTrue(s.contains(Long.valueOf(i))); + assertTrue(v.contains(Integer.valueOf(i))); + } + assertEquals(9, s.size()); + assertEquals(4, a.get(4)); + assertEquals(-1, a.get(13)); + } +}