This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 6ab8540206 [SYSTEMDS-3657] Improved word embedding encoder / dedup 
blocks
6ab8540206 is described below

commit 6ab85402066e15681377a271669d5bcf1d58596a
Author: e-strauss <[email protected]>
AuthorDate: Sat Jan 13 15:57:14 2024 +0100

    [SYSTEMDS-3657] Improved word embedding encoder / dedup blocks
    
    - Bug fix for memory estimates of dedup block
    - optimised recalc of nnz for dedup block
    - add stats for spark broadcast for transformapply encoder
    - added accurate mem estimate for transform apply's wordembedding in 
ParameterizedBuiltinOp.java
    
    Closes #1942.
---
 .../apache/sysds/hops/ParameterizedBuiltinOp.java  | 29 ++++++++++++
 .../apache/sysds/runtime/data/DenseBlockFP64.java  |  2 +-
 .../sysds/runtime/data/DenseBlockFP64DEDUP.java    | 23 +++++++--
 .../cp/ParameterizedBuiltinCPInstruction.java      |  2 +
 .../spark/ParameterizedBuiltinSPInstruction.java   | 11 ++++-
 .../spark/utils/FrameRDDAggregateUtils.java        |  6 ++-
 .../sysds/runtime/matrix/data/MatrixBlock.java     | 26 ++++++----
 .../encode/ColumnEncoderWordEmbedding.java         | 55 ++++++----------------
 .../transform/encode/MultiColumnEncoder.java       | 25 +++++++---
 .../TransformFrameEncodeWordEmbedding1Test.java    |  5 +-
 .../TransformFrameEncodeWordEmbedding2Test.java    |  5 ++
 .../TransformFrameEncodeWordEmbeddings.dml         | 10 ++--
 12 files changed, 128 insertions(+), 71 deletions(-)

diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java 
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 01883e2f5d..964a60d528 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -47,6 +47,7 @@ import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.lops.ParameterizedBuiltin;
 import org.apache.sysds.parser.DMLProgram;
 import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
 import org.apache.sysds.runtime.meta.DataCharacteristics;
 import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -694,6 +695,34 @@ public class ParameterizedBuiltinOp extends 
MultiThreadedHop {
                
                return ret;
        }
+       @Override
+       public void computeMemEstimate(MemoTable memo){
+               if( _op == ParamBuiltinOp.TRANSFORMAPPLY){
+                       Hop spec = getParameterHop("spec");
+                       if(spec instanceof LiteralOp && ((LiteralOp) 
spec).getStringValue().contains("word_embedding")
+                               && memo.hasInputStatistics(this)){
+                               //Special case for WordEmbedding Operator
+                               //Step 1) Compute hop output memory estimate 
(incl size inference)
+                               DataCharacteristics idc = 
memo.getAllInputStats(getTargetHop());
+                               DataCharacteristics edc = 
memo.getAllInputStats(getParameterHop("embedding"));
+                               if (idc != null && edc != null && 
edc.dimsKnown() && idc.dimsKnown()) {
+                                       DataCharacteristics wdc = new 
MatrixCharacteristics(
+                                               idc.getRows(), edc.getCols(), 
-1, idc.getRows()*edc.getCols());
+                                       _outputMemEstimate = 
DenseBlockFP64DEDUP.estimateMemory(
+                                               wdc.getRows(), edc.getCols(), 
edc.getRows());
+
+                                       //propagate worst-case estimate
+                                       memo.memoizeStatistics(getHopID(), wdc);
+
+                                       //Step 2) Compute hop intermediate 
memory estimate
+                                       _processingMemEstimate = 
3*_outputMemEstimate; //Note Elias: factor needs to be adjusted
+                                       _memEstimate = getInputOutputSize();
+                                       return;
+                               }
+                       }
+               }
+               super.computeMemEstimate(memo);
+       }
        
        @Override 
        public boolean allowsAllExecTypes() {
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
index ac4e8955d3..f837e95820 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
@@ -83,7 +83,7 @@ public class DenseBlockFP64 extends DenseBlockDRB
        }
        
        public static double estimateMemory(long nrows, long ncols) {
-               if( (double)nrows + ncols > Long.MAX_VALUE )
+               if( (double)nrows * ncols > Long.MAX_VALUE )
                        return Long.MAX_VALUE;
                return DenseBlock.estimateMemory(nrows, ncols)
                        + MemoryEstimates.doubleArrayCost(nrows * ncols);
diff --git 
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java 
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
index 1a3c84fa4d..c9789a9e64 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -31,8 +31,12 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
 {
        private static final long serialVersionUID = -4012376952006079198L;
        private double[][] _data;
+       //TODO: implement estimator for nr of distinct
        private int _distinct = 0;
 
+       public void setDistinct(int d){
+               _distinct = d;
+       }
        protected DenseBlockFP64DEDUP(int[] dims) {
                super(dims);
                reset(_rlen, _odims, 0);
@@ -317,10 +321,19 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
                return UtilFunctions.toLong(get(ix[0], pos(ix)));
        }
 
-       public double estimateMemory(){
-               if( (double)_rlen + this._odims[0] > Long.MAX_VALUE )
+       public long estimateMemory(){
+               if( (double)_rlen * _odims[0] > Long.MAX_VALUE )
                        return Long.MAX_VALUE;
-               return DenseBlock.estimateMemory(_rlen, _odims[0])
-                               + 
MemoryEstimates.doubleArrayCost(_odims[0])*_distinct + 
MemoryEstimates.objectArrayCost(_rlen);
+               return estimateMemory(_rlen, _odims[0], _distinct);
+       }
+
+       public static long estimateMemory(int rows, int cols, int duplicates){
+               return estimateMemory((long) rows, (long)cols, (long) 
duplicates);
+       }
+
+       public static long estimateMemory(long rows, long cols, long 
duplicates){
+               return ((long) (DenseBlock.estimateMemory(rows, cols)))
+                               + ((long) 
MemoryEstimates.doubleArrayCost(cols)*duplicates) 
+                               + ((long) 
MemoryEstimates.objectArrayCost(rows));
        }
-}
+}
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 0307fbb03b..d0aea7bce9 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -336,6 +336,8 @@ public class ParameterizedBuiltinCPInstruction extends 
ComputationCPInstruction
                        ec.setMatrixOutput(output.getName(), mbout);
                        ec.releaseFrameInput(params.get("target"));
                        ec.releaseFrameInput(params.get("meta"));
+                       if(params.get("embedding") != null)
+                               ec.releaseMatrixInput(params.get("embedding"));
                }
                else if(opcode.equalsIgnoreCase("transformdecode")) {
                        // acquire locks
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 3b61b768b0..61e6e799f0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -32,6 +32,7 @@ import org.apache.spark.api.java.function.Function;
 import org.apache.spark.api.java.function.PairFlatMapFunction;
 import org.apache.spark.api.java.function.PairFunction;
 import org.apache.spark.broadcast.Broadcast;
+import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.common.Types;
 import org.apache.sysds.common.Types.CorrectionLocationType;
 import org.apache.sysds.common.Types.FileFormat;
@@ -96,6 +97,7 @@ import 
org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
 import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 
+import org.apache.sysds.utils.stats.SparkStatistics;
 import scala.Tuple2;
 
 public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction {
@@ -545,15 +547,22 @@ public class ParameterizedBuiltinSPInstruction extends 
ComputationSPInstruction
                                .createEncoder(params.get("spec"), colnames, 
fo.getSchema(), (int) fo.getNumColumns(), meta, embeddings);
                        encoder.updateAllDCEncoders();
                        mcOut.setDimension(mcIn.getRows() - ((omap != null) ? 
omap.getNumRmRows() : 0), encoder.getNumOutCols());
+
+                       long t0 = System.nanoTime();
                        Broadcast<MultiColumnEncoder> bmeta = 
sec.getSparkContext().broadcast(encoder);
                        Broadcast<TfOffsetMap> bomap = (omap != null) ? 
sec.getSparkContext().broadcast(omap) : null;
+                       if (DMLScript.STATISTICS) {
+                               
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
+                               SparkStatistics.incBroadcastCount(1);
+                       }
 
                        // execute transform apply
                        JavaPairRDD<MatrixIndexes, MatrixBlock> out;
                        Tuple2<Boolean, Integer> aligned = 
FrameRDDAggregateUtils.checkRowAlignment(in, -1);
                        // NOTE: currently disabled for LegacyEncoders, because 
OMIT probably results in not aligned
                        // blocks and for IMPUTE was an inaccuracy for the 
"testHomesImputeColnamesSparkCSV" test case.
-                       // Expected: 8.150349617004395 vs actual: 8.15035 at 0 
8 (expected is calculated from transform encode,
+
+                       // Error in test case: Expected: 8.150349617004395 vs 
actual: 8.15035 at 0 8 (expected is calculated from transform encode,
                        // which currently always uses the else branch: either 
inaccuracy must come from serialisation of
                        // matrixblock or from binaryBlockToBinaryBlock reblock
                        if(aligned._1 && mcOut.getCols() <= aligned._2 && 
!encoder.hasLegacyEncoder() /*&& containsWE*/) {
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
index e77c2209ea..08b139061b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
@@ -43,8 +43,10 @@ public class FrameRDDAggregateUtils
                                return in2;
                        if (in2 == null)
                                return in1;
-                       if (!in1._1() || !in2._1())
-                               return new Tuple5<>(false, null, null, null, 
null);
+                       if (!in1._1() )
+                               return in1;
+                       if (!in2._1() )
+                               return in2;
 
                        //default evaluation
                        int in1_max = in1._3();
diff --git 
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java 
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 6e3ad9f8b9..84ff9b7c52 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -1399,13 +1399,23 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
                        final ExecutorService pool = CommonThreadPool.get(k);
                        try {
                                List<Future<Long>> f = new ArrayList<>();
-                               final int bz = 1000;
-                               for(int i = 0; i < rlen; i += bz) {
-                                       for(int ii = 0; ii < clen; ii += bz) {
+                               if(denseBlock instanceof DenseBlockFP64DEDUP){
+                                       int bz = (int) Math.ceil(((double) 
rlen) / k*2);
+                                       for(int i = 0; i < rlen; i += bz) {
                                                final int j = i;
-                                               final int jj = ii;
-                                               f.add(pool.submit(() -> //
-                                               recomputeNonZeros(j, Math.min(j 
+ bz, rlen) - 1, jj, Math.min(jj + bz, clen) - 1)));
+                                               f.add(pool.submit(() -> 
+                                                       
denseBlock.countNonZeros(j, Math.min(j + bz, rlen) -1, 0, clen -1)));
+                                       }
+                               }
+                               else {
+                                       final int bz = 1000;
+                                       for (int i = 0; i < rlen; i += bz) {
+                                               for (int ii = 0; ii < clen; ii 
+= bz) {
+                                                       final int j = i;
+                                                       final int jj = ii;
+                                                       f.add(pool.submit(() ->
+                                                               
recomputeNonZeros(j, Math.min(j + bz, rlen) - 1, jj, Math.min(jj + bz, clen) - 
1)));
+                                               }
                                        }
                                }
                                long nnz = 0;
@@ -2722,8 +2732,8 @@ public class MatrixBlock extends MatrixValue implements 
CacheBlock<MatrixBlock>,
        
        public long estimateSizeInMemory() {
                if (denseBlock instanceof DenseBlockFP64DEDUP) {
-                       double size = getHeaderSize() + ((DenseBlockFP64DEDUP) 
denseBlock).estimateMemory();
-                       return (long) Math.min(size, Long.MAX_VALUE);
+                       long size = getHeaderSize() + ((DenseBlockFP64DEDUP) 
denseBlock).estimateMemory();
+                       return Math.min(size, Long.MAX_VALUE);
                }
                return estimateSizeInMemory(rlen, clen, getSparsity());
        }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
index 8d862f8575..65fde02994 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -37,7 +37,7 @@ import java.util.concurrent.ConcurrentHashMap;
 public class ColumnEncoderWordEmbedding extends ColumnEncoder {
        private MatrixBlock _wordEmbeddings;
        private Map<Object, Long> _rcdMap;
-       private ConcurrentHashMap<String, double[]> _embMap;
+       private HashMap<String, double[]> _embMap;
 
        public ColumnEncoderWordEmbedding() {
                super(-1);
@@ -54,6 +54,10 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
        public int getDomainSize(){
                return _wordEmbeddings.getNumColumns();
        }
+
+       public int getNrDistinctEmbeddings(){
+               return _wordEmbeddings.getNumRows();
+       }
        protected ColumnEncoderWordEmbedding(int colID) {
                super(colID);
        }
@@ -78,50 +82,18 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
                        embedding[i] = this._wordEmbeddings.quickGetValue((int) 
r, _colID - 1 + i);
                }
                return embedding;
-
        }
 
        @Override
        public void applyDense(CacheBlock<?> in, MatrixBlock out, int 
outputCol, int rowStart, int blk){
                int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
-               if(blk == -1){
-                       HashMap<String, double[]> _embMapSingleThread = new 
HashMap<>();
-                       for(int i=rowStart; i<rowEnd; i++){
-                               String key = in.getString(i, _colID-1);
-                               if(key == null || key.isEmpty()) {
-                                       continue;
-                               }
-                               double[] embedding = 
_embMapSingleThread.get(key);
-                               if(embedding == null){
-                                       long code = lookupRCDMap(key);
-                                       if(code == -1L){
-                                               continue;
-                                       }
-                                       embedding = 
getEmbeddedingFromEmbeddingMatrix(code - 1);
-                                       _embMapSingleThread.put(key, embedding);
-                               }
-                               out.quickSetRow(i, embedding);
-                       }
-               }
-               else{
-                       //map each string to the corresponding embedding vector
-                       for(int i=rowStart; i<rowEnd; i++){
-                               String key = in.getString(i, _colID-1);
-                               if(key == null || key.isEmpty()) {
-                                       //codes[i-startInd] = Double.NaN;
-                                       continue;
-                               }
-                               double[] embedding = _embMap.get(key);
-                               if(embedding == null){
-                                       long code = lookupRCDMap(key);
-                                       if(code == -1L){
-                                               continue;
-                                       }
-                                       embedding = 
getEmbeddedingFromEmbeddingMatrix(code - 1);
-                                       _embMap.put(key, embedding);
-                               }
+               for(int i=rowStart; i<rowEnd; i++){
+                       String key = in.getString(i, _colID-1);
+                       if(key == null || key.isEmpty())
+                               continue;
+                       double[] embedding = _embMap.get(key);
+                       if(embedding != null)
                                out.quickSetRow(i, embedding);
-                       }
                }
        }
 
@@ -157,7 +129,8 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
        @Override
        public void initEmbeddings(MatrixBlock embeddings){
                this._wordEmbeddings = embeddings;
-               this._embMap = new ConcurrentHashMap<>();
+               this._embMap = new HashMap<>();
+               _rcdMap.forEach((word, index) -> _embMap.put((String) word, 
getEmbeddedingFromEmbeddingMatrix(index - 1)));
        }
 
        @Override
@@ -182,6 +155,6 @@ public class ColumnEncoderWordEmbedding extends 
ColumnEncoder {
                        _rcdMap.put(key, value);
                }
                _wordEmbeddings.readExternal(in);
-               this._embMap = new ConcurrentHashMap<>();
+               initEmbeddings(_wordEmbeddings);
        }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index bd9e2ba79f..59c5f2c973 100644
--- 
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -50,6 +50,7 @@ import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.compress.estim.ComEstSample;
 import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
 import org.apache.sysds.runtime.data.SparseBlock;
 import org.apache.sysds.runtime.data.SparseBlockCSR;
 import org.apache.sysds.runtime.data.SparseRowVector;
@@ -354,12 +355,16 @@ public class MultiColumnEncoder implements Encoder {
 
                boolean hasDC = false;
                boolean hasWE = false;
+               int distinctWE = 0;
                for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
                        hasDC |= 
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
-                       hasWE |= 
columnEncoder.hasEncoder(ColumnEncoderWordEmbedding.class);
+                       for (ColumnEncoder enc : columnEncoder.getEncoders())
+                               if(enc instanceof ColumnEncoderWordEmbedding){
+                                       hasWE = true;
+                                       distinctWE = 
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+                               }
                }
-               //hasWE = false;
-               outputMatrixPreProcessing(out, in, hasDC, hasWE);
+               outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE);
                if(k > 1) {
                        if(!_partitionDone) //happens if this method is 
directly called
                                deriveNumRowPartitions(in, k);
@@ -548,7 +553,7 @@ public class MultiColumnEncoder implements Encoder {
                return totMemOverhead;
        }
 
-       private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock<?> input, boolean hasDC, boolean hasWE) {
+       private static void outputMatrixPreProcessing(MatrixBlock output, 
CacheBlock<?> input, boolean hasDC, boolean hasWE, int distinctWE) {
                long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
                if(output.isInSparseFormat()) {
                        if (MatrixBlock.DEFAULT_SPARSEBLOCK != 
SparseBlock.Type.CSR
@@ -596,6 +601,8 @@ public class MultiColumnEncoder implements Encoder {
                else {
                        // Allocate dense block and set nnz to total #entries
                        output.allocateDenseBlock(true, hasWE);
+                       if( hasWE)
+                               ((DenseBlockFP64DEDUP) 
output.getDenseBlock()).setDistinct(distinctWE);
                        //output.setAllNonZeros();
                }
 
@@ -1150,13 +1157,19 @@ public class MultiColumnEncoder implements Encoder {
                @Override
                public Object call() throws Exception {
                        boolean hasUDF = 
_encoder.getColumnEncoders().stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderUDF.class));
-                       boolean hasWE = 
_encoder.getColumnEncoders().stream().anyMatch(e -> 
e.hasEncoder(ColumnEncoderWordEmbedding.class));
+                       boolean hasWE = false;
+                       int distinctWE = 0;
+                       for (ColumnEncoder enc : _encoder.getEncoders())
+                               if(enc instanceof ColumnEncoderWordEmbedding){
+                                       hasWE = true;
+                                       distinctWE = 
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+                               }
                        int numCols = _encoder.getNumOutCols();
                        boolean hasDC = 
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
                        long estNNz = (long) _input.getNumRows() * (hasUDF ? 
numCols : _input.getNumColumns());
                        boolean sparse = 
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) && 
!hasUDF;
                        _output.reset(_input.getNumRows(), numCols, sparse, 
estNNz);
-                       outputMatrixPreProcessing(_output, _input, hasDC, 
hasWE);
+                       outputMatrixPreProcessing(_output, _input, hasDC, 
hasWE, distinctWE);
                        return null;
                }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
index 25cb95b3a2..4375dcda3d 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.transform;
 
 import org.apache.sysds.common.Types.ExecMode;
 import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
@@ -90,8 +91,8 @@ public class TransformFrameEncodeWordEmbedding1Test extends 
AutomatedTestBase
                        }
 
                        // Compare results
-                       //HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
-                       
//TestUtils.compareMatrices(TestUtils.convertHashMapToDoubleArray(res_actual), 
res_expected, 1e-6);
+                       HashMap<MatrixValue.CellIndex, Double> res_actual = 
readDMLMatrixFromOutputDir("result");
+                       
TestUtils.compareMatrices(TestUtils.convertHashMapToDoubleArray(res_actual), 
res_expected, 1e-6);
                }
                catch(Exception ex) {
                        throw new RuntimeException(ex);
diff --git 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index 9d690be8b1..b994ae83c3 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -64,6 +64,11 @@ public class TransformFrameEncodeWordEmbedding2Test extends 
AutomatedTestBase
                runTransformTest(TEST_NAME1, ExecMode.SPARK);
        }
 
+       @Test
+       public void testTransformToWordEmbeddingsAuto() {
+               runTransformTest(TEST_NAME1, ExecMode.HYBRID);
+       }
+
        private void runTransformTest(String testname, ExecMode rt)
        {
                //set runtime platform
diff --git 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
index 227e9311dc..a358b2669b 100644
--- 
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
+++ 
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
@@ -28,15 +28,15 @@ Data = read($2, data_type="frame", format="csv");
 # Read the recode map for the distinct tokens
 Meta = read($3, data_type="frame", format="csv");
 
-jspec = "{ids: true, recode: [1]}";
+jspec = "{ids: true, dummycode: [1]}";
 #[Data_enc2, Meta2] = transformencode(target=Data, spec=jspec);
 
 Data_enc = transformapply(target=Data, spec=jspec, meta=Meta);
-print(nrow(Data_enc) + " x " + ncol(Data_enc))
-print(toString(Data_enc[1,1]))
+#print(nrow(Data_enc) + " x " + ncol(Data_enc))
+#print(toString(Data_enc[1,1]))
 
 # Apply the embeddings on all tokens (1K x 100)
-#R = Data_enc %*% E;
+R = Data_enc %*% E;
 
-#write(R, $4, format="text");
+write(R, $4, format="text");
 

Reply via email to