This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new c21273cb09 [SYSTEMDS-3782] Bag-of-words encoder for Spark backend
c21273cb09 is described below
commit c21273cb098906e40e8bf7b3f7d5f32536a342d0
Author: e-strauss <[email protected]>
AuthorDate: Sun Nov 24 14:35:06 2024 +0100
[SYSTEMDS-3782] Bag-of-words encoder for Spark backend
Closes #2145.
---
...ltiReturnParameterizedBuiltinSPInstruction.java | 15 +-
.../spark/ParameterizedBuiltinSPInstruction.java | 8 +
.../runtime/transform/encode/ColumnEncoder.java | 2 +-
.../transform/encode/ColumnEncoderBagOfWords.java | 206 ++++++++----
.../transform/encode/ColumnEncoderComposite.java | 14 +-
.../runtime/transform/encode/EncoderFactory.java | 4 +
.../transform/encode/MultiColumnEncoder.java | 359 +++++++++++++++------
.../ColumnEncoderMixedFunctionalityTests.java | 116 +++++++
.../transform/ColumnEncoderSerializationTest.java | 33 +-
.../transform/TransformFrameEncodeBagOfWords.java | 177 ++++++++--
...dml => TransformFrameEncodeApplyBagOfWords.dml} | 24 +-
.../transform/TransformFrameEncodeBagOfWords.dml | 6 +-
12 files changed, 758 insertions(+), 206 deletions(-)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index 90683eab29..df9fd84f77 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -62,6 +62,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
@@ -263,6 +264,7 @@ public class MultiReturnParameterizedBuiltinSPInstruction
extends ComputationSPI
// encoder-specific outputs
List<ColumnEncoderRecode> raEncoders =
_encoder.getColumnEncoders(ColumnEncoderRecode.class);
List<ColumnEncoderBin> baEncoders =
_encoder.getColumnEncoders(ColumnEncoderBin.class);
+ List<ColumnEncoderBagOfWords> bowEncoders =
_encoder.getColumnEncoders(ColumnEncoderBagOfWords.class);
ArrayList<Tuple2<Integer, Object>> ret = new
ArrayList<>();
// output recode maps as columnID - token pairs
@@ -273,8 +275,14 @@ public class MultiReturnParameterizedBuiltinSPInstruction
extends ComputationSPI
for(Entry<Integer, HashSet<Object>> e1 :
tmp.entrySet())
for(Object token : e1.getValue())
ret.add(new
Tuple2<>(e1.getKey(), token));
- if(!raEncoders.isEmpty())
- raEncoders.forEach(columnEncoderRecode
-> columnEncoderRecode.getCPRecodeMapsPartial().clear());
+ raEncoders.forEach(columnEncoderRecode ->
columnEncoderRecode.getCPRecodeMapsPartial().clear());
+ }
+
+ if(!bowEncoders.isEmpty()){
+ for (ColumnEncoderBagOfWords bowEnc :
bowEncoders)
+ for (Object token :
bowEnc.getPartialTokenDictionary())
+ ret.add(new
Tuple2<>(bowEnc.getColID(), token));
+ bowEncoders.forEach(enc ->
enc.getPartialTokenDictionary().clear());
}
// output binning column min/max as columnID - min/max
pairs
@@ -321,7 +329,8 @@ public class MultiReturnParameterizedBuiltinSPInstruction
extends ComputationSPI
StringBuilder sb = new StringBuilder();
// handle recode maps
- if(_encoder.containsEncoderForID(colID,
ColumnEncoderRecode.class)) {
+ if(_encoder.containsEncoderForID(colID,
ColumnEncoderRecode.class) ||
+ _encoder.containsEncoderForID(colID,
ColumnEncoderBagOfWords.class)) {
while(iter.hasNext()) {
String token =
TfUtils.sanitizeSpaces(iter.next().toString());
sb.append(rowID).append('
').append(scolID).append(' ');
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 61e6e799f0..4d4b012444 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -88,6 +88,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
import org.apache.sysds.runtime.transform.decode.Decoder;
import org.apache.sysds.runtime.transform.decode.DecoderFactory;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
@@ -1056,6 +1057,13 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
// execute block transform apply
MultiColumnEncoder encoder = _bencoder.getValue();
+ // we need to create a copy of the encoder since the
bag of word encoder stores frameblock specific state
+ // which would be overwritten when multiple blocks are
located on a executor
+ // to avoid this, we need to create a shallow copy of
the MCEncoder, where we only instantiate new bow
+ // encoders objects with the frameblock specific fields
and shallow copy the other fields (like meta)
+ // other encoders are reused and not newly instantiated
+
if(!encoder.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty())
+ encoder = new MultiColumnEncoder(encoder); //
create copy
MatrixBlock tmp = encoder.apply(blk);
// remap keys
if(_omap != null) {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index f10da3d946..019df7f847 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -450,7 +450,7 @@ public abstract class ColumnEncoder implements Encoder,
Comparable<ColumnEncoder
}
public enum EncoderType {
- Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit,
MVImpute, Composite, WordEmbedding,
+ Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit,
MVImpute, Composite, WordEmbedding, BagOfWords
}
/*
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
index c138901ad1..25b1a0ce87 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
@@ -29,6 +29,9 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.utils.stats.TransformStatistics;
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
@@ -45,14 +48,15 @@ import static
org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
public class ColumnEncoderBagOfWords extends ColumnEncoder {
public static int NUM_SAMPLES_MAP_ESTIMATION = 16000;
- protected int[] nnzPerRow;
- private Map<String, Integer> tokenDictionary;
- protected String seperatorRegex = "\\s+"; // whitespace
- protected boolean caseSensitive = false;
- protected long nnz = 0;
- protected long[] nnzPartials;
- protected int defaultNnzCapacity = 64;
- protected double avgNnzPerRow = 1.0;
+ private Map<Object, Long> _tokenDictionary; // switched from int to
long to reuse code from RecodeEncoder
+ private HashSet<Object> _tokenDictionaryPart = null;
+ protected String _seperatorRegex = "\\s+"; // whitespace
+ protected boolean _caseSensitive = false;
+ protected int[] _nnzPerRow;
+ protected long _nnz = 0;
+ protected long[] _nnzPartials;
+ protected int _defaultNnzCapacity = 64;
+ protected double _avgNnzPerRow = 1.0;
protected ColumnEncoderBagOfWords(int colID) {
super(colID);
@@ -62,9 +66,40 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
super(-1);
}
+ public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) {
+ super(enc._colID);
+ _nnzPerRow = enc._nnzPerRow != null ? enc._nnzPerRow.clone() :
null;
+ _tokenDictionary = enc._tokenDictionary;
+ _seperatorRegex = enc._seperatorRegex;
+ _caseSensitive = enc._caseSensitive;
+ }
+
+ public void setTokenDictionary(HashMap<Object, Long> dict){
+ _tokenDictionary = dict;
+ }
+
+ public Map<Object, Long> getTokenDictionary() {
+ return _tokenDictionary;
+ }
+
protected void initNnzPartials(int rows, int numBlocks){
- this.nnzPerRow = new int[rows];
- this.nnzPartials = new long[numBlocks];
+ _nnzPerRow = new int[rows];
+ _nnzPartials = new long[numBlocks];
+ }
+
+ public double computeNnzEstimate(CacheBlock<?> in, int[] sampleIndices)
{
+ // estimates the nnz per row for this encoder
+ final int max_index =
Math.min(ColumnEncoderBagOfWords.NUM_SAMPLES_MAP_ESTIMATION,
sampleIndices.length);
+ int nnz = 0;
+ for (int i = 0; i < max_index; i++) {
+ int sind = sampleIndices[i];
+ String current = in.getString(sind, _colID - 1);
+ if(current != null)
+ for(String token : tokenize(current,
_caseSensitive, _seperatorRegex))
+ if(!token.isEmpty() &&
_tokenDictionary.containsKey(token))
+ nnz++;
+ }
+ return (double) nnz / max_index;
}
public void computeMapSizeEstimate(CacheBlock<?> in, int[]
sampleIndices) {
@@ -76,10 +111,10 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder
{
int[] nnzPerRow = new int[max_index];
for (int i = 0; i < max_index; i++) {
int sind = sampleIndices[i];
- String current = in.getString(sind, this._colID - 1);
+ String current = in.getString(sind, _colID - 1);
Set<String> tokenSetRow = new HashSet<>();
if(current != null)
- for(String token : tokenize(current,
caseSensitive, seperatorRegex))
+ for(String token : tokenize(current,
_caseSensitive, _seperatorRegex))
if(!token.isEmpty()){
tokenSetRow.add(token);
if
(distinctFreq.containsKey(token))
@@ -94,9 +129,9 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
nnzPerRow[i] = tokenSetRow.size();
}
Arrays.sort(nnzPerRow);
- avgNnzPerRow = (double) Arrays.stream(nnzPerRow).sum() /
nnzPerRow.length;
+ _avgNnzPerRow = (double) Arrays.stream(nnzPerRow).sum() /
nnzPerRow.length;
// default value for HashSets in build phase -> 75% without
resize (Division by 0.9 -> is the resize threshold)
- defaultNnzCapacity = (int) Math.max( nnzPerRow[(int)
(nnzPerRow.length*0.75)] / 0.9, 64);
+ _defaultNnzCapacity = (int) Math.max( nnzPerRow[(int)
(nnzPerRow.length*0.75)] / 0.9, 64);
// we increase the upperbound of the total count estimate by 20%
double avgSentenceLength = numTokensSample*1.2 / max_index;
@@ -118,6 +153,18 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder
{
_estMetaSize = _estNumDistincts * _avgEntrySize;
}
+ public void computeNnzPerRow(CacheBlock<?> in, int start, int end){
+ for (int i = start; i < end; i++) {
+ String current = in.getString(i, _colID - 1);
+ HashSet<String> distinctTokens = new HashSet<>();
+ if(current != null)
+ for(String token : tokenize(current,
_caseSensitive, _seperatorRegex))
+ if(!token.isEmpty() &&
_tokenDictionary.containsKey(token))
+ distinctTokens.add(token);
+ _nnzPerRow[i] = distinctTokens.size();
+ }
+ }
+
public static String[] tokenize(String current, boolean caseSensitive,
String seperatorRegex) {
// string builder is faster than regex
StringBuilder finalString = new StringBuilder();
@@ -132,7 +179,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
@Override
public int getDomainSize(){
- return tokenDictionary.size();
+ return _tokenDictionary.size();
}
@Override
@@ -150,8 +197,6 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
return TransformType.BAG_OF_WORDS;
}
-
-
public Callable<Object> getBuildTask(CacheBlock<?> in) {
return new ColumnBagOfWordsBuildTask(this, in);
}
@@ -159,38 +204,39 @@ public class ColumnEncoderBagOfWords extends
ColumnEncoder {
@Override
public void build(CacheBlock<?> in) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- tokenDictionary = new HashMap<>(_estNumDistincts);
- int i = 0;
- this.nnz = 0;
- nnzPerRow = new int[in.getNumRows()];
+ _tokenDictionary = new HashMap<>(_estNumDistincts);
+ int i = 1;
+ _nnz = 0;
+ _nnzPerRow = new int[in.getNumRows()];
HashSet<String> tokenSetPerRow;
for (int r = 0; r < in.getNumRows(); r++) {
// start with a higher default capacity to avoid resizes
- tokenSetPerRow = new HashSet<>(defaultNnzCapacity);
- String current = in.getString(r, this._colID - 1);
+ tokenSetPerRow = new HashSet<>(_defaultNnzCapacity);
+ String current = in.getString(r, _colID - 1);
if(current != null)
- for(String token : tokenize(current,
caseSensitive, seperatorRegex))
+ for(String token : tokenize(current,
_caseSensitive, _seperatorRegex))
if(!token.isEmpty()){
tokenSetPerRow.add(token);
-
if(!this.tokenDictionary.containsKey(token))
-
this.tokenDictionary.put(token, i++);
+
if(!_tokenDictionary.containsKey(token))
+
_tokenDictionary.put(token, (long) i++);
}
- this.nnzPerRow[r] = tokenSetPerRow.size();
- this.nnz += tokenSetPerRow.size();
+ _nnzPerRow[r] = tokenSetPerRow.size();
+ _nnz += tokenSetPerRow.size();
}
if(DMLScript.STATISTICS)
TransformStatistics.incBagOfWordsBuildTime(System.nanoTime()-t0);
}
@Override
- public Callable<Object> getPartialBuildTask(CacheBlock<?> in, int
startRow, int blockSize,
-
HashMap<Integer, Object> ret, int pos) {
- return new BowPartialBuildTask(in, _colID, startRow, blockSize,
ret, nnzPerRow, caseSensitive, seperatorRegex, nnzPartials, pos);
+ public Callable<Object> getPartialBuildTask(CacheBlock<?> in,
+ int startRow, int blockSize, HashMap<Integer, Object> ret, int
pos) {
+ return new BowPartialBuildTask(in, _colID, startRow, blockSize,
ret,
+ _nnzPerRow, _caseSensitive, _seperatorRegex,
_nnzPartials, pos);
}
@Override
public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?>
ret) {
- tokenDictionary = new HashMap<>(this._estNumDistincts);
+ _tokenDictionary = new HashMap<>(_estNumDistincts);
return new BowMergePartialBuildTask(this, ret);
}
@@ -205,6 +251,32 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder
{
}
}
+ @Override
+ public void prepareBuildPartial() {
+ // ensure allocated partial recode map
+ if(_tokenDictionaryPart == null)
+ _tokenDictionaryPart = new HashSet<>();
+ }
+
+
+ public HashSet<Object> getPartialTokenDictionary(){
+ return _tokenDictionaryPart;
+ }
+
+ @Override
+ public void buildPartial(FrameBlock in) {
+ if(!isApplicable())
+ return;
+ for (int r = 0; r < in.getNumRows(); r++) {
+ String current = in.getString(r, _colID - 1);
+ if(current != null)
+ for(String token : tokenize(current,
_caseSensitive, _seperatorRegex)){
+ if(!token.isEmpty())
+ _tokenDictionaryPart.add(token);
+ }
+ }
+ }
+
protected void applySparse(CacheBlock<?> in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK ==
SparseBlock.Type.MCSR;
mcsr = false; // force CSR for transformencode FIXME
@@ -214,20 +286,20 @@ public class ColumnEncoderBagOfWords extends
ColumnEncoder {
throw new NotImplementedException();
}
else { // csr
- HashMap<String, Integer> counter =
countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
+ HashMap<String, Integer> counter =
countTokenAppearances(in, r);
if(counter.isEmpty())
sparseRowsWZeros.add(r);
else {
SparseBlockCSR csrblock =
(SparseBlockCSR) out.getSparseBlock();
int[] rptr = csrblock.rowPointers();
// assert that nnz from build is equal
to nnz from apply
- assert counter.size() == nnzPerRow[r];
- Pair[] columnValuePairs = new
Pair[counter.size()];
+ Pair[] columnValuePairs = new
Pair[_nnzPerRow[r]];
int i = 0;
for (Map.Entry<String, Integer> entry :
counter.entrySet()) {
String token = entry.getKey();
- columnValuePairs[i] = new
Pair(outputCol + tokenDictionary.get(token), entry.getValue());
- i++;
+ columnValuePairs[i] = new
Pair((int) (outputCol + _tokenDictionary.getOrDefault(token, 0L) - 1),
entry.getValue());
+ // if token is not included
columnValuePairs[i] is overwritten in the next iteration
+ i +=
_tokenDictionary.containsKey(token) ? 1 : 0;
}
// insertion sorts performs better on
small arrays
if(columnValuePairs.length >= 128)
@@ -237,7 +309,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
// Manually fill the column-indexes and
values array
for (i = 0; i <
columnValuePairs.length; i++) {
int index =
sparseRowPointerOffset != null ? sparseRowPointerOffset[r] - 1 + i : i;
- index += rptr[r] + this._colID
-1;
+ index += rptr[r] + _colID -1;
csrblock.indexes()[index] =
columnValuePairs[i].key;
csrblock.values()[index] =
columnValuePairs[i].value;
}
@@ -264,42 +336,68 @@ public class ColumnEncoderBagOfWords extends
ColumnEncoder {
@Override
protected void applyDense(CacheBlock<?> in, MatrixBlock out, int
outputCol, int rowStart, int blk){
for (int r = rowStart; r < Math.max(in.getNumRows(), rowStart +
blk); r++) {
- HashMap<String, Integer> counter =
countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
+ HashMap<String, Integer> counter =
countTokenAppearances(in, r);
for (Map.Entry<String, Integer> entry :
counter.entrySet())
- out.set(r, outputCol +
tokenDictionary.get(entry.getKey()), entry.getValue());
+ out.set(r, (int) (outputCol +
_tokenDictionary.get(entry.getKey()) - 1), entry.getValue());
}
}
- private static HashMap<String, Integer> countTokenAppearances(
- CacheBlock<?> in, int r, int c, boolean caseSensitive, String
separator)
+ private HashMap<String, Integer> countTokenAppearances(
+ CacheBlock<?> in, int r)
{
- String current = in.getString(r, c);
+ String current = in.getString(r, _colID - 1);
HashMap<String, Integer> counter = new HashMap<>();
if(current != null)
- for (String token : tokenize(current, caseSensitive,
separator))
- if (!token.isEmpty())
+ for (String token : tokenize(current, _caseSensitive,
_seperatorRegex))
+ if (!token.isEmpty() &&
_tokenDictionary.containsKey(token))
counter.put(token,
counter.getOrDefault(token, 0) + 1);
return counter;
}
@Override
public void allocateMetaData(FrameBlock meta) {
- meta.ensureAllocatedColumns(this.getDomainSize());
+ meta.ensureAllocatedColumns(getDomainSize());
}
@Override
public FrameBlock getMetaData(FrameBlock out) {
int rowID = 0;
StringBuilder sb = new StringBuilder();
- for(Map.Entry<String, Integer> e :
this.tokenDictionary.entrySet()) {
- out.set(rowID++, _colID - 1,
constructRecodeMapEntry(e.getKey(), Long.valueOf(e.getValue()), sb));
+ for(Map.Entry<Object, Long> e : _tokenDictionary.entrySet()) {
+ out.set(rowID++, _colID - 1,
constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
}
return out;
}
@Override
public void initMetaData(FrameBlock meta) {
- throw new NotImplementedException();
+ if(meta != null && meta.getNumRows() > 0) {
+ _tokenDictionary = meta.getRecodeMap(_colID - 1);
+ }
+ }
+
+ @Override
+ public void writeExternal(ObjectOutput out) throws IOException {
+ super.writeExternal(out);
+
+ out.writeInt(_tokenDictionary == null ? 0 :
_tokenDictionary.size());
+ if(_tokenDictionary != null)
+ for(Map.Entry<Object, Long> e :
_tokenDictionary.entrySet()) {
+ out.writeUTF((String) e.getKey());
+ out.writeLong(e.getValue());
+ }
+ }
+
+ @Override
+ public void readExternal(ObjectInput in) throws IOException {
+ super.readExternal(in);
+ int size = in.readInt();
+ _tokenDictionary = new HashMap<>(size * 4 / 3);
+ for(int j = 0; j < size; j++) {
+ String key = in.readUTF();
+ Long value = in.readLong();
+ _tokenDictionary.put(key, value);
+ }
}
private static class BowPartialBuildTask implements Callable<Object> {
@@ -340,7 +438,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
long nnzPartial = 0;
for (int r = _startRow; r < endRow; r++) {
tokenSetPerRow = new HashSet<>(64);
- String current = _input.getString(r,
this._colID - 1);
+ String current = _input.getString(r, _colID -
1);
if(current != null)
for(String token : tokenize(current,
_caseSensitive, _seperator))
if(!token.isEmpty()){
@@ -378,15 +476,15 @@ public class ColumnEncoderBagOfWords extends
ColumnEncoder {
@Override
public Object call() {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
- Map<String, Integer> tokenDictionary =
_encoder.tokenDictionary;
+ Map<Object, Long> tokenDictionary =
_encoder._tokenDictionary;
for(Object tokenSet : _partialMaps.values()){
( (HashSet<?>) tokenSet).forEach(token -> {
if(!tokenDictionary.containsKey(token))
- tokenDictionary.put((String)
token, tokenDictionary.size());
+ tokenDictionary.put(token,
(long) tokenDictionary.size() + 1);
});
}
- for (long nnzPartial : _encoder.nnzPartials)
- _encoder.nnz += nnzPartial;
+ for (long nnzPartial : _encoder._nnzPartials)
+ _encoder._nnz += nnzPartial;
if(DMLScript.STATISTICS){
TransformStatistics.incBagOfWordsBuildTime(System.nanoTime() - t0);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 9544372914..536b387a1d 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -62,9 +62,9 @@ public class ColumnEncoderComposite extends ColumnEncoder {
public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders,
FrameBlock meta) {
super(-1);
- if(!(columnEncoders.size() > 0 &&
+ if(!(!columnEncoders.isEmpty() &&
columnEncoders.stream().allMatch((encoder ->
encoder._colID == columnEncoders.get(0)._colID))))
- throw new DMLRuntimeException("Tried to create
Composite Encoder with no encoders or mismatching columIDs");
+ throw new DMLRuntimeException("Tried to create
Composite Encoder with no encoders or mismatching columnIDs");
_colID = columnEncoders.get(0)._colID;
_meta = meta;
_columnEncoders = columnEncoders;
@@ -73,6 +73,11 @@ public class ColumnEncoderComposite extends ColumnEncoder {
public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders) {
this(columnEncoders, null);
}
+ public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, int
colID) {
+ super(colID);
+ _columnEncoders = columnEncoders;
+ _meta = null;
+ }
public ColumnEncoderComposite(ColumnEncoder columnEncoder) {
super(columnEncoder._colID);
@@ -166,7 +171,8 @@ public class ColumnEncoderComposite extends ColumnEncoder {
if(t == null)
continue;
// Linear execution between encoders so they can't be
built in parallel
- if(tasks.size() != 0) {
+ if(!tasks.isEmpty()) {
+ // TODO: is that still needed? currently there
is no CompositeEncoder with 2 encoders with build phase
// avoid unnecessary map initialization
depMap = (depMap == null) ? new HashMap<>() :
depMap;
// This workaround is needed since sublist is
only valid for effective final lists,
@@ -207,6 +213,8 @@ public class ColumnEncoderComposite extends ColumnEncoder {
public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int
outputCol, int rowStart, int blk) {
try {
for(int i = 0; i < _columnEncoders.size(); i++) {
+ // set sparseRowPointerOffset in the encoder
+ _columnEncoders.get(i).sparseRowPointerOffset =
this.sparseRowPointerOffset;
if(i == 0) {
// 1. encoder writes data into
MatrixBlock Column all others use this column for further encoding
_columnEncoders.get(i).apply(in, out,
outputCol, rowStart, blk);
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 05ad8e4694..1c2478d711 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -265,6 +265,8 @@ public interface EncoderFactory {
return EncoderType.Recode.ordinal();
else if(columnEncoder instanceof ColumnEncoderWordEmbedding)
return EncoderType.WordEmbedding.ordinal();
+ else if(columnEncoder instanceof ColumnEncoderBagOfWords)
+ return EncoderType.BagOfWords.ordinal();
throw new DMLRuntimeException("Unsupported encoder type: " +
columnEncoder.getClass().getCanonicalName());
}
@@ -283,6 +285,8 @@ public interface EncoderFactory {
return new ColumnEncoderRecode();
case WordEmbedding:
return new ColumnEncoderWordEmbedding();
+ case BagOfWords:
+ return new ColumnEncoderBagOfWords();
default:
throw new DMLRuntimeException("Unsupported
encoder type: " + etype);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 0417e67ba1..79c05ca8e7 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -89,6 +89,20 @@ public class MultiColumnEncoder implements Encoder {
_columnEncoders = columnEncoders;
}
+ public MultiColumnEncoder(MultiColumnEncoder menc) {
+ // This constructor creates a shallow copy for all encoders
except for bag_of_words encoders
+ List<ColumnEncoderComposite> colEncs = menc._columnEncoders;
+ _columnEncoders= new ArrayList<>();
+ for (ColumnEncoderComposite cColEnc : colEncs) {
+ List<ColumnEncoder> newEncs = new ArrayList<>();
+ ColumnEncoderComposite cColEncCopy = new
ColumnEncoderComposite(newEncs, cColEnc._colID);
+ _columnEncoders.add(cColEncCopy);
+ for (ColumnEncoder enc : cColEnc.getEncoders()) {
+ newEncs.add(enc instanceof
ColumnEncoderBagOfWords ? new ColumnEncoderBagOfWords((ColumnEncoderBagOfWords)
enc) : enc);
+ }
+ }
+ }
+
public MultiColumnEncoder() {
_columnEncoders = new ArrayList<>();
}
@@ -327,16 +341,14 @@ public class MultiColumnEncoder implements Encoder {
public MatrixBlock apply(CacheBlock<?> in, int k) {
// domain sizes are not updated if called from transformapply
- boolean hasUDF = _columnEncoders.stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
- boolean hasWE = _columnEncoders.stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderWordEmbedding.class));
- for(ColumnEncoderComposite columnEncoder : _columnEncoders)
- columnEncoder.updateAllDCEncoders();
+ EncoderMeta encm = getEncMeta(_columnEncoders, true, k, in);
+ updateAllDCEncoders();
int numCols = getNumOutCols();
- long estNNz = (long) in.getNumRows() * (hasUDF ? numCols :
hasWE ? getEstNNzRow() : in.getNumColumns());
- // FIXME: estimate nnz for multiple encoders including
dummycode and embedding
- boolean sparse =
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) &&
!hasUDF;
+ long estNNz = (long) in.getNumRows() * (encm.hasWE ||
encm.hasUDF ? numCols : (in.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW);
+ // FIXME: estimate nnz for multiple encoders including dummycode
+ boolean sparse =
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) &&
!encm.hasUDF;
MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols,
sparse, estNNz);
- return apply(in, out, 0, k);
+ return apply(in, out, 0, k, encm, estNNz);
}
public void updateAllDCEncoders(){
@@ -345,10 +357,11 @@ public class MultiColumnEncoder implements Encoder {
}
public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int
outputCol) {
- return apply(in, out, outputCol, 1);
+ // unused method, only exists currently because of the interface
+ throw new DMLRuntimeException("MultiColumnEncoder apply without
Encoder Characteristics should not be called directly");
}
- public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int
outputCol, int k) {
+ public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int
outputCol, int k, EncoderMeta encm, long nnz) {
// There should be a encoder for every column
if(hasLegacyEncoder() && !(in instanceof FrameBlock))
throw new DMLRuntimeException("LegacyEncoders do not
support non FrameBlock Inputs");
@@ -361,31 +374,20 @@ public class MultiColumnEncoder implements Encoder {
if(in.getNumRows() == 0)
throw new DMLRuntimeException("Invalid input with wrong
number or rows");
- boolean hasDC = false;
- boolean hasWE = false;
- //TODO adapt transform apply for BOW
- int distinctWE = 0;
- int sizeWE = 0;
- for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
- hasDC |=
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
- for (ColumnEncoder enc : columnEncoder.getEncoders())
- if(enc instanceof ColumnEncoderWordEmbedding){
- hasWE = true;
- distinctWE =
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
- sizeWE = enc.getDomainSize();
- }
- }
- outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE,
sizeWE, 0, null, -1);
+ ArrayList<int[]> nnzOffsets = outputMatrixPreProcessing(out,
in, encm, nnz, k);
if(k > 1) {
if(!_partitionDone) //happens if this method is
directly called
deriveNumRowPartitions(in, k);
- applyMT(in, out, outputCol, k);
+ applyMT(in, out, outputCol, k, nnzOffsets);
}
else {
- int offset = outputCol;
+ int offset = outputCol, i = 0;
+ int[] nnzOffset = null;
for(ColumnEncoderComposite columnEncoder :
_columnEncoders) {
+ columnEncoder.sparseRowPointerOffset =
nnzOffset;
columnEncoder.apply(in, out,
columnEncoder._colID - 1 + offset);
- offset = getOffset(offset, columnEncoder);
+ offset = getOutputColOffset(offset,
columnEncoder);
+ nnzOffset = nnzOffsets != null ?
nnzOffsets.get(i++) : null;
}
}
// Recomputing NNZ since we access the block directly
@@ -399,36 +401,44 @@ public class MultiColumnEncoder implements Encoder {
return out;
}
- private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in,
MatrixBlock out, int outputCol) {
+ private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in,
MatrixBlock out, int outputCol, ArrayList<int[]> nnzOffsets) {
List<DependencyTask<?>> tasks = new ArrayList<>();
int offset = outputCol;
+ int i = 0;
+ int[] currentNnzOffsets = null;
for(ColumnEncoderComposite e : _columnEncoders) {
- tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 +
offset, null));
- offset = getOffset(offset, e);
+ tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 +
offset, currentNnzOffsets));
+ currentNnzOffsets = nnzOffsets != null ?
nnzOffsets.get(i++) : null;
+ offset = getOutputColOffset(offset, e);
}
return tasks;
}
- private int getOffset(int offset, ColumnEncoderComposite e) {
+ private int getOutputColOffset(int offset, ColumnEncoderComposite e) {
if(e.hasEncoder(ColumnEncoderDummycode.class))
offset +=
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
if(e.hasEncoder(ColumnEncoderWordEmbedding.class))
offset +=
e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
+ if(e.hasEncoder(ColumnEncoderBagOfWords.class))
+ offset +=
e.getEncoder(ColumnEncoderBagOfWords.class).getDomainSize() - 1;
return offset;
}
- private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol,
int k) {
+ private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol,
int k, ArrayList<int[]> nnzOffsets) {
DependencyThreadPool pool = new DependencyThreadPool(k);
try {
if(APPLY_ENCODER_SEPARATE_STAGES) {
int offset = outputCol;
+ int i = 0;
+ int[] currentNnzOffsets = null;
for (ColumnEncoderComposite e :
_columnEncoders) {
- // for now bag of words is only used in
encode
-
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, null));
- offset = getOffset(offset, e);
+
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset,
currentNnzOffsets));
+ offset = getOutputColOffset(offset, e);
+ currentNnzOffsets = nnzOffsets != null
? nnzOffsets.get(i) : null;
+ i++;
}
} else
- pool.submitAllAndWait(getApplyTasks(in, out,
outputCol));
+ pool.submitAllAndWait(getApplyTasks(in, out,
outputCol, nnzOffsets));
}
catch(ExecutionException | InterruptedException e) {
throw new DMLRuntimeException(e);
@@ -635,7 +645,7 @@ public class MultiColumnEncoder implements Encoder {
}
}
- private int[] getSampleIndices(CacheBlock<?> in, int sampleSize, int
seed, int k){
+ private static int[] getSampleIndices(CacheBlock<?> in, int sampleSize,
int seed, int k){
return ComEstSample.getSortedSample(in.getNumRows(),
sampleSize, seed, k);
}
@@ -659,11 +669,11 @@ public class MultiColumnEncoder implements Encoder {
return totMemOverhead;
}
- private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE,
-
int distinctWE, int sizeWE, int numBOW, int[] nnzPerRowBOW,
int nnz) {
+ private static ArrayList<int[]> outputMatrixPreProcessing(MatrixBlock
output, CacheBlock<?> input, EncoderMeta encm, long nnz, int k) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(nnz < 0)
- nnz = output.getNumRows() * input.getNumColumns();
+ nnz = (long) output.getNumRows() *
input.getNumColumns();
+ ArrayList<int[]> bowNnzRowOffsets = null;
if(output.isInSparseFormat()) {
if (MatrixBlock.DEFAULT_SPARSEBLOCK !=
SparseBlock.Type.CSR
&& MatrixBlock.DEFAULT_SPARSEBLOCK !=
SparseBlock.Type.MCSR)
@@ -673,7 +683,7 @@ public class MultiColumnEncoder implements Encoder {
if (mcsr) {
output.allocateBlock();
SparseBlock block = output.getSparseBlock();
- if (hasDC &&
OptimizerUtils.getTransformNumThreads()>1) {
+ if (encm.hasDC &&
OptimizerUtils.getTransformNumThreads()>1) {
// DC forces a single threaded
allocation after the build phase and
// before the apply starts. Below code
parallelizes sparse allocation.
IntStream.range(0, output.getNumRows())
@@ -695,30 +705,52 @@ public class MultiColumnEncoder implements Encoder {
}
}
else { //csr
- SparseBlockCSR csrblock = new
SparseBlockCSR(output.getNumRows(), nnz, nnz);
// Manually fill the row pointers based on
nnzs/row (= #cols in the input)
// Not using the set() methods to 1) avoid
binary search and shifting,
// 2) reduce thread contentions on the arrays
- int[] rptr = csrblock.rowPointers();
- if(nnzPerRowBOW != null)
- for (int i=0; i<rptr.length-1; i++) {
//TODO: parallelize
- int nnzPerRow =
input.getNumColumns() - numBOW + nnzPerRowBOW[i];
- rptr[i+1] = rptr[i] + nnzPerRow;
- }
- else
- for (int i=0; i<rptr.length-1; i++) {
//TODO: parallelize
- rptr[i+1] = rptr[i] +
input.getNumColumns();
+ int nnzInt = (int) nnz;
+ int[] rptr = new int[output.getNumRows()+1];
+ // easy case: no bow encoders
+ // nnz per row = #encoders = #inputCols
+ if(encm.numBOWEnc <= 0 )
+ for (int i = 0; i < rptr.length - 1;
i++)
+ rptr[i + 1] = rptr[i] +
input.getNumColumns();
+ else {
+ if( encm.nnzPerRowBOW != null) {
+ // #nzPerRow has been already
computed and aggregated for all bow encoders
+ int static_offset =
input.getNumColumns() - encm.numBOWEnc;
+ // - #bow since the nnz are
already counted
+ for (int i = 0; i < rptr.length
- 1; i++) {
+ int nnzPerRow =
static_offset + encm.nnzPerRowBOW[i];
+ rptr[i + 1] = rptr[i] +
nnzPerRow;
+ }
+ } else {
+ // case for transform_apply
where the #nnz ofr bow is unknown yet, since we have no build phase,
+ // we have to compute the nnz
now, we parallelize for now over the #bowEncoders and #rows
+ // for the aggregation we
parallelize just over the number of rows
+ bowNnzRowOffsets =
getNnzPerRowFromBOWEncoders(input, encm, k);
+ // the last array contains the
complete aggregation
+ int static_offset =
input.getNumColumns() - 1;
+ // we just subtract -1 since we
already subtracted -1 for every bow encoder except the first
+ int[] aggOffsets =
bowNnzRowOffsets.get(bowNnzRowOffsets.size() - 1);
+ for (int i = 0; i <
rptr.length-1; i++) {
+ rptr[i+1] = rptr[i] +
static_offset + aggOffsets[i];
+ }
+ nnzInt = rptr[rptr.length-1];
}
+ }
+ SparseBlockCSR csrblock = new
SparseBlockCSR(rptr, new int[nnzInt], new double[nnzInt], nnzInt) ;
output.setSparseBlock(csrblock);
+
}
}
else {
// Allocate dense block and set nnz to total #entries
- output.allocateDenseBlock(true, hasWE);
- if( hasWE){
+ output.allocateDenseBlock(true, encm.hasWE);
+ if(encm.hasWE){
DenseBlockFP64DEDUP dedup =
((DenseBlockFP64DEDUP) output.getDenseBlock());
- dedup.setDistinct(distinctWE);
- dedup.setEmbeddingSize(sizeWE);
+ dedup.setDistinct(encm.distinctWE);
+ dedup.setEmbeddingSize(encm.sizeWE);
}
//output.setAllNonZeros();
}
@@ -727,6 +759,77 @@ public class MultiColumnEncoder implements Encoder {
LOG.debug("Elapsed time for allocation: "+ ((double)
System.nanoTime() - t0) / 1000000 + " ms");
TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime()-t0);
}
+ return bowNnzRowOffsets;
+ }
+
+ private static ArrayList<int[]>
getNnzPerRowFromBOWEncoders(CacheBlock<?> input, EncoderMeta encm, int k) {
+ ArrayList<int[]> bowNnzRowOffsets;
+ int min_block_size = 1000;
+ int num_blocks = input.getNumRows() / min_block_size;
+ // 1 <= num_blks1 <= k / #enc
+ int num_blks1= Math.min( (k + encm.numBOWEnc - 1)/
encm.numBOWEnc, Math.max(num_blocks, 1));
+ int blk_len1 = (input.getNumRows() + num_blks1 - 1) / num_blks1;
+ // 1 <= num_blks2 <= k
+ int num_blks2= Math.min(k, Math.max(num_blocks, 1));
+ int blk_len2 = (input.getNumRows() + num_blks2 - 1) / num_blks1;
+
+ ExecutorService pool = CommonThreadPool.get(k);
+ ArrayList<int[]> bowNnzRowOffsetsFinal = new ArrayList<>();
+ try {
+ encm.bowEncoders.forEach(e -> e._nnzPerRow = new
int[input.getNumRows()]);
+ ArrayList<Future<?>> list = new ArrayList<>();
+ for (int i = 0; i < num_blks1; i++) {
+ int start = i * blk_len1;
+ int end = Math.min((i + 1) * blk_len1,
input.getNumRows());
+ list.add(pool.submit(() ->
encm.bowEncoders.stream().parallel()
+ .forEach(e -> e.computeNnzPerRow(input,
start, end))));
+ }
+ for(Future<?> f : list)
+ f.get();
+ list.clear();
+ int[] previous = null;
+ for(ColumnEncoderComposite enc : encm.encs){
+
if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
+ previous = previous == null ?
+
enc.getEncoder(ColumnEncoderBagOfWords.class)._nnzPerRow :
+ new int[input.getNumRows()];
+ }
+ bowNnzRowOffsetsFinal.add(previous);
+ }
+ for (int i = 0; i < num_blks2; i++) {
+ int start = i * blk_len1;
+ list.add(pool.submit(() ->
aggregateNnzPerRow(start, blk_len2,
+ input.getNumRows(), encm.encs,
bowNnzRowOffsetsFinal)));
+ }
+ for(Future<?> f : list)
+ f.get();
+ bowNnzRowOffsets = bowNnzRowOffsetsFinal;
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ finally {
+ pool.shutdown();
+ }
+ return bowNnzRowOffsets;
+ }
+
+ private static void aggregateNnzPerRow(int start, int blk_len, int
numRows, List<ColumnEncoderComposite> encs, ArrayList<int[]> bowNnzRowOffsets) {
+ int end = Math.min(start + blk_len, numRows);
+ int pos = 0;
+ int[] aggRowOffsets = null;
+ for(ColumnEncoderComposite enc : encs){
+ int[] currentOffsets = bowNnzRowOffsets.get(pos);
+ if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) {
+ ColumnEncoderBagOfWords bow =
enc.getEncoder(ColumnEncoderBagOfWords.class);
+ if(aggRowOffsets == null)
+ aggRowOffsets = currentOffsets;
+ else
+ for (int i = start; i < end; i++)
+ currentOffsets[i] =
aggRowOffsets[i] + bow._nnzPerRow[i] - 1;
+ }
+ pos++;
+ }
}
private void outputMatrixPostProcessing(MatrixBlock output, int k){
@@ -822,7 +925,7 @@ public class MultiColumnEncoder implements Encoder {
tasks.add(new
ColumnMetaDataTask<>(columnEncoder, meta));
List<Future<Object>> taskret =
pool.invokeAll(tasks);
for (Future<Object> task : taskret)
- task.get();
+ task.get();
}
catch(Exception ex) {
throw new DMLRuntimeException(ex);
@@ -1021,13 +1124,6 @@ public class MultiColumnEncoder implements Encoder {
return getEncoderTypes(-1);
}
- public int getEstNNzRow(){
- int nnz = 0;
- for(int i = 0; i < _columnEncoders.size(); i++)
- nnz += _columnEncoders.get(i).getDomainSize();
- return nnz;
- }
-
public int getNumOutCols() {
int sum = 0;
for(int i = 0; i < _columnEncoders.size(); i++)
@@ -1218,6 +1314,92 @@ public class MultiColumnEncoder implements Encoder {
return sb.toString();
}
+ private static class EncoderMeta {
+ // contains information about the encoders and their relevant
data characteristics
+ public final boolean hasUDF;
+ public final boolean hasDC;
+ public final boolean hasWE;
+ public final int distinctWE;
+ public final int sizeWE;
+ public final long nnzBOW;
+ public final int numBOWEnc;
+ public final int[] nnzPerRowBOW;
+ public final ArrayList<ColumnEncoderBagOfWords> bowEncoders;
+ public final List<ColumnEncoderComposite> encs;
+
+ public EncoderMeta(boolean hasUDF, boolean hasDC, boolean
hasWE, int distinctWE, int sizeWE, long nnzBOW,
+ int numBOWEncoder, int[]
nnzPerRowBOW, ArrayList<ColumnEncoderBagOfWords> bows,
+ List<ColumnEncoderComposite>
encoders) {
+ this.hasUDF = hasUDF;
+ this.hasDC = hasDC;
+ this.hasWE = hasWE;
+ this.distinctWE = distinctWE;
+ this.sizeWE = sizeWE;
+ this.nnzBOW = nnzBOW;
+ this.numBOWEnc = numBOWEncoder;
+ this.nnzPerRowBOW = nnzPerRowBOW;
+ this.bowEncoders = bows;
+ this.encs = encoders;
+ }
+ }
+
+ private static EncoderMeta getEncMeta(List<ColumnEncoderComposite>
encoders, boolean noBuild, int k, CacheBlock<?> in) {
+ boolean hasUDF = false, hasDC = false, hasWE = false;
+ int distinctWE = 0;
+ int sizeWE = 0;
+ long nnzBOW = 0;
+ int numBOWEncoder = 0;
+ int[] nnzPerRowBOW = null;
+ ArrayList<ColumnEncoderBagOfWords> bows = new ArrayList<>();
+ for (ColumnEncoderComposite enc : encoders){
+ if(enc.hasEncoder(ColumnEncoderUDF.class))
+ hasUDF = true;
+ else if (enc.hasEncoder(ColumnEncoderDummycode.class))
+ hasDC = true;
+ else if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
+ ColumnEncoderBagOfWords bowEnc =
enc.getEncoder(ColumnEncoderBagOfWords.class);
+ numBOWEncoder++;
+ nnzBOW += bowEnc._nnz;
+ if(noBuild){
+ // estimate nnz by sampling
+ bows.add(bowEnc);
+ } else if(nnzPerRowBOW != null)
+ for (int i = 0; i <
bowEnc._nnzPerRow.length; i++) {
+ nnzPerRowBOW[i] +=
bowEnc._nnzPerRow[i];
+ }
+ else {
+ nnzPerRowBOW =
bowEnc._nnzPerRow.clone();
+ }
+ }
+ else
if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){
+ hasWE = true;
+ distinctWE =
enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings();
+ sizeWE = enc.getDomainSize();
+ }
+ }
+ if(!bows.isEmpty()){
+ int[] sampleInds = getSampleIndices(in, in.getNumRows()
> 1000 ? (int) (0.1 * in.getNumRows()) : in.getNumRows(), (int)
System.nanoTime(), 1);
+ // Concurrent (column-wise) bag of words nnz estimation
per row, we estimate the number of nnz because the
+ // exact number is only needed for sparse outputs not
for dense, if sparse, we recount the nnz for all rows later
+ // Note: the sampling might be problematic since we
used for the sparsity estimation -> which impacts performance
+ // if we go for the non-ideal output format
+ ExecutorService pool = CommonThreadPool.get(k);
+ try {
+ Double result = pool.submit(() ->
bows.stream().parallel()
+ .mapToDouble(e ->
e.computeNnzEstimate(in, sampleInds))
+ .sum()).get();
+ nnzBOW = (long) Math.ceil(result);
+ }
+ catch(Exception ex) {
+ throw new DMLRuntimeException(ex);
+ }
+ finally{
+ pool.shutdown();
+ }
+ }
+ return new EncoderMeta(hasUDF, hasDC, hasWE, distinctWE, sizeWE,
nnzBOW, numBOWEncoder, nnzPerRowBOW, bows, encoders);
+ }
+
/*
* Currently, not in use will be integrated in the future
*/
@@ -1271,41 +1453,12 @@ public class MultiColumnEncoder implements Encoder {
@Override
public Object call() {
- boolean hasUDF = false, hasDC = false, hasWE = false;
- int distinctWE = 0;
- int sizeWE = 0;
- long nnzBOW = 0;
- int numBOWEncoder = 0;
- int[] nnzPerRowBOW = null;
- for (ColumnEncoderComposite enc :
_encoder.getEncoders()){
- if(enc.hasEncoder(ColumnEncoderUDF.class))
- hasUDF = true;
- else if
(enc.hasEncoder(ColumnEncoderDummycode.class))
- hasDC = true;
- else
if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
- ColumnEncoderBagOfWords bowEnc =
enc.getEncoder(ColumnEncoderBagOfWords.class);
- numBOWEncoder++;
- nnzBOW += bowEnc.nnz;
- if(nnzPerRowBOW != null)
- for (int i = 0; i <
bowEnc.nnzPerRow.length; i++) {
- nnzPerRowBOW[i] +=
bowEnc.nnzPerRow[i];
- }
- else {
- nnzPerRowBOW =
bowEnc.nnzPerRow.clone();
- }
- }
- else
if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){
- hasWE = true;
- distinctWE =
enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings();
- sizeWE = enc.getDomainSize();
- }
- }
-
+ EncoderMeta encm = getEncMeta(_encoder.getEncoders(),
false, -1, _input);
int numCols = _encoder.getNumOutCols();
- long estNNz = (long) _input.getNumRows() * (hasUDF ?
numCols : _input.getNumColumns() - numBOWEncoder) + nnzBOW;
- boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!hasUDF;
+ long estNNz = (long) _input.getNumRows() * (encm.hasUDF
? numCols : _input.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW;
+ boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!encm.hasUDF;
_output.reset(_input.getNumRows(), numCols, sparse,
estNNz);
- outputMatrixPreProcessing(_output, _input, hasDC,
hasWE, distinctWE, sizeWE, numBOWEncoder, nnzPerRowBOW, (int) estNNz);
+ outputMatrixPreProcessing(_output, _input, encm,
estNNz, 1);
return null;
}
@@ -1381,13 +1534,14 @@ public class MultiColumnEncoder implements Encoder {
@Override
public Object call() throws Exception {
+ // updates the outputCol offset and sets the nnz
offsets, which are created by bow encoders, in each encoder
int currentCol = -1;
int currentOffset = 0;
int[] sparseRowPointerOffsets = null;
for(DependencyTask<?> dtask : _applyTasksWrappers) {
((ApplyTasksWrapperTask)
dtask).setOffset(currentOffset);
if(sparseRowPointerOffsets != null)
- ((ApplyTasksWrapperTask)
dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets.clone());
+ ((ApplyTasksWrapperTask)
dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets);
int nonOffsetCol = ((ApplyTasksWrapperTask)
dtask)._encoder._colID - 1;
if(nonOffsetCol > currentCol) {
currentCol = nonOffsetCol;
@@ -1398,11 +1552,14 @@ public class MultiColumnEncoder implements Encoder {
ColumnEncoderBagOfWords bow =
enc.getEncoder(ColumnEncoderBagOfWords.class);
currentOffset +=
bow.getDomainSize() - 1;
if(sparseRowPointerOffsets ==
null)
- sparseRowPointerOffsets
= bow.nnzPerRow.clone();
- else
+ sparseRowPointerOffsets
= bow._nnzPerRow;
+ else{
+ sparseRowPointerOffsets
= sparseRowPointerOffsets.clone();
+ // TODO: experiment if
it makes sense to parallize here (for frames with many rows)
for (int r = 0; r <
sparseRowPointerOffsets.length; r++) {
-
sparseRowPointerOffsets[r] += bow.nnzPerRow[r] - 1;
+
sparseRowPointerOffsets[r] += bow._nnzPerRow[r] - 1;
}
+ }
}
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
new file mode 100644
index 0000000000..3537998d99
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.EncoderOmit;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class ColumnEncoderMixedFunctionalityTests extends AutomatedTestBase
+{
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ }
+
+ @Test
+ public void testCompositeConstructor1() {
+ ColumnEncoderComposite cEnc1 = new ColumnEncoderComposite(null,
1);
+ ColumnEncoderComposite cEnc2 = new
ColumnEncoderComposite(cEnc1);
+ assert cEnc1.getColID() == cEnc2.getColID();
+
+ }
+ @Test
+ public void testCompositeConstructor2() {
+ List<ColumnEncoder> encoderList = new ArrayList<>();
+ encoderList.add( new ColumnEncoderComposite(null, 1));
+ encoderList.add( new ColumnEncoderComposite(null, 2));
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> new ColumnEncoderComposite(encoderList, null));
+ assertTrue(e.getMessage().contains("Tried to create Composite
Encoder with no encoders or mismatching columnIDs"));
+ }
+
+ @Test
+ public void testEncoderFactoryGetUnsupportedEncoderType(){
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> EncoderFactory.getEncoderType(new ColumnEncoderComposite()));
+ assertTrue(e.getMessage().contains("Unsupported encoder type:
org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite"));
+ }
+
+ @Test
+ public void testEncoderFactoryCreateUnsupportedInstanceType(){
+ // type(7) = composite, which we don't use for encoding the type
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> EncoderFactory.createInstance(7));
+ assertTrue(e.getMessage().contains("Unsupported encoder type:
Composite"));
+ }
+
+ @Test
+ public void testMultiColumnEncoderApplyWithWrongInputCharacteristics1(){
+ // apply call without metadata about encoders
+ MultiColumnEncoder mEnc = new MultiColumnEncoder();
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> mEnc.apply(null, null, 0));
+ assertTrue(e.getMessage().contains("MultiColumnEncoder apply
without Encoder Characteristics should not be called directly"));
+ }
+
+ @Test
+ public void testMultiColumnEncoderApplyWithWrongInputCharacteristics2(){
+ // apply with LegacyEncoders + non FrameBlock Inputs
+ MultiColumnEncoder mEnc = new MultiColumnEncoder();
+ mEnc.addReplaceLegacyEncoder(new EncoderOmit());
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> mEnc.apply(new MatrixBlock(), null, 0, 0, null, 0L));
+ assertTrue(e.getMessage().contains("LegacyEncoders do not
support non FrameBlock Inputs"));
+ }
+
+ @Test
+ public void testMultiColumnEncoderApplyWithWrongInputCharacteristics3(){
+ // #CompositeEncoders != #cols
+ ArrayList<ColumnEncoder> encs = new ArrayList<>();
+ encs.add(new ColumnEncoderBagOfWords());
+ ArrayList<ColumnEncoderComposite> cEncs = new ArrayList<>();
+ cEncs.add(new ColumnEncoderComposite(encs));
+ MultiColumnEncoder mEnc = new MultiColumnEncoder(cEncs);
+ FrameBlock in = new FrameBlock(2, Types.ValueType.FP64);
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> mEnc.apply(in, null, 0, 0, null, 0L));
+ assertTrue(e.getMessage().contains("Not every column in has a
CompositeEncoder. Please make sure every column has a encoder or slice the
input accordingly"));
+ }
+
+ @Test
+ public void testMultiColumnEncoderApplyWithWrongInputCharacteristics4(){
+ // input has 0 rows
+ MultiColumnEncoder mEnc = new MultiColumnEncoder();
+ MatrixBlock in = new MatrixBlock();
+ DMLRuntimeException e = assertThrows(DMLRuntimeException.class,
() -> mEnc.apply(in, null, 0, 0, null, 0L));
+ assertTrue(e.getMessage().contains("Invalid input with wrong
number or rows"));
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
index aeec927e73..2bd1e64697 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
@@ -25,11 +25,14 @@ import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.util.HashMap;
import java.util.List;
+import java.util.Map;
import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
@@ -56,7 +59,8 @@ public class ColumnEncoderSerializationTest extends
AutomatedTestBase
RECODE,
DUMMY,
IMPUTE,
- OMIT
+ OMIT,
+ BOW
}
@Override
@@ -88,6 +92,12 @@ public class ColumnEncoderSerializationTest extends
AutomatedTestBase
@Test
public void testComposite8() { runTransformSerTest(TransformType.OMIT,
schemaStrings); }
+ @Test
+ public void testComposite9() { runTransformSerTest(TransformType.BOW,
schemaStrings); }
+
+ @Test
+ public void testComposite10() { runTransformSerTest(TransformType.BOW,
schemaMixed); }
+
@@ -117,11 +127,21 @@ public class ColumnEncoderSerializationTest extends
AutomatedTestBase
"{ \"id\": 7, \"method\":
\"global_mode\" }, { \"id\": 9, \"method\": \"global_mean\" } ]\n\n}";
else if (type == TransformType.OMIT)
spec = "{ \"ids\": true, \"omit\": [ 1,2,4,5,6,7,8,9 ],
\"recode\": [ 2, 7 ] }";
+ else if (type == TransformType.BOW)
+ spec = "{ \"ids\": true, \"omit\": [ 1,4,5,6,8,9 ],
\"bag_of_words\": [ 2, 7 ] }";
frame.setSchema(schema);
String[] cnames = frame.getColumnNames();
MultiColumnEncoder encoderIn =
EncoderFactory.createEncoder(spec, cnames, frame.getNumColumns(), null);
+ if(type == TransformType.BOW){
+ List<ColumnEncoderBagOfWords> encs =
encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class);
+ HashMap<Object, Long> dict = new HashMap<>();
+ dict.put("val1", 1L);
+ dict.put("val2", 2L);
+ dict.put("val3", 300L);
+ encs.forEach(e -> e.setTokenDictionary(dict));
+ }
MultiColumnEncoder encoderOut;
// serialization and deserialization
@@ -141,7 +161,16 @@ public class ColumnEncoderSerializationTest extends
AutomatedTestBase
for(Class<? extends ColumnEncoder> classtype: typesIn){
Assert.assertArrayEquals(encoderIn.getFromAllIntArray(classtype,
ColumnEncoder::getColID), encoderOut.getFromAllIntArray(classtype,
ColumnEncoder::getColID));
}
-
+ if(type == TransformType.BOW){
+ List<ColumnEncoderBagOfWords> encsIn =
encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class);
+ List<ColumnEncoderBagOfWords> encsOut =
encoderOut.getColumnEncoders(ColumnEncoderBagOfWords.class);
+ for (int i = 0; i < encsIn.size(); i++) {
+ Map<Object, Long> eOutDict =
encsOut.get(i).getTokenDictionary();
+
encsIn.get(i).getTokenDictionary().forEach((k,v) -> {
+ assert v.equals(eOutDict.get(k));
+ });
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
index e3d1c07be2..1965d743f7 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
@@ -19,13 +19,13 @@
package org.apache.sysds.test.functions.transform;
-import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -36,17 +36,18 @@ import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
+import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.nio.file.Files;
import static
org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords.tokenize;
-import static org.junit.Assert.assertThrows;
public class TransformFrameEncodeBagOfWords extends AutomatedTestBase
{
private final static String TEST_NAME1 =
"TransformFrameEncodeBagOfWords";
+ private final static String TEST_NAME2 =
"TransformFrameEncodeApplyBagOfWords";
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeBagOfWords.class.getSimpleName() + "/";
// for benchmarking: Digital_Music_Text.csv
@@ -56,6 +57,7 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
}
// These tests result in dense output
@@ -63,6 +65,17 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
public void testTransformBagOfWords() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false,
false);
}
+ @Test
+ public void testTransformApplyBagOfWords() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false,
false);
+ }
+
+ @Test
+ public void testTransformApplySeparateStagesBagOfWords() {
+ MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true;
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false,
false);
+ MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false;
+ }
@Test
public void testTransformBagOfWordsError() {
@@ -74,32 +87,62 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false);
}
+ @Test
+ public void testTransformApplyBagOfWordsPlusRecode() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false);
+ }
+
@Test
public void testTransformBagOfWords2() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true);
}
+ @Test
+ public void testTransformApplyBagOfWords2() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true);
+ }
+
@Test
public void testTransformBagOfWordsPlusRecode2() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, true);
}
+ @Test
+ public void testTransformApplyBagOfWordsPlusRecode2() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true);
+ }
+
// AmazonReviewDataset transformation results in a sparse output
@Test
public void testTransformBagOfWordsAmazonReviews() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false,
false, true);
}
+ @Test
+ public void testTransformApplyBagOfWordsAmazonReviews() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false,
false, true);
+ }
+
@Test
public void testTransformBagOfWordsAmazonReviews2() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true,
true);
}
+ @Test
+ public void testTransformApplyBagOfWordsAmazonReviews2() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true,
true);
+ }
+
@Test
public void testTransformBagOfWordsAmazonReviewsAndRandRecode() {
runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false,
true);
}
+ @Test
+ public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false,
true);
+ }
+
@Test
public void testTransformBagOfWordsAmazonReviewsAndDummyCode() {
// TODO: compare result
@@ -118,16 +161,55 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
}
@Test
- public void testNotImplementedFunction(){
- ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords();
- assertThrows(NotImplementedException.class, () ->
bow.initMetaData(null));
+ public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode2() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true,
true);
+ }
+
+ @Test
+ public void
testTransformApplySeparateStagesBagOfWordsAmazonReviewsAndRandRecode2() {
+ MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true;
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true,
true);
+ MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false;
}
- //@Test
+
+ @Test
public void testTransformBagOfWordsSpark() {
runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false);
}
+ @Test
+ public void testTransformBagOfWordsAmazonReviewsSpark() {
+ runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false,
true);
+ }
+
+ @Test
+ public void testTransformBagOfWordsAmazonReviews2Spark() {
+ runTransformTest(TEST_NAME1, ExecMode.SPARK, false, true, true);
+ }
+
+ @Test
+ public void testTransformBagOfWordsAmazonReviewsAndRandRecodeSpark() {
+ runTransformTest(TEST_NAME1, ExecMode.SPARK, true, false, true);
+ }
+
+ @Test
+ public void testTransformBagOfWordsAmazonReviewsAndRandRecode2Spark() {
+ runTransformTest(TEST_NAME1, ExecMode.SPARK, true, true, true);
+ }
+
+ @Test
+ public void testBuildPartialBagOfWordsNotApplicable() {
+ ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords();
+ assert bow.getColID() == -1;
+ try {
+ bow.buildPartial(null); // should run without error
+ } catch (Exception e) {
+ throw new AssertionError("Test failed: Expected no
errors due to early abort (colId = -1). " +
+ "Encountered exception:\n" + e +
"\nMessage: " + Arrays.toString(e.getStackTrace()));
+ }
+ }
+
private void runTransformTest(String testname, ExecMode rt, boolean
recode, boolean dup){
runTransformTest(testname, rt, recode, dup, false);
}
@@ -154,31 +236,31 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
if(!fromFile)
writeStringsToCsvFile(sentenceColumn,
recodeColumn, baseDirectory + INPUT_DIR + "data", dup);
- int mode = 0;
- if(error)
- mode = 1;
- if(dc)
- mode = 2;
- if(pt)
- mode = 3;
- programArgs = new String[]{"-stats","-args", fromFile ?
DATASET_DIR + DATASET : input("data"),
+ int mode = error ? 1 : (dc ? 2 : (pt ? 3 : 0));
+ programArgs = new String[]{"-explain",
"recompile_runtime", "-stats","-args", fromFile ? DATASET_DIR + DATASET :
input("data"),
output("result"), output("dict"),
String.valueOf(recode), String.valueOf(dup),
String.valueOf(fromFile),
String.valueOf(mode)};
if(error)
runTest(true, EXCEPTION_EXPECTED,
DMLRuntimeException.class, -1);
else{
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
- FrameBlock dict_frame = readDMLFrameFromHDFS(
"dict", Types.FileFormat.CSV);
- int cols = recode? dict_frame.getNumRows() + 1
: dict_frame.getNumRows();
- if(dup)
- cols *= 2;
- if(mode == 0){
- HashMap<MatrixValue.CellIndex, Double>
res_actual = readDMLMatrixFromOutputDir("result");
- double[][] result =
TestUtils.convertHashMapToDoubleArray(res_actual,
Math.min(sentenceColumn.length, 100),
- cols);
- checkResults(sentenceColumn, result,
recodeColumn, dict_frame, dup ? 2 : 1);
+ if(testname == TEST_NAME2){
+ double errorValue =
readDMLScalarFromOutputDir("result").values()
+
.stream().findFirst().orElse(1000.0);
+ System.out.println(errorValue);
+ assert errorValue <= 10;
+ } else {
+ FrameBlock dict_frame =
readDMLFrameFromHDFS( "dict", Types.FileFormat.CSV);
+ int cols = recode?
dict_frame.getNumRows() + 1 : dict_frame.getNumRows();
+ if(dup)
+ cols *= 2;
+ if(mode == 0){
+ HashMap<MatrixValue.CellIndex,
Double> res_actual = readDMLMatrixFromOutputDir("result");
+ double[][] result =
TestUtils.convertHashMapToDoubleArray(res_actual,
Math.min(sentenceColumn.length, 100),
+ cols);
+ checkResults(sentenceColumn,
result, recodeColumn, dict_frame, dup ? 2 : 1);
+ }
}
-
}
@@ -211,6 +293,9 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
{
HashMap<String, Integer>[] indices = new HashMap[duplicates];
HashMap<String, Integer>[] rcdMaps = new HashMap[duplicates];
+ String errors = "";
+ int num_errors = 0;
+ int max_errors = 100;
int frameCol = 0;
// even when the set of tokens is the same for duplicates, the
order in which the tokens dicts are merged
// is not always the same for all columns in multithreaded mode
@@ -219,8 +304,9 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
rcdMaps[i] = new HashMap<>();
for (int j = 0; j < dict.getNumRows(); j++) {
String[] tuple = dict.getString(j,
frameCol).split("\u00b7");
- indices[i].put(tuple[0],
Integer.parseInt(tuple[1]));
+ indices[i].put(tuple[0],
Integer.parseInt(tuple[1]) - 1);
}
+ System.out.println("Bow dict size: " +
indices[i].size());
frameCol++;
if(recodeColumn != null){
for (int j = 0; j < dict.getNumRows(); j++) {
@@ -232,6 +318,7 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
}
frameCol++;
}
+ System.out.println("Rec dict size: " +
rcdMaps[i].size());
}
// only check the first 100 rows
@@ -266,19 +353,49 @@ public class TransformFrameEncodeBagOfWords extends
AutomatedTestBase
for(Map.Entry<String, Integer> entry :
count.entrySet()){
String word = entry.getKey();
int count_expected = entry.getValue();
- int index = indices[j].get(word);
- assert result[row][index + offset] ==
count_expected;
+ Integer index = indices[j].get(word);
+ if(index == null){
+ throw new AssertionError("row
[" + row + "]: not found word: " + word);
+ }
+ if(result[row][index + offset] !=
count_expected){
+ String error_message = "bow
result[" + row + "," + (index + offset) + "]=" +
+
result[row][index + offset] + " does not match the expected: " + count_expected;
+ if(num_errors < max_errors)
+ errors += error_message
+ '\n';
+ else
+ throw new
AssertionError(errors + error_message);
+ num_errors++;
+ }
+ }
+ for(int zeroIndex : zeroIndices){
+ if(result[row][offset + zeroIndex] !=
0){
+ String error_message = "bow
result[" + row + "," + (offset + zeroIndex) + "]=" +
+
result[row][offset + zeroIndex] + " does not match the expected: 0";
+ if(num_errors < max_errors)
+ errors += error_message
+ '\n';
+ else
+ throw new
AssertionError(errors + error_message);
+ num_errors++;
+ }
}
- for(int zeroIndex : zeroIndices)
- assert result[row][offset + zeroIndex]
== 0;
offset += indices[j].size();
// compare results: recode
if(recodeColumn != null){
- assert result[row][offset] ==
rcdMaps[j].get(recodeColumn[row]);
+ if(result[row][offset] !=
rcdMaps[j].get(recodeColumn[row])){
+ String error_message = "recode
result[" + row + "," + offset + "]=" +
+
result[row][offset]+ " does not match the expected: " +
rcdMaps[j].get(recodeColumn[row]);
+ if(num_errors < max_errors)
+ errors += error_message
+ '\n';
+ else
+ throw new
AssertionError(errors + error_message);
+ num_errors++;
+ }
offset++;
}
}
}
+ if (num_errors > 0)
+ throw new AssertionError(errors);
}
public static void writeStringsToCsvFile(String[] sentences, String[]
recodeTokens, String fileName, boolean duplicate) throws IOException {
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
similarity index 85%
copy from
src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
copy to
src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
index 49231c97c0..b2cd0ccda2 100644
--- a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
@@ -55,13 +55,16 @@ if(as.integer($7) == 3){
jspec = "{ids: true, bag_of_words: [1]}";
}
+[Data_enc, Meta] = transformencode(target=Data, spec=jspec);
+while(FALSE){}
+
i = 0
total = 0
j = 0
# set to 20 for benchmarking
-while(i < 1){
+while(i < 30){
t0 = time()
- [Data_enc, Meta] = transformencode(target=Data, spec=jspec);
+ Data_enc2 = transformapply(target=Data, spec=jspec, meta=Meta)
if(i > 10){
total = total + time() - t0
j = j + 1
@@ -69,10 +72,13 @@ while(i < 1){
i = i + 1
}
print(total/1000000000 / j)
-print(nrow(Data_enc) + " x " + ncol(Data_enc))
-#reduce nr rows for large input tests
-if(nrow(Data_enc) > 100){
- Data_enc = Data_enc[1:100,]
-}
-write(Data_enc, $2, format="text");
-write(Meta, $3, format="csv");
+
+i = 0
+
+Error = sign(Data_enc2 - Data_enc)
+Error_agg = sum(Error * Error)
+#print(sum(sign(Data_enc2)))
+#print(sum(sign(Data_enc)))
+#print(Error_agg)
+write(Error_agg, $2, format="text");
+
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
index 49231c97c0..2a69f314a4 100644
--- a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
+++ b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
@@ -21,15 +21,15 @@
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($1, data_type="frame", format="csv");
+#print(toString(Data))
if(!as.boolean($4) & as.boolean($6)){
Data = Data[,1]
}
+while(FALSE){}
if(as.boolean($5) & as.boolean($6)){
Data = cbind(Data,Data)
}
-while(FALSE){}
-
if (as.boolean($4)) {
if (as.boolean($5)) {
jspec = "{ids: true, bag_of_words: [1,3], recode : [2,4]}";
@@ -54,7 +54,7 @@ if(as.integer($7) == 3){
Data = cbind(Data, ones)
jspec = "{ids: true, bag_of_words: [1]}";
}
-
+while(FALSE){}
i = 0
total = 0
j = 0