This is an automated email from the ASF dual-hosted git repository.
baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3e6af1b814 [SYSTEMDS-3663] Low overhead join indexes
3e6af1b814 is described below
commit 3e6af1b814bf2c71e89d79a6ca4f88fb71608ebe
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Jan 7 17:06:30 2024 +0100
[SYSTEMDS-3663] Low overhead join indexes
This commit adds a few more variations to indexes to allow
efficient combination and ordering of column indexes when co-coding.
This is critical in cases where thousands of columns are combined,
since the execution time suddenly is dominated not by combining columns
but the column indexes.
Closes #1979
---
.../compress/colgroup/indexes/AColIndex.java | 56 ++++-
.../compress/colgroup/indexes/ColIndexFactory.java | 2 +
.../compress/colgroup/indexes/CombinedIndex.java | 246 +++++++++++++++++++++
.../compress/colgroup/indexes/IColIndex.java | 80 ++++++-
.../compress/colgroup/indexes/RangeIndex.java | 84 ++++---
.../compress/colgroup/indexes/TwoRangesIndex.java | 4 +-
6 files changed, 437 insertions(+), 35 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
index df4685a65d..81a5f5b480 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.compress.colgroup.indexes;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
public abstract class AColIndex implements IColIndex {
@@ -69,11 +71,55 @@ public abstract class AColIndex implements IColIndex {
@Override
public boolean containsAny(IColIndex idx) {
- final IIterate it = idx.iterator();
- while(it.hasNext())
- if(contains(it.next()))
- return true;
+ if(idx instanceof TwoRangesIndex){
+ TwoRangesIndex o = (TwoRangesIndex) idx;
+ return this.containsAny(o.idx1) ||
this.containsAny(o.idx2);
+ }
+ else if(idx instanceof CombinedIndex){
+ CombinedIndex ci = (CombinedIndex) idx;
+ return containsAny(ci.l) || containsAny(ci.r);
+ }
+ else{
+ final IIterate it = idx.iterator();
+ while(it.hasNext())
+ if(contains(it.next()))
+ return true;
+
+ return false;
+ }
+ }
- return false;
+ @Override
+ public void decompressToDenseFromSparse(SparseBlock sb, int vr, int
off, double[] c) {
+ if(sb instanceof SparseBlockCSR)
+ decompressToDenseFromSparseCSR((SparseBlockCSR)sb, vr,
off, c);
+ else
+ decompressToDenseFromSparseGeneric(sb, vr, off, c);
+ }
+
+ private void decompressToDenseFromSparseGeneric(SparseBlock sb, int vr,
int off, double[] c) {
+ if(sb.isEmpty(vr))
+ return;
+ final int apos = sb.pos(vr);
+ final int alen = sb.size(vr) + apos;
+ final int[] aix = sb.indexes(vr);
+ final double[] aval = sb.values(vr);
+ for(int j = apos; j < alen; j++)
+ c[off + get(aix[j])] += aval[j];
+ }
+
+ private void decompressToDenseFromSparseCSR(SparseBlockCSR sb, int vr,
int off, double[] c) {
+ final int apos = sb.pos(vr);
+ final int alen = sb.size(vr) + apos;
+ final int[] aix = sb.indexes(vr);
+ final double[] aval = sb.values(vr);
+ for(int j = apos; j < alen; j++)
+ c[off + get(aix[j])] += aval[j];
+ }
+
+ @Override
+ public void decompressVec(int nCol, double[] c, int off, double[]
values, int rowIdx) {
+ for(int j = 0; j < nCol; j++)
+ c[off + get(j)] += values[rowIdx + j];
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
index fd929b8a1a..c9a45e4aee 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
@@ -48,6 +48,8 @@ public interface ColIndexFactory {
return RangeIndex.read(in);
case TWORANGE:
return TwoRangesIndex.read(in);
+ case COMBINED:
+ return CombinedIndex.read(in);
default:
throw new DMLCompressionException("Failed
reading column index of type: " + t);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
new file mode 100644
index 0000000000..f1a80a6d27
--- /dev/null
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.indexes;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+
+public class CombinedIndex extends AColIndex {
+ protected final IColIndex l;
+ protected final IColIndex r;
+
+ public CombinedIndex(IColIndex l, IColIndex r) {
+ this.l = l;
+ this.r = r;
+ }
+
+ @Override
+ public int size() {
+ return l.size() + r.size();
+ }
+
+ @Override
+ public int get(int i) {
+ if(i >= l.size())
+ return r.get(i - l.size());
+ else
+ return l.get(i);
+ }
+
+ @Override
+ public IColIndex shift(int i) {
+ return new CombinedIndex(l.shift(i), r.shift(i));
+ }
+
+ @Override
+ public void write(DataOutput out) throws IOException {
+ out.write(ColIndexType.COMBINED.ordinal());
+ l.write(out);
+ r.write(out);
+ }
+
+ @Override
+ public long getExactSizeOnDisk() {
+ return 1 + l.getExactSizeOnDisk() + r.getExactSizeOnDisk();
+ }
+
+ @Override
+ public long estimateInMemorySize() {
+ return 16 + 8 + 8 + l.estimateInMemorySize() +
r.estimateInMemorySize();
+ }
+
+ @Override
+ public IIterate iterator() {
+ return new CombinedIterator();
+ }
+
+ @Override
+ public int findIndex(int i) {
+ final int a = l.findIndex(i);
+ if(a < 0) {
+ final int b = r.findIndex(i);
+ if(b < 0)
+ return b + a + 1;
+ else
+ return b + l.size();
+ }
+ else
+ return a;
+ }
+
+ @Override
+ public SliceResult slice(int l, int u) {
+ return getArrayIndex().slice(l, u);
+ }
+
+ @Override
+ public boolean equals(IColIndex other) {
+ if(other == this)
+ return true;
+ else if(size() == other.size()) {
+ if(other instanceof CombinedIndex) {
+ CombinedIndex o = (CombinedIndex) other;
+ return o.l.equals(l) && o.r.equals(r);
+ }
+ else {
+ IIterate t = iterator();
+ IIterate o = other.iterator();
+
+ while(t.hasNext()) {
+ if(t.next() != o.next())
+ return false;
+ }
+ return true;
+ }
+ }
+ return false;
+ }
+
+ @Override
+ public IColIndex combine(IColIndex other) {
+ final int sr = other.size();
+ final int sl = size();
+ final int maxCombined = Math.max(this.get(this.size() - 1),
other.get(other.size() - 1));
+ final int minCombined = Math.min(this.get(0), other.get(0));
+ if(sr + sl == maxCombined - minCombined + 1) {
+ return new RangeIndex(minCombined, maxCombined + 1);
+ }
+
+ final int[] ret = new int[sr + sl];
+ IIterate t = iterator();
+ IIterate o = other.iterator();
+ int i = 0;
+ while(t.hasNext() && o.hasNext()) {
+ final int tv = t.v();
+ final int ov = o.v();
+ if(tv < ov) {
+ ret[i++] = tv;
+ t.next();
+ }
+ else {
+ ret[i++] = ov;
+ o.next();
+ }
+ }
+ while(t.hasNext())
+ ret[i++] = t.next();
+ while(o.hasNext())
+ ret[i++] = o.next();
+
+ return ColIndexFactory.create(ret);
+
+ }
+
+ @Override
+ public boolean isContiguous() {
+ return false;
+ }
+
+ @Override
+ public int[] getReorderingIndex() {
+ return getArrayIndex().getReorderingIndex();
+ }
+
+ @Override
+ public boolean isSorted() {
+ return true;
+ }
+
+ @Override
+ public IColIndex sort() {
+ throw new DMLCompressionException("CombinedIndex is always
sorted");
+ }
+
+ @Override
+ public boolean contains(int i) {
+ return l.contains(i) || r.contains(i);
+ }
+
+ @Override
+ public double avgOfIndex() {
+ double lv = l.avgOfIndex() * l.size();
+ double rv = r.avgOfIndex() * r.size();
+ return (lv + rv) / size();
+ }
+
+ private IColIndex getArrayIndex() {
+ int s = size();
+ int[] vals = new int[s];
+ IIterate a = iterator();
+ for(int i = 0; i < s; i++) {
+ vals[i] = a.next();
+ }
+ return ColIndexFactory.create(vals);
+ }
+
+ public static CombinedIndex read(DataInput in) throws IOException {
+ return new CombinedIndex(ColIndexFactory.read(in),
ColIndexFactory.read(in));
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder sb = new StringBuilder();
+ sb.append(this.getClass().getSimpleName());
+ sb.append("[");
+ sb.append(l);
+ sb.append(", ");
+ sb.append(r);
+ sb.append("]");
+ return sb.toString();
+ }
+
+ protected class CombinedIterator implements IIterate {
+ boolean doneFirst = false;
+ IIterate I = l.iterator();
+
+ @Override
+ public int next() {
+ int v = I.next();
+ if(!I.hasNext() && !doneFirst) {
+ doneFirst = true;
+ I = r.iterator();
+ }
+ return v;
+
+ }
+
+ @Override
+ public boolean hasNext() {
+ return I.hasNext() || doneFirst == false;
+ }
+
+ @Override
+ public int v() {
+ return I.v();
+ }
+
+ @Override
+ public int i() {
+ if(doneFirst)
+ return I.i() + l.size();
+ else
+ return I.i();
+ }
+ }
+
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
index 60c2cec4b2..8da8ad518f 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
@@ -22,13 +22,16 @@ package org.apache.sysds.runtime.compress.colgroup.indexes;
import java.io.DataOutput;
import java.io.IOException;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.matrix.data.Pair;
+
/**
* Class to contain column indexes for the compression column groups.
*/
public interface IColIndex {
public static enum ColIndexType {
- SINGLE, TWO, ARRAY, RANGE, TWORANGE, UNKNOWN;
+ SINGLE, TWO, ARRAY, RANGE, TWORANGE, COMBINED, UNKNOWN;
}
/**
@@ -212,6 +215,76 @@ public interface IColIndex {
*/
public double avgOfIndex();
+ /**
+ * Decompress this
+ */
+ /**
+ * Decompress this column index into the dense c array.
+ *
+ * @param sb A sparse block to extract values out of and insert into c
+ * @param vr The row to extract from the sparse block
+ * @param off The offset that the row starts at in c.
+ * @param c The dense output to decompress into
+ */
+ public void decompressToDenseFromSparse(SparseBlock sb, int vr, int
off, double[] c);
+
+ /**
+ * Decompress into c using the values provided. The offset to start
into c is off and then row index is similarly the
+ * offset of values. nCol specify the number of values to add over.
+ *
+ * @param nCol The number of columns to copy.
+ * @param c The output to add into
+ * @param off The offset to start in c
+ * @param values the values to copy from
+ * @param rowIdx The offset to start in values
+ */
+ public void decompressVec(int nCol, double[] c, int off, double[]
values, int rowIdx);
+
+ /**
+ * Indicate if the two given column indexes are in order such that the
first set of indexes all are of lower value
+ * than the second.
+ *
+ * @param a the first column index
+ * @param b the second column index
+ * @return If the first all is lower than the second.
+ */
+ public static boolean inOrder(IColIndex a, IColIndex b) {
+ return a.get(a.size() - 1) < b.get(0);
+ }
+
+ public static Pair<int[], int[]> reorderingIndexes(IColIndex a,
IColIndex b){
+ final int[] ar = new int[a.size()];
+ final int[] br = new int[b.size()];
+ final IIterate ai = a.iterator();
+ final IIterate bi = b.iterator();
+
+ int ia = 0;
+ int ib = 0;
+ int i = 0;
+ while(ai.hasNext() && bi.hasNext()){
+ if(ai.v()< bi.v()){
+ ar[ia++] = i++;
+ ai.next();
+ }
+ else{
+ br[ib++] = i++;
+ bi.next();
+ }
+ }
+
+ while(ai.hasNext()){
+ ar[ia++] = i++;
+ ai.next();
+ }
+
+ while(bi.hasNext()){
+ br[ib++] = i++;
+ bi.next();
+ }
+
+ return new Pair<int[],int[]>(ar, br);
+ }
+
/** A Class for slice results containing indexes for the slicing of
dictionaries, and the resulting column index */
public static class SliceResult {
/** Start index to slice inside the dictionary */
@@ -223,9 +296,10 @@ public interface IColIndex {
/**
* The slice result
+ *
* @param idStart The starting index
- * @param idEnd The ending index (not inclusive)
- * @param ret The resulting IColIndex
+ * @param idEnd The ending index (not inclusive)
+ * @param ret The resulting IColIndex
*/
protected SliceResult(int idStart, int idEnd, IColIndex ret) {
this.idStart = idStart;
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
index bbe5aeb8a5..17c2bed3ba 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
@@ -24,6 +24,7 @@ import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
+import org.apache.commons.lang.NotImplementedException;
import org.apache.sysds.runtime.compress.DMLCompressionException;
import org.apache.sysds.runtime.compress.utils.IntArrayList;
@@ -133,7 +134,7 @@ public class RangeIndex extends AColIndex {
int minU = Math.min(u, this.u);
int offL = maxL - this.l;
int offR = minU - this.l;
- return new SliceResult(offL, offR, new RangeIndex(maxL
- l, minU - l ));
+ return new SliceResult(offL, offR, new RangeIndex(maxL
- l, minU - l));
}
}
@@ -147,16 +148,40 @@ public class RangeIndex extends AColIndex {
return other.equals(this);
}
+ @Override
+ public boolean containsAny(IColIndex idx) {
+ if(idx instanceof RangeIndex) {
+ RangeIndex o = (RangeIndex) idx;
+ if(o.l >= u)
+ return false;
+ else if(o.u <= l)
+ return false;
+ else if(o.l <= l && o.u > l)
+ return true;
+ else if(o.l < u && o.u > u)
+ return true;
+ else
+ throw new NotImplementedException(idx + " " +
this);
+ }
+ else
+ return super.containsAny(idx);
+ }
+
@Override
public IColIndex combine(IColIndex other) {
+ final int sr = other.size();
if(other.size() == 1) {
int v = other.get(0);
if(v + 1 == l)
return new RangeIndex(l - 1, u);
else if(v == u)
return new RangeIndex(l, u + 1);
+ else if(v < l)
+ return new CombinedIndex(other, this);
+ else
+ return new CombinedIndex(this, other);
}
- if(other instanceof RangeIndex) {
+ else if(other instanceof RangeIndex) {
if(other.get(0) == u)
return new RangeIndex(l, other.get(other.size()
- 1) + 1);
else if(other.get(other.size() - 1) == l - 1)
@@ -166,31 +191,40 @@ public class RangeIndex extends AColIndex {
else
return new TwoRangesIndex(this, (RangeIndex)
other);
}
-
- final int sr = other.size();
- final int sl = size();
- final int[] ret = new int[sr + sl];
-
- int pl = 0;
- int pr = 0;
- int i = 0;
- while(pl < sl && pr < sr) {
- final int vl = get(pl);
- final int vr = other.get(pr);
- if(vl < vr) {
- ret[i++] = vl;
- pl++;
- }
- else {
- ret[i++] = vr;
- pr++;
+ else if(other.get(sr - 1) < l) {
+ return new CombinedIndex(other, this);
+ }
+ else if(other.get(0) > u) {
+ return new CombinedIndex(this, other);
+ }
+ else {
+ // final int sr = other.size();
+ final int sl = size();
+ final int[] ret = new int[sr + sl];
+
+ int pl = 0;
+ int pr = 0;
+ int i = 0;
+ while(pl < sl && pr < sr) {
+ final int vl = get(pl);
+ final int vr = other.get(pr);
+ if(vl < vr) {
+ ret[i++] = vl;
+ pl++;
+ }
+ else {
+ ret[i++] = vr;
+ pr++;
+ }
}
+ while(pl < sl)
+ ret[i++] = get(pl++);
+ while(pr < sr)
+ ret[i++] = other.get(pr++);
+ return ColIndexFactory.create(ret);
+
}
- while(pl < sl)
- ret[i++] = get(pl++);
- while(pr < sr)
- ret[i++] = other.get(pr++);
- return ColIndexFactory.create(ret);
+
}
@Override
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
index 51634b9269..f1c27d4415 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
@@ -28,9 +28,9 @@ import
org.apache.sysds.runtime.compress.DMLCompressionException;
public class TwoRangesIndex extends AColIndex {
/** The lower index range */
- private final RangeIndex idx1;
+ protected final RangeIndex idx1;
/** The upper index range */
- private final RangeIndex idx2;
+ protected final RangeIndex idx2;
public TwoRangesIndex(RangeIndex lower, RangeIndex higher) {
this.idx1 = lower;