This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 384a7071a4 [SYSTEMDS-3580] Add word embedding encoder
384a7071a4 is described below

commit 384a7071a41bf6ff5cf5c7410ebb351effb2e47b
Author: e-strauss <[email protected]>
AuthorDate: Fri Jun 9 12:15:36 2023 +0200

    [SYSTEMDS-3580] Add word embedding encoder
    
    This patch extends the transformapply API to accept the pre-trained word
    embeddings along with the dictionary as inputs. The new word embedding
    column encoder is placed after recode and replace the recoded indices
    with the embedding vectors. This addition removes the requirement of
    a matrix multiplication to produce the embedding matrix.
    The current implementation is slower than the baseline (w/ MatMult).
    The future commits will introduce a new dense block to deduplicate
    the large embeddings and multi-threading.
    
    Closes #1839
---
 .../ParameterizedBuiltinFunctionExpression.java    |   8 +-
 .../apache/sysds/runtime/data/DenseBlockFP64.java  |   2 +-
 .../cp/ParameterizedBuiltinCPInstruction.java      |  13 +-
 .../apache/sysds/runtime/transform/TfUtils.java    |   2 +-
 .../runtime/transform/encode/ColumnEncoder.java    |  10 +-
 .../transform/encode/ColumnEncoderComposite.java   |   6 +
 .../encode/ColumnEncoderWordEmbedding.java         | 111 +++++++++
 .../runtime/transform/encode/EncoderFactory.java   |  33 ++-
 .../transform/encode/MultiColumnEncoder.java       |  34 ++-
 .../sysds/utils/stats/TransformStatistics.java     |  11 +-
 .../TransformFrameEncodeWordEmbedding2Test.java    | 258 +++++++++++++++++++++
 .../TransformFrameEncodeWordEmbeddings2.dml        |  36 +++
 ...ansformFrameEncodeWordEmbeddings2MultiCols1.dml |  43 ++++
 ...ansformFrameEncodeWordEmbeddings2MultiCols2.dml |  44 ++++
 14 files changed, 593 insertions(+), 18 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
index 1d30d13fea..1906ee818e 100644
--- 
a/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
+++ 
b/src/main/java/org/apache/sysds/parser/ParameterizedBuiltinFunctionExpression.java
@@ -49,6 +49,7 @@ public class ParameterizedBuiltinFunctionExpression extends 
DataIdentifier
        public static final String TF_FN_PARAM_DATA = "target";
        public static final String TF_FN_PARAM_MTD2 = "meta";
        public static final String TF_FN_PARAM_SPEC = "spec";
+       public static final String TF_FN_PARAM_EMBD = "embedding";
        public static final String LINEAGE_TRACE = "lineage";
        public static final String TF_FN_PARAM_MTD = "transformPath"; //NOTE 
MB: for backwards compatibility
        
@@ -617,11 +618,14 @@ public class ParameterizedBuiltinFunctionExpression 
extends DataIdentifier
                //validate data / metadata (recode maps)
                checkDataType(false, "transformapply", TF_FN_PARAM_DATA, 
DataType.FRAME, conditional);
                checkDataType(false, "transformapply", TF_FN_PARAM_MTD2, 
DataType.FRAME, conditional);
-               
+
                //validate specification
                checkDataValueType(false, "transformapply", TF_FN_PARAM_SPEC, 
DataType.SCALAR, ValueType.STRING, conditional);
                validateTransformSpec(TF_FN_PARAM_SPEC, conditional);
-               
+
+               //validate additional argument for word_embeddings tranform
+               checkDataType(true, "transformapply", TF_FN_PARAM_EMBD, 
DataType.MATRIX, conditional);
+
                //set output dimensions
                output.setDataType(DataType.MATRIX);
                output.setValueType(ValueType.FP64);
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
index 44f8846ea9..719ad3a9cd 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
@@ -178,7 +178,7 @@ public class DenseBlockFP64 extends DenseBlockDRB
                System.arraycopy(v, 0, _data, pos(r), _odims[0]);
                return this;
        }
-       
+
        @Override
        public DenseBlock set(int[] ix, double v) {
                _data[pos(ix)] = v;
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 9dfbdbec7f..18a199e930 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -54,7 +54,11 @@ import 
org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
 import org.apache.sysds.runtime.util.AutoDiff;
 import org.apache.sysds.runtime.util.DataConverter;
 
-import java.util.*;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
 import java.util.stream.Collectors;
 import java.util.stream.IntStream;
 
@@ -310,11 +314,12 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        // acquire locks
                        FrameBlock data = 
ec.getFrameInput(params.get("target"));
                        FrameBlock meta = ec.getFrameInput(params.get("meta"));
+                       MatrixBlock embeddings = params.get("embedding") != 
null ? ec.getMatrixInput(params.get("embedding")) : null;
                        String[] colNames = data.getColumnNames();
 
                        // compute transformapply
                        MultiColumnEncoder encoder = EncoderFactory
-                               .createEncoder(params.get("spec"), colNames, 
data.getNumColumns(), meta);
+                               .createEncoder(params.get("spec"), colNames, 
data.getNumColumns(), meta, embeddings);
                        MatrixBlock mbout = encoder.apply(data, 
OptimizerUtils.getTransformNumThreads());
 
                        // release locks
@@ -346,7 +351,7 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
 
                        // compute transformapply
                        MultiColumnEncoder encoder = EncoderFactory
-                               .createEncoder(params.get("spec"), colNames, 
meta.getNumColumns(), null);
+                               .createEncoder(params.get("spec"), colNames, 
meta.getNumColumns(), null, null);
                        MatrixBlock mbout = encoder.getColMapping(meta);
 
                        // release locks
@@ -532,6 +537,8 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        CPOperand target = new CPOperand(params.get("target"), 
ValueType.FP64, DataType.FRAME);
                        CPOperand meta = getLiteral("meta", ValueType.UNKNOWN, 
DataType.FRAME);
                        CPOperand spec = getStringLiteral("spec");
+                       //FIXME: Taking only spec file name as a literal leads 
to wrong reuse
+                       //TODO: Add Embedding to the lineage item
                        return Pair.of(output.getName(),
                                new LineageItem(getOpcode(), 
LineageItemUtils.getLineage(ec, target, meta, spec)));
                }
diff --git a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java 
b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
index ec4758a819..b264004b61 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/TfUtils.java
@@ -47,7 +47,7 @@ public class TfUtils implements Serializable
        
        //transform methods
        public enum TfMethod {
-               IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT;
+               IMPUTE, RECODE, HASH, BIN, DUMMYCODE, UDF, OMIT, WORD_EMBEDDING;
                @Override
                public String toString() {
                        return name().toLowerCase();
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
index 610e0cc414..3020553e71 100644
--- a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
+++ b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoder.java
@@ -65,8 +65,13 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
        protected int _nBuildPartitions = 0;
        protected int _nApplyPartitions = 0;
 
+       //Override in ColumnEncoderWordEmbedding
+       public void initEmbeddings(MatrixBlock embeddings){
+               return;
+       }
+
        protected enum TransformType{
-               BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, N_A
+               BIN, RECODE, DUMMYCODE, FEATURE_HASH, PASS_THROUGH, UDF, 
WORD_EMBEDDING, N_A
        }
 
        protected ColumnEncoder(int colID) {
@@ -106,6 +111,9 @@ public abstract class ColumnEncoder implements Encoder, 
Comparable<ColumnEncoder
                                case DUMMYCODE:
                                        
TransformStatistics.incDummyCodeApplyTime(t);
                                        break;
+                               case WORD_EMBEDDING:
+                                       
TransformStatistics.incWordEmbeddingApplyTime(t);
+                                       break;
                                case FEATURE_HASH:
                                        
TransformStatistics.incFeatureHashingApplyTime(t);
                                        break;
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
index 6f18263a26..fd69d5bf26 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderComposite.java
@@ -319,6 +319,12 @@ public class ColumnEncoderComposite extends ColumnEncoder {
                        columnEncoder.initMetaData(out);
        }
 
+       //pass down init to actual encoders, only ColumnEncoderWordEmbedding 
has actually implemented the init method
+       public void initEmbeddings(MatrixBlock embeddings){
+               for(ColumnEncoder columnEncoder : _columnEncoders)
+                       columnEncoder.initEmbeddings(embeddings);
+       }
+
        @Override
        public String toString() {
                StringBuilder sb = new StringBuilder();
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
new file mode 100644
index 0000000000..03584cf5ee
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.runtime.transform.encode;
+
+import org.apache.commons.lang.NotImplementedException;
+import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+
+import static org.apache.sysds.runtime.util.UtilFunctions.getEndIndex;
+
+public class ColumnEncoderWordEmbedding extends ColumnEncoder {
+    private MatrixBlock wordEmbeddings;
+
+    //domain size is equal to the number columns of the embedding column 
(equal to length of an embedding vector)
+    @Override
+    public int getDomainSize(){
+        return wordEmbeddings.getNumColumns();
+    }
+    protected ColumnEncoderWordEmbedding(int colID) {
+        super(colID);
+    }
+
+    @Override
+    protected double getCode(CacheBlock<?> in, int row) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    protected double[] getCodeCol(CacheBlock<?> in, int startInd, int blkSize) 
{
+        throw new NotImplementedException();
+    }
+
+    //previous recode replaced strings with indices of the corresponding 
matrix row index
+    //now, the indices are replaced with actual word embedding vectors
+    //current limitation: in case the transform is done on multiple cols, the 
same embedding
+    //matrix is used for both transform
+    @Override
+    public void applyDense(CacheBlock<?> in, MatrixBlock out, int outputCol, 
int rowStart, int blk){
+        if (!(in instanceof MatrixBlock)){
+            throw new DMLRuntimeException("ColumnEncoderWordEmbedding called 
with: " + in.getClass().getSimpleName() +
+                    " and not MatrixBlock");
+        }
+        int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
+        //map each recoded index to the corresponding embedding vector
+        for(int i=rowStart; i<rowEnd; i++){
+            double embeddingIndex = in.getDouble(i, outputCol);
+            //fill row with zeroes
+            if(Double.isNaN(embeddingIndex)){
+                for (int j = outputCol; j < outputCol + getDomainSize(); j++)
+                    out.quickSetValue(i, j, 0.0);
+            }
+            //array copy
+            else{
+                for (int j = outputCol; j < outputCol + getDomainSize(); j++){
+                    out.quickSetValue(i, j, wordEmbeddings.quickGetValue((int) 
embeddingIndex - 1,j - outputCol ));
+                }
+            }
+        }
+    }
+
+
+    @Override
+    protected TransformType getTransformType() {
+        return TransformType.WORD_EMBEDDING;
+    }
+
+    @Override
+    public void build(CacheBlock<?> in) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public void allocateMetaData(FrameBlock meta) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public FrameBlock getMetaData(FrameBlock out) {
+        throw new NotImplementedException();
+    }
+
+    @Override
+    public void initMetaData(FrameBlock meta) {
+        return;
+    }
+
+    //save embeddings matrix reference for apply step
+    @Override
+    public void initEmbeddings(MatrixBlock embeddings){
+        this.wordEmbeddings = embeddings;
+    }
+}
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
index 075b6fbdd4..313258831a 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/EncoderFactory.java
@@ -36,6 +36,7 @@ import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.frame.data.FrameBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
 import org.apache.sysds.runtime.transform.TfUtils.TfMethod;
 import org.apache.sysds.runtime.transform.encode.ColumnEncoder.EncoderType;
 import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
@@ -68,7 +69,21 @@ public interface EncoderFactory {
        }
 
        public static MultiColumnEncoder createEncoder(String spec, String[] 
colnames, ValueType[] schema, FrameBlock meta,
-               int minCol, int maxCol) {
+               int minCol, int maxCol){
+               return createEncoder(spec, colnames, schema, meta, null, 
minCol, maxCol);
+       }
+
+       public static MultiColumnEncoder createEncoder(String spec, String[] 
colnames, int clen, FrameBlock meta, MatrixBlock embeddings) {
+               return createEncoder(spec, colnames, 
UtilFunctions.nCopies(clen, ValueType.STRING), meta, embeddings);
+       }
+
+       public static MultiColumnEncoder createEncoder(String spec, String[] 
colnames, ValueType[] schema,
+                                                                               
                   FrameBlock meta, MatrixBlock embeddings) {
+               return createEncoder(spec, colnames, schema, meta, embeddings, 
-1, -1);
+       }
+
+       public static MultiColumnEncoder createEncoder(String spec, String[] 
colnames, ValueType[] schema, FrameBlock meta,
+               MatrixBlock embeddings, int minCol, int maxCol) {
                MultiColumnEncoder encoder;
                int clen = schema.length;
 
@@ -88,9 +103,18 @@ public interface EncoderFactory {
                        List<Integer> dcIDs = Arrays.asList(ArrayUtils
                                .toObject(TfMetaUtils.parseJsonIDList(jSpec, 
colnames, TfMethod.DUMMYCODE.toString(), minCol, maxCol)));
                        List<Integer> binIDs = 
TfMetaUtils.parseBinningColIDs(jSpec, colnames, minCol, maxCol);
+                       List<Integer> weIDs = Arrays.asList(ArrayUtils
+                                       
.toObject(TfMetaUtils.parseJsonIDList(jSpec, colnames, 
TfMethod.WORD_EMBEDDING.toString(), minCol, maxCol)));
+
+                       //check if user passed an embeddings matrix
+                       if(!weIDs.isEmpty() && embeddings == null)
+                               throw new DMLRuntimeException("Missing argument 
Embeddings Matrix for transform [" + TfMethod.WORD_EMBEDDING + "]");
+
                        // NOTE: any dummycode column requires recode as 
preparation, unless the dummycode
                        // column follows binning or feature hashing
                        rcIDs = unionDistinct(rcIDs, except(except(dcIDs, 
binIDs), haIDs));
+                       // NOTE: Word Embeddings requires recode as preparation
+                       rcIDs = unionDistinct(rcIDs, weIDs);
                        // Error out if the first level encoders have overlaps
                        if (intersect(rcIDs, binIDs, haIDs))
                                throw new DMLRuntimeException("More than one 
encoders (recode, binning, hashing) on one column is not allowed");
@@ -114,7 +138,9 @@ public interface EncoderFactory {
                        if(!ptIDs.isEmpty())
                                for(Integer id : ptIDs)
                                        addEncoderToMap(new 
ColumnEncoderPassThrough(id), colEncoders);
-                       
+                       if(!weIDs.isEmpty())
+                               for(Integer id : weIDs)
+                                       addEncoderToMap(new 
ColumnEncoderWordEmbedding(id), colEncoders);
                        if(!binIDs.isEmpty())
                                for(Object o : (JSONArray) 
jSpec.get(TfMethod.BIN.toString())) {
                                        JSONObject colspec = (JSONObject) o;
@@ -185,6 +211,9 @@ public interface EncoderFactory {
                                }
                                encoder.initMetaData(meta);
                        }
+                       //initialize embeddings matrix block in the encoders in 
case word embedding transform is used
+                       if(!weIDs.isEmpty())
+                               encoder.initEmbeddings(embeddings);
                }
                catch(Exception ex) {
                        throw new DMLRuntimeException(ex);
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 6838cdd1e2..59c22f5640 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -314,10 +314,12 @@ public class MultiColumnEncoder implements Encoder {
        public MatrixBlock apply(CacheBlock<?> in, int k) {
                // domain sizes are not updated if called from transformapply
                boolean hasUDF = _columnEncoders.stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderUDF.class));
+               boolean hasWE = _columnEncoders.stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderWordEmbedding.class));
                for(ColumnEncoderComposite columnEncoder : _columnEncoders)
                        columnEncoder.updateAllDCEncoders();
                int numCols = getNumOutCols();
-               long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : 
(long) in.getNumColumns());
+               long estNNz = (long) in.getNumRows() * (hasUDF ? numCols : 
hasWE ? getEstNNzRow() : (long) in.getNumColumns());
+               // FIXME: estimate nnz for multiple encoders including 
dummycode and embedding
                boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(in.getNumRows(), numCols, estNNz) && 
!hasUDF;
                MatrixBlock out = new MatrixBlock(in.getNumRows(), numCols, 
sparse, estNNz);
                return apply(in, out, 0, k);
@@ -353,8 +355,7 @@ public class MultiColumnEncoder implements Encoder {
                        int offset = outputCol;
                        for(ColumnEncoderComposite columnEncoder : 
_columnEncoders) {
                                columnEncoder.apply(in, out, 
columnEncoder._colID - 1 + offset);
-                               if 
(columnEncoder.hasEncoder(ColumnEncoderDummycode.class))
-                                       offset += 
columnEncoder.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+                               offset = getOffset(offset, columnEncoder);
                        }
                }
                // Recomputing NNZ since we access the block directly
@@ -373,12 +374,19 @@ public class MultiColumnEncoder implements Encoder {
                int offset = outputCol;
                for(ColumnEncoderComposite e : _columnEncoders) {
                        tasks.addAll(e.getApplyTasks(in, out, e._colID - 1 + 
offset));
-                       if(e.hasEncoder(ColumnEncoderDummycode.class))
-                               offset += 
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+                       offset = getOffset(offset, e);
                }
                return tasks;
        }
 
+       private int getOffset(int offset, ColumnEncoderComposite e) {
+               if(e.hasEncoder(ColumnEncoderDummycode.class))
+                       offset += 
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+               if(e.hasEncoder(ColumnEncoderWordEmbedding.class))
+                       offset += 
e.getEncoder(ColumnEncoderWordEmbedding.class).getDomainSize() - 1;
+               return offset;
+       }
+
        private void applyMT(CacheBlock<?> in, MatrixBlock out, int outputCol, 
int k) {
                DependencyThreadPool pool = new DependencyThreadPool(k);
                try {
@@ -386,8 +394,7 @@ public class MultiColumnEncoder implements Encoder {
                                int offset = outputCol;
                                for (ColumnEncoderComposite e : 
_columnEncoders) {
                                        
pool.submitAllAndWait(e.getApplyTasks(in, out, e._colID - 1 + offset));
-                                       if 
(e.hasEncoder(ColumnEncoderDummycode.class))
-                                               offset += 
e.getEncoder(ColumnEncoderDummycode.class)._domainSize - 1;
+                                       offset = getOffset(offset, e);
                                }
                        } else
                                pool.submitAllAndWait(getApplyTasks(in, out, 
outputCol));
@@ -696,6 +703,12 @@ public class MultiColumnEncoder implements Encoder {
                        _legacyMVImpute.initMetaData(meta);
        }
 
+       //pass down init to composite encoders
+       public void initEmbeddings(MatrixBlock embeddings) {
+               for(ColumnEncoder columnEncoder : _columnEncoders)
+                       columnEncoder.initEmbeddings(embeddings);
+       }
+
        @Override
        public void prepareBuildPartial() {
                for(Encoder encoder : _columnEncoders)
@@ -855,6 +868,13 @@ public class MultiColumnEncoder implements Encoder {
                return getEncoderTypes(-1);
        }
 
+       public int getEstNNzRow(){
+               int nnz = 0;
+               for(int i = 0; i < _columnEncoders.size(); i++)
+                       nnz += _columnEncoders.get(i).getDomainSize();
+               return nnz;
+       }
+
        public int getNumOutCols() {
                int sum = 0;
                for(int i = 0; i < _columnEncoders.size(); i++)
diff --git 
a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java 
b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
index b7779e4ee1..9ace729462 100644
--- a/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
+++ b/src/main/java/org/apache/sysds/utils/stats/TransformStatistics.java
@@ -32,6 +32,8 @@ public class TransformStatistics {
        //private static final LongAdder applyTime = new LongAdder();
        private static final LongAdder recodeApplyTime = new LongAdder();
        private static final LongAdder dummyCodeApplyTime = new LongAdder();
+
+       private static final LongAdder wordEmbeddingApplyTime = new LongAdder();
        private static final LongAdder passThroughApplyTime = new LongAdder();
        private static final LongAdder featureHashingApplyTime = new 
LongAdder();
        private static final LongAdder binningApplyTime = new LongAdder();
@@ -55,6 +57,10 @@ public class TransformStatistics {
                dummyCodeApplyTime.add(t);
        }
 
+       public static void incWordEmbeddingApplyTime(long t){
+               wordEmbeddingApplyTime.add(t);
+       }
+
        public static void incBinningApplyTime(long t) {
                binningApplyTime.add(t);
        }
@@ -112,7 +118,7 @@ public class TransformStatistics {
                return dummyCodeApplyTime.longValue() + 
binningApplyTime.longValue() +
                                featureHashingApplyTime.longValue() + 
passThroughApplyTime.longValue() +
                                recodeApplyTime.longValue() + 
UDFApplyTime.longValue() +
-                               omitApplyTime.longValue() + 
imputeApplyTime.longValue();
+                               omitApplyTime.longValue() + 
imputeApplyTime.longValue() + wordEmbeddingApplyTime.longValue();
        }
 
        public static void reset() {
@@ -163,6 +169,9 @@ public class TransformStatistics {
                        if(dummyCodeApplyTime.longValue() > 0)
                                sb.append("\tDummyCode apply 
time:\t").append(String.format("%.3f",
                                        
dummyCodeApplyTime.longValue()*1e-9)).append(" sec.\n");
+                       if(wordEmbeddingApplyTime.longValue() > 0)
+                               sb.append("\tWordEmbedding apply 
time:\t").append(String.format("%.3f",
+                                               
wordEmbeddingApplyTime.longValue()*1e-9)).append(" sec.\n");
                        if(featureHashingApplyTime.longValue() > 0)
                                sb.append("\tHashing apply 
time:\t").append(String.format("%.3f",
                                        
featureHashingApplyTime.longValue()*1e-9)).append(" sec.\n");
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
new file mode 100644
index 0000000000..8ab52d9f64
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.transform;
+
+import org.apache.sysds.common.Types.ExecMode;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.BufferedWriter;
+import java.io.FileWriter;
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Date;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+
+public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
+{
+    private final static String TEST_NAME1 = 
"TransformFrameEncodeWordEmbeddings2";
+    private final static String TEST_NAME2 = 
"TransformFrameEncodeWordEmbeddings2MultiCols1";
+    private final static String TEST_NAME3 = 
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+
+    private final static String TEST_DIR = "functions/transform/";
+    private final static String TEST_CLASS_DIR = TEST_DIR + 
TransformFrameEncodeWordEmbeddingTest.class.getSimpleName() + "/";
+
+    @Override
+    public void setUp() {
+        TestUtils.clearAssertionInformation();
+        addTestConfiguration(TEST_NAME1, new TestConfiguration(TEST_DIR, 
TEST_NAME1));
+        addTestConfiguration(TEST_NAME2, new TestConfiguration(TEST_DIR, 
TEST_NAME2));
+        addTestConfiguration(TEST_NAME3, new TestConfiguration(TEST_DIR, 
TEST_NAME3));
+    }
+
+    @Test
+    public void testTransformToWordEmbeddings() {
+        runTransformTest(TEST_NAME1, ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    @Ignore
+    public void testNonRandomTransformToWordEmbeddings2Cols() {
+        runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+    }
+
+    @Test
+    @Ignore
+    public void testRandomTransformToWordEmbeddings4Cols() {
+        runTransformTestMultiCols(TEST_NAME3, ExecMode.SINGLE_NODE);
+    }
+
+    private void runTransformTest(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 100;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
320);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            double[][] resultActualDouble = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            //System.out.println("Actual Result [" + resultActualDouble.length 
+ "x" + resultActualDouble[0].length + "]:");
+            //print2DimDoubleArray(resultActualDouble);
+            //System.out.println("\nExpected Result [" + res_expected.length + 
"x" + res_expected[0].length + "]:");
+            //print2DimDoubleArray(res_expected);
+            TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private void print2DimDoubleArray(double[][] resultActualDouble) {
+        Arrays.stream(resultActualDouble).forEach(
+                e -> System.out.println(Arrays.stream(e).mapToObj(d -> 
String.format("%06.1f", d))
+                        .reduce("", (sub, elem) -> sub + " " + elem)));
+    }
+
+    private void runTransformTestMultiCols(String testname, ExecMode rt)
+    {
+        //set runtime platform
+        ExecMode rtold = setExecMode(rt);
+        try
+        {
+            int rows = 100;
+            int cols = 100;
+            getAndLoadTestConfiguration(testname);
+            fullDMLScriptName = getScript();
+
+            // Generate random embeddings for the distinct tokens
+            double[][] a = createRandomMatrix("embeddings", rows, cols, 0, 10, 
1, new Date().getTime());
+
+            // Generate random distinct tokens
+            List<String> strings = generateRandomStrings(rows, 10);
+
+            // Generate the dictionary by assigning unique ID to each distinct 
token
+            Map<String,Integer> map = writeDictToCsvFile(strings, 
baseDirectory + INPUT_DIR + "dict");
+
+            // Create the dataset by repeating and shuffling the distinct 
tokens
+            List<String> stringsColumn = shuffleAndMultiplyStrings(strings, 
10);
+            writeStringsToCsvFile(stringsColumn, baseDirectory + INPUT_DIR + 
"data");
+
+            //run script
+            programArgs = new String[]{"-stats","-args", input("embeddings"), 
input("data"), input("dict"), output("result"), output("result2")};
+            runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+            // Manually derive the expected result
+            double[][] res_expected = manuallyDeriveWordEmbeddings(cols, a, 
map, stringsColumn);
+
+            // Compare results
+            HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+            HashMap<MatrixValue.CellIndex, Double> res_actual2 = 
readDMLMatrixFromOutputDir("result2");
+            double[][] resultActualDouble  = 
TestUtils.convertHashMapToDoubleArray(res_actual);
+            double[][] resultActualDouble2 = 
TestUtils.convertHashMapToDoubleArray(res_actual2);
+            //System.out.println("Actual Result1 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+            ///print2DimDoubleArray(resultActualDouble);
+            //System.out.println("\nActual Result2 [" + 
resultActualDouble.length + "x" + resultActualDouble[0].length + "]:");
+            //print2DimDoubleArray(resultActualDouble2);
+            //System.out.println("\nExpected Result [" + res_expected.length + 
"x" + res_expected[0].length + "]:");
+            //print2DimDoubleArray(res_expected);
+            TestUtils.compareMatrices(resultActualDouble, res_expected, 1e-6);
+            TestUtils.compareMatrices(resultActualDouble, resultActualDouble2, 
1e-6);
+        }
+        catch(Exception ex) {
+            throw new RuntimeException(ex);
+
+        }
+        finally {
+            resetExecMode(rtold);
+        }
+    }
+
+    private double[][] manuallyDeriveWordEmbeddings(int cols, double[][] a, 
Map<String, Integer> map, List<String> stringsColumn) {
+        // Manually derive the expected result
+        double[][] res_expected = new double[stringsColumn.size()][cols];
+        for (int i = 0; i < stringsColumn.size(); i++) {
+            int rowMapped = map.get(stringsColumn.get(i));
+            System.arraycopy(a[rowMapped], 0, res_expected[i], 0, cols);
+        }
+        return res_expected;
+    }
+
+    private double[][] generateWordEmbeddings(int rows, int cols) {
+        double[][] a = new double[rows][cols];
+        for (int i = 0; i < a.length; i++) {
+            for (int j = 0; j < a[i].length; j++) {
+                a[i][j] = cols *i + j;
+            }
+
+        }
+        return a;
+    }
+
+    public static List<String> shuffleAndMultiplyStrings(List<String> strings, 
int multiply){
+        List<String> out = new ArrayList<>();
+        Random random = new Random();
+        for (int i = 0; i < strings.size()*multiply; i++) {
+            out.add(strings.get(random.nextInt(strings.size())));
+        }
+        return out;
+    }
+
+    public static List<String> generateRandomStrings(int numStrings, int 
stringLength) {
+        List<String> randomStrings = new ArrayList<>();
+        Random random = new Random();
+        String characters = 
"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789";
+        for (int i = 0; i < numStrings; i++) {
+            randomStrings.add(generateRandomString(random, stringLength, 
characters));
+        }
+        return randomStrings;
+    }
+
+    public static String generateRandomString(Random random, int stringLength, 
String characters){
+        StringBuilder randomString = new StringBuilder();
+        for (int j = 0; j < stringLength; j++) {
+            int randomIndex = random.nextInt(characters.length());
+            randomString.append(characters.charAt(randomIndex));
+        }
+        return randomString.toString();
+    }
+
+    public static void writeStringsToCsvFile(List<String> strings, String 
fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            for (String line : strings) {
+                bw.write(line);
+                bw.newLine();
+            }
+        } catch (IOException e) {
+            e.printStackTrace();
+        }
+    }
+
+    public static Map<String,Integer> writeDictToCsvFile(List<String> strings, 
String fileName) {
+        try (BufferedWriter bw = new BufferedWriter(new FileWriter(fileName))) 
{
+            Map<String,Integer> map = new HashMap<>();
+            for (int i = 0; i < strings.size(); i++) {
+                map.put(strings.get(i), i);
+                bw.write(strings.get(i) + Lop.DATATYPE_PREFIX + (i+1) + "\n");
+            }
+            return map;
+        } catch (IOException e) {
+            e.printStackTrace();
+            return null;
+        }
+    }
+}
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
new file mode 100644
index 0000000000..29a4bfab74
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2.dml
@@ -0,0 +1,36 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+jspec = "{ids: true, word_embedding: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+write(Data_enc, $4, format="text");
+
+
+
+
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
new file mode 100644
index 0000000000..00484697d6
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
@@ -0,0 +1,43 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+DataExtension = as.frame(matrix(1, rows=length(Data), cols=1))
+Data = cbind(Data, DataExtension)
+Data = cbind(DataExtension, Data)
+Meta = cbind(Meta, Meta)
+
+jspec = "{ids: true, word_embedding: [2]}";
+#jspec = "{ids: true, dummycode: [2]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+Data_enc = Data_enc[,2:101]
+write(Data_enc, $4, format="text");
+
+
+
+
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
new file mode 100644
index 0000000000..fd742520e7
--- /dev/null
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols2.dml
@@ -0,0 +1,44 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# Read the pre-trained word embeddings
+E = read($1, rows=100, cols=100, format="text");
+# Read the token sequence (1K) w/ 100 distinct tokens
+Data = read($2, data_type="frame", format="csv");
+# Read the recode map for the distinct tokens
+Meta = read($3, data_type="frame", format="csv");
+
+DataExtension = as.frame(matrix(1, rows=length(Data), cols=1))
+Data = cbind(Data, DataExtension)
+Data = cbind(Data, Data)
+Meta = cbind(Meta, Meta)
+Meta = cbind(Meta, Meta)
+
+jspec = "{ids: true, word_embedding: [1,3]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
+
+Data_enc1 = Data_enc[,1:100]
+Data_enc2 = Data_enc[,102:201]
+write(Data_enc1, $4, format="text");
+write(Data_enc2, $5, format="text");
+
+
+


Reply via email to