This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new c21273cb09 [SYSTEMDS-3782] Bag-of-words encoder for Spark backend
c21273cb09 is described below

commit c21273cb098906e40e8bf7b3f7d5f32536a342d0
Author: e-strauss <[email protected]>
AuthorDate: Sun Nov 24 14:35:06 2024 +0100

    [SYSTEMDS-3782] Bag-of-words encoder for Spark backend
    
    Closes #2145.
---
 ...ltiReturnParameterizedBuiltinSPInstruction.java |  15 +-
 .../spark/ParameterizedBuiltinSPInstruction.java   |   8 +
 .../runtime/transform/encode/ColumnEncoder.java    |   2 +-
 .../transform/encode/ColumnEncoderBagOfWords.java  | 206 ++++++++----
 .../transform/encode/ColumnEncoderComposite.java   |  14 +-
 .../runtime/transform/encode/EncoderFactory.java   |   4 +
 .../transform/encode/MultiColumnEncoder.java       | 359 +++++++++++++++------
 .../ColumnEncoderMixedFunctionalityTests.java      | 116 +++++++
 .../transform/ColumnEncoderSerializationTest.java  |  33 +-
 .../transform/TransformFrameEncodeBagOfWords.java  | 177 ++++++++--
 ...dml => TransformFrameEncodeApplyBagOfWords.dml} |  24 +-
 .../transform/TransformFrameEncodeBagOfWords.dml   |   6 +-
 12 files changed, 758 insertions(+), 206 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
index 90683eab29..df9fd84f77 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/MultiReturnParameterizedBuiltinSPInstruction.java
@@ -62,6 +62,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.transform.TfUtils;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoderBin;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoderRecode;
@@ -263,6 +264,7 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
                        // encoder-specific outputs
                        List<ColumnEncoderRecode> raEncoders = 
_encoder.getColumnEncoders(ColumnEncoderRecode.class);
                        List<ColumnEncoderBin> baEncoders = 
_encoder.getColumnEncoders(ColumnEncoderBin.class);
+                       List<ColumnEncoderBagOfWords> bowEncoders = 
_encoder.getColumnEncoders(ColumnEncoderBagOfWords.class);
                        ArrayList<Tuple2<Integer, Object>> ret = new 
ArrayList<>();
 
                        // output recode maps as columnID - token pairs
@@ -273,8 +275,14 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
                                for(Entry<Integer, HashSet<Object>> e1 : 
tmp.entrySet())
                                        for(Object token : e1.getValue())
                                                ret.add(new 
Tuple2<>(e1.getKey(), token));
-                               if(!raEncoders.isEmpty())
-                                       raEncoders.forEach(columnEncoderRecode 
-> columnEncoderRecode.getCPRecodeMapsPartial().clear());
+                               raEncoders.forEach(columnEncoderRecode -> 
columnEncoderRecode.getCPRecodeMapsPartial().clear());
+                       }
+
+                       if(!bowEncoders.isEmpty()){
+                               for (ColumnEncoderBagOfWords bowEnc : 
bowEncoders)
+                                       for (Object token : 
bowEnc.getPartialTokenDictionary())
+                                               ret.add(new 
Tuple2<>(bowEnc.getColID(), token));
+                               bowEncoders.forEach(enc -> 
enc.getPartialTokenDictionary().clear());
                        }
 
                        // output binning column min/max as columnID - min/max 
pairs
@@ -321,7 +329,8 @@ public class MultiReturnParameterizedBuiltinSPInstruction 
extends ComputationSPI
                        StringBuilder sb = new StringBuilder();
 
                        // handle recode maps
-                       if(_encoder.containsEncoderForID(colID, 
ColumnEncoderRecode.class)) {
+                       if(_encoder.containsEncoderForID(colID, 
ColumnEncoderRecode.class) ||
+                                       _encoder.containsEncoderForID(colID, 
ColumnEncoderBagOfWords.class)) {
                                while(iter.hasNext()) {
                                        String token = 
TfUtils.sanitizeSpaces(iter.next().toString());
                                        sb.append(rowID).append(' 
').append(scolID).append(' ');
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 61e6e799f0..4d4b012444 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -88,6 +88,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
 import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
 import org.apache.sysds.runtime.transform.decode.Decoder;
 import org.apache.sysds.runtime.transform.decode.DecoderFactory;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
@@ -1056,6 +1057,13 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
 
                        // execute block transform apply
                        MultiColumnEncoder encoder = _bencoder.getValue();
+                       // we need to create a copy of the encoder since the 
bag of word encoder stores frameblock specific state
+                       // which would be overwritten when multiple blocks are 
located on a executor
+                       // to avoid this, we need to create a shallow copy of 
the MCEncoder, where we only instantiate new bow
+                       // encoders objects with the frameblock specific fields 
and shallow copy the other fields (like meta)
+                       // other encoders are reused and not newly instantiated
+                       
if(!encoder.getColumnEncoders(ColumnEncoderBagOfWords.class).isEmpty())
+                               encoder = new MultiColumnEncoder(encoder); // 
create copy
                        MatrixBlock tmp = encoder.apply(blk);
                        // remap keys
                        if(_omap != null) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index f10da3d946..019df7f847 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -450,7 +450,7 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
        }
 
        public enum EncoderType {
-               Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, 
MVImpute, Composite, WordEmbedding,
+               Recode, FeatureHash, PassThrough, Bin, Dummycode, Omit, 
MVImpute, Composite, WordEmbedding, BagOfWords
        }
 
        /*
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
index c138901ad1..25b1a0ce87 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderBagOfWords.java
@@ -29,6 +29,9 @@ import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.utils.stats.TransformStatistics;
 
+import java.io.IOException;
+import java.io.ObjectInput;
+import java.io.ObjectOutput;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -45,14 +48,15 @@ import static 
org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 public class ColumnEncoderBagOfWords extends ColumnEncoder {
 
        public static int NUM_SAMPLES_MAP_ESTIMATION = 16000;
-       protected int[] nnzPerRow;
-       private Map<String, Integer> tokenDictionary;
-       protected String seperatorRegex = "\\s+"; // whitespace
-       protected boolean caseSensitive = false;
-       protected long nnz = 0;
-       protected long[] nnzPartials;
-       protected int defaultNnzCapacity = 64;
-       protected double avgNnzPerRow = 1.0;
+       private Map<Object, Long> _tokenDictionary; // switched from int to 
long to reuse code from RecodeEncoder
+       private HashSet<Object> _tokenDictionaryPart = null;
+       protected String _seperatorRegex = "\\s+"; // whitespace
+       protected boolean _caseSensitive = false;
+       protected int[] _nnzPerRow;
+       protected long _nnz = 0;
+       protected long[] _nnzPartials;
+       protected int _defaultNnzCapacity = 64;
+       protected double _avgNnzPerRow = 1.0;
 
        protected ColumnEncoderBagOfWords(int colID) {
                super(colID);
@@ -62,9 +66,40 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
                super(-1);
        }
 
+       public ColumnEncoderBagOfWords(ColumnEncoderBagOfWords enc) {
+               super(enc._colID);
+               _nnzPerRow = enc._nnzPerRow != null ? enc._nnzPerRow.clone() : 
null;
+               _tokenDictionary = enc._tokenDictionary;
+               _seperatorRegex = enc._seperatorRegex;
+               _caseSensitive = enc._caseSensitive;
+       }
+
+       public void setTokenDictionary(HashMap<Object, Long> dict){
+               _tokenDictionary = dict;
+       }
+
+       public Map<Object, Long> getTokenDictionary() {
+               return _tokenDictionary;
+       }
+
        protected void initNnzPartials(int rows, int numBlocks){
-               this.nnzPerRow = new int[rows];
-               this.nnzPartials = new long[numBlocks];
+               _nnzPerRow = new int[rows];
+               _nnzPartials = new long[numBlocks];
+       }
+
+       public double computeNnzEstimate(CacheBlock<?> in, int[] sampleIndices) 
{
+               // estimates the nnz per row for this encoder
+               final int max_index = 
Math.min(ColumnEncoderBagOfWords.NUM_SAMPLES_MAP_ESTIMATION, 
sampleIndices.length);
+               int nnz = 0;
+               for (int i = 0; i < max_index; i++) {
+                       int sind = sampleIndices[i];
+                       String current = in.getString(sind, _colID - 1);
+                       if(current != null)
+                               for(String token : tokenize(current, 
_caseSensitive, _seperatorRegex))
+                                       if(!token.isEmpty() && 
_tokenDictionary.containsKey(token))
+                                               nnz++;
+               }
+               return (double) nnz / max_index;
        }
 
        public void computeMapSizeEstimate(CacheBlock<?> in, int[] 
sampleIndices) {
@@ -76,10 +111,10 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder 
{
                int[] nnzPerRow = new int[max_index];
                for (int i = 0; i < max_index; i++) {
                        int sind = sampleIndices[i];
-                       String current = in.getString(sind, this._colID - 1);
+                       String current = in.getString(sind, _colID - 1);
                        Set<String> tokenSetRow = new HashSet<>();
                        if(current != null)
-                               for(String token : tokenize(current, 
caseSensitive, seperatorRegex))
+                               for(String token : tokenize(current, 
_caseSensitive, _seperatorRegex))
                                        if(!token.isEmpty()){
                                                tokenSetRow.add(token);
                                                if 
(distinctFreq.containsKey(token))
@@ -94,9 +129,9 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
                        nnzPerRow[i] = tokenSetRow.size();
                }
                Arrays.sort(nnzPerRow);
-               avgNnzPerRow = (double) Arrays.stream(nnzPerRow).sum() / 
nnzPerRow.length;
+               _avgNnzPerRow = (double) Arrays.stream(nnzPerRow).sum() / 
nnzPerRow.length;
                // default value for HashSets in build phase -> 75% without 
resize (Division by 0.9 -> is the resize threshold)
-               defaultNnzCapacity = (int) Math.max( nnzPerRow[(int) 
(nnzPerRow.length*0.75)] / 0.9, 64);
+               _defaultNnzCapacity = (int) Math.max( nnzPerRow[(int) 
(nnzPerRow.length*0.75)] / 0.9, 64);
                // we increase the upperbound of the total count estimate by 20%
                double avgSentenceLength = numTokensSample*1.2 / max_index;
 
@@ -118,6 +153,18 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder 
{
                _estMetaSize = _estNumDistincts * _avgEntrySize;
        }
 
+       public void computeNnzPerRow(CacheBlock<?> in, int start, int end){
+               for (int i = start; i < end; i++) {
+                       String current = in.getString(i, _colID - 1);
+                       HashSet<String> distinctTokens = new HashSet<>();
+                       if(current != null)
+                               for(String token : tokenize(current, 
_caseSensitive, _seperatorRegex))
+                                       if(!token.isEmpty() && 
_tokenDictionary.containsKey(token))
+                                               distinctTokens.add(token);
+                       _nnzPerRow[i] = distinctTokens.size();
+               }
+       }
+
        public static String[] tokenize(String current, boolean caseSensitive, 
String seperatorRegex) {
                // string builder is faster than regex
                StringBuilder finalString = new StringBuilder();
@@ -132,7 +179,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
 
        @Override
        public int getDomainSize(){
-               return tokenDictionary.size();
+               return _tokenDictionary.size();
        }
 
        @Override
@@ -150,8 +197,6 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
                return TransformType.BAG_OF_WORDS;
        }
 
-
-
        public Callable<Object> getBuildTask(CacheBlock<?> in) {
                return new ColumnBagOfWordsBuildTask(this, in);
        }
@@ -159,38 +204,39 @@ public class ColumnEncoderBagOfWords extends 
ColumnEncoder {
        @Override
        public void build(CacheBlock<?> in) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-               tokenDictionary = new HashMap<>(_estNumDistincts);
-               int i = 0;
-               this.nnz = 0;
-               nnzPerRow = new int[in.getNumRows()];
+               _tokenDictionary = new HashMap<>(_estNumDistincts);
+               int i = 1;
+               _nnz = 0;
+               _nnzPerRow = new int[in.getNumRows()];
                HashSet<String> tokenSetPerRow;
                for (int r = 0; r < in.getNumRows(); r++) {
                        // start with a higher default capacity to avoid resizes
-                       tokenSetPerRow = new HashSet<>(defaultNnzCapacity);
-                       String current = in.getString(r, this._colID - 1);
+                       tokenSetPerRow = new HashSet<>(_defaultNnzCapacity);
+                       String current = in.getString(r, _colID - 1);
                        if(current != null)
-                               for(String token : tokenize(current, 
caseSensitive, seperatorRegex))
+                               for(String token : tokenize(current, 
_caseSensitive, _seperatorRegex))
                                        if(!token.isEmpty()){
                                                tokenSetPerRow.add(token);
-                                               
if(!this.tokenDictionary.containsKey(token))
-                                                       
this.tokenDictionary.put(token, i++);
+                                               
if(!_tokenDictionary.containsKey(token))
+                                                       
_tokenDictionary.put(token, (long) i++);
                                        }
-                       this.nnzPerRow[r] = tokenSetPerRow.size();
-                       this.nnz += tokenSetPerRow.size();
+                       _nnzPerRow[r] = tokenSetPerRow.size();
+                       _nnz += tokenSetPerRow.size();
                }
                if(DMLScript.STATISTICS)
                        
TransformStatistics.incBagOfWordsBuildTime(System.nanoTime()-t0);
        }
 
        @Override
-       public Callable<Object> getPartialBuildTask(CacheBlock<?> in, int 
startRow, int blockSize,
-                                                                               
                HashMap<Integer, Object> ret, int pos) {
-               return new BowPartialBuildTask(in, _colID, startRow, blockSize, 
ret, nnzPerRow, caseSensitive, seperatorRegex, nnzPartials, pos);
+       public Callable<Object> getPartialBuildTask(CacheBlock<?> in, 
+               int startRow, int blockSize, HashMap<Integer, Object> ret, int 
pos) {
+               return new BowPartialBuildTask(in, _colID, startRow, blockSize, 
ret,
+                       _nnzPerRow, _caseSensitive, _seperatorRegex, 
_nnzPartials, pos);
        }
 
        @Override
        public Callable<Object> getPartialMergeBuildTask(HashMap<Integer, ?> 
ret) {
-               tokenDictionary = new HashMap<>(this._estNumDistincts);
+               _tokenDictionary = new HashMap<>(_estNumDistincts);
                return new BowMergePartialBuildTask(this, ret);
        }
 
@@ -205,6 +251,32 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder 
{
                }
        }
 
+       @Override
+       public void prepareBuildPartial() {
+               // ensure allocated partial recode map
+               if(_tokenDictionaryPart == null)
+                       _tokenDictionaryPart = new HashSet<>();
+       }
+
+
+       public HashSet<Object> getPartialTokenDictionary(){
+               return _tokenDictionaryPart;
+       }
+
+       @Override
+       public void buildPartial(FrameBlock in) {
+               if(!isApplicable())
+                       return;
+               for (int r = 0; r < in.getNumRows(); r++) {
+                       String current = in.getString(r, _colID - 1);
+                       if(current != null)
+                               for(String token : tokenize(current, 
_caseSensitive, _seperatorRegex)){
+                                       if(!token.isEmpty())
+                                               _tokenDictionaryPart.add(token);
+                               }
+               }
+       }
+
        protected void applySparse(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int rowStart, int blk) {
                boolean mcsr = MatrixBlock.DEFAULT_SPARSEBLOCK == 
SparseBlock.Type.MCSR;
                mcsr = false; // force CSR for transformencode FIXME
@@ -214,20 +286,20 @@ public class ColumnEncoderBagOfWords extends 
ColumnEncoder {
                                throw new NotImplementedException();
                        }
                        else { // csr
-                               HashMap<String, Integer> counter = 
countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
+                               HashMap<String, Integer> counter = 
countTokenAppearances(in, r);
                                if(counter.isEmpty())
                                        sparseRowsWZeros.add(r);
                                else {
                                        SparseBlockCSR csrblock = 
(SparseBlockCSR) out.getSparseBlock();
                                        int[] rptr = csrblock.rowPointers();
                                        // assert that nnz from build is equal 
to nnz from apply
-                                       assert counter.size() == nnzPerRow[r];
-                                       Pair[] columnValuePairs = new 
Pair[counter.size()];
+                                       Pair[] columnValuePairs = new 
Pair[_nnzPerRow[r]];
                                        int i = 0;
                                        for (Map.Entry<String, Integer> entry : 
counter.entrySet()) {
                                                String token = entry.getKey();
-                                               columnValuePairs[i] = new 
Pair(outputCol + tokenDictionary.get(token), entry.getValue());
-                                               i++;
+                                               columnValuePairs[i] = new 
Pair((int) (outputCol + _tokenDictionary.getOrDefault(token, 0L) - 1), 
entry.getValue());
+                                               // if token is not included 
columnValuePairs[i] is overwritten in the next iteration
+                                               i += 
_tokenDictionary.containsKey(token) ? 1 : 0;
                                        }
                                        // insertion sorts performs better on 
small arrays
                                        if(columnValuePairs.length >= 128)
@@ -237,7 +309,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
                                        // Manually fill the column-indexes and 
values array
                                        for (i = 0; i < 
columnValuePairs.length; i++) {
                                                int index = 
sparseRowPointerOffset != null ? sparseRowPointerOffset[r] - 1 + i : i;
-                                               index += rptr[r] + this._colID 
-1;
+                                               index += rptr[r] + _colID -1;
                                                csrblock.indexes()[index] = 
columnValuePairs[i].key;
                                                csrblock.values()[index] = 
columnValuePairs[i].value;
                                        }
@@ -264,42 +336,68 @@ public class ColumnEncoderBagOfWords extends 
ColumnEncoder {
        @Override
        protected void applyDense(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int rowStart, int blk){
                for (int r = rowStart; r < Math.max(in.getNumRows(), rowStart + 
blk); r++) {
-                       HashMap<String, Integer> counter = 
countTokenAppearances(in, r, _colID-1, caseSensitive, seperatorRegex);
+                       HashMap<String, Integer> counter = 
countTokenAppearances(in, r);
                        for (Map.Entry<String, Integer> entry : 
counter.entrySet())
-                               out.set(r, outputCol + 
tokenDictionary.get(entry.getKey()), entry.getValue());
+                               out.set(r, (int) (outputCol + 
_tokenDictionary.get(entry.getKey()) - 1), entry.getValue());
                }
        }
 
-       private static HashMap<String, Integer> countTokenAppearances(
-               CacheBlock<?> in, int r, int c, boolean caseSensitive, String 
separator)
+       private HashMap<String, Integer> countTokenAppearances(
+                       CacheBlock<?> in, int r)
        {
-               String current = in.getString(r, c);
+               String current = in.getString(r, _colID - 1);
                HashMap<String, Integer> counter = new HashMap<>();
                if(current != null)
-                       for (String token : tokenize(current, caseSensitive, 
separator))
-                               if (!token.isEmpty())
+                       for (String token : tokenize(current, _caseSensitive, 
_seperatorRegex))
+                               if (!token.isEmpty() && 
_tokenDictionary.containsKey(token))
                                        counter.put(token, 
counter.getOrDefault(token, 0) + 1);
                return counter;
        }
 
        @Override
        public void allocateMetaData(FrameBlock meta) {
-               meta.ensureAllocatedColumns(this.getDomainSize());
+               meta.ensureAllocatedColumns(getDomainSize());
        }
 
        @Override
        public FrameBlock getMetaData(FrameBlock out) {
                int rowID = 0;
                StringBuilder sb = new StringBuilder();
-               for(Map.Entry<String, Integer> e : 
this.tokenDictionary.entrySet()) {
-                       out.set(rowID++, _colID - 1, 
constructRecodeMapEntry(e.getKey(), Long.valueOf(e.getValue()), sb));
+               for(Map.Entry<Object, Long> e : _tokenDictionary.entrySet()) {
+                       out.set(rowID++, _colID - 1, 
constructRecodeMapEntry(e.getKey(), e.getValue(), sb));
                }
                return out;
        }
 
        @Override
        public void initMetaData(FrameBlock meta) {
-               throw new NotImplementedException();
+               if(meta != null && meta.getNumRows() > 0) {
+                       _tokenDictionary = meta.getRecodeMap(_colID - 1);
+               }
+       }
+
+       @Override
+       public void writeExternal(ObjectOutput out) throws IOException {
+               super.writeExternal(out);
+
+               out.writeInt(_tokenDictionary == null ? 0 : 
_tokenDictionary.size());
+               if(_tokenDictionary != null)
+                       for(Map.Entry<Object, Long> e : 
_tokenDictionary.entrySet()) {
+                               out.writeUTF((String) e.getKey());
+                               out.writeLong(e.getValue());
+                       }
+       }
+
+       @Override
+       public void readExternal(ObjectInput in) throws IOException {
+               super.readExternal(in);
+               int size = in.readInt();
+               _tokenDictionary = new HashMap<>(size * 4 / 3);
+               for(int j = 0; j < size; j++) {
+                       String key = in.readUTF();
+                       Long value = in.readLong();
+                       _tokenDictionary.put(key, value);
+               }
        }
 
        private static class BowPartialBuildTask implements Callable<Object> {
@@ -340,7 +438,7 @@ public class ColumnEncoderBagOfWords extends ColumnEncoder {
                        long nnzPartial = 0;
                        for (int r = _startRow; r < endRow; r++) {
                                tokenSetPerRow = new HashSet<>(64);
-                               String current = _input.getString(r, 
this._colID - 1);
+                               String current = _input.getString(r, _colID - 
1);
                                if(current != null)
                                        for(String token : tokenize(current, 
_caseSensitive, _seperator))
                                                if(!token.isEmpty()){
@@ -378,15 +476,15 @@ public class ColumnEncoderBagOfWords extends 
ColumnEncoder {
                @Override
                public Object call() {
                        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
-                       Map<String, Integer> tokenDictionary = 
_encoder.tokenDictionary;
+                       Map<Object, Long> tokenDictionary = 
_encoder._tokenDictionary;
                        for(Object tokenSet : _partialMaps.values()){
                                ( (HashSet<?>) tokenSet).forEach(token -> {
                                        if(!tokenDictionary.containsKey(token))
-                                               tokenDictionary.put((String) 
token, tokenDictionary.size());
+                                               tokenDictionary.put(token, 
(long) tokenDictionary.size() + 1);
                                });
                        }
-                       for (long nnzPartial : _encoder.nnzPartials)
-                               _encoder.nnz += nnzPartial;
+                       for (long nnzPartial : _encoder._nnzPartials)
+                               _encoder._nnz += nnzPartial;
                        if(DMLScript.STATISTICS){
                                
TransformStatistics.incBagOfWordsBuildTime(System.nanoTime() - t0);
                        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 9544372914..536b387a1d 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -62,9 +62,9 @@ public class ColumnEncoderComposite extends ColumnEncoder {
 
        public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, 
FrameBlock meta) {
                super(-1);
-               if(!(columnEncoders.size() > 0 &&
+               if(!(!columnEncoders.isEmpty() &&
                        columnEncoders.stream().allMatch((encoder -> 
encoder._colID == columnEncoders.get(0)._colID))))
-                       throw new DMLRuntimeException("Tried to create 
Composite Encoder with no encoders or mismatching columIDs");
+                       throw new DMLRuntimeException("Tried to create 
Composite Encoder with no encoders or mismatching columnIDs");
                _colID = columnEncoders.get(0)._colID;
                _meta = meta;
                _columnEncoders = columnEncoders;
@@ -73,6 +73,11 @@ public class ColumnEncoderComposite extends ColumnEncoder {
        public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders) {
                this(columnEncoders, null);
        }
+       public ColumnEncoderComposite(List<ColumnEncoder> columnEncoders, int 
colID) {
+               super(colID);
+               _columnEncoders = columnEncoders;
+               _meta = null;
+       }
 
        public ColumnEncoderComposite(ColumnEncoder columnEncoder) {
                super(columnEncoder._colID);
@@ -166,7 +171,8 @@ public class ColumnEncoderComposite extends ColumnEncoder {
                        if(t == null)
                                continue;
                        // Linear execution between encoders so they can't be 
built in parallel
-                       if(tasks.size() != 0) {
+                       if(!tasks.isEmpty()) {
+                               // TODO: is that still needed? currently there 
is no CompositeEncoder with 2 encoders with build phase
                                // avoid unnecessary map initialization
                                depMap = (depMap == null) ? new HashMap<>() : 
depMap;
                                // This workaround is needed since sublist is 
only valid for effective final lists,
@@ -207,6 +213,8 @@ public class ColumnEncoderComposite extends ColumnEncoder {
        public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int rowStart, int blk) {
                try {
                        for(int i = 0; i < _columnEncoders.size(); i++) {
+                               // set sparseRowPointerOffset in the encoder
+                               _columnEncoders.get(i).sparseRowPointerOffset = 
this.sparseRowPointerOffset;
                                if(i == 0) {
                                        // 1. encoder writes data into 
MatrixBlock Column all others use this column for further encoding
                                        _columnEncoders.get(i).apply(in, out, 
outputCol, rowStart, blk);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 05ad8e4694..1c2478d711 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -265,6 +265,8 @@ public interface EncoderFactory {
                        return EncoderType.Recode.ordinal();
                else if(columnEncoder instanceof ColumnEncoderWordEmbedding)
                        return EncoderType.WordEmbedding.ordinal();
+               else if(columnEncoder instanceof ColumnEncoderBagOfWords)
+                       return EncoderType.BagOfWords.ordinal();
                throw new DMLRuntimeException("Unsupported encoder type: " + 
columnEncoder.getClass().getCanonicalName());
        }
 
@@ -283,6 +285,8 @@ public interface EncoderFactory {
                                return new ColumnEncoderRecode();
                        case WordEmbedding:
                                return new ColumnEncoderWordEmbedding();
+                       case BagOfWords:
+                               return new ColumnEncoderBagOfWords();
                        default:
                                throw new DMLRuntimeException("Unsupported 
encoder type: " + etype);
                }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 0417e67ba1..79c05ca8e7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -89,6 +89,20 @@ public class MultiColumnEncoder implements Encoder {
                _columnEncoders = columnEncoders;
        }
 
+       public MultiColumnEncoder(MultiColumnEncoder menc) {
+               // This constructor creates a shallow copy for all encoders 
except for bag_of_words encoders
+               List<ColumnEncoderComposite> colEncs = menc._columnEncoders;
+               _columnEncoders= new ArrayList<>();
+               for (ColumnEncoderComposite cColEnc : colEncs) {
+                       List<ColumnEncoder> newEncs = new ArrayList<>();
+                       ColumnEncoderComposite cColEncCopy = new 
ColumnEncoderComposite(newEncs, cColEnc._colID);
+                       _columnEncoders.add(cColEncCopy);
+                       for (ColumnEncoder enc : cColEnc.getEncoders()) {
+                               newEncs.add(enc instanceof 
ColumnEncoderBagOfWords ? new ColumnEncoderBagOfWords((ColumnEncoderBagOfWords) 
enc) : enc);
+                       }
+               }
+       }
+
        public MultiColumnEncoder() {
                _columnEncoders = new ArrayList<>();
        }
@@ -327,16 +341,14 @@ public class MultiColumnEncoder implements Encoder {
 
        public MatrixBlock apply(CacheBlock<?> in, int k) {
                // domain sizes are not updated if called from transformapply
-               boolean hasUDF = _columnEncoders.stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderUDF.class));
-               boolean hasWE = _columnEncoders.stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderWordEmbedding.class));
-               for(ColumnEncoderComposite columnEncoder : _columnEncoders)
-                       columnEncoder.updateAllDCEncoders();
+               EncoderMeta encm = getEncMeta(_columnEncoders, true, k, in);
+               updateAllDCEncoders();
                int numCols = getNumOutCols();
-               long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : 
hasWE ? getEstNNzRow() : in.getNumColumns());
-               // FIXME: estimate nnz for multiple encoders including 
dummycode and embedding
-               boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && 
!hasUDF;
+               long estNNz = (long) in.getNumRows() * (encm.hasWE || 
encm.hasUDF ? numCols : (in.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW);
+               // FIXME: estimate nnz for multiple encoders including dummycode
+               boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && 
!encm.hasUDF;
                MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, 
sparse, estNNz);
-               return apply(in, out, 0, k);
+               return apply(in, out, 0, k, encm, estNNz);
        }
 
        public void updateAllDCEncoders(){
@@ -345,10 +357,11 @@ public class MultiColumnEncoder implements Encoder {
        }
 
        public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int 
outputCol) {
-               return apply(in, out, outputCol, 1);
+               // unused method, only exists currently because of the interface
+               throw new DMLRuntimeException("MultiColumnEncoder apply without 
Encoder Characteristics should not be called directly");
        }
 
-       public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int k) {
+       public MatrixBlock apply(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int k, EncoderMeta encm, long nnz) {
                // There should be a encoder for every column
                if(hasLegacyEncoder() && !(in instanceof FrameBlock))
                        throw new DMLRuntimeException("LegacyEncoders do not 
support non FrameBlock Inputs");
@@ -361,31 +374,20 @@ public class MultiColumnEncoder implements Encoder {
                if(in.getNumRows() == 0)
                        throw new DMLRuntimeException("Invalid input with wrong 
number or rows");
 
-               boolean hasDC = false;
-               boolean hasWE = false;
-               //TODO adapt transform apply for BOW
-               int distinctWE = 0;
-               int sizeWE = 0;
-               for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
-                       hasDC |= 
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
-                       for (ColumnEncoder enc : columnEncoder.getEncoders())
-                               if(enc instanceof ColumnEncoderWordEmbedding){
-                                       hasWE = true;
-                                       distinctWE = 
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
-                                       sizeWE = enc.getDomainSize();
-                               }
-               }
-               outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE, 
sizeWE, 0, null, -1);
+               ArrayList<int[]> nnzOffsets = outputMatrixPreProcessing(out, 
in, encm, nnz, k);
                if(k > 1) {
                        if(!_partitionDone) //happens if this method is 
directly called
                                deriveNumRowPartitions(in, k);
-                       applyMT(in, out, outputCol, k);
+                       applyMT(in, out, outputCol, k, nnzOffsets);
                }
                else {
-                       int offset = outputCol;
+                       int offset = outputCol, i = 0;
+                       int[] nnzOffset = null;
                        for(ColumnEncoderComposite columnEncoder : 
_columnEncoders) {
+                               columnEncoder.sparseRowPointerOffset = 
nnzOffset;
                                columnEncoder.apply(in, out, 
columnEncoder._colID - 1 + offset);
-                               offset = getOffset(offset, columnEncoder);
+                               offset = getOutputColOffset(offset, 
columnEncoder);
+                               nnzOffset = nnzOffsets != null ? 
nnzOffsets.get(i++) : null;
                        }
                }
                // Recomputing NNZ since we access the block directly
@@ -399,36 +401,44 @@ public class MultiColumnEncoder implements Encoder {
                return out;
        }
 
-       private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in, 
MatrixBlock out, int outputCol) {
+       private List<DependencyTask<?>> getApplyTasks(CacheBlock<?> in, 
MatrixBlock out, int outputCol, ArrayList<int[]> nnzOffsets) {
                List<DependencyTask<?>> tasks = new ArrayList<>();
                int offset = outputCol;
+               int i = 0;
+               int[] currentNnzOffsets = null;
                for(ColumnEncoderComposite e : _columnEncoders) {
-                       tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + 
offset, null));
-                       offset = getOffset(offset, e);
+                       tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + 
offset, currentNnzOffsets));
+                       currentNnzOffsets = nnzOffsets != null ? 
nnzOffsets.get(i++) : null;
+                       offset = getOutputColOffset(offset, e);
                }
                return tasks;
        }
 
-       private int getOffset(int offset, ColumnEncoderComposite e) {
+       private int getOutputColOffset(int offset, ColumnEncoderComposite e) {
                if(e.hasEncoder(ColumnEncoderDummycode.class))
                        offset += 
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
                if(e.hasEncoder(ColumnEncoderWordEmbedding.class))
                        offset += 
e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
+               if(e.hasEncoder(ColumnEncoderBagOfWords.class))
+                       offset += 
e.getEncoder(ColumnEncoderBagOfWords.class).getDomainSize() - 1;
                return offset;
        }
 
-       private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol, 
int k) {
+       private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol, 
int k, ArrayList<int[]> nnzOffsets) {
                DependencyThreadPool pool = new DependencyThreadPool(k);
                try {
                        if(APPLY_ENCODER_SEPARATE_STAGES) {
                                int offset = outputCol;
+                               int i = 0;
+                               int[] currentNnzOffsets = null;
                                for (ColumnEncoderComposite e : 
_columnEncoders) {
-                                       // for now bag of words is only used in 
encode
-                                       
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, null));
-                                       offset = getOffset(offset, e);
+                                       
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset, 
currentNnzOffsets));
+                                       offset = getOutputColOffset(offset, e);
+                                       currentNnzOffsets = nnzOffsets != null 
? nnzOffsets.get(i) : null;
+                                       i++;
                                }
                        } else
-                               pool.submitAllAndWait(getApplyTasks(in, out, 
outputCol));
+                               pool.submitAllAndWait(getApplyTasks(in, out, 
outputCol, nnzOffsets));
                }
                catch(ExecutionException | InterruptedException e) {
                        throw new DMLRuntimeException(e);
@@ -635,7 +645,7 @@ public class MultiColumnEncoder implements Encoder {
                }
        }
 
-       private int[] getSampleIndices(CacheBlock<?> in, int sampleSize, int 
seed, int k){
+       private static int[] getSampleIndices(CacheBlock<?> in, int sampleSize, 
int seed, int k){
                return ComEstSample.getSortedSample(in.getNumRows(), 
sampleSize, seed, k);
        }
 
@@ -659,11 +669,11 @@ public class MultiColumnEncoder implements Encoder {
                return totMemOverhead;
        }
 
-       private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock<?> input, boolean hasDC, boolean hasWE,
-                                                                               
                  int distinctWE, int sizeWE, int numBOW, int[] nnzPerRowBOW, 
int nnz) {
+       private static ArrayList<int[]> outputMatrixPreProcessing(MatrixBlock 
output, CacheBlock<?> input, EncoderMeta encm, long nnz, int k) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                if(nnz < 0)
-                       nnz = output.getNumRows() * input.getNumColumns();
+                       nnz = (long) output.getNumRows() * 
input.getNumColumns();
+               ArrayList<int[]> bowNnzRowOffsets = null;
                if(output.isInSparseFormat()) {
                        if (MatrixBlock.DEFAULT_SPARSEBLOCK != 
SparseBlock.Type.CSR
                                        && MatrixBlock.DEFAULT_SPARSEBLOCK != 
SparseBlock.Type.MCSR)
@@ -673,7 +683,7 @@ public class MultiColumnEncoder implements Encoder {
                        if (mcsr) {
                                output.allocateBlock();
                                SparseBlock block = output.getSparseBlock();
-                               if (hasDC && 
OptimizerUtils.getTransformNumThreads()>1) {
+                               if (encm.hasDC && 
OptimizerUtils.getTransformNumThreads()>1) {
                                        // DC forces a single threaded 
allocation after the build phase and
                                        // before the apply starts. Below code 
parallelizes sparse allocation.
                                        IntStream.range(0, output.getNumRows())
@@ -695,30 +705,52 @@ public class MultiColumnEncoder implements Encoder {
                                }
                        }
                        else { //csr
-                               SparseBlockCSR csrblock = new 
SparseBlockCSR(output.getNumRows(), nnz, nnz);
                                // Manually fill the row pointers based on 
nnzs/row (= #cols in the input)
                                // Not using the set() methods to 1) avoid 
binary search and shifting, 
                                // 2) reduce thread contentions on the arrays
-                               int[] rptr = csrblock.rowPointers();
-                               if(nnzPerRowBOW != null)
-                                       for (int i=0; i<rptr.length-1; i++) { 
//TODO: parallelize
-                                               int nnzPerRow = 
input.getNumColumns() - numBOW + nnzPerRowBOW[i];
-                                               rptr[i+1] = rptr[i] + nnzPerRow;
-                                       }
-                               else
-                                       for (int i=0; i<rptr.length-1; i++) { 
//TODO: parallelize
-                                               rptr[i+1] = rptr[i] + 
input.getNumColumns();
+                               int nnzInt = (int) nnz;
+                               int[] rptr = new int[output.getNumRows()+1];
+                               // easy case: no bow encoders
+                               // nnz per row = #encoders = #inputCols
+                               if(encm.numBOWEnc <= 0 )
+                                       for (int i = 0; i < rptr.length - 1; 
i++)
+                                               rptr[i + 1] = rptr[i] + 
input.getNumColumns();
+                               else {
+                                       if( encm.nnzPerRowBOW != null) {
+                                               // #nzPerRow has been already 
computed and aggregated for all bow encoders
+                                               int static_offset = 
input.getNumColumns() - encm.numBOWEnc;
+                                               // - #bow since the nnz are 
already counted
+                                               for (int i = 0; i < rptr.length 
- 1; i++) {
+                                                       int nnzPerRow = 
static_offset + encm.nnzPerRowBOW[i];
+                                                       rptr[i + 1] = rptr[i] + 
nnzPerRow;
+                                               }
+                                       } else {
+                                               // case for transform_apply 
where the #nnz ofr bow is unknown yet, since we have no build phase,
+                                               // we have to compute the nnz 
now, we parallelize for now over the #bowEncoders and #rows
+                                               // for the aggregation we 
parallelize just over the number of rows
+                                               bowNnzRowOffsets = 
getNnzPerRowFromBOWEncoders(input, encm, k);
+                                               // the last array contains the 
complete aggregation
+                                               int static_offset = 
input.getNumColumns() - 1;
+                                               // we just subtract -1 since we 
already subtracted -1 for every bow encoder except the first
+                                               int[] aggOffsets = 
bowNnzRowOffsets.get(bowNnzRowOffsets.size() - 1);
+                                               for (int i = 0; i < 
rptr.length-1; i++) {
+                                                       rptr[i+1] = rptr[i] + 
static_offset + aggOffsets[i];
+                                               }
+                                               nnzInt = rptr[rptr.length-1];
                                        }
+                               }
+                               SparseBlockCSR csrblock = new 
SparseBlockCSR(rptr, new int[nnzInt],  new double[nnzInt], nnzInt) ;
                                output.setSparseBlock(csrblock);
+
                        }
                }
                else {
                        // Allocate dense block and set nnz to total #entries
-                       output.allocateDenseBlock(true, hasWE);
-                       if( hasWE){
+                       output.allocateDenseBlock(true, encm.hasWE);
+                       if(encm.hasWE){
                                DenseBlockFP64DEDUP dedup = 
((DenseBlockFP64DEDUP) output.getDenseBlock());
-                               dedup.setDistinct(distinctWE);
-                               dedup.setEmbeddingSize(sizeWE);
+                               dedup.setDistinct(encm.distinctWE);
+                               dedup.setEmbeddingSize(encm.sizeWE);
                        }
                        //output.setAllNonZeros();
                }
@@ -727,6 +759,77 @@ public class MultiColumnEncoder implements Encoder {
                        LOG.debug("Elapsed time for allocation: "+ ((double) 
System.nanoTime() - t0) / 1000000 + " ms");
                        
TransformStatistics.incOutMatrixPreProcessingTime(System.nanoTime()-t0);
                }
+               return bowNnzRowOffsets;
+       }
+
+       private static ArrayList<int[]> 
getNnzPerRowFromBOWEncoders(CacheBlock<?> input, EncoderMeta encm, int k) {
+               ArrayList<int[]> bowNnzRowOffsets;
+               int min_block_size = 1000;
+               int num_blocks = input.getNumRows() / min_block_size;
+               // 1 <= num_blks1 <= k / #enc
+               int num_blks1= Math.min( (k + encm.numBOWEnc - 1)/ 
encm.numBOWEnc, Math.max(num_blocks, 1));
+               int blk_len1 = (input.getNumRows() + num_blks1 - 1) / num_blks1;
+               // 1 <= num_blks2 <= k
+               int num_blks2= Math.min(k, Math.max(num_blocks, 1));
+               int blk_len2 = (input.getNumRows() + num_blks2 - 1) / num_blks1;
+
+               ExecutorService pool = CommonThreadPool.get(k);
+               ArrayList<int[]> bowNnzRowOffsetsFinal = new ArrayList<>();
+               try {
+                       encm.bowEncoders.forEach(e -> e._nnzPerRow = new 
int[input.getNumRows()]);
+                       ArrayList<Future<?>> list = new ArrayList<>();
+                       for (int i = 0; i < num_blks1; i++) {
+                               int start = i * blk_len1;
+                               int end = Math.min((i + 1) * blk_len1, 
input.getNumRows());
+                               list.add(pool.submit(() -> 
encm.bowEncoders.stream().parallel()
+                                       .forEach(e -> e.computeNnzPerRow(input, 
start, end))));
+                       }
+                       for(Future<?> f : list)
+                               f.get();
+                       list.clear();
+                       int[] previous = null;
+                       for(ColumnEncoderComposite enc : encm.encs){
+                               
if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
+                                       previous = previous == null ? 
+                                               
enc.getEncoder(ColumnEncoderBagOfWords.class)._nnzPerRow :
+                                               new int[input.getNumRows()];
+                               }
+                               bowNnzRowOffsetsFinal.add(previous);
+                       }
+                       for (int i = 0; i < num_blks2; i++) {
+                               int start = i * blk_len1;
+                               list.add(pool.submit(() -> 
aggregateNnzPerRow(start, blk_len2, 
+                                       input.getNumRows(), encm.encs, 
bowNnzRowOffsetsFinal)));
+                       }
+                       for(Future<?> f : list)
+                               f.get();
+                       bowNnzRowOffsets = bowNnzRowOffsetsFinal;
+               }
+               catch(Exception ex) {
+                       throw new DMLRuntimeException(ex);
+               }
+               finally {
+                       pool.shutdown();
+               }
+               return bowNnzRowOffsets;
+       }
+
+       private static void aggregateNnzPerRow(int start, int blk_len, int 
numRows, List<ColumnEncoderComposite> encs, ArrayList<int[]> bowNnzRowOffsets) {
+               int end = Math.min(start + blk_len, numRows);
+               int pos = 0;
+               int[] aggRowOffsets = null;
+               for(ColumnEncoderComposite enc : encs){
+                       int[] currentOffsets = bowNnzRowOffsets.get(pos);
+                       if (enc.hasEncoder(ColumnEncoderBagOfWords.class)) {
+                               ColumnEncoderBagOfWords bow = 
enc.getEncoder(ColumnEncoderBagOfWords.class);
+                               if(aggRowOffsets == null)
+                                       aggRowOffsets = currentOffsets;
+                               else
+                                       for (int i = start; i < end; i++)
+                                               currentOffsets[i] = 
aggRowOffsets[i] + bow._nnzPerRow[i] - 1;
+                       }
+                       pos++;
+               }
        }
 
        private void outputMatrixPostProcessing(MatrixBlock output, int k){
@@ -822,7 +925,7 @@ public class MultiColumnEncoder implements Encoder {
                                        tasks.add(new 
ColumnMetaDataTask<>(columnEncoder, meta));
                                List<Future<Object>> taskret = 
pool.invokeAll(tasks);
                                for (Future<Object> task : taskret)
-                               task.get();
+                                       task.get();
                        }
                        catch(Exception ex) {
                                throw new DMLRuntimeException(ex);
@@ -1021,13 +1124,6 @@ public class MultiColumnEncoder implements Encoder {
                return getEncoderTypes(-1);
        }
 
-       public int getEstNNzRow(){
-               int nnz = 0;
-               for(int i = 0; i < _columnEncoders.size(); i++)
-                       nnz += _columnEncoders.get(i).getDomainSize();
-               return nnz;
-       }
-
        public int getNumOutCols() {
                int sum = 0;
                for(int i = 0; i < _columnEncoders.size(); i++)
@@ -1218,6 +1314,92 @@ public class MultiColumnEncoder implements Encoder {
                return sb.toString();
        }
 
+       private static class EncoderMeta {
+               // contains information about the encoders and their relevant 
data characteristics
+               public final boolean hasUDF;
+               public final boolean hasDC;
+               public final boolean hasWE;
+               public final int distinctWE;
+               public final int sizeWE;
+               public final long nnzBOW;
+               public final int numBOWEnc;
+               public final int[] nnzPerRowBOW;
+               public final ArrayList<ColumnEncoderBagOfWords> bowEncoders;
+               public final List<ColumnEncoderComposite> encs;
+
+               public EncoderMeta(boolean hasUDF, boolean hasDC, boolean 
hasWE, int distinctWE, int sizeWE, long nnzBOW,
+                                                  int numBOWEncoder, int[] 
nnzPerRowBOW, ArrayList<ColumnEncoderBagOfWords> bows,
+                                                  List<ColumnEncoderComposite> 
encoders) {
+                       this.hasUDF = hasUDF;
+                       this.hasDC = hasDC;
+                       this.hasWE = hasWE;
+                       this.distinctWE = distinctWE;
+                       this.sizeWE = sizeWE;
+                       this.nnzBOW = nnzBOW;
+                       this.numBOWEnc = numBOWEncoder;
+                       this.nnzPerRowBOW = nnzPerRowBOW;
+                       this.bowEncoders = bows;
+                       this.encs = encoders;
+               }
+       }
+
+       private static EncoderMeta getEncMeta(List<ColumnEncoderComposite> 
encoders, boolean noBuild, int k, CacheBlock<?> in) {
+               boolean hasUDF = false, hasDC = false, hasWE = false;
+               int distinctWE = 0;
+               int sizeWE = 0;
+               long nnzBOW = 0;
+               int numBOWEncoder = 0;
+               int[] nnzPerRowBOW = null;
+               ArrayList<ColumnEncoderBagOfWords> bows = new ArrayList<>();
+               for (ColumnEncoderComposite enc : encoders){
+                       if(enc.hasEncoder(ColumnEncoderUDF.class))
+                               hasUDF = true;
+                       else if (enc.hasEncoder(ColumnEncoderDummycode.class))
+                               hasDC = true;
+                       else if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
+                               ColumnEncoderBagOfWords bowEnc = 
enc.getEncoder(ColumnEncoderBagOfWords.class);
+                               numBOWEncoder++;
+                               nnzBOW += bowEnc._nnz;
+                               if(noBuild){
+                                       // estimate nnz by sampling
+                                       bows.add(bowEnc);
+                               } else if(nnzPerRowBOW != null)
+                                       for (int i = 0; i < 
bowEnc._nnzPerRow.length; i++) {
+                                               nnzPerRowBOW[i] += 
bowEnc._nnzPerRow[i];
+                                       }
+                               else {
+                                       nnzPerRowBOW = 
bowEnc._nnzPerRow.clone();
+                               }
+                       }
+                       else 
if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){
+                               hasWE = true;
+                               distinctWE = 
enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings();
+                               sizeWE = enc.getDomainSize();
+                       }
+               }
+               if(!bows.isEmpty()){
+                       int[] sampleInds = getSampleIndices(in, in.getNumRows() 
> 1000 ? (int) (0.1 * in.getNumRows()) : in.getNumRows(), (int) 
System.nanoTime(), 1);
+                       // Concurrent (column-wise) bag of words nnz estimation 
per row, we estimate the number of nnz because the
+                       // exact number is only needed for sparse outputs not 
for dense, if sparse, we recount the nnz for all rows later
+                       // Note: the sampling might be problematic since we 
used for the sparsity estimation -> which impacts performance
+                       // if we go for the non-ideal output format
+                       ExecutorService pool = CommonThreadPool.get(k);
+                       try {
+                               Double result = pool.submit(() -> 
bows.stream().parallel()
+                                                       .mapToDouble(e -> 
e.computeNnzEstimate(in, sampleInds))
+                                                       .sum()).get();
+                               nnzBOW = (long) Math.ceil(result);
+                       }
+                       catch(Exception ex) {
+                               throw new DMLRuntimeException(ex);
+                       }
+                       finally{
+                               pool.shutdown();
+                       }
+               }
+        return new EncoderMeta(hasUDF, hasDC, hasWE, distinctWE, sizeWE, 
nnzBOW, numBOWEncoder, nnzPerRowBOW, bows, encoders);
+       }
+
        /*
         * Currently, not in use will be integrated in the future
         */
@@ -1271,41 +1453,12 @@ public class MultiColumnEncoder implements Encoder {
 
                @Override
                public Object call() {
-                       boolean hasUDF = false, hasDC = false, hasWE = false;
-                       int distinctWE = 0;
-                       int sizeWE = 0;
-                       long nnzBOW = 0;
-                       int numBOWEncoder = 0;
-                       int[] nnzPerRowBOW = null;
-                       for (ColumnEncoderComposite enc : 
_encoder.getEncoders()){
-                               if(enc.hasEncoder(ColumnEncoderUDF.class))
-                                       hasUDF = true;
-                               else if 
(enc.hasEncoder(ColumnEncoderDummycode.class))
-                                       hasDC = true;
-                               else 
if(enc.hasEncoder(ColumnEncoderBagOfWords.class)){
-                                       ColumnEncoderBagOfWords bowEnc = 
enc.getEncoder(ColumnEncoderBagOfWords.class);
-                                       numBOWEncoder++;
-                                       nnzBOW += bowEnc.nnz;
-                                       if(nnzPerRowBOW != null)
-                                               for (int i = 0; i < 
bowEnc.nnzPerRow.length; i++) {
-                                                       nnzPerRowBOW[i] += 
bowEnc.nnzPerRow[i];
-                                               }
-                                       else {
-                                               nnzPerRowBOW = 
bowEnc.nnzPerRow.clone();
-                                       }
-                               }
-                               else 
if(enc.hasEncoder(ColumnEncoderWordEmbedding.class)){
-                                       hasWE = true;
-                                       distinctWE = 
enc.getEncoder(ColumnEncoderWordEmbedding.class).getNrDistinctEmbeddings();
-                                       sizeWE = enc.getDomainSize();
-                               }
-                       }
-
+                       EncoderMeta encm = getEncMeta(_encoder.getEncoders(), 
false, -1, _input);
                        int numCols = _encoder.getNumOutCols();
-                       long estNNz = (long) _input.getNumRows() * (hasUDF ? 
numCols : _input.getNumColumns() - numBOWEncoder) + nnzBOW;
-                       boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && 
!hasUDF;
+                       long estNNz = (long) _input.getNumRows() * (encm.hasUDF 
? numCols : _input.getNumColumns() - encm.numBOWEnc) + encm.nnzBOW;
+                       boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && 
!encm.hasUDF;
                        _output.reset(_input.getNumRows(), numCols, sparse, 
estNNz);
-                       outputMatrixPreProcessing(_output, _input, hasDC, 
hasWE, distinctWE, sizeWE, numBOWEncoder, nnzPerRowBOW, (int) estNNz);
+                       outputMatrixPreProcessing(_output, _input, encm, 
estNNz, 1);
                        return null;
                }
 
@@ -1381,13 +1534,14 @@ public class MultiColumnEncoder implements Encoder {
 
                @Override
                public Object call() throws Exception {
+                       // updates the outputCol offset and sets the nnz 
offsets, which are created by bow encoders, in each encoder
                        int currentCol = -1;
                        int currentOffset = 0;
                        int[] sparseRowPointerOffsets = null;
                        for(DependencyTask<?> dtask : _applyTasksWrappers) {
                                ((ApplyTasksWrapperTask) 
dtask).setOffset(currentOffset);
                                if(sparseRowPointerOffsets != null)
-                                       ((ApplyTasksWrapperTask) 
dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets.clone());
+                                       ((ApplyTasksWrapperTask) 
dtask).setSparseRowPointerOffsets(sparseRowPointerOffsets);
                                int nonOffsetCol = ((ApplyTasksWrapperTask) 
dtask)._encoder._colID - 1;
                                if(nonOffsetCol > currentCol) {
                                        currentCol = nonOffsetCol;
@@ -1398,11 +1552,14 @@ public class MultiColumnEncoder implements Encoder {
                                                ColumnEncoderBagOfWords bow = 
enc.getEncoder(ColumnEncoderBagOfWords.class);
                                                currentOffset += 
bow.getDomainSize() - 1;
                                                if(sparseRowPointerOffsets == 
null)
-                                                       sparseRowPointerOffsets 
= bow.nnzPerRow.clone();
-                                               else
+                                                       sparseRowPointerOffsets 
= bow._nnzPerRow;
+                                               else{
+                                                       sparseRowPointerOffsets 
= sparseRowPointerOffsets.clone();
+                                                       // TODO: experiment if 
it makes sense to parallize here (for frames with many rows)
                                                        for (int r = 0; r < 
sparseRowPointerOffsets.length; r++) {
-                                                               
sparseRowPointerOffsets[r] += bow.nnzPerRow[r] - 1;
+                                                               
sparseRowPointerOffsets[r] += bow._nnzPerRow[r] - 1;
                                                        }
+                                               }
                                        }
                                }
                        }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
new file mode 100644
index 0000000000..3537998d99
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderMixedFunctionalityTests.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
+import org.apache.sysds.runtime.transform.encode.EncoderFactory;
+import org.apache.sysds.runtime.transform.encode.EncoderOmit;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assert.assertTrue;
+
+public class ColumnEncoderMixedFunctionalityTests extends AutomatedTestBase
+{
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+       }
+
+       @Test
+       public void testCompositeConstructor1() {
+               ColumnEncoderComposite cEnc1 = new ColumnEncoderComposite(null, 
1);
+               ColumnEncoderComposite cEnc2 = new 
ColumnEncoderComposite(cEnc1);
+               assert cEnc1.getColID() == cEnc2.getColID();
+
+       }
+       @Test
+       public void testCompositeConstructor2() {
+               List<ColumnEncoder> encoderList = new ArrayList<>();
+               encoderList.add( new ColumnEncoderComposite(null, 1));
+               encoderList.add( new ColumnEncoderComposite(null, 2));
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() ->  new ColumnEncoderComposite(encoderList, null));
+               assertTrue(e.getMessage().contains("Tried to create Composite 
Encoder with no encoders or mismatching columnIDs"));
+       }
+
+       @Test
+       public void testEncoderFactoryGetUnsupportedEncoderType(){
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() ->  EncoderFactory.getEncoderType(new ColumnEncoderComposite()));
+               assertTrue(e.getMessage().contains("Unsupported encoder type: 
org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite"));
+       }
+
+       @Test
+       public void testEncoderFactoryCreateUnsupportedInstanceType(){
+               // type(7) = composite, which we don't use for encoding the type
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() ->  EncoderFactory.createInstance(7));
+               assertTrue(e.getMessage().contains("Unsupported encoder type: 
Composite"));
+       }
+
+       @Test
+       public void testMultiColumnEncoderApplyWithWrongInputCharacteristics1(){
+               // apply call without metadata about encoders
+               MultiColumnEncoder mEnc = new MultiColumnEncoder();
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() -> mEnc.apply(null, null, 0));
+               assertTrue(e.getMessage().contains("MultiColumnEncoder apply 
without Encoder Characteristics should not be called directly"));
+       }
+
+       @Test
+       public void testMultiColumnEncoderApplyWithWrongInputCharacteristics2(){
+               // apply with LegacyEncoders + non FrameBlock Inputs
+               MultiColumnEncoder mEnc = new MultiColumnEncoder();
+               mEnc.addReplaceLegacyEncoder(new EncoderOmit());
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() -> mEnc.apply(new MatrixBlock(), null, 0, 0, null, 0L));
+               assertTrue(e.getMessage().contains("LegacyEncoders do not 
support non FrameBlock Inputs"));
+       }
+
+       @Test
+       public void testMultiColumnEncoderApplyWithWrongInputCharacteristics3(){
+               // #CompositeEncoders != #cols
+               ArrayList<ColumnEncoder> encs = new ArrayList<>();
+               encs.add(new ColumnEncoderBagOfWords());
+               ArrayList<ColumnEncoderComposite> cEncs = new ArrayList<>();
+               cEncs.add(new ColumnEncoderComposite(encs));
+               MultiColumnEncoder mEnc = new MultiColumnEncoder(cEncs);
+               FrameBlock in = new FrameBlock(2, Types.ValueType.FP64);
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() -> mEnc.apply(in, null, 0, 0, null, 0L));
+               assertTrue(e.getMessage().contains("Not every column in has a 
CompositeEncoder. Please make sure every column has a encoder or slice the 
input accordingly"));
+       }
+
+       @Test
+       public void testMultiColumnEncoderApplyWithWrongInputCharacteristics4(){
+               // input has 0 rows
+               MultiColumnEncoder mEnc = new MultiColumnEncoder();
+               MatrixBlock in = new MatrixBlock();
+               DMLRuntimeException e = assertThrows(DMLRuntimeException.class, 
() -> mEnc.apply(in, null, 0, 0, null, 0L));
+               assertTrue(e.getMessage().contains("Invalid input with wrong 
number or rows"));
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
index aeec927e73..2bd1e64697 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/ColumnEncoderSerializationTest.java
@@ -25,11 +25,14 @@ import java.io.IOException;
 import java.io.ObjectInput;
 import java.io.ObjectInputStream;
 import java.io.ObjectOutputStream;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 
 import org.apache.sysds.common.Types;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
+import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoderComposite;
 import org.apache.sysds.runtime.transform.encode.EncoderFactory;
 import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
@@ -56,7 +59,8 @@ public class ColumnEncoderSerializationTest extends 
AutomatedTestBase
                RECODE,
                DUMMY,
                IMPUTE,
-               OMIT
+               OMIT,
+               BOW
        }
 
        @Override
@@ -88,6 +92,12 @@ public class ColumnEncoderSerializationTest extends 
AutomatedTestBase
        @Test
        public void testComposite8() { runTransformSerTest(TransformType.OMIT, 
schemaStrings); }
 
+       @Test
+       public void testComposite9() { runTransformSerTest(TransformType.BOW, 
schemaStrings); }
+
+       @Test
+       public void testComposite10() { runTransformSerTest(TransformType.BOW, 
schemaMixed); }
+
 
 
 
@@ -117,11 +127,21 @@ public class ColumnEncoderSerializationTest extends 
AutomatedTestBase
                                        "{ \"id\": 7, \"method\": 
\"global_mode\" }, { \"id\": 9, \"method\": \"global_mean\" } ]\n\n}";
                else if (type == TransformType.OMIT)
                        spec = "{ \"ids\": true, \"omit\": [ 1,2,4,5,6,7,8,9 ], 
\"recode\": [ 2, 7 ] }";
+               else if (type == TransformType.BOW)
+                       spec = "{ \"ids\": true, \"omit\": [ 1,4,5,6,8,9 ], 
\"bag_of_words\": [ 2, 7 ] }";
 
                frame.setSchema(schema);
                String[] cnames = frame.getColumnNames();
 
                MultiColumnEncoder encoderIn = 
EncoderFactory.createEncoder(spec, cnames, frame.getNumColumns(), null);
+               if(type == TransformType.BOW){
+                       List<ColumnEncoderBagOfWords> encs = 
encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class);
+                       HashMap<Object, Long> dict = new HashMap<>();
+                       dict.put("val1", 1L);
+                       dict.put("val2", 2L);
+                       dict.put("val3", 300L);
+                       encs.forEach(e -> e.setTokenDictionary(dict));
+               }
                MultiColumnEncoder encoderOut;
 
                // serialization and deserialization
@@ -141,7 +161,16 @@ public class ColumnEncoderSerializationTest extends 
AutomatedTestBase
                for(Class<? extends ColumnEncoder> classtype: typesIn){
                        
Assert.assertArrayEquals(encoderIn.getFromAllIntArray(classtype, 
ColumnEncoder::getColID), encoderOut.getFromAllIntArray(classtype, 
ColumnEncoder::getColID));
                }
-
+               if(type == TransformType.BOW){
+                       List<ColumnEncoderBagOfWords> encsIn = 
encoderIn.getColumnEncoders(ColumnEncoderBagOfWords.class);
+                       List<ColumnEncoderBagOfWords> encsOut = 
encoderOut.getColumnEncoders(ColumnEncoderBagOfWords.class);
+                       for (int i = 0; i < encsIn.size(); i++) {
+                               Map<Object, Long> eOutDict = 
encsOut.get(i).getTokenDictionary();
+                               
encsIn.get(i).getTokenDictionary().forEach((k,v) -> {
+                                       assert v.equals(eOutDict.get(k));
+                               });
+                       }
+               }
 
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
index e3d1c07be2..1965d743f7 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeBagOfWords.java
@@ -19,13 +19,13 @@
 
 package org.apache.sysds.test.functions.transform;
 
-import org.apache.commons.lang3.NotImplementedException;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords;
+import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -36,17 +36,18 @@ import java.io.IOException;
 import java.nio.file.Path;
 import java.nio.file.Paths;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.nio.file.Files;
 
 import static 
org.apache.sysds.runtime.transform.encode.ColumnEncoderBagOfWords.tokenize;
-import static org.junit.Assert.assertThrows;
 
 public class TransformFrameEncodeBagOfWords extends AutomatedTestBase
 {
        private final static String TEST_NAME1 = 
"TransformFrameEncodeBagOfWords";
+       private final static String TEST_NAME2 = 
"TransformFrameEncodeApplyBagOfWords";
        private final static String TEST_DIR = "functions/transform/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeBagOfWords.class.getSimpleName() + "/";
        // for benchmarking: Digital_Music_Text.csv
@@ -56,6 +57,7 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
        public void setUp() {
                TestUtils.clearAssertionInformation();
                addTestConfiguration(TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+               addTestConfiguration(TEST_NAME2, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
        }
 
        // These tests result in dense output
@@ -63,6 +65,17 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
        public void testTransformBagOfWords() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, 
false);
        }
+       @Test
+       public void testTransformApplyBagOfWords() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, 
false);
+       }
+
+       @Test
+       public void testTransformApplySeparateStagesBagOfWords() {
+               MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true;
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, 
false);
+               MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false;
+       }
 
        @Test
        public void testTransformBagOfWordsError() {
@@ -74,32 +87,62 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false);
        }
 
+       @Test
+       public void testTransformApplyBagOfWordsPlusRecode() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false);
+       }
+
        @Test
        public void testTransformBagOfWords2() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true);
        }
 
+       @Test
+       public void testTransformApplyBagOfWords2() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true);
+       }
+
        @Test
        public void testTransformBagOfWordsPlusRecode2() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, true);
        }
 
+       @Test
+       public void testTransformApplyBagOfWordsPlusRecode2() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true);
+       }
+
        // AmazonReviewDataset transformation results in a sparse output
        @Test
        public void testTransformBagOfWordsAmazonReviews() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, 
false, true);
        }
 
+       @Test
+       public void testTransformApplyBagOfWordsAmazonReviews() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, 
false, true);
+       }
+
        @Test
        public void testTransformBagOfWordsAmazonReviews2() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, false, true, 
true);
        }
 
+       @Test
+       public void testTransformApplyBagOfWordsAmazonReviews2() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, false, true, 
true);
+       }
+
        @Test
        public void testTransformBagOfWordsAmazonReviewsAndRandRecode() {
                runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE, true, false, 
true);
        }
 
+       @Test
+       public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, false, 
true);
+       }
+
        @Test
        public void testTransformBagOfWordsAmazonReviewsAndDummyCode() {
                // TODO: compare result
@@ -118,16 +161,55 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
        }
 
        @Test
-       public void testNotImplementedFunction(){
-               ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords();
-               assertThrows(NotImplementedException.class, () -> 
bow.initMetaData(null));
+       public void testTransformApplyBagOfWordsAmazonReviewsAndRandRecode2() {
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true, 
true);
+       }
+
+       @Test
+       public void 
testTransformApplySeparateStagesBagOfWordsAmazonReviewsAndRandRecode2() {
+               MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = true;
+               runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE, true, true, 
true);
+               MultiColumnEncoder.APPLY_ENCODER_SEPARATE_STAGES = false;
        }
 
-       //@Test
+
+       @Test
        public void testTransformBagOfWordsSpark() {
                runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false);
        }
 
+       @Test
+       public void testTransformBagOfWordsAmazonReviewsSpark() {
+               runTransformTest(TEST_NAME1, ExecMode.SPARK, false, false, 
true);
+       }
+
+       @Test
+       public void testTransformBagOfWordsAmazonReviews2Spark() {
+               runTransformTest(TEST_NAME1, ExecMode.SPARK, false, true, true);
+       }
+
+       @Test
+       public void testTransformBagOfWordsAmazonReviewsAndRandRecodeSpark() {
+               runTransformTest(TEST_NAME1, ExecMode.SPARK, true, false, true);
+       }
+
+       @Test
+       public void testTransformBagOfWordsAmazonReviewsAndRandRecode2Spark() {
+               runTransformTest(TEST_NAME1, ExecMode.SPARK, true, true, true);
+       }
+
+       @Test
+       public void testBuildPartialBagOfWordsNotApplicable() {
+               ColumnEncoderBagOfWords bow = new ColumnEncoderBagOfWords();
+               assert bow.getColID() == -1;
+               try {
+                       bow.buildPartial(null); // should run without error
+               } catch (Exception e) {
+                       throw new AssertionError("Test failed: Expected no 
errors due to early abort (colId = -1). " +
+                                       "Encountered exception:\n" + e + 
"\nMessage: " + Arrays.toString(e.getStackTrace()));
+               }
+       }
+
        private void runTransformTest(String testname, ExecMode rt, boolean 
recode, boolean dup){
                runTransformTest(testname, rt, recode, dup, false);
        }
@@ -154,31 +236,31 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
                        if(!fromFile)
                                writeStringsToCsvFile(sentenceColumn, 
recodeColumn, baseDirectory + INPUT_DIR + "data", dup);
 
-                       int mode = 0;
-                       if(error)
-                               mode = 1;
-                       if(dc)
-                               mode = 2;
-                       if(pt)
-                               mode = 3;
-                       programArgs = new String[]{"-stats","-args", fromFile ? 
DATASET_DIR + DATASET : input("data"),
+                       int mode = error ? 1 : (dc ? 2 : (pt ? 3 : 0));
+                       programArgs = new String[]{"-explain", 
"recompile_runtime", "-stats","-args", fromFile ? DATASET_DIR + DATASET : 
input("data"),
                                        output("result"), output("dict"), 
String.valueOf(recode), String.valueOf(dup),
                                        String.valueOf(fromFile), 
String.valueOf(mode)};
                        if(error)
                                runTest(true, EXCEPTION_EXPECTED, 
DMLRuntimeException.class, -1);
                        else{
                                runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
-                               FrameBlock dict_frame = readDMLFrameFromHDFS( 
"dict", Types.FileFormat.CSV);
-                               int cols = recode? dict_frame.getNumRows() + 1 
: dict_frame.getNumRows();
-                               if(dup)
-                                       cols *= 2;
-                               if(mode == 0){
-                                       HashMap<MatrixValue.CellIndex, Double> 
res_actual = readDMLMatrixFromOutputDir("result");
-                                       double[][] result = 
TestUtils.convertHashMapToDoubleArray(res_actual, 
Math.min(sentenceColumn.length, 100),
-                                                       cols);
-                                       checkResults(sentenceColumn, result, 
recodeColumn, dict_frame, dup ? 2 : 1);
+                               if(testname == TEST_NAME2){
+                                       double errorValue = 
readDMLScalarFromOutputDir("result").values()
+                                                       
.stream().findFirst().orElse(1000.0);
+                                       System.out.println(errorValue);
+                                       assert errorValue <= 10;
+                               } else {
+                                       FrameBlock dict_frame = 
readDMLFrameFromHDFS( "dict", Types.FileFormat.CSV);
+                                       int cols = recode? 
dict_frame.getNumRows() + 1 : dict_frame.getNumRows();
+                                       if(dup)
+                                               cols *= 2;
+                                       if(mode == 0){
+                                               HashMap<MatrixValue.CellIndex, 
Double> res_actual = readDMLMatrixFromOutputDir("result");
+                                               double[][] result = 
TestUtils.convertHashMapToDoubleArray(res_actual, 
Math.min(sentenceColumn.length, 100),
+                                                               cols);
+                                               checkResults(sentenceColumn, 
result, recodeColumn, dict_frame, dup ? 2 : 1);
+                                       }
                                }
-
                        }
 
 
@@ -211,6 +293,9 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
        {
                HashMap<String, Integer>[] indices = new HashMap[duplicates];
                HashMap<String, Integer>[] rcdMaps = new HashMap[duplicates];
+               String errors = "";
+               int num_errors = 0;
+               int max_errors = 100;
                int frameCol = 0;
                // even when the set of tokens is the same for duplicates, the 
order in which the tokens dicts are merged
                // is not always the same for all columns in multithreaded mode
@@ -219,8 +304,9 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
                        rcdMaps[i] = new HashMap<>();
                        for (int j = 0; j < dict.getNumRows(); j++) {
                                String[] tuple = dict.getString(j, 
frameCol).split("\u00b7");
-                               indices[i].put(tuple[0], 
Integer.parseInt(tuple[1]));
+                               indices[i].put(tuple[0], 
Integer.parseInt(tuple[1]) - 1);
                        }
+                       System.out.println("Bow dict size: " + 
indices[i].size());
                        frameCol++;
                        if(recodeColumn != null){
                                for (int j = 0; j < dict.getNumRows(); j++) {
@@ -232,6 +318,7 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
                                }
                                frameCol++;
                        }
+                       System.out.println("Rec dict size: " + 
rcdMaps[i].size());
                }
 
                // only check the first 100 rows
@@ -266,19 +353,49 @@ public class TransformFrameEncodeBagOfWords extends 
AutomatedTestBase
                                for(Map.Entry<String, Integer> entry : 
count.entrySet()){
                                        String word = entry.getKey();
                                        int count_expected = entry.getValue();
-                                       int index = indices[j].get(word);
-                                       assert result[row][index + offset] == 
count_expected;
+                                       Integer index = indices[j].get(word);
+                                       if(index == null){
+                                               throw new AssertionError("row 
[" + row + "]: not found word: " + word);
+                                       }
+                                       if(result[row][index + offset] != 
count_expected){
+                                               String error_message = "bow 
result[" + row + "," + (index + offset) + "]=" +
+                                                               
result[row][index + offset] + " does not match the expected: " + count_expected;
+                                               if(num_errors < max_errors)
+                                                       errors += error_message 
+ '\n';
+                                               else
+                                                       throw new 
AssertionError(errors + error_message);
+                                               num_errors++;
+                                       }
+                               }
+                               for(int zeroIndex : zeroIndices){
+                                       if(result[row][offset + zeroIndex] != 
0){
+                                               String error_message = "bow 
result[" + row + "," + (offset + zeroIndex) + "]=" +
+                                                               
result[row][offset + zeroIndex] + " does not match the expected: 0";
+                                               if(num_errors < max_errors)
+                                                       errors += error_message 
+ '\n';
+                                               else
+                                                       throw new 
AssertionError(errors + error_message);
+                                               num_errors++;
+                                       }
                                }
-                               for(int zeroIndex : zeroIndices)
-                                       assert result[row][offset + zeroIndex] 
== 0;
                                offset += indices[j].size();
                                // compare results: recode
                                if(recodeColumn != null){
-                                       assert result[row][offset] == 
rcdMaps[j].get(recodeColumn[row]);
+                                       if(result[row][offset] != 
rcdMaps[j].get(recodeColumn[row])){
+                                               String error_message = "recode 
result[" + row + "," + offset + "]=" +
+                                                               
result[row][offset]+ " does not match the expected: " + 
rcdMaps[j].get(recodeColumn[row]);
+                                               if(num_errors < max_errors)
+                                                       errors += error_message 
+ '\n';
+                                               else
+                                                       throw new 
AssertionError(errors + error_message);
+                                               num_errors++;
+                                       }
                                        offset++;
                                }
                        }
                }
+               if (num_errors > 0)
+                       throw new AssertionError(errors);
        }
 
        public static void writeStringsToCsvFile(String[] sentences, String[] 
recodeTokens, String fileName, boolean duplicate) throws IOException {
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
similarity index 85%
copy from 
src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
copy to 
src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
index 49231c97c0..b2cd0ccda2 100644
--- a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeApplyBagOfWords.dml
@@ -55,13 +55,16 @@ if(as.integer($7) == 3){
   jspec = "{ids: true, bag_of_words: [1]}";
 }
 
+[Data_enc, Meta] = transformencode(target=Data, spec=jspec);
+while(FALSE){}
+
 i = 0
 total = 0
 j = 0
 # set to 20 for benchmarking
-while(i < 1){
+while(i < 30){
   t0 = time()
-  [Data_enc, Meta] = transformencode(target=Data, spec=jspec);
+  Data_enc2 = transformapply(target=Data, spec=jspec, meta=Meta)
   if(i > 10){
     total = total + time() - t0
     j = j + 1
@@ -69,10 +72,13 @@ while(i < 1){
   i = i + 1
 }
 print(total/1000000000 / j)
-print(nrow(Data_enc) + " x " + ncol(Data_enc))
-#reduce nr rows for large input tests
-if(nrow(Data_enc) > 100){
-  Data_enc = Data_enc[1:100,]
-}
-write(Data_enc, $2, format="text");
-write(Meta, $3, format="csv");
+
+i = 0
+
+Error = sign(Data_enc2 - Data_enc)
+Error_agg = sum(Error * Error)
+#print(sum(sign(Data_enc2)))
+#print(sum(sign(Data_enc)))
+#print(Error_agg)
+write(Error_agg, $2, format="text");
+
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
index 49231c97c0..2a69f314a4 100644
--- a/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
+++ b/src/test/scripts/functions/transform/TransformFrameEncodeBagOfWords.dml
@@ -21,15 +21,15 @@
 
 # Read the token sequence (1K) w/ 100 distinct tokens
 Data = read($1, data_type="frame", format="csv");
+#print(toString(Data))
 
 if(!as.boolean($4) & as.boolean($6)){
   Data = Data[,1]
 }
+while(FALSE){}
 if(as.boolean($5) & as.boolean($6)){
   Data = cbind(Data,Data)
 }
-while(FALSE){}
-
 if (as.boolean($4)) {
   if (as.boolean($5)) {
     jspec = "{ids: true, bag_of_words: [1,3], recode : [2,4]}";
@@ -54,7 +54,7 @@ if(as.integer($7) == 3){
   Data = cbind(Data, ones)
   jspec = "{ids: true, bag_of_words: [1]}";
 }
-
+while(FALSE){}
 i = 0
 total = 0
 j = 0

Reply via email to