This is an automated email from the ASF dual-hosted git repository.
arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 3764784 [SYSTEMDS-3236] Cache-friendly Apply phase for dense target
matrix
3764784 is described below
commit 3764784f5d2ca7098fbde4f85bef4ff3244e1b64
Author: arnabp <[email protected]>
AuthorDate: Fri Dec 3 16:19:55 2021 +0100
[SYSTEMDS-3236] Cache-friendly Apply phase for dense target matrix
This patch adds loop-tiling logic to the apply phase of transformencode
to exploit CPU caches. Currently, the changes are limited to dense
matrices.
Loop-tiling shows 2x performance improvement in recoding a frame
having 5M rows, 100 columens (100K unique in each) and w/ 32 threads.
---
.../runtime/transform/encode/ColumnEncoder.java | 16 ++++++++++++++-
.../runtime/transform/encode/ColumnEncoderBin.java | 23 ++++++++++++++++++++++
.../transform/encode/ColumnEncoderComposite.java | 5 +++++
.../transform/encode/ColumnEncoderDummycode.java | 5 +++++
.../transform/encode/ColumnEncoderFeatureHash.java | 16 +++++++++++++++
.../transform/encode/ColumnEncoderPassThrough.java | 9 +++++++++
.../transform/encode/ColumnEncoderRecode.java | 22 ++++++++++++---------
7 files changed, 86 insertions(+), 10 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 0706cb4..d82d0a9 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -116,6 +116,8 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
protected abstract double getCode(CacheBlock in, int row);
+ protected abstract double[] getCodeCol(CacheBlock in, int startInd, int
blkSize);
+
protected void applySparse(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
int index = _colID - 1;
@@ -126,10 +128,22 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
}
}
- protected void applyDense(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
+ /*protected void applyDense(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
for(int i = rowStart; i < getEndIndex(in.getNumRows(),
rowStart, blk); i++) {
out.quickSetValue(i, outputCol, getCode(in, i));
}
+ }*/
+
+ protected void applyDense(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
+ // Apply loop tiling to exploit CPU caches
+ double[] codes = getCodeCol(in, rowStart, blk);
+ int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
+ int B = 32; //tile size
+ for(int i = rowStart; i < rowEnd; i+=B) {
+ int lim = Math.min(i+B, rowEnd);
+ for (int ii=i; ii<lim; ii++)
+ out.quickSetValue(ii, outputCol,
codes[ii-rowStart]);
+ }
}
protected abstract TransformType getTransformType();
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
index ab9f662..d802254 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBin.java
@@ -95,6 +95,7 @@ public class ColumnEncoderBin extends ColumnEncoder {
}
protected double getCode(CacheBlock in, int row){
+ // find the right bucket for a single row
if( _binMins.length == 0 || _binMaxs.length == 0 ) {
LOG.warn("ColumnEncoderBin: applyValue without bucket
boundaries, assign 1");
return 1; //robustness in case of missing bins
@@ -107,6 +108,28 @@ public class ColumnEncoderBin extends ColumnEncoder {
return ((ix < 0) ? Math.abs(ix + 1) : ix) + 1;
}
+
+ @Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ // find the right bucket for a block of rows
+ int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
+ double codes[] = new double[endInd-startInd];
+ for (int i=startInd; i<endInd; i++) {
+ if (_binMins.length == 0 || _binMaxs.length == 0) {
+ LOG.warn("ColumnEncoderBin: applyValue without
bucket boundaries, assign 1");
+ codes[i-startInd] = 1; //robustness in case of
missing bins
+ continue;
+ }
+ double inVal = in.getDoubleNaN(i, _colID - 1);
+ if (Double.isNaN(inVal) || inVal < _binMins[0] || inVal
> _binMaxs[_binMaxs.length-1]) {
+ codes[i-startInd] = Double.NaN;
+ continue;
+ }
+ int ix = Arrays.binarySearch(_binMaxs, inVal);
+ codes[i-startInd] = ((ix < 0) ? Math.abs(ix + 1) : ix)
+ 1;
+ }
+ return codes;
+ }
@Override
protected TransformType getTransformType() {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 1f5fce6..a4a9563 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -208,6 +208,11 @@ public class ColumnEncoderComposite extends ColumnEncoder {
}
@Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ throw new DMLRuntimeException("CompositeEncoder does not have a
Code");
+ }
+
+ @Override
protected TransformType getTransformType() {
return TransformType.N_A;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
index 1ea4933..27fe6fd 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderDummycode.java
@@ -75,6 +75,11 @@ public class ColumnEncoderDummycode extends ColumnEncoder {
throw new DMLRuntimeException("DummyCoder does not have a
code");
}
+ @Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ throw new DMLRuntimeException("DummyCoder does not have a
code");
+ }
+
protected void applySparse(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
if (!(in instanceof MatrixBlock)){
throw new DMLRuntimeException("ColumnEncoderDummycode
called with: " + in.getClass().getSimpleName() +
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
index 922b5a8..f7d9033 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderFeatureHash.java
@@ -19,6 +19,7 @@
package org.apache.sysds.runtime.transform.encode;
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
@@ -65,12 +66,27 @@ public class ColumnEncoderFeatureHash extends ColumnEncoder
{
@Override
protected double getCode(CacheBlock in, int row) {
+ // hash a single row
String key = in.getString(row, _colID - 1);
if(key == null)
return Double.NaN;
return (key.hashCode() % _K) + 1;
}
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ // hash a block of rows
+ int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
+ double codes[] = new double[endInd-startInd];
+ for (int i=startInd; i<endInd; i++) {
+ String key = in.getString(i, _colID - 1);
+ if(key == null || key.isEmpty())
+ codes[i-startInd] = Double.NaN;
+ else
+ codes[i-startInd] = (key.hashCode() % _K) + 1;
+ }
+ return codes;
+ }
+
@Override
public void build(CacheBlock in) {
// do nothing (no meta data other than K)
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
index bc20fe3..e249159 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderPassThrough.java
@@ -65,6 +65,15 @@ public class ColumnEncoderPassThrough extends ColumnEncoder {
return in.getDoubleNaN(row, _colID - 1);
}
+ @Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
+ double codes[] = new double[endInd-startInd];
+ for (int i=startInd; i<endInd; i++) {
+ codes[i-startInd] = in.getDoubleNaN(i, _colID-1);
+ }
+ return codes;
+ }
protected void applySparse(CacheBlock in, MatrixBlock out, int
outputCol, int rowStart, int blk){
Set<Integer> sparseRowsWZeros = null;
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index 8246f97..14c0b61 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -171,6 +171,7 @@ public class ColumnEncoderRecode extends ColumnEncoder {
}
protected double getCode(CacheBlock in, int r){
+ // lookup for a single row
Object okey = in.getString(r, _colID - 1);
String key = (okey != null) ? okey.toString() : null;
if(key == null || key.isEmpty())
@@ -179,16 +180,19 @@ public class ColumnEncoderRecode extends ColumnEncoder {
return (code < 0) ? Double.NaN : code;
}
- protected double[] getCodeCol(CacheBlock in) {
- Object[] coldata = (Object[])
((FrameBlock)in).getColumnData(_colID-1);
- double codes[] = new double[in.getNumRows()];
- for (int i=0; i<coldata.length; i++) {
- Object okey = coldata[i];
- String key = (okey != null) ? okey.toString() : null;
- if(key == null || key.isEmpty())
- codes[i] = Double.NaN;
+ @Override
+ protected double[] getCodeCol(CacheBlock in, int startInd, int blkSize)
{
+ // lookup for a block of rows
+ int endInd = getEndIndex(in.getNumRows(), startInd, blkSize);
+ double codes[] = new double[endInd-startInd];
+ for (int i=startInd; i<endInd; i++) {
+ String key = in.getString(i, _colID-1);
+ if(key == null || key.isEmpty()) {
+ codes[i-startInd] = Double.NaN;
+ continue;
+ }
long code = lookupRCDMap(key);
- codes[i] = code;
+ codes[i-startInd] = (code < 0) ? Double.NaN : code;
}
return codes;
}