This is an automated email from the ASF dual-hosted git repository.

baunsgaard pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 3e6af1b814 [SYSTEMDS-3663] Low overhead join indexes
3e6af1b814 is described below

commit 3e6af1b814bf2c71e89d79a6ca4f88fb71608ebe
Author: Sebastian Baunsgaard <[email protected]>
AuthorDate: Sun Jan 7 17:06:30 2024 +0100

    [SYSTEMDS-3663] Low overhead join indexes
    
    This commit adds a few more variations to indexes to allow
    efficient combination and ordering of column indexes when co-coding.
    This is critical in cases where thousands of columns are combined,
    since the execution time suddenly is dominated not by combining columns
    but the column indexes.
    
    Closes #1979
---
 .../compress/colgroup/indexes/AColIndex.java       |  56 ++++-
 .../compress/colgroup/indexes/ColIndexFactory.java |   2 +
 .../compress/colgroup/indexes/CombinedIndex.java   | 246 +++++++++++++++++++++
 .../compress/colgroup/indexes/IColIndex.java       |  80 ++++++-
 .../compress/colgroup/indexes/RangeIndex.java      |  84 ++++---
 .../compress/colgroup/indexes/TwoRangesIndex.java  |   4 +-
 6 files changed, 437 insertions(+), 35 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
index df4685a65d..81a5f5b480 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/AColIndex.java
@@ -21,6 +21,8 @@ package org.apache.sysds.runtime.compress.colgroup.indexes;
 
 import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.data.SparseBlockCSR;
 
 public abstract class AColIndex implements IColIndex {
 
@@ -69,11 +71,55 @@ public abstract class AColIndex implements IColIndex {
 
        @Override
        public boolean containsAny(IColIndex idx) {
-               final IIterate it = idx.iterator();
-               while(it.hasNext())
-                       if(contains(it.next()))
-                               return true;
+               if(idx instanceof TwoRangesIndex){
+                       TwoRangesIndex o = (TwoRangesIndex) idx;
+                       return this.containsAny(o.idx1) || 
this.containsAny(o.idx2);
+               }
+               else if(idx instanceof CombinedIndex){
+                       CombinedIndex ci = (CombinedIndex) idx;
+                       return containsAny(ci.l) || containsAny(ci.r);
+               }
+               else{
+                       final IIterate it = idx.iterator();
+                       while(it.hasNext())
+                               if(contains(it.next()))
+                                       return true;
+       
+                       return false;
+               }
+       }
 
-               return false;
+       @Override
+       public void decompressToDenseFromSparse(SparseBlock sb, int vr, int 
off, double[] c) {
+               if(sb instanceof SparseBlockCSR)
+                       decompressToDenseFromSparseCSR((SparseBlockCSR)sb, vr, 
off, c);
+               else
+                       decompressToDenseFromSparseGeneric(sb, vr, off, c);
+       }
+
+       private void decompressToDenseFromSparseGeneric(SparseBlock sb, int vr, 
int off, double[] c) {
+               if(sb.isEmpty(vr))
+                       return;
+               final int apos = sb.pos(vr);
+               final int alen = sb.size(vr) + apos;
+               final int[] aix = sb.indexes(vr);
+               final double[] aval = sb.values(vr);
+               for(int j = apos; j < alen; j++)
+                       c[off + get(aix[j])] += aval[j];
+       }
+
+       private void decompressToDenseFromSparseCSR(SparseBlockCSR sb, int vr, 
int off, double[] c) {
+               final int apos = sb.pos(vr);
+               final int alen = sb.size(vr) + apos;
+               final int[] aix = sb.indexes(vr);
+               final double[] aval = sb.values(vr);
+               for(int j = apos; j < alen; j++)
+                       c[off + get(aix[j])] += aval[j];
+       }
+
+       @Override
+       public void decompressVec(int nCol, double[] c, int off, double[] 
values, int rowIdx) {
+               for(int j = 0; j < nCol; j++)
+                       c[off + get(j)] += values[rowIdx + j];
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
index fd929b8a1a..c9a45e4aee 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/ColIndexFactory.java
@@ -48,6 +48,8 @@ public interface ColIndexFactory {
                                return RangeIndex.read(in);
                        case TWORANGE: 
                                return TwoRangesIndex.read(in);
+                       case COMBINED:
+                               return CombinedIndex.read(in);
                        default:
                                throw new DMLCompressionException("Failed 
reading column index of type: " + t);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
new file mode 100644
index 0000000000..f1a80a6d27
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/CombinedIndex.java
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.compress.colgroup.indexes;
+
+import java.io.DataInput;
+import java.io.DataOutput;
+import java.io.IOException;
+
+import org.apache.sysds.runtime.compress.DMLCompressionException;
+
+public class CombinedIndex extends AColIndex {
+       protected final IColIndex l;
+       protected final IColIndex r;
+
+       public CombinedIndex(IColIndex l, IColIndex r) {
+               this.l = l;
+               this.r = r;
+       }
+
+       @Override
+       public int size() {
+               return l.size() + r.size();
+       }
+
+       @Override
+       public int get(int i) {
+               if(i >= l.size())
+                       return r.get(i - l.size());
+               else
+                       return l.get(i);
+       }
+
+       @Override
+       public IColIndex shift(int i) {
+               return new CombinedIndex(l.shift(i), r.shift(i));
+       }
+
+       @Override
+       public void write(DataOutput out) throws IOException {
+               out.write(ColIndexType.COMBINED.ordinal());
+               l.write(out);
+               r.write(out);
+       }
+
+       @Override
+       public long getExactSizeOnDisk() {
+               return 1 + l.getExactSizeOnDisk() + r.getExactSizeOnDisk();
+       }
+
+       @Override
+       public long estimateInMemorySize() {
+               return 16 + 8 + 8 + l.estimateInMemorySize() + 
r.estimateInMemorySize();
+       }
+
+       @Override
+       public IIterate iterator() {
+               return new CombinedIterator();
+       }
+
+       @Override
+       public int findIndex(int i) {
+               final int a = l.findIndex(i);
+               if(a < 0) {
+                       final int b = r.findIndex(i);
+                       if(b < 0)
+                               return b + a + 1;
+                       else
+                               return b + l.size();
+               }
+               else
+                       return a;
+       }
+
+       @Override
+       public SliceResult slice(int l, int u) {
+               return getArrayIndex().slice(l, u);
+       }
+
+       @Override
+       public boolean equals(IColIndex other) {
+               if(other == this)
+                       return true;
+               else if(size() == other.size()) {
+                       if(other instanceof CombinedIndex) {
+                               CombinedIndex o = (CombinedIndex) other;
+                               return o.l.equals(l) && o.r.equals(r);
+                       }
+                       else {
+                               IIterate t = iterator();
+                               IIterate o = other.iterator();
+
+                               while(t.hasNext()) {
+                                       if(t.next() != o.next())
+                                               return false;
+                               }
+                               return true;
+                       }
+               }
+               return false;
+       }
+
+       @Override
+       public IColIndex combine(IColIndex other) {
+               final int sr = other.size();
+               final int sl = size();
+               final int maxCombined = Math.max(this.get(this.size() - 1), 
other.get(other.size() - 1));
+               final int minCombined = Math.min(this.get(0), other.get(0));
+               if(sr + sl == maxCombined - minCombined + 1) {
+                       return new RangeIndex(minCombined, maxCombined + 1);
+               }
+
+               final int[] ret = new int[sr + sl];
+               IIterate t = iterator();
+               IIterate o = other.iterator();
+               int i = 0;
+               while(t.hasNext() && o.hasNext()) {
+                       final int tv = t.v();
+                       final int ov = o.v();
+                       if(tv < ov) {
+                               ret[i++] = tv;
+                               t.next();
+                       }
+                       else {
+                               ret[i++] = ov;
+                               o.next();
+                       }
+               }
+               while(t.hasNext())
+                       ret[i++] = t.next();
+               while(o.hasNext())
+                       ret[i++] = o.next();
+
+               return ColIndexFactory.create(ret);
+
+       }
+
+       @Override
+       public boolean isContiguous() {
+               return false;
+       }
+
+       @Override
+       public int[] getReorderingIndex() {
+               return getArrayIndex().getReorderingIndex();
+       }
+
+       @Override
+       public boolean isSorted() {
+               return true;
+       }
+
+       @Override
+       public IColIndex sort() {
+               throw new DMLCompressionException("CombinedIndex is always 
sorted");
+       }
+
+       @Override
+       public boolean contains(int i) {
+               return l.contains(i) || r.contains(i);
+       }
+
+       @Override
+       public double avgOfIndex() {
+               double lv = l.avgOfIndex() * l.size();
+               double rv = r.avgOfIndex() * r.size();
+               return (lv + rv) / size();
+       }
+
+       private IColIndex getArrayIndex() {
+               int s = size();
+               int[] vals = new int[s];
+               IIterate a = iterator();
+               for(int i = 0; i < s; i++) {
+                       vals[i] = a.next();
+               }
+               return ColIndexFactory.create(vals);
+       }
+
+       public static CombinedIndex read(DataInput in) throws IOException {
+               return new CombinedIndex(ColIndexFactory.read(in), 
ColIndexFactory.read(in));
+       }
+
+       @Override
+       public String toString() {
+               StringBuilder sb = new StringBuilder();
+               sb.append(this.getClass().getSimpleName());
+               sb.append("[");
+               sb.append(l);
+               sb.append(", ");
+               sb.append(r);
+               sb.append("]");
+               return sb.toString();
+       }
+
+       protected class CombinedIterator implements IIterate {
+               boolean doneFirst = false;
+               IIterate I = l.iterator();
+
+               @Override
+               public int next() {
+                       int v = I.next();
+                       if(!I.hasNext() && !doneFirst) {
+                               doneFirst = true;
+                               I = r.iterator();
+                       }
+                       return v;
+
+               }
+
+               @Override
+               public boolean hasNext() {
+                       return I.hasNext() || doneFirst == false;
+               }
+
+               @Override
+               public int v() {
+                       return I.v();
+               }
+
+               @Override
+               public int i() {
+                       if(doneFirst) 
+                               return I.i() + l.size();
+                       else
+                               return I.i();
+               }
+       }
+
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
index 60c2cec4b2..8da8ad518f 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
@@ -22,13 +22,16 @@ package org.apache.sysds.runtime.compress.colgroup.indexes;
 import java.io.DataOutput;
 import java.io.IOException;
 
+import org.apache.sysds.runtime.data.SparseBlock;
+import org.apache.sysds.runtime.matrix.data.Pair;
+
 /**
  * Class to contain column indexes for the compression column groups.
  */
 public interface IColIndex {
 
        public static enum ColIndexType {
-               SINGLE, TWO, ARRAY, RANGE, TWORANGE, UNKNOWN;
+               SINGLE, TWO, ARRAY, RANGE, TWORANGE, COMBINED, UNKNOWN;
        }
 
        /**
@@ -212,6 +215,76 @@ public interface IColIndex {
         */
        public double avgOfIndex();
 
+       /**
+        * Decompress this
+        */
+       /**
+        * Decompress this column index into the dense c array.
+        * 
+        * @param sb  A sparse block to extract values out of and insert into c
+        * @param vr  The row to extract from the sparse block
+        * @param off The offset that the row starts at in c.
+        * @param c   The dense output to decompress into
+        */
+       public void decompressToDenseFromSparse(SparseBlock sb, int vr, int 
off, double[] c);
+
+       /**
+        * Decompress into c using the values provided. The offset to start 
into c is off and then row index is similarly the
+        * offset of values. nCol specify the number of values to add over.
+        * 
+        * @param nCol   The number of columns to copy.
+        * @param c      The output to add into
+        * @param off    The offset to start in c
+        * @param values the values to copy from
+        * @param rowIdx The offset to start in values
+        */
+       public void decompressVec(int nCol, double[] c, int off, double[] 
values, int rowIdx);
+
+       /**
+        * Indicate if the two given column indexes are in order such that the 
first set of indexes all are of lower value
+        * than the second.
+        * 
+        * @param a the first column index
+        * @param b the second column index
+        * @return If the first all is lower than the second.
+        */
+       public static boolean inOrder(IColIndex a, IColIndex b) {
+               return a.get(a.size() - 1) < b.get(0);
+       }
+
+       public static Pair<int[], int[]> reorderingIndexes(IColIndex a, 
IColIndex b){
+               final int[] ar = new int[a.size()];
+               final int[] br = new int[b.size()];
+               final IIterate ai = a.iterator();
+               final IIterate bi = b.iterator();
+
+               int ia = 0;
+               int ib = 0;
+               int i = 0;
+               while(ai.hasNext() && bi.hasNext()){
+                       if(ai.v()< bi.v()){
+                               ar[ia++] = i++;
+                               ai.next();
+                       }
+                       else{
+                               br[ib++] = i++;
+                               bi.next();
+                       }
+               }
+
+               while(ai.hasNext()){
+                               ar[ia++] = i++;
+                       ai.next();
+               }
+
+               while(bi.hasNext()){
+                               br[ib++] = i++;
+                       bi.next();
+               }
+
+               return new Pair<int[],int[]>(ar, br);
+       }
+
        /** A Class for slice results containing indexes for the slicing of 
dictionaries, and the resulting column index */
        public static class SliceResult {
                /** Start index to slice inside the dictionary */
@@ -223,9 +296,10 @@ public interface IColIndex {
 
                /**
                 * The slice result
+                * 
                 * @param idStart The starting index
-                * @param idEnd The ending index (not inclusive)
-                * @param ret The resulting IColIndex
+                * @param idEnd   The ending index (not inclusive)
+                * @param ret     The resulting IColIndex
                 */
                protected SliceResult(int idStart, int idEnd, IColIndex ret) {
                        this.idStart = idStart;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
index bbe5aeb8a5..17c2bed3ba 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/RangeIndex.java
@@ -24,6 +24,7 @@ import java.io.DataOutput;
 import java.io.IOException;
 import java.util.Arrays;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.sysds.runtime.compress.DMLCompressionException;
 import org.apache.sysds.runtime.compress.utils.IntArrayList;
 
@@ -133,7 +134,7 @@ public class RangeIndex extends AColIndex {
                        int minU = Math.min(u, this.u);
                        int offL = maxL - this.l;
                        int offR = minU - this.l;
-                       return new SliceResult(offL, offR, new RangeIndex(maxL 
- l, minU - l ));
+                       return new SliceResult(offL, offR, new RangeIndex(maxL 
- l, minU - l));
                }
        }
 
@@ -147,16 +148,40 @@ public class RangeIndex extends AColIndex {
                        return other.equals(this);
        }
 
+       @Override
+       public boolean containsAny(IColIndex idx) {
+               if(idx instanceof RangeIndex) {
+                       RangeIndex o = (RangeIndex) idx;
+                       if(o.l >= u)
+                               return false;
+                       else if(o.u <= l)
+                               return false;
+                       else if(o.l <= l && o.u > l)
+                               return true;
+                       else if(o.l < u && o.u > u)
+                               return true;
+                       else
+                               throw new NotImplementedException(idx + " " + 
this);
+               }
+               else
+                       return super.containsAny(idx);
+       }
+
        @Override
        public IColIndex combine(IColIndex other) {
+               final int sr = other.size();
                if(other.size() == 1) {
                        int v = other.get(0);
                        if(v + 1 == l)
                                return new RangeIndex(l - 1, u);
                        else if(v == u)
                                return new RangeIndex(l, u + 1);
+                       else if(v < l)
+                               return new CombinedIndex(other, this);
+                       else
+                               return new CombinedIndex(this, other);
                }
-               if(other instanceof RangeIndex) {
+               else if(other instanceof RangeIndex) {
                        if(other.get(0) == u)
                                return new RangeIndex(l, other.get(other.size() 
- 1) + 1);
                        else if(other.get(other.size() - 1) == l - 1)
@@ -166,31 +191,40 @@ public class RangeIndex extends AColIndex {
                        else
                                return new TwoRangesIndex(this, (RangeIndex) 
other);
                }
-
-               final int sr = other.size();
-               final int sl = size();
-               final int[] ret = new int[sr + sl];
-
-               int pl = 0;
-               int pr = 0;
-               int i = 0;
-               while(pl < sl && pr < sr) {
-                       final int vl = get(pl);
-                       final int vr = other.get(pr);
-                       if(vl < vr) {
-                               ret[i++] = vl;
-                               pl++;
-                       }
-                       else {
-                               ret[i++] = vr;
-                               pr++;
+               else if(other.get(sr - 1) < l) {
+                       return new CombinedIndex(other, this);
+               }
+               else if(other.get(0) > u) {
+                       return new CombinedIndex(this, other);
+               }
+               else {
+                       // final int sr = other.size();
+                       final int sl = size();
+                       final int[] ret = new int[sr + sl];
+
+                       int pl = 0;
+                       int pr = 0;
+                       int i = 0;
+                       while(pl < sl && pr < sr) {
+                               final int vl = get(pl);
+                               final int vr = other.get(pr);
+                               if(vl < vr) {
+                                       ret[i++] = vl;
+                                       pl++;
+                               }
+                               else {
+                                       ret[i++] = vr;
+                                       pr++;
+                               }
                        }
+                       while(pl < sl)
+                               ret[i++] = get(pl++);
+                       while(pr < sr)
+                               ret[i++] = other.get(pr++);
+                       return ColIndexFactory.create(ret);
+
                }
-               while(pl < sl)
-                       ret[i++] = get(pl++);
-               while(pr < sr)
-                       ret[i++] = other.get(pr++);
-               return ColIndexFactory.create(ret);
+
        }
 
        @Override
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
index 51634b9269..f1c27d4415 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/TwoRangesIndex.java
@@ -28,9 +28,9 @@ import 
org.apache.sysds.runtime.compress.DMLCompressionException;
 public class TwoRangesIndex extends AColIndex {
 
        /** The lower index range */
-       private final RangeIndex idx1;
+       protected final RangeIndex idx1;
        /** The upper index range */
-       private final RangeIndex idx2;
+       protected final RangeIndex idx2;
 
        public TwoRangesIndex(RangeIndex lower, RangeIndex higher) {
                this.idx1 = lower;

Reply via email to