This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 6ab8540206 [SYSTEMDS-3657] Improved word embedding encoder / dedup
blocks
6ab8540206 is described below
commit 6ab85402066e15681377a271669d5bcf1d58596a
Author: e-strauss <[email protected]>
AuthorDate: Sat Jan 13 15:57:14 2024 +0100
[SYSTEMDS-3657] Improved word embedding encoder / dedup blocks
- Bug fix for memory estimates of dedup block
- optimised recalc of nnz for dedup block
- add stats for spark broadcast for transformapply encoder
- added accurate mem estimate for transform apply's wordembedding in
ParameterizedBuiltinOp.java
Closes #1942.
---
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 29 ++++++++++++
.../apache/sysds/runtime/data/DenseBlockFP64.java | 2 +-
.../sysds/runtime/data/DenseBlockFP64DEDUP.java | 23 +++++++--
.../cp/ParameterizedBuiltinCPInstruction.java | 2 +
.../spark/ParameterizedBuiltinSPInstruction.java | 11 ++++-
.../spark/utils/FrameRDDAggregateUtils.java | 6 ++-
.../sysds/runtime/matrix/data/MatrixBlock.java | 26 ++++++----
.../encode/ColumnEncoderWordEmbedding.java | 55 ++++++----------------
.../transform/encode/MultiColumnEncoder.java | 25 +++++++---
.../TransformFrameEncodeWordEmbedding1Test.java | 5 +-
.../TransformFrameEncodeWordEmbedding2Test.java | 5 ++
.../TransformFrameEncodeWordEmbeddings.dml | 10 ++--
12 files changed, 128 insertions(+), 71 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 01883e2f5d..964a60d528 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -47,6 +47,7 @@ import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.lops.ParameterizedBuiltin;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.Statement;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.instructions.cp.ParamservBuiltinCPInstruction;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
@@ -694,6 +695,34 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
return ret;
}
+ @Override
+ public void computeMemEstimate(MemoTable memo){
+ if( _op == ParamBuiltinOp.TRANSFORMAPPLY){
+ Hop spec = getParameterHop("spec");
+ if(spec instanceof LiteralOp && ((LiteralOp)
spec).getStringValue().contains("word_embedding")
+ && memo.hasInputStatistics(this)){
+ //Special case for WordEmbedding Operator
+ //Step 1) Compute hop output memory estimate
(incl size inference)
+ DataCharacteristics idc =
memo.getAllInputStats(getTargetHop());
+ DataCharacteristics edc =
memo.getAllInputStats(getParameterHop("embedding"));
+ if (idc != null && edc != null &&
edc.dimsKnown() && idc.dimsKnown()) {
+ DataCharacteristics wdc = new
MatrixCharacteristics(
+ idc.getRows(), edc.getCols(),
-1, idc.getRows()*edc.getCols());
+ _outputMemEstimate =
DenseBlockFP64DEDUP.estimateMemory(
+ wdc.getRows(), edc.getCols(),
edc.getRows());
+
+ //propagate worst-case estimate
+ memo.memoizeStatistics(getHopID(), wdc);
+
+ //Step 2) Compute hop intermediate
memory estimate
+ _processingMemEstimate =
3*_outputMemEstimate; //Note Elias: factor needs to be adjusted
+ _memEstimate = getInputOutputSize();
+ return;
+ }
+ }
+ }
+ super.computeMemEstimate(memo);
+ }
@Override
public boolean allowsAllExecTypes() {
diff --git a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
index ac4e8955d3..f837e95820 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64.java
@@ -83,7 +83,7 @@ public class DenseBlockFP64 extends DenseBlockDRB
}
public static double estimateMemory(long nrows, long ncols) {
- if( (double)nrows + ncols > Long.MAX_VALUE )
+ if( (double)nrows * ncols > Long.MAX_VALUE )
return Long.MAX_VALUE;
return DenseBlock.estimateMemory(nrows, ncols)
+ MemoryEstimates.doubleArrayCost(nrows * ncols);
diff --git
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
index 1a3c84fa4d..c9789a9e64 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -31,8 +31,12 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
{
private static final long serialVersionUID = -4012376952006079198L;
private double[][] _data;
+ //TODO: implement estimator for nr of distinct
private int _distinct = 0;
+ public void setDistinct(int d){
+ _distinct = d;
+ }
protected DenseBlockFP64DEDUP(int[] dims) {
super(dims);
reset(_rlen, _odims, 0);
@@ -317,10 +321,19 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
return UtilFunctions.toLong(get(ix[0], pos(ix)));
}
- public double estimateMemory(){
- if( (double)_rlen + this._odims[0] > Long.MAX_VALUE )
+ public long estimateMemory(){
+ if( (double)_rlen * _odims[0] > Long.MAX_VALUE )
return Long.MAX_VALUE;
- return DenseBlock.estimateMemory(_rlen, _odims[0])
- +
MemoryEstimates.doubleArrayCost(_odims[0])*_distinct +
MemoryEstimates.objectArrayCost(_rlen);
+ return estimateMemory(_rlen, _odims[0], _distinct);
+ }
+
+ public static long estimateMemory(int rows, int cols, int duplicates){
+ return estimateMemory((long) rows, (long)cols, (long)
duplicates);
+ }
+
+ public static long estimateMemory(long rows, long cols, long
duplicates){
+ return ((long) (DenseBlock.estimateMemory(rows, cols)))
+ + ((long)
MemoryEstimates.doubleArrayCost(cols)*duplicates)
+ + ((long)
MemoryEstimates.objectArrayCost(rows));
}
-}
+}
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
index 0307fbb03b..d0aea7bce9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/ParameterizedBuiltinCPInstruction.java
@@ -336,6 +336,8 @@ public class ParameterizedBuiltinCPInstruction extends
ComputationCPInstruction
ec.setMatrixOutput(output.getName(), mbout);
ec.releaseFrameInput(params.get("target"));
ec.releaseFrameInput(params.get("meta"));
+ if(params.get("embedding") != null)
+ ec.releaseMatrixInput(params.get("embedding"));
}
else if(opcode.equalsIgnoreCase("transformdecode")) {
// acquire locks
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
index 3b61b768b0..61e6e799f0 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/ParameterizedBuiltinSPInstruction.java
@@ -32,6 +32,7 @@ import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFlatMapFunction;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.broadcast.Broadcast;
+import org.apache.sysds.api.DMLScript;
import org.apache.sysds.common.Types;
import org.apache.sysds.common.Types.CorrectionLocationType;
import org.apache.sysds.common.Types.FileFormat;
@@ -96,6 +97,7 @@ import
org.apache.sysds.runtime.transform.tokenize.TokenizerFactory;
import org.apache.sysds.runtime.util.DataConverter;
import org.apache.sysds.runtime.util.UtilFunctions;
+import org.apache.sysds.utils.stats.SparkStatistics;
import scala.Tuple2;
public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction {
@@ -545,15 +547,22 @@ public class ParameterizedBuiltinSPInstruction extends
ComputationSPInstruction
.createEncoder(params.get("spec"), colnames,
fo.getSchema(), (int) fo.getNumColumns(), meta, embeddings);
encoder.updateAllDCEncoders();
mcOut.setDimension(mcIn.getRows() - ((omap != null) ?
omap.getNumRmRows() : 0), encoder.getNumOutCols());
+
+ long t0 = System.nanoTime();
Broadcast<MultiColumnEncoder> bmeta =
sec.getSparkContext().broadcast(encoder);
Broadcast<TfOffsetMap> bomap = (omap != null) ?
sec.getSparkContext().broadcast(omap) : null;
+ if (DMLScript.STATISTICS) {
+
SparkStatistics.accBroadCastTime(System.nanoTime() - t0);
+ SparkStatistics.incBroadcastCount(1);
+ }
// execute transform apply
JavaPairRDD<MatrixIndexes, MatrixBlock> out;
Tuple2<Boolean, Integer> aligned =
FrameRDDAggregateUtils.checkRowAlignment(in, -1);
// NOTE: currently disabled for LegacyEncoders, because
OMIT probably results in not aligned
// blocks and for IMPUTE was an inaccuracy for the
"testHomesImputeColnamesSparkCSV" test case.
- // Expected: 8.150349617004395 vs actual: 8.15035 at 0
8 (expected is calculated from transform encode,
+
+ // Error in test case: Expected: 8.150349617004395 vs
actual: 8.15035 at 0 8 (expected is calculated from transform encode,
// which currently always uses the else branch: either
inaccuracy must come from serialisation of
// matrixblock or from binaryBlockToBinaryBlock reblock
if(aligned._1 && mcOut.getCols() <= aligned._2 &&
!encoder.hasLegacyEncoder() /*&& containsWE*/) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
index e77c2209ea..08b139061b 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/spark/utils/FrameRDDAggregateUtils.java
@@ -43,8 +43,10 @@ public class FrameRDDAggregateUtils
return in2;
if (in2 == null)
return in1;
- if (!in1._1() || !in2._1())
- return new Tuple5<>(false, null, null, null,
null);
+ if (!in1._1() )
+ return in1;
+ if (!in2._1() )
+ return in2;
//default evaluation
int in1_max = in1._3();
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 6e3ad9f8b9..84ff9b7c52 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -1399,13 +1399,23 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
final ExecutorService pool = CommonThreadPool.get(k);
try {
List<Future<Long>> f = new ArrayList<>();
- final int bz = 1000;
- for(int i = 0; i < rlen; i += bz) {
- for(int ii = 0; ii < clen; ii += bz) {
+ if(denseBlock instanceof DenseBlockFP64DEDUP){
+ int bz = (int) Math.ceil(((double)
rlen) / k*2);
+ for(int i = 0; i < rlen; i += bz) {
final int j = i;
- final int jj = ii;
- f.add(pool.submit(() -> //
- recomputeNonZeros(j, Math.min(j
+ bz, rlen) - 1, jj, Math.min(jj + bz, clen) - 1)));
+ f.add(pool.submit(() ->
+
denseBlock.countNonZeros(j, Math.min(j + bz, rlen) -1, 0, clen -1)));
+ }
+ }
+ else {
+ final int bz = 1000;
+ for (int i = 0; i < rlen; i += bz) {
+ for (int ii = 0; ii < clen; ii
+= bz) {
+ final int j = i;
+ final int jj = ii;
+ f.add(pool.submit(() ->
+
recomputeNonZeros(j, Math.min(j + bz, rlen) - 1, jj, Math.min(jj + bz, clen) -
1)));
+ }
}
}
long nnz = 0;
@@ -2722,8 +2732,8 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
public long estimateSizeInMemory() {
if (denseBlock instanceof DenseBlockFP64DEDUP) {
- double size = getHeaderSize() + ((DenseBlockFP64DEDUP)
denseBlock).estimateMemory();
- return (long) Math.min(size, Long.MAX_VALUE);
+ long size = getHeaderSize() + ((DenseBlockFP64DEDUP)
denseBlock).estimateMemory();
+ return Math.min(size, Long.MAX_VALUE);
}
return estimateSizeInMemory(rlen, clen, getSparsity());
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
index 8d862f8575..65fde02994 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.java
@@ -37,7 +37,7 @@ import java.util.concurrent.ConcurrentHashMap;
public class ColumnEncoderWordEmbedding extends ColumnEncoder {
private MatrixBlock _wordEmbeddings;
private Map<Object, Long> _rcdMap;
- private ConcurrentHashMap<String, double[]> _embMap;
+ private HashMap<String, double[]> _embMap;
public ColumnEncoderWordEmbedding() {
super(-1);
@@ -54,6 +54,10 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
public int getDomainSize(){
return _wordEmbeddings.getNumColumns();
}
+
+ public int getNrDistinctEmbeddings(){
+ return _wordEmbeddings.getNumRows();
+ }
protected ColumnEncoderWordEmbedding(int colID) {
super(colID);
}
@@ -78,50 +82,18 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
embedding[i] = this._wordEmbeddings.quickGetValue((int)
r, _colID - 1 + i);
}
return embedding;
-
}
@Override
public void applyDense(CacheBlock<?> in, MatrixBlock out, int
outputCol, int rowStart, int blk){
int rowEnd = getEndIndex(in.getNumRows(), rowStart, blk);
- if(blk == -1){
- HashMap<String, double[]> _embMapSingleThread = new
HashMap<>();
- for(int i=rowStart; i<rowEnd; i++){
- String key = in.getString(i, _colID-1);
- if(key == null || key.isEmpty()) {
- continue;
- }
- double[] embedding =
_embMapSingleThread.get(key);
- if(embedding == null){
- long code = lookupRCDMap(key);
- if(code == -1L){
- continue;
- }
- embedding =
getEmbeddedingFromEmbeddingMatrix(code - 1);
- _embMapSingleThread.put(key, embedding);
- }
- out.quickSetRow(i, embedding);
- }
- }
- else{
- //map each string to the corresponding embedding vector
- for(int i=rowStart; i<rowEnd; i++){
- String key = in.getString(i, _colID-1);
- if(key == null || key.isEmpty()) {
- //codes[i-startInd] = Double.NaN;
- continue;
- }
- double[] embedding = _embMap.get(key);
- if(embedding == null){
- long code = lookupRCDMap(key);
- if(code == -1L){
- continue;
- }
- embedding =
getEmbeddedingFromEmbeddingMatrix(code - 1);
- _embMap.put(key, embedding);
- }
+ for(int i=rowStart; i<rowEnd; i++){
+ String key = in.getString(i, _colID-1);
+ if(key == null || key.isEmpty())
+ continue;
+ double[] embedding = _embMap.get(key);
+ if(embedding != null)
out.quickSetRow(i, embedding);
- }
}
}
@@ -157,7 +129,8 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
@Override
public void initEmbeddings(MatrixBlock embeddings){
this._wordEmbeddings = embeddings;
- this._embMap = new ConcurrentHashMap<>();
+ this._embMap = new HashMap<>();
+ _rcdMap.forEach((word, index) -> _embMap.put((String) word,
getEmbeddedingFromEmbeddingMatrix(index - 1)));
}
@Override
@@ -182,6 +155,6 @@ public class ColumnEncoderWordEmbedding extends
ColumnEncoder {
_rcdMap.put(key, value);
}
_wordEmbeddings.readExternal(in);
- this._embMap = new ConcurrentHashMap<>();
+ initEmbeddings(_wordEmbeddings);
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index bd9e2ba79f..59c5f2c973 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -50,6 +50,7 @@ import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.compress.estim.ComEstSample;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.data.SparseBlock;
import org.apache.sysds.runtime.data.SparseBlockCSR;
import org.apache.sysds.runtime.data.SparseRowVector;
@@ -354,12 +355,16 @@ public class MultiColumnEncoder implements Encoder {
boolean hasDC = false;
boolean hasWE = false;
+ int distinctWE = 0;
for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
hasDC |=
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
- hasWE |=
columnEncoder.hasEncoder(ColumnEncoderWordEmbedding.class);
+ for (ColumnEncoder enc : columnEncoder.getEncoders())
+ if(enc instanceof ColumnEncoderWordEmbedding){
+ hasWE = true;
+ distinctWE =
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+ }
}
- //hasWE = false;
- outputMatrixPreProcessing(out, in, hasDC, hasWE);
+ outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE);
if(k > 1) {
if(!_partitionDone) //happens if this method is
directly called
deriveNumRowPartitions(in, k);
@@ -548,7 +553,7 @@ public class MultiColumnEncoder implements Encoder {
return totMemOverhead;
}
- private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE) {
+ private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE, int distinctWE) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(output.isInSparseFormat()) {
if (MatrixBlock.DEFAULT_SPARSEBLOCK !=
SparseBlock.Type.CSR
@@ -596,6 +601,8 @@ public class MultiColumnEncoder implements Encoder {
else {
// Allocate dense block and set nnz to total #entries
output.allocateDenseBlock(true, hasWE);
+ if( hasWE)
+ ((DenseBlockFP64DEDUP)
output.getDenseBlock()).setDistinct(distinctWE);
//output.setAllNonZeros();
}
@@ -1150,13 +1157,19 @@ public class MultiColumnEncoder implements Encoder {
@Override
public Object call() throws Exception {
boolean hasUDF =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
- boolean hasWE =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderWordEmbedding.class));
+ boolean hasWE = false;
+ int distinctWE = 0;
+ for (ColumnEncoder enc : _encoder.getEncoders())
+ if(enc instanceof ColumnEncoderWordEmbedding){
+ hasWE = true;
+ distinctWE =
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+ }
int numCols = _encoder.getNumOutCols();
boolean hasDC =
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumRows() * (hasUDF ?
numCols : _input.getNumColumns());
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!hasUDF;
_output.reset(_input.getNumRows(), numCols, sparse,
estNNz);
- outputMatrixPreProcessing(_output, _input, hasDC,
hasWE);
+ outputMatrixPreProcessing(_output, _input, hasDC,
hasWE, distinctWE);
return null;
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
index 25cb95b3a2..4375dcda3d 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.transform;
import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.test.TestUtils;
@@ -90,8 +91,8 @@ public class TransformFrameEncodeWordEmbedding1Test extends
AutomatedTestBase
}
// Compare results
- //HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
-
//TestUtils.compareMatrices(TestUtils.convertHashMapToDoubleArray(res_actual),
res_expected, 1e-6);
+ HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
+
TestUtils.compareMatrices(TestUtils.convertHashMapToDoubleArray(res_actual),
res_expected, 1e-6);
}
catch(Exception ex) {
throw new RuntimeException(ex);
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index 9d690be8b1..b994ae83c3 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -64,6 +64,11 @@ public class TransformFrameEncodeWordEmbedding2Test extends
AutomatedTestBase
runTransformTest(TEST_NAME1, ExecMode.SPARK);
}
+ @Test
+ public void testTransformToWordEmbeddingsAuto() {
+ runTransformTest(TEST_NAME1, ExecMode.HYBRID);
+ }
+
private void runTransformTest(String testname, ExecMode rt)
{
//set runtime platform
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
index 227e9311dc..a358b2669b 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings.dml
@@ -28,15 +28,15 @@ Data = read($2, data_type="frame", format="csv");
# Read the recode map for the distinct tokens
Meta = read($3, data_type="frame", format="csv");
-jspec = "{ids: true, recode: [1]}";
+jspec = "{ids: true, dummycode: [1]}";
#[Data_enc2, Meta2] = transformencode(target=Data, spec=jspec);
Data_enc = transformapply(target=Data, spec=jspec, meta=Meta);
-print(nrow(Data_enc) + " x " + ncol(Data_enc))
-print(toString(Data_enc[1,1]))
+#print(nrow(Data_enc) + " x " + ncol(Data_enc))
+#print(toString(Data_enc[1,1]))
# Apply the embeddings on all tokens (1K x 100)
-#R = Data_enc %*% E;
+R = Data_enc %*% E;
-#write(R, $4, format="text");
+write(R, $4, format="text");