This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 23935c2a20 [SYSTEMDS-3581] New dense block to deduplicate rows
23935c2a20 is described below

commit 23935c2a2000778ab67b1eb24a4aab85482de783
Author: e-strauss <[email protected]>
AuthorDate: Thu Jun 15 12:04:58 2023 +0200

    [SYSTEMDS-3581] New dense block to deduplicate rows
    
    This patch adds a new dense block to deduplicate the duplicate
    rows by having pointers to the same row which is allocated once.
    Internally, it has a 2d array, where the 2nd dimension has the
    pointer to the rows. The use case word embedding which produces
    many duplicate dense rows.
    
    Closes #1842
---
 .../sysds/runtime/data/DenseBlockFP64DEDUP.java    | 253 +++++++++++++++++++++
 .../sysds/runtime/data/DenseBlockFactory.java      |  35 ++-
 .../sysds/runtime/data/DenseBlockLFP64DEDUP.java   | 241 ++++++++++++++++++++
 .../sysds/runtime/matrix/data/LibMatrixMult.java   |   2 +-
 .../sysds/runtime/matrix/data/MatrixBlock.java     |  22 +-
 .../transform/encode/ColumnEncoderRecode.java      |   9 +
 .../encode/ColumnEncoderWordEmbedding.java         |  69 ++++--
 .../runtime/transform/encode/EncoderFactory.java   |  10 +-
 .../transform/encode/MultiColumnEncoder.java       |  17 +-
 ...=> TransformFrameEncodeWordEmbedding1Test.java} |   8 +-
 .../TransformFrameEncodeWordEmbedding2Test.java    |  66 ++++--
 .../TransformFrameEncodeWordEmbeddings.dml         |   2 +-
 .../TransformFrameEncodeWordEmbeddings2.dml        |   3 +-
 13 files changed, 679 insertions(+), 58 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
new file mode 100644
index 0000000000..bf0e83ec5b
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -0,0 +1,253 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+public class DenseBlockFP64DEDUP extends DenseBlockDRB{
+    private double[][] _data;
+
+    protected DenseBlockFP64DEDUP(int[] dims) {
+        super(dims);
+        reset(_rlen, _odims, 0);
+    }
+
+    @Override
+    protected void allocateBlock(int bix, int length) {
+        _data[bix] = new double[length];
+    }
+
+    @Override
+    public void reset(int rlen, int[] odims, double v) {
+        if(rlen >  capacity() / _odims[0])
+            _data = new double[rlen][];
+        else{
+            if(v == 0.0){
+                for(int i = 0; i < rlen; i++)
+                    _data[i] = null;
+            }
+            else {
+                for(int i = 0; i < rlen; i++){
+                    if(odims[0] > _odims[0] ||_data[i] == null )
+                        allocateBlock(i, odims[0]);
+                    Arrays.fill(_data[i], 0, odims[0], v);
+                }
+            }
+        }
+        _rlen = rlen;
+        _odims = odims;
+    }
+
+    @Override
+    public void resetNoFill(int rlen, int[] odims) {
+        if(_data == null || rlen > _rlen){
+            _data = new double[rlen][];
+        }
+        _rlen = rlen;
+        _odims = odims;
+    }
+
+    @Override
+    public boolean isNumeric() {
+        return true;
+    }
+
+    @Override
+    public boolean isNumeric(Types.ValueType vt) {
+        return Types.ValueType.FP64 == vt;
+    }
+
+    @Override
+    public long capacity() {
+        return (_data != null) ? _data.length*_odims[0] : -1;
+    }
+
+    @Override
+    public long countNonZeros(){
+        long nnz = 0;
+        HashMap<double[], Long> cache = new HashMap<double[], Long>();
+        for (int i = 0; i < _rlen; i++) {
+            double[] row = this._data[i];
+            if(row == null)
+                continue;
+            Long count = cache.getOrDefault(row, null);
+            if(count == null){
+                count = Long.valueOf(countNonZeros(i));
+                cache.put(row, count);
+            }
+            nnz += count;
+        }
+        return nnz;
+    }
+
+    @Override
+    public int countNonZeros(int r) {
+        return _data[r] == null ? 0 : UtilFunctions.computeNnz(_data[r], 0, 
_odims[0]);
+    }
+
+    @Override
+    protected long computeNnz(int bix, int start, int length) {
+        int nnz = 0;
+        int row_start = (int) Math.floor(start / _odims[0]);
+        int col_start = start % _odims[0];
+        for (int i = 0; i < length; i++) {
+            if(_data[row_start] == null){
+                i += _odims[0] - 1 - col_start;
+                col_start = 0;
+                row_start += 1;
+                continue;
+            }
+            nnz += _data[row_start][col_start] != 0 ? 1 : 0;
+            col_start += 1;
+            if(col_start == _odims[0]) {
+                col_start = 0;
+                row_start += 1;
+            }
+        }
+        return nnz;
+    }
+
+    @Override
+    public int pos(int r){
+        return 0;
+    }
+
+    @Override
+    public int pos(int[] ix){
+        int pos = ix[ix.length - 1];
+        for(int i = 1; i < ix.length - 1; i++)
+            pos += ix[i] * _odims[i];
+        return pos;
+    }
+
+    @Override
+    public double[] values(int r) {
+        return valuesAt(r);
+    }
+
+    @Override
+    public double[] valuesAt(int bix) {
+        return _data[bix] == null ? new double[_odims[0]] : _data[bix];
+    }
+
+    @Override
+    public int index(int r) {
+        return r;
+    }
+
+    @Override
+    public int numBlocks(){
+        return _data.length;
+    }
+
+    @Override
+    public int size(int bix) {
+        return _odims[0];
+    }
+
+    @Override
+    public void incr(int r, int c) {
+        incr(r,c,1.0);
+    }
+
+    @Override
+    public void incr(int r, int c, double delta) {
+        if(_data[r] == null)
+            allocateBlock(r, _odims[0]);
+        _data[r][c] += delta;
+    }
+
+    @Override
+    protected void fillBlock(int bix, int fromIndex, int toIndex, double v) {
+        if(_data[bix] == null)
+            allocateBlock(bix, _odims[0]);
+        Arrays.fill(_data[bix], fromIndex, toIndex, v);
+    }
+
+    @Override
+    protected void setInternal(int bix, int ix, double v) {
+        set(bix, ix, v);
+    }
+
+    @Override
+    public DenseBlock set(int r, int c, double v) {
+        if(_data[r] == null)
+            _data[r] = new double[_odims[0]];
+        _data[r][c] = v;
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(int r, double[] v) {
+        if(v.length == _odims[0])
+            _data[r] = v;
+        else
+            throw new RuntimeException("set Denseblock called with an array 
length [" + v.length +"], array to overwrite is of length [" + _odims[0] + "]");
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(DenseBlock db) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, double v) {
+        return set(ix[0], pos(ix), v);
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, long v) {
+        return set(ix[0], pos(ix), v);
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, String v) {
+        return set(ix[0], pos(ix), Double.parseDouble(v));
+    }
+
+    @Override
+    public double get(int r, int c) {
+        if(_data[r] == null)
+            return 0.0;
+        else
+            return _data[r][c];
+    }
+
+    @Override
+    public double get(int[] ix) {
+        return get(ix[0], pos(ix));
+    }
+
+    @Override
+    public String getString(int[] ix) {
+        return String.valueOf(get(ix[0], pos(ix)));
+    }
+
+    @Override
+    public long getLong(int[] ix) {
+        return UtilFunctions.toLong(get(ix[0], pos(ix)));
+    }
+}
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
index e840585f49..c48a9d0aca 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFactory.java
@@ -32,15 +32,29 @@ public abstract class DenseBlockFactory
        public static DenseBlock createDenseBlock(int rlen, int clen) {
                return createDenseBlock(new int[]{rlen, clen});
        }
+
+       public static DenseBlock createDenseBlock(int rlen, int clen, boolean 
dedup) {
+               return createDenseBlock(new int[]{rlen, clen}, dedup);
+       }
        
        public static DenseBlock createDenseBlock(int[] dims) {
                return createDenseBlock(ValueType.FP64, dims);
        }
+
+       public static DenseBlock createDenseBlock(int[] dims, boolean dedup) {
+               return createDenseBlock(ValueType.FP64, dims, dedup);
+       }
        
        public static DenseBlock createDenseBlock(ValueType vt, int[] dims) {
                DenseBlock.Type type = (UtilFunctions.prod(dims) < 
Integer.MAX_VALUE) ?
                        DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
-               return createDenseBlock(vt, type, dims);
+               return createDenseBlock(vt, type, dims, false);
+       }
+
+       public static DenseBlock createDenseBlock(ValueType vt, int[] dims, 
boolean dedup) {
+               DenseBlock.Type type = (UtilFunctions.prod(dims) < 
Integer.MAX_VALUE) ?
+                               DenseBlock.Type.DRB : DenseBlock.Type.LDRB;
+               return createDenseBlock(vt, type, dims, dedup);
        }
 
        public static DenseBlock createDenseBlock(BitSet data, int[] dims) {
@@ -75,7 +89,24 @@ public abstract class DenseBlockFactory
                return createDenseBlock(data, new int[]{rlen, clen});
        }
        
-       public static DenseBlock createDenseBlock(ValueType vt, DenseBlock.Type 
type, int[] dims) {
+       public static DenseBlock createDenseBlock(ValueType vt, DenseBlock.Type 
type, int[] dims, boolean dedup) {
+               if( dedup ) {
+                       switch( type ) {
+                               case DRB:
+                                       switch(vt) {
+                                               case FP64: return new 
DenseBlockFP64DEDUP(dims);
+                                               default:
+                                                       throw new 
DMLRuntimeException("Unsupported dense block value type with deduplication 
enabled: "+vt.name());
+                                       }
+                               case LDRB:
+                                       switch(vt) {
+                                               default:
+                                                       throw new 
NotImplementedException();
+                                       }
+                               default:
+                                       throw new 
DMLRuntimeException("Unexpected dense block type: "+type.name());
+                       }
+               }
                switch( type ) {
                        case DRB:
                                switch(vt) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
new file mode 100644
index 0000000000..5a4249c920
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockLFP64DEDUP.java
@@ -0,0 +1,241 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.data;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.Arrays;
+import java.util.HashMap;
+
+public class DenseBlockLFP64DEDUP extends DenseBlockLDRB{
+    //WIP
+    private double[][] _data;
+
+    protected DenseBlockLFP64DEDUP(int[] dims) {
+        super(dims);
+        reset(_rlen, _odims, 0);
+    }
+
+    @Override
+    protected void allocateBlocks(int numBlocks) {
+        _data = new double[numBlocks][];
+    }
+
+    @Override
+    protected void allocateBlock(int bix, int length) {
+        _data[bix] = new double[length];
+    }
+
+    @Override
+    public void reset(int rlen, int[] odims, double v) {
+        if(rlen >  capacity() / _odims[0]) {
+            this.allocateBlocks(rlen);
+            if (v != 0.0) {
+                for (int i = 0; i < rlen; i++) {
+                    allocateBlock(i, odims[0]);
+                    Arrays.fill(_data[i], 0, odims[0], v);
+                }
+            }
+        }
+        else{
+            if(v == 0.0){
+                for(int i = 0; i < rlen; i++)
+                    _data[i] = null;
+            }
+            else {
+                for(int i = 0; i < rlen; i++){
+                    if(odims[0] > _odims[0] ||_data[i] == null )
+                        allocateBlock(i, odims[0]);
+                    Arrays.fill(_data[i], 0, odims[0], v);
+                }
+            }
+        }
+        _blen = 1;
+        _rlen = rlen;
+        _odims = odims;
+    }
+
+    @Override
+    public boolean isNumeric() {
+        return true;
+    }
+
+    @Override
+    public boolean isNumeric(Types.ValueType vt) {
+        return Types.ValueType.FP64 == vt;
+    }
+
+    @Override
+    public boolean isContiguous() {
+        return false;
+    }
+
+    @Override
+    public long capacity() {
+        return (_data != null) ? _data.length*_odims[0] : -1;
+    }
+
+    @Override
+    public long countNonZeros(){
+        long nnz = 0;
+        HashMap<double[], Long> cache = new HashMap<double[], Long>();
+        for (int i = 0; i < _rlen; i++) {
+            double[] row = this._data[i];
+            if(row == null)
+                continue;
+            Long count = cache.getOrDefault(row, null);
+            if(count == null){
+                count = Long.valueOf(countNonZeros(i));
+                cache.put(row, count);
+            }
+            nnz += count;
+        }
+        return nnz;
+    }
+
+    @Override
+    public int countNonZeros(int r) {
+        return _data[r] == null ? 0 : UtilFunctions.computeNnz(_data[r], 0, 
_odims[0]);
+    }
+
+    @Override
+    protected long computeNnz(int bix, int start, int length) {
+        int nnz = 0;
+        int row_start = (int) Math.floor(start / _odims[0]);
+        int col_start = start % _odims[0];
+        for (int i = 0; i < length; i++) {
+            if(_data[row_start] == null){
+                i += _odims[0] - 1 - col_start;
+                col_start = 0;
+                row_start += 1;
+                continue;
+            }
+            nnz += _data[row_start][col_start] != 0 ? 1 : 0;
+            col_start += 1;
+            if(col_start == _odims[0]) {
+                col_start = 0;
+                row_start += 1;
+            }
+        }
+        return nnz;
+    }
+
+    @Override
+    public int pos(int r){
+        return 0;
+    }
+
+    @Override
+    public double[] values(int r) {
+        if(_data[r] == null)
+            allocateBlock(r, _odims[0]);
+        return _data[r];
+    }
+
+    @Override
+    public double[] valuesAt(int bix) {
+        return values(bix);
+    }
+
+
+    @Override
+    public int numBlocks(){
+        return _data.length;
+    }
+
+    @Override
+    public void incr(int r, int c) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public void incr(int r, int c, double delta) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    protected void fillBlock(int bix, int fromIndex, int toIndex, double v) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    protected void setInternal(int bix, int ix, double v) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int r, int c, double v) {
+        if(_data[r] == null)
+            _data[r] = new double[_odims[0]];
+        _data[r][c] = v;
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(int r, double[] v) {
+        if(v.length == _odims[0])
+            _data[r] = v;
+        else
+            throw new RuntimeException("set Denseblock called with an array 
length [" + v.length +"], array to overwrite is of length [" + _odims[0] + "]");
+        return this;
+    }
+
+    @Override
+    public DenseBlock set(DenseBlock db) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, double v) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, long v) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public DenseBlock set(int[] ix, String v) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public double get(int r, int c) {
+        return _data[r][c];
+    }
+
+    @Override
+    public double get(int[] ix) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public String getString(int[] ix) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public long getLong(int[] ix) {
+        throw new NotImplementedException();
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 99fb4b30ca..ce44e2e343 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -1738,7 +1738,7 @@ public class LibMatrixMult
                        int alen = a.size(i);
                        int[] aixs = a.indexes(i);
                        double[] avals = a.values(i);
-                       if( alen==1 ) { 
+                       if( alen==1 ) {
                                //row selection (now aggregation) with 
potential scaling
                                int aix = aixs[apos];
                                int lnnz = 0;
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index dcc17076c2..c075a9fd29 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -35,6 +35,7 @@ import java.util.concurrent.ExecutorService;
 import java.util.concurrent.Future;
 import java.util.stream.IntStream;
 
+import org.apache.commons.lang.NotImplementedException;
 import org.apache.commons.lang3.ArrayUtils;
 import org.apache.commons.lang3.concurrent.ConcurrentUtils;
 import org.apache.commons.logging.Log;
@@ -383,8 +384,12 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        allocateDenseBlock();
                return this;
        }
-       
-       public boolean allocateDenseBlock(boolean clearNNZ) {
+
+       public boolean allocateDenseBlock(boolean clearNNZ){
+               return allocateDenseBlock(clearNNZ, false);
+       }
+
+       public boolean allocateDenseBlock(boolean clearNNZ, boolean 
containsDuplicates) {
                //allocate block if non-existing or too small (guaranteed to be 
0-initialized),
                long limit = (long)rlen * clen;
                //clear nnz if necessary
@@ -393,7 +398,7 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                sparse = false;
 
                if( denseBlock == null ){
-                       denseBlock = DenseBlockFactory.createDenseBlock(rlen, 
clen);
+                       denseBlock = DenseBlockFactory.createDenseBlock(rlen, 
clen, containsDuplicates);
                        return true;
                }
                else if( denseBlock.capacity() < limit ){
@@ -669,6 +674,17 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                }
        }
 
+       public void quickSetRow(int r, double[] values){
+               if(sparse)
+                       throw new NotImplementedException();
+               else{
+                       //allocate and init dense block (w/o overwriting nnz)
+                       allocateDenseBlock(false);
+                       nonZeros += UtilFunctions.computeNnz(values, 0, 
values.length) - denseBlock.countNonZeros(r);
+                       denseBlock.set(r, values);
+               }
+       }
+
        public double quickGetValueThreadSafe(int r, int c) {
                if(sparse) {
                        if(!(sparseBlock instanceof SparseBlockMCSR))
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
index eb7e706e0c..334149515b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderRecode.java
@@ -88,6 +88,15 @@ public class ColumnEncoderRecode extends ColumnEncoder {
         * @return string array of token and code
         */
        public static String[] splitRecodeMapEntry(String value) {
+               // remove " chars from string (if the string contains comma in 
the csv file, then it must contained by double quotes)
+               /*if(value.contains("\"")){
+                       //remove just last and first appearance
+                       int firstIndex = value.indexOf("\"");
+                       int lastIndex = value.lastIndexOf("\"");
+                       if (firstIndex != lastIndex)
+                               value = value.substring(0, firstIndex) + 
value.substring(firstIndex + 1, lastIndex) + value.substring(lastIndex + 1);
+               }*/
+
                // Instead of using splitCSV which is forcing string with 
RFC-4180 format,
                // using Lop.DATATYPE_PREFIX separator to split token and code
                int pos = value.lastIndexOf(Lop.DATATYPE_PREFIX);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
index 03584cf5ee..9d08c3bc20 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -20,20 +20,31 @@
 package org.apache.sysds.runtime.transform.encode;
 
 import org.apache.commons.lang.NotImplementedException;
-import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 
+import java.util.HashMap;
+
 import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
 
 public class ColumnEncoderWordEmbedding extends ColumnEncoder {
-    private MatrixBlock wordEmbeddings;
+    private MatrixBlock _wordEmbeddings;
+    private HashMap<Object, Long> _rcdMap;
+    private HashMap<String, double[]> _embMap;
+
+    private long lookupRCDMap(Object key) {
+        return _rcdMap.getOrDefault(key, -1L);
+    }
 
-    //domain size is equal to the number columns of the embedding column 
(equal to length of an embedding vector)
+    private double[] lookupEMBMap(Object key) {
+        return _embMap.getOrDefault(key, null);
+    }
+
+    //domain size is equal to the number columns of the embeddings column 
thats equal to length of an embedding vector
     @Override
     public int getDomainSize(){
-        return wordEmbeddings.getNumColumns();
+        return _wordEmbeddings.getNumColumns();
     }
     protected ColumnEncoderWordEmbedding(int colID) {
         super(colID);
@@ -49,31 +60,44 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
         throw new NotImplementedException();
     }
 
-    //previous recode replaced strings with indices of the corresponding 
matrix row index
+    //previously recode replaced strings with indices of the corresponding 
matrix row index
     //now, the indices are replaced with actual word embedding vectors
-    //current limitation: in case the transform is done on multiple cols, the 
same embedding
-    //matrix is used for both transform
+    //current limitation: in case the transform is done on multiple cols, the 
same embedding matrix is used for both transform
+
+    private double[] getEmbeddedingFromEmbeddingMatrix(long r){
+        double[] embedding = new double[getDomainSize()];
+        for (int i = 0; i < getDomainSize(); i++) {
+            embedding[i] = this._wordEmbeddings.quickGetValue((int) r, _colID 
- 1 + i);
+        }
+        return embedding;
+
+    }
+
     @Override
     public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, 
int rowStart, int blk){
-        if (!(in instanceof MatrixBlock)){
+        /*if (!(in instanceof MatrixBlock)){
             throw new DMLRuntimeException("ColumnEncoderWordEmbedding called 
with: " + in.getClass().getSimpleName() +
                     " and not MatrixBlock");
-        }
+        }*/
         int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
-        //map each recoded index to the corresponding embedding vector
+
+        //map each string to the corresponding embedding vector
         for(int i=rowStart; i<rowEnd; i++){
-            double embeddingIndex = in.getDouble(i, outputCol);
-            //fill row with zeroes
-            if(Double.isNaN(embeddingIndex)){
-                for (int j = outputCol; j < outputCol + getDomainSize(); j++)
-                    out.quickSetValue(i, j, 0.0);
+            String key = in.getString(i, _colID-1);
+            if(key == null || key.isEmpty()) {
+                //codes[i-startInd] = Double.NaN;
+                continue;
             }
-            //array copy
-            else{
-                for (int j = outputCol; j < outputCol + getDomainSize(); j++){
-                    out.quickSetValue(i, j, wordEmbeddings.quickGetValue((int) 
embeddingIndex - 1,j - outputCol ));
+            double[] embedding = lookupEMBMap(key);
+            if(embedding == null){
+                long code = lookupRCDMap(key);
+                if(code == -1L){
+                    continue;
                 }
+                embedding = getEmbeddedingFromEmbeddingMatrix(code - 1);
+                _embMap.put(key, embedding);
             }
+            out.quickSetRow(i, embedding);
         }
     }
 
@@ -100,12 +124,15 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
 
     @Override
     public void initMetaData(FrameBlock meta) {
-        return;
+        if(meta == null || meta.getNumRows() <= 0)
+            return;
+        _rcdMap = meta.getRecodeMap(_colID - 1); // 1-based
     }
 
     //save embeddings matrix reference for apply step
     @Override
     public void initEmbeddings(MatrixBlock embeddings){
-        this.wordEmbeddings = embeddings;
+        this._wordEmbeddings = embeddings;
+        this._embMap = new HashMap<>((int) (embeddings.getNumRows()*1.2),1.0f);
     }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 313258831a..07e84a81e7 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -114,13 +114,13 @@ public interface EncoderFactory {
                        // column follows binning or feature hashing
                        rcIDs = unionDistinct(rcIDs, except(except(dcIDs, 
binIDs), haIDs));
                        // NOTE: Word Embeddings requires recode as preparation
-                       rcIDs = unionDistinct(rcIDs, weIDs);
+                       //rcIDs = unionDistinct(rcIDs, weIDs);
                        // Error out if the first level encoders have overlaps
-                       if (intersect(rcIDs, binIDs, haIDs))
-                               throw new DMLRuntimeException("More than one 
encoders (recode, binning, hashing) on one column is not allowed");
+                       if (intersect(rcIDs, binIDs, haIDs, weIDs))
+                               throw new DMLRuntimeException("More than one 
encoders (recode, binning, hashing, word_embedding) on one column is not 
allowed");
 
-                       List<Integer> ptIDs = 
except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs, 
haIDs)),
-                               binIDs);
+                       List<Integer> ptIDs = 
except(except(except(UtilFunctions.getSeqList(1, clen, 1), unionDistinct(rcIDs, 
haIDs)),
+                               binIDs), weIDs);
                        List<Integer> oIDs = Arrays.asList(ArrayUtils
                                .toObject(TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.OMIT.toString(), minCol, maxCol)));
                        List<Integer> mvIDs = Arrays.asList(ArrayUtils.toObject(
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 59c22f5640..2e64709b64 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -343,9 +343,13 @@ public class MultiColumnEncoder implements Encoder {
                        throw new DMLRuntimeException("Invalid input with wrong 
number or rows");
 
                boolean hasDC = false;
-               for(ColumnEncoderComposite columnEncoder : _columnEncoders)
-                       hasDC = 
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
-               outputMatrixPreProcessing(out, in, hasDC);
+               boolean hasWE = false;
+               for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
+                       hasDC |= 
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
+                       hasWE |= 
columnEncoder.hasEncoder(ColumnEncoderWordEmbedding.class);
+               }
+               //hasWE = false;
+               outputMatrixPreProcessing(out, in, hasDC, hasWE);
                if(k > 1) {
                        if(!_partitionDone) //happens if this method is 
directly called
                                deriveNumRowPartitions(in, k);
@@ -533,7 +537,7 @@ public class MultiColumnEncoder implements Encoder {
                return totMemOverhead;
        }
 
-       private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock<?> input, boolean hasDC) {
+       private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock<?> input, boolean hasDC, boolean hasWE) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                if(output.isInSparseFormat()) {
                        if (MatrixBlock.DEFAULT_SPARSEBLOCK != 
SparseBlock.Type.CSR
@@ -580,7 +584,7 @@ public class MultiColumnEncoder implements Encoder {
                }
                else {
                        // Allocate dense block and set nnz to total #entries
-                       output.allocateBlock();
+                       output.allocateDenseBlock(true, hasWE);
                        //output.setAllNonZeros();
                }
 
@@ -1119,12 +1123,13 @@ public class MultiColumnEncoder implements Encoder {
                @Override
                public Object call() throws Exception {
                        boolean hasUDF = 
_encoder.getColumnEncoders().stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderUDF.class));
+                       boolean hasWE = 
_encoder.getColumnEncoders().stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderWordEmbedding.class));
                        int numCols = _encoder.getNumOutCols();
                        boolean hasDC = 
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
                        long estNNz = (long) _input.getNumRows() * (hasUDF ? 
numCols : (long) _input.getNumColumns());
                        boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && 
!hasUDF;
                        _output.reset(_input.getNumRows(), numCols, sparse, 
estNNz);
-                       outputMatrixPreProcessing(_output, _input, hasDC);
+                       outputMatrixPreProcessing(_output, _input, hasDC, 
hasWE);
                        return null;
                }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
similarity index 96%
rename from 
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
rename to 
src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
index 9972f6b1c6..a69e287d33 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
@@ -37,11 +37,11 @@ import java.util.List;
 import java.util.Map;
 import java.util.Random;
 
-public class TransformFrameEncodeWordEmbeddingTest extends AutomatedTestBase
+public class TransformFrameEncodeWordEmbedding1Test extends AutomatedTestBase
 {
     private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings";
     private final static String TEST_DIR = "functions/transform/";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
 
     @Override
     public void setUp() {
@@ -61,7 +61,7 @@ public class TransformFrameEncodeWordEmbeddingTest extends 
AutomatedTestBase
         try
         {
             int rows = 100;
-            int cols = 100;
+            int cols = 300;
             getAndLoadTestConfiguration(testname);
             fullDMLScriptName = getScript();
 
@@ -72,7 +72,7 @@ public class TransformFrameEncodeWordEmbeddingTest extends 
AutomatedTestBase
             // Generate the dictionary by assigning unique ID to each distinct 
token
             Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
             // Create the dataset by repeating and shuffling the distinct 
tokens
-            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
10);
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
32);
             writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
 
             programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index 8ab52d9f64..06c8b6ee0b 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -42,18 +42,18 @@ import java.util.Random;
 public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
 {
     private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings2";
-    private final static String TEST_NAME2 = 
"TransformFrameEncodeWordEmbeddings2MultiCols1";
-    private final static String TEST_NAME3 = 
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+    private final static String TEST_NAME2a = 
"TransformFrameEncodeWordEmbeddings2MultiCols1";
+    private final static String TEST_NAME2b = 
"TransformFrameEncodeWordEmbeddings2MultiCols2";
 
     private final static String TEST_DIR = "functions/transform/";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
 
     @Override
     public void setUp() {
         TestUtils.clearAssertionInformation();
         addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, 
TEST_NAME1));
-        addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, 
TEST_NAME2));
-        addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, 
TEST_NAME3));
+        addTestConfiguration(TEST_NAME2a, new TestConfiguration(TEST_DIR, 
TEST_NAME2a));
+        addTestConfiguration(TEST_NAME2b, new TestConfiguration(TEST_DIR, 
TEST_NAME2b));
     }
 
     @Test
@@ -64,23 +64,30 @@ public class TransformFrameEncodeWordEmbedding2Test extends 
AutomatedTestBase
     @Test
     @Ignore
     public void testNonRandomTransformToWordEmbeddings2Cols() {
-        runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+        runTransformTest(TEST_NAME2a, ExecMode.SINGLE_NODE);
     }
 
     @Test
     @Ignore
     public void testRandomTransformToWordEmbeddings4Cols() {
-        runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE);
+        runTransformTestMultiCols(TEST_NAME2b, ExecMode.SINGLE_NODE);
     }
 
-    private void runTransformTest(String testname, ExecMode rt)
+    @Test
+    @Ignore
+    public void runBenchmark(){
+        runBenchmark(TEST_NAME1, ExecMode.SINGLE_NODE);
+    }
+
+
+    private void runBenchmark(String testname, ExecMode rt)
     {
         //set runtime platform
         ExecMode rtold = setExecMode(rt);
         try
         {
             int rows = 100;
-            int cols = 100;
+            int cols = 300;
             getAndLoadTestConfiguration(testname);
             fullDMLScriptName = getScript();
 
@@ -100,6 +107,43 @@ public class TransformFrameEncodeWordEmbedding2Test 
extends AutomatedTestBase
             //run script
             programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
             runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private void runTransformTest(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 300;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
32);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
 
             // Manually derive the expected result
             double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
@@ -107,10 +151,6 @@ public class TransformFrameEncodeWordEmbedding2Test 
extends AutomatedTestBase
             // Compare results
             HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
             double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
-            //System.out.println("Actual Result [" + resultActualDouble.length 
+ "x" + resultActualDouble[0].length + "]:");
-            //print2DimDoubleArray(resultActualDouble);
-            //System.out.println("\nExpected Result [" + res_expected.length + 
"x" + res_expected[0].length + "]:");
-            //print2DimDoubleArray(res_expected);
             TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
         }
         catch(Exception ex) {
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
index 1aa1fb0fed..dcab56b0fd 100644
--- 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 # Read the pre-trained word embeddings
-E = read($1, rows=100, cols=100, format="text");
+E = read($1, rows=100, cols=300, format="text");
 # Read the token sequence (1K) w/ 100 distinct tokens
 Data = read($2, data_type="frame", format="csv");
 # Read the recode map for the distinct tokens
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
index 29a4bfab74..139bf2b9f4 100644
--- 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
@@ -20,7 +20,7 @@
 #-------------------------------------------------------------
 
 # Read the pre-trained word embeddings
-E = read($1, rows=100, cols=100, format="text");
+E = read($1, rows=100, cols=300, format="text");
 # Read the token sequence (1K) w/ 100 distinct tokens
 Data = read($2, data_type="frame", format="csv");
 # Read the recode map for the distinct tokens
@@ -28,7 +28,6 @@ Meta = read($3, data_type="frame", format="csv");
 
 jspec = "{ids: true, word_embedding: [1]}";
 Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
-
 write(Data_enc, $4, format="text");
 
 


Reply via email to