This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 516822f957 [SYSTEMDS-3648] Extended ops for dedup-blocks, LSTM CP
instruction
516822f957 is described below
commit 516822f957710fe1be742031d05dbc5f488a473b
Author: e-strauss <[email protected]>
AuthorDate: Wed Nov 8 15:33:41 2023 +0100
[SYSTEMDS-3648] Extended ops for dedup-blocks, LSTM CP instruction
Closes #1994.
---
.../java/org/apache/sysds/hops/FunctionOp.java | 4 +-
.../apache/sysds/hops/ParameterizedBuiltinOp.java | 2 +-
.../sysds/parser/BuiltinFunctionExpression.java | 31 +-
.../sysds/runtime/data/DenseBlockFP64DEDUP.java | 215 ++++++---
.../runtime/instructions/CPInstructionParser.java | 2 +
.../runtime/instructions/InstructionUtils.java | 16 +-
.../runtime/instructions/cp/DnnCPInstruction.java | 134 +++++-
.../sysds/runtime/matrix/data/DnnParameters.java | 42 +-
.../sysds/runtime/matrix/data/LibMatrixAgg.java | 36 +-
.../sysds/runtime/matrix/data/LibMatrixDNN.java | 31 ++
.../runtime/matrix/data/LibMatrixDNNLSTM.java | 497 +++++++++++++++++++++
.../sysds/runtime/matrix/data/LibMatrixMult.java | 26 +-
.../sysds/runtime/matrix/data/MatrixBlock.java | 86 ++--
.../transform/encode/MultiColumnEncoder.java | 17 +-
.../sysds/test/component/compress/io/IOTest.java | 2 +
.../apache/sysds/test/functions/dnn/LSTMTest.java | 172 +++++++
.../test/functions/io/binary/SerializeTest.java | 12 +-
.../TransformFrameEncodeWordEmbedding1Test.java | 19 +-
.../TransformFrameEncodeWordEmbedding2Test.java | 35 +-
.../TransformFrameEncodeWordEmbeddingMMTest.java | 6 +-
...ransformFrameEncodeWordEmbeddingRowSumTest.java | 5 +
.../scripts/functions/tensor/LSTMBackwardTest.dml | 86 ++++
.../scripts/functions/tensor/LSTMForwardTest.dml | 66 +++
...TransformFrameEncodeWordEmbeddings1Reshape.dml} | 17 +-
...TransformFrameEncodeWordEmbeddings2Reshape.dml} | 16 +-
.../TransformFrameEncodeWordEmbeddingsMM.dml | 4 +-
26 files changed, 1412 insertions(+), 167 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/FunctionOp.java
b/src/main/java/org/apache/sysds/hops/FunctionOp.java
index 45f21e975b..28cd6eeafb 100644
--- a/src/main/java/org/apache/sysds/hops/FunctionOp.java
+++ b/src/main/java/org/apache/sysds/hops/FunctionOp.java
@@ -349,9 +349,7 @@ public class FunctionOp extends Hop
&&
OptimizerUtils.isSparkExecutionMode())) ? ExecType.SPARK : ExecType.CP);
}
else if(isBuiltinFunction &&
(getFunctionName().equalsIgnoreCase("lstm") ||
getFunctionName().equalsIgnoreCase("lstm_backward"))) {
- if(!DMLScript.USE_ACCELERATOR)
- throw new RuntimeException("The
function " + getFunctionName() + " is only supported on GPU.");
- _etype = ExecType.GPU;
+ _etype = DMLScript.USE_ACCELERATOR ?
ExecType.GPU : ExecType.CP;
}
else if(isBuiltinFunction &&
(getFunctionName().equalsIgnoreCase("batch_norm2d") ||
getFunctionName().equalsIgnoreCase("batch_norm2d_backward"))) {
_etype = DMLScript.USE_ACCELERATOR ?
ExecType.GPU : ExecType.CP;
diff --git a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
index 964a60d528..a2bd1f188a 100644
--- a/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
+++ b/src/main/java/org/apache/sysds/hops/ParameterizedBuiltinOp.java
@@ -723,7 +723,7 @@ public class ParameterizedBuiltinOp extends
MultiThreadedHop {
}
super.computeMemEstimate(memo);
}
-
+
@Override
public boolean allowsAllExecTypes() {
return false;
diff --git
a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
index fa1a163036..b5b27682d5 100644
--- a/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
+++ b/src/main/java/org/apache/sysds/parser/BuiltinFunctionExpression.java
@@ -217,6 +217,8 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
case LSTM:
{
+ //TODO: LSTM on GPU has different INPUT/OUTPUT than
LSTM on CPU
+
// X, W, bias, out0, c0, return_sequences
checkNumParameters(6);
checkMatrixParam(getFirstExpr());
@@ -225,8 +227,8 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
checkMatrixParam(getFourthExpr());
checkMatrixParam(getFifthExpr());
- // setup output properties
- if(getOutputs() == null || getOutputs().length != 2) {
+ // setup output properties, on CPU there are 3 more
additionally outputs (cache_out, cache_c, cache_ifog)
+ if(getOutputs() == null || (getOutputs().length != 2 &&
getOutputs().length != 5)) {
int numOutputs = getOutputs() == null ? 0 :
getOutputs().length;
raiseValidateError("The builtin function lstm
has two outputs, but instead found: " + numOutputs, conditional);
}
@@ -244,7 +246,30 @@ public class BuiltinFunctionExpression extends
DataIdentifier {
cy.setValueType(ValueType.FP64);
cy.setDimensions(getExpr(4).getOutput().getDim1(),
getExpr(4).getOutput().getDim2());
cy.setBlocksize(getExpr(4).getOutput().getBlocksize());
-
+
+ if(getOutputs().length == 5){
+ DataIdentifier cache_out = (DataIdentifier)
getOutputs()[2];
+ DataIdentifier cache_c = (DataIdentifier)
getOutputs()[3];
+ DataIdentifier cache_ifog = (DataIdentifier)
getOutputs()[4];
+
+ // Output3 - cache_out: (T,N*M) T is unknown
upfront
+ cache_out.setDataType(DataType.MATRIX);
+ cache_out.setValueType(ValueType.FP64);
+ cache_out.setDimensions(-1, -1);
+
cache_out.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+
+ // Output4 - cache_c: (T,N*M)
+ cache_c.setDataType(DataType.MATRIX);
+ cache_c.setValueType(ValueType.FP64);
+ cache_out.setDimensions(-1, -1);
+
cache_out.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+
+ // Output5 - cache_ifog: (T,N*M)
+ cache_ifog.setDataType(DataType.MATRIX);
+ cache_ifog.setValueType(ValueType.FP64);
+ cache_ifog.setDimensions(-1, -1);
+
cache_ifog.setBlocksize(getFirstExpr().getOutput().getBlocksize());
+ }
break;
}
case LSTM_BACKWARD:
diff --git
a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
index d1b0c8a91b..af16ec3d85 100644
--- a/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
+++ b/src/main/java/org/apache/sysds/runtime/data/DenseBlockFP64DEDUP.java
@@ -24,7 +24,6 @@ import org.apache.sysds.common.Types;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.MemoryEstimates;
-import java.util.Arrays;
import java.util.HashMap;
public class DenseBlockFP64DEDUP extends DenseBlockDRB
@@ -33,10 +32,19 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
private double[][] _data;
//TODO: implement estimator for nr of distinct
private int _distinct = 0;
+ private int _emb_size = 0;
+ private int _embPerRow = 0;
public void setDistinct(int d){
_distinct = d;
}
+
+ public void setEmbeddingSize(int s){
+ _emb_size = s;
+ if (_odims[0] % _emb_size != 0)
+ throw new RuntimeException("[Error] DedupDenseBlock:
ncols[=" + _odims[0] + "] % emb_size[=" + _emb_size + "] != 0");
+ _embPerRow = _odims[0] / _emb_size;
+ }
protected DenseBlockFP64DEDUP(int[] dims) {
super(dims);
reset(_rlen, _odims, 0);
@@ -46,29 +54,31 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
return _distinct;
}
+ public int getNrEmbsPerRow(){
+ return _embPerRow;
+ }
+
+ public int getEmbSize(){
+ return _emb_size;
+ }
+
@Override
protected void allocateBlock(int bix, int length) {
- _data[bix] = new double[length];
+ _data = new double[length][];
}
@Override
public void reset(int rlen, int[] odims, double v) {
if(rlen > _rlen)
- _data = new double[rlen][];
+ allocateBlock(0,rlen);
else{
if(_data == null)
- _data = new double[rlen][];
- if(v == 0.0){
+ allocateBlock(0,rlen);
+ if(v == 0.0)
for(int i = 0; i < rlen; i++)
_data[i] = null;
- }
- else {
- for(int i = 0; i < rlen; i++) {
- if(odims[0] > _odims[0] ||_data[i] ==
null )
- allocateBlock(i, odims[0]);
- Arrays.fill(_data[i], 0, odims[0], v);
- }
- }
+ else
+ throw new NotImplementedException("Reset of
DedupBlock with constant value is supported");
}
_rlen = rlen;
_odims = odims;
@@ -76,13 +86,19 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
@Override
public void resetNoFill(int rlen, int[] odims) {
- if(_data == null || rlen > _rlen){
+ if(_data == null || rlen > _rlen)
_data = new double[rlen][];
- }
_rlen = rlen;
_odims = odims;
}
+ public void resetNoFillDedup(int rlen, int embsPerRow) {
+ if(_data == null || rlen > _rlen)
+ _data = new double[rlen*embsPerRow][];
+ _embPerRow = embsPerRow;
+ _rlen = rlen;
+ }
+
@Override
public boolean isNumeric() {
return true;
@@ -95,7 +111,7 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
@Override
public long capacity() {
- return (_data != null) ? ((long) _data.length)*_odims[0] : -1;
+ return (_data != null) ? _data.length : -1;
}
@Override
@@ -143,7 +159,7 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
@Override
protected long computeNnz(int bix, int start, int length) {
int nnz = 0;
- int row_start = (int) Math.floor(start / _odims[0]);
+ int row_start = (int) Math.floor(((double) start) / _odims[0]);
int col_start = start % _odims[0];
for (int i = 0; i < length; i++) {
if(_data[row_start] == null){
@@ -163,59 +179,91 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
}
@Override
- public int pos(int r){
- return 0;
+ public int numBlocks() {
+ int blocksize = blockSize();
+ if(blocksize < _rlen){
+ int numBlocks = _rlen / blocksize;
+ if (_rlen % blocksize > 0)
+ numBlocks += 1;
+ return numBlocks;
+ }
+ else
+ return 1;
}
@Override
- public int pos(int r, int c){
- return c;
+ public int blockSize() {
+ int blocksize = Integer.MAX_VALUE / _odims[0];
+ return Math.min(blocksize, _rlen);
}
@Override
- public int pos(int[] ix){
- int pos = ix[ix.length - 1];
- for(int i = 1; i < ix.length - 1; i++)
- pos += ix[i] * _odims[i];
- return pos;
+ public int blockSize(int bix) {
+ int blocksize = blockSize();
+ return Math.min(blocksize, _rlen - bix * blocksize);
}
@Override
- public int blockSize(int bix) {
- return 1;
+ public boolean isContiguous() {
+ return numBlocks() == 1;
}
@Override
- public boolean isContiguous() {
- return false;
+ public boolean isContiguous(int rl, int ru) {
+ return index(rl) == index(ru);
}
+
@Override
- public boolean isContiguous(int rl, int ru){
- return rl == ru;
+ public int pos(int r) {
+ return (r % blockSize()) * _odims[0];
}
+
@Override
- public double[] values(int r) {
- return valuesAt(r);
+ public int pos(int r, int c) {
+ return (r % blockSize()) * _odims[0] + c;
}
@Override
- public double[] valuesAt(int bix) {
- return _data[bix] == null ? new double[_odims[0]] : _data[bix];
+ public int pos(int[] ix){
+ int pos = pos(ix[0]);
+ pos += ix[ix.length - 1];
+ for(int i = 1; i < ix.length - 1; i++)
+ pos += ix[i] * _odims[i];
+ return pos;
}
@Override
- public int index(int r) {
- return r;
+ public double[] values(int r) {
+ return valuesAt(index(r));
+ }
+
+ @Override
+ public double[] valuesAt(int bix) {
+ int blocksize = blockSize(bix);
+ int blocksizeOther = blockSize();
+ double[] out = new double[_odims[0]*blocksize];
+ if(_data != null) {
+ for (int i = 0; i < blocksize; i++) {
+ for (int j = 0; j < _embPerRow; j++) {
+ int posInDedup = i * _embPerRow + j;
+ int posInDense = posInDedup * _emb_size;
+ posInDedup +=
bix*blocksizeOther*_embPerRow;
+ if(_data[posInDedup] != null)
+
System.arraycopy(_data[posInDedup], 0, out, posInDense, _emb_size);
+ }
+ }
+ }
+ return out;
}
@Override
- public int numBlocks(){
- return _data.length;
+ public int index(int r) {
+ return r / blockSize();
}
@Override
public int size(int bix) {
- return _odims[0];
+ return blockSize(bix) * _odims[0];
}
@Override
@@ -223,18 +271,50 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
incr(r,c,1.0);
}
+ public void createDeepCopyOfEmbedding(int pos){
+ if(_data[pos] == null)
+ _data[pos] = new double[_emb_size];
+ else {
+ double[] tmp = new double[_emb_size];
+ System.arraycopy(_data[pos], 0, tmp, 0, _emb_size);
+ _data[pos] = tmp;
+ }
+ }
+
@Override
public void incr(int r, int c, double delta) {
- if(_data[r] == null)
- allocateBlock(r, _odims[0]);
- _data[r][c] += delta;
+ int roffset = c / _emb_size;
+ int coffset = c % _emb_size;
+
+ //creates a deep copy to avoid unexpected changes in other rows
due deduplication
+ createDeepCopyOfEmbedding(r*_embPerRow + roffset);
+ _data[r*_embPerRow + roffset][coffset] += delta;
}
@Override
protected void fillBlock(int bix, int fromIndex, int toIndex, double v)
{
- if(_data[bix] == null)
- allocateBlock(bix, _odims[0]);
- Arrays.fill(_data[bix], fromIndex, toIndex, v);
+ int roffset = fromIndex / _emb_size;
+ int coffset = fromIndex % _emb_size;
+ int r2offset = fromIndex / _emb_size;
+ int c2offset = fromIndex % _emb_size;
+ int blockoffset = bix*blockSize();
+
+ int c = coffset;
+ int cmax = _emb_size;
+ int rmax = r2offset;
+
+ if(c2offset != 0)
+ rmax += 1;
+ for (int r = roffset; r < rmax; r++) {
+ //creates a deep copy to avoid unexpected changes in
other rows due deduplication
+ createDeepCopyOfEmbedding(blockoffset + roffset);
+ if(r == r2offset)
+ cmax = c2offset;
+ for(; c < cmax; c++){
+ _data[blockoffset + r][c] = v;
+ }
+ c = 0;
+ }
}
@Override
@@ -244,18 +324,30 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
@Override
public DenseBlock set(int r, int c, double v) {
- if(_data[r] == null)
- _data[r] = new double[_odims[0]];
- _data[r][c] = v;
+ int roffset = c / _emb_size;
+ int coffset = c % _emb_size;
+
+ //creates a deep copy to avoid unexpected changes in other rows
due deduplication
+ createDeepCopyOfEmbedding(r*_embPerRow + roffset);
+ _data[r*_embPerRow + roffset][coffset] = v;
return this;
}
@Override
public DenseBlock set(int r, double[] v) {
- if(v.length == _odims[0])
+ if(_embPerRow == 1)
_data[r] = v;
else
- throw new RuntimeException("set Denseblock called with
an array length [" + v.length +"], array to overwrite is of length [" +
_odims[0] + "]");
+ for (int i = 0; i < _embPerRow; i++) {
+ //creates a deep copy to avoid unexpected
changes in other rows due deduplication
+ createDeepCopyOfEmbedding(r*_embPerRow + i);
+ System.arraycopy(v, i*_emb_size,
_data[r*_embPerRow + i],0, _emb_size);
+ }
+ return this;
+ }
+
+ public DenseBlock setDedupDirectly(int r, double[] v) {
+ _data[r] = v;
return this;
}
@@ -265,6 +357,7 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
}
@Override
+ //todo
public DenseBlock set(int rl, int ru, int ol, int ou, DenseBlock db) {
if( !(db instanceof DenseBlockFP64DEDUP))
throw new NotImplementedException();
@@ -298,12 +391,20 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
return set(ix[0], pos(ix), Double.parseDouble(v));
}
+ public double[] getDedupDirectly(int pos){
+ return _data[pos];
+ }
+
@Override
public double get(int r, int c) {
- if(_data[r] == null)
+ if(_embPerRow == 1)
+ return _data[r][c];
+ int roffset = c / _emb_size;
+ int coffset = c % _emb_size;
+ if(_data[r*_embPerRow + roffset] == null)
return 0.0;
else
- return _data[r][c];
+ return _data[r*_embPerRow + roffset][coffset];
}
@Override
@@ -322,18 +423,16 @@ public class DenseBlockFP64DEDUP extends DenseBlockDRB
}
public long estimateMemory(){
- if( (double)_rlen * _odims[0] > Long.MAX_VALUE )
- return Long.MAX_VALUE;
return estimateMemory(_rlen, _odims[0], _distinct);
}
public static long estimateMemory(int rows, int cols, int duplicates){
- return estimateMemory((long) rows, (long)cols, (long)
duplicates);
+ return estimateMemory((long) rows, (long) cols, (long)
duplicates);
}
public static long estimateMemory(long rows, long cols, long
duplicates){
return ((long) (DenseBlock.estimateMemory(rows, cols)))
- + ((long)
MemoryEstimates.doubleArrayCost(cols)*duplicates)
+ + ((long)
MemoryEstimates.doubleArrayCost(cols)*duplicates)
+ ((long)
MemoryEstimates.objectArrayCost(rows));
}
}
\ No newline at end of file
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
index c73d755b5e..52f6257bf9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/CPInstructionParser.java
@@ -289,6 +289,8 @@ public class CPInstructionParser extends InstructionParser {
String2CPInstructionType.put( "bias_multiply" ,
CPType.Dnn);
String2CPInstructionType.put( "batch_norm2d",
CPType.Dnn);
String2CPInstructionType.put( "batch_norm2d_backward",
CPType.Dnn);
+ String2CPInstructionType.put( "lstm" , CPType.Dnn);
+ String2CPInstructionType.put( "lstm_backward" ,
CPType.Dnn);
// Quaternary instruction opcodes
String2CPInstructionType.put( "wsloss" , CPType.Quaternary);
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
index 906348eea7..9c54aa282d 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java
@@ -616,8 +616,14 @@ public class InstructionUtils {
parseExtendedBinaryOperator(opcode));
}
-
- public static BinaryOperator parseBinaryOperator(String opcode)
+
+ public static BinaryOperator parseBinaryOperator(String opcode, int k){
+ BinaryOperator bop = parseBinaryOperator(opcode);
+ bop.setNumThreads(k);
+ return bop;
+ }
+
+ public static BinaryOperator parseBinaryOperator(String opcode)
{
if(opcode.equalsIgnoreCase("=="))
return new BinaryOperator(Equals.getEqualsFnObject());
@@ -919,6 +925,12 @@ public class InstructionUtils {
throw new DMLRuntimeException("Unknown binary opcode " +
opcode);
}
+
+ public static ScalarOperator parseScalarBinaryOperator(String opcode,
boolean arg1IsScalar, double constant, int k){
+ ScalarOperator sop = parseScalarBinaryOperator(opcode,
arg1IsScalar, constant);
+ sop.setNumThreads(k);
+ return sop;
+ }
public static String deriveAggregateOperatorOpcode(String opcode) {
switch( opcode ) {
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
index 2b7429834c..9dc55078c9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/DnnCPInstruction.java
@@ -50,6 +50,9 @@ public class DnnCPInstruction extends UnaryCPInstruction {
private final CPOperand _in6;
private final CPOperand _in7;
private final CPOperand _in8;
+ private final CPOperand _in9;
+ private final CPOperand _in10;
+ private final CPOperand _in11;
private final CPOperand _out2;
private final CPOperand _out3;
private final CPOperand _out4;
@@ -68,6 +71,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
_in2 = in2;
_in3 = in3;
_in4 = null; _in5 = null; _in6 = null; _in7 = null; _in8 = null;
+ _in9 = null; _in10 = null; _in11 = null;
_out2 = null; _out3 = null; _out4 = null; _out5 = null;
_stride = stride;
_padding = padding;
@@ -103,9 +107,13 @@ public class DnnCPInstruction extends UnaryCPInstruction {
ArrayList<CPOperand> filter_shape, int numThreads,
double intermediateMemoryBudget) {
this(in, in2, in3, out, stride, padding, input_shape,
filter_shape, numThreads, intermediateMemoryBudget, opcode, istr);
}
+
+ public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3,
CPOperand in4, CPOperand in5, CPOperand in6, CPOperand in7, CPOperand in8,
CPOperand out1, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5,
String opcode, String str, int i) {
+ this(in1, in2, in3, in4, in5, in6, in7, in8,null, null, null,
out1, out2, out3, out4, out5, opcode, str, 0);
+ }
public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3,
CPOperand in4, CPOperand in5,
- CPOperand in6, CPOperand in7, CPOperand in8,
+ CPOperand in6, CPOperand in7, CPOperand in8, CPOperand
in9, CPOperand in10, CPOperand in11,
CPOperand out, CPOperand out2, CPOperand out3,
CPOperand out4, CPOperand out5, String opcode, String istr,
double intermediateMemoryBudget) throws
DMLRuntimeException {
super(CPType.Dnn, null, in1, out, opcode, istr);
@@ -116,6 +124,9 @@ public class DnnCPInstruction extends UnaryCPInstruction {
_in6 = in6;
_in7 = in7;
_in8 = in8;
+ _in9 = in9;
+ _in10 = in10;
+ _in11 = in11;
_out2 = out2;
_out3 = out3;
_out4 = out4;
@@ -265,6 +276,41 @@ public class DnnCPInstruction extends UnaryCPInstruction {
CPOperand out3 = new CPOperand(parts[9]); // dBias
return new DnnCPInstruction(in1, in2, in3, in4, in5,
in6, null, null, out, out2, out3, null, null, opcode, str, 0);
}
+ else if (opcode.equalsIgnoreCase("lstm")) {
+ InstructionUtils.checkNumFields(parts, 11);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand in5 = new CPOperand(parts[5]);
+ CPOperand in6 = new CPOperand(parts[6]);
+ CPOperand out1 = new CPOperand(parts[7]);
+ CPOperand out2 = new CPOperand(parts[8]);
+ CPOperand out3 = new CPOperand(parts[9]);
+ CPOperand out4 = new CPOperand(parts[10]);
+ CPOperand out5 = new CPOperand(parts[11]);
+ return new DnnCPInstruction(in1, in2, in3, in4, in5,
in6, null, null, out1, out2, out3, out4, out5, opcode, str, 0);
+ } if(opcode.equalsIgnoreCase("lstm_backward")){
+ InstructionUtils.checkNumFields(parts, 16);
+ CPOperand in1 = new CPOperand(parts[1]);
+ CPOperand in2 = new CPOperand(parts[2]);
+ CPOperand in3 = new CPOperand(parts[3]);
+ CPOperand in4 = new CPOperand(parts[4]);
+ CPOperand in5 = new CPOperand(parts[5]);
+ CPOperand in6 = new CPOperand(parts[6]);
+ CPOperand in7 = new CPOperand(parts[7]);
+ CPOperand in8 = new CPOperand(parts[8]);
+ CPOperand in9 = new CPOperand(parts[9]);
+ CPOperand in10 = new CPOperand(parts[10]);
+ CPOperand in11 = new CPOperand(parts[11]);
+
+ CPOperand out1 = new CPOperand(parts[12]);
+ CPOperand out2 = new CPOperand(parts[13]);
+ CPOperand out3 = new CPOperand(parts[14]);
+ CPOperand out4 = new CPOperand(parts[15]);
+ CPOperand out5 = new CPOperand(parts[16]);
+ return new DnnCPInstruction(in1, in2, in3, in4, in5,
in6, in7, in8, in9, in10, in11, out1, out2, out3, out4, out5, opcode, str, 0);
+ }
else {
throw new DMLRuntimeException("Unknown opcode while
parsing a DnnCPInstruction: " + str);
}
@@ -408,7 +454,81 @@ public class DnnCPInstruction extends UnaryCPInstruction {
filter.sparseToDense();
return filter.isInSparseFormat();
}
-
+
+ private void processLSTMInstruction(ExecutionContext ec, boolean
backward) {
+ // batchSize=N, seqLength=T, numFeatures=D and hiddenSize=M
+ // input X:(N, T*D), ==> (T, D, N)
+ // weight W:(D+M+2, 4M)
+ // previous output input3 (also represented by hx)
+ // and cell state input4 (also represented by cx): (N, M) ==>
(1, M, N)
+ MatrixBlock X = ec.getMatrixInput(input1.getName());
+ MatrixBlock W = ec.getMatrixInput(_in2.getName());
+ MatrixBlock bias = ec.getMatrixInput(_in3.getName());
+ MatrixBlock out0 = ec.getMatrixInput(_in4.getName());
+ MatrixBlock c0 = ec.getMatrixInput(_in5.getName());
+ boolean return_sequences =
ec.getScalarInput(_in6).getBooleanValue();
+
+ MatrixBlock dout = null, dc = null, cache_out = null, cache_c =
null, cache_ifog = null;
+ if(backward){
+ dout = ec.getMatrixInput(_in7.getName());
+ dc = ec.getMatrixInput(_in8.getName());
+ cache_out = ec.getMatrixInput(_in9.getName());
+ cache_c = ec.getMatrixInput(_in10.getName());
+ cache_ifog = ec.getMatrixInput(_in11.getName());
+ }
+
+ //Check input dimensions
+ int M = out0.getNumColumns(); // hiddenSize .. since input3:
(N, M)
+ int N = out0.getNumRows();
+ int numRowsW = W.getNumRows();
+ int numColsW = W.getNumColumns();
+ int D = numRowsW - M; // since W:(D+M, 4M) ... numFeatures
+ int T = X.getNumColumns() / D;
+ if(c0.getNumColumns() != out0.getNumColumns() ||
out0.getNumRows() != c0.getNumRows()){
+ throw new DMLRuntimeException("Incorrect input
dimension for LSTM. Expected input4 and input3 Matrix to be of "+
+ "the same Dimension (N, M), but got
("+c0.getNumRows()+", " +c0.getNumColumns()+") and ("+
+ out0.getNumRows()+",
"+out0.getNumColumns()+")");
+ }
+ if(W.getNumColumns() != 4*M){
+ throw new DMLRuntimeException("Incorrect input
dimension for LSTM. Expected Weight Matrix to be of "+
+ "Dimension (D+M, 4M) = ("+numRowsW+",
"+4*M+"), but got ("+numRowsW+", "+numColsW+")");
+ }
+ if(bias.getNumColumns() != 4*M || bias.getNumRows() != 1){
+ throw new DMLRuntimeException("Incorrect input
dimension for LSTM. Expected bias Matrix to be of "+
+ "Dimension (1, 4M) = (1, "+4*M+"), but
got ("+bias.getNumRows()+", "+bias.getNumColumns()+")");
+ }
+
+ //prepare output matrices
+ // out = backward / forward
+ // -------------------------
+ // out1 = dX / out
+ // out2 = dW / c
+ // out3 = db / cache_out
+ // out4 = dout0 / cache_c
+ // out5 = dc0 / cache_ifog
+ MatrixBlock out1 = new MatrixBlock(N, backward ? T*D :
return_sequences ? T*M : M,false);
+ MatrixBlock out2 = new MatrixBlock(backward ? D + M : N,
backward ? 4*M : M,false);
+ MatrixBlock out3 = new MatrixBlock(backward ? 1 : T, backward ?
4*M : N*M,false);
+ MatrixBlock out4 = new MatrixBlock(backward ? N : T, backward ?
M : N*M,false);
+ MatrixBlock out5 = new MatrixBlock(backward ? N : T, backward
? M : N*4*M,false);
+
+ //
+ DnnParameters params = new DnnParameters(N,D,T,M,X, W, bias,
out0, c0, cache_out, cache_c, cache_ifog, return_sequences, dout, dc,out1,
out2, out3, out4, out5, _numThreads);
+ if(backward)
+ LibMatrixDNN.lstmBackward(params);
+ else
+ LibMatrixDNN.lstm(params);
+
+ // release inputs/outputs
+ ec.releaseMatrixInput(input1.getName(), _in2.getName(),
_in3.getName(), _in4.getName(), _in5.getName());
+ if(backward)
+ ec.releaseMatrixInput(_in7.getName(), _in8.getName(),
_in9.getName(), _in10.getName(), _in11.getName());
+ ec.setMatrixOutput(output.getName(), out1);
+ ec.setMatrixOutput(_out2.getName(), out2);
+ ec.setMatrixOutput(_out3.getName(), out3);
+ ec.setMatrixOutput(_out4.getName(), out4);
+ ec.setMatrixOutput(_out5.getName(), out5);
+ }
@Override
public void processInstruction(ExecutionContext ec) {
@@ -433,6 +553,14 @@ public class DnnCPInstruction extends UnaryCPInstruction {
processBatchNorm2dBackwardInstruction(ec);
return;
}
+ else if (instOpcode.equalsIgnoreCase("lstm")) {
+ processLSTMInstruction(ec, false);
+ return;
+ }
+ else if (instOpcode.equalsIgnoreCase("lstm_backward")) {
+ processLSTMInstruction(ec, true);
+ return;
+ }
// acquire inputs
MatrixBlock outputBlock = null;
@@ -581,7 +709,7 @@ public class DnnCPInstruction extends UnaryCPInstruction {
ec.releaseMatrixInput(input1.getName());
ec.setMatrixOutput(getOutputVariableName(), outputBlock);
}
-
+
/**
* Reset the number of thread to respect the intermediate CP memory
budget
*
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/DnnParameters.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/DnnParameters.java
index cee6c0d2c1..006e500329 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/DnnParameters.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/DnnParameters.java
@@ -27,20 +27,22 @@ import org.apache.sysds.runtime.util.DnnUtils;
/**
* This class is container that stores parameters required for executing
following operations:
- * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling,
maxpooling_backward
+ * conv2d, conv2d_backward_data, conv2d_backward_filter, maxpooling,
maxpooling_backward, lstm, lstm_backward
*/
public class DnnParameters implements Serializable
{
private static final long serialVersionUID = -212362627205772829L;
- public int N, C, H, W, K, R, S, P, Q;
+ public int N, C, H, W, K, R, S, P, Q, D, T, M;
public int stride_h, stride_w, pad_h, pad_w;
public int numThreads;
// Optional variables used by ConvolutionCPInstruction
public boolean enableNative = false;
-
+ public boolean return_sequences;
+
public MatrixBlock input1; public MatrixBlock input2; public
MatrixBlock output;
+ public MatrixBlock input3, input4, input5, input6, input7, input8,
input9, output2, output3, output4, output5;
public MatrixBlock bias;
public int [] start_indexes_h, end_indexes_h, start_indexes_w,
end_indexes_w;
@@ -97,7 +99,39 @@ public class DnnParameters implements Serializable
Q = (int) DnnUtils.getQ(W, S, stride_w, pad_w);
this.numThreads = numThreads;
}
-
+
+ public DnnParameters(int N, int D, int T, int M, MatrixBlock x,
MatrixBlock w, MatrixBlock bias, MatrixBlock out0,
+ MatrixBlock c0, boolean
return_sequences, int numThreads){
+ this.N = N;
+ this.D = D;
+ this.T = T;
+ this.M = M;
+
+ this.input1 = x;
+ this.input2 = w;
+ this.bias = bias;
+ this.input3 = out0;
+ this.input4 = c0;
+
+ this.return_sequences = return_sequences;
+ this.numThreads = numThreads;
+ }
+
+ public DnnParameters(int n, int d, int t, int m, MatrixBlock x,
MatrixBlock w, MatrixBlock bias, MatrixBlock out0, MatrixBlock c0, MatrixBlock
cache_out, MatrixBlock cache_c, MatrixBlock cache_ifog, boolean
return_sequences, MatrixBlock dout, MatrixBlock dc, MatrixBlock dx, MatrixBlock
dw, MatrixBlock db, MatrixBlock dout0, MatrixBlock dc0, int numThreads) {
+ this(n, d, t, m, x, w, bias, out0, c0, return_sequences,
numThreads);
+ this.input5 = dout;
+ this.input6 = dc;
+ this.input7 = cache_out;
+ this.input8 = cache_c;
+ this.input9 = cache_ifog;
+ this.output = dx;
+ this.output2 = dw;
+ this.output3 = db;
+ this.output4 = dout0;
+ this.output5 = dc0;
+ }
+
+
private static int convertToInt(long val) {
if( val > Integer.MAX_VALUE )
throw new DMLRuntimeException("The value for
DnnParameters is too large:" + val);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
index 64a7bf3aa9..2efdc2a0f7 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixAgg.java
@@ -1888,7 +1888,7 @@ public class LibMatrixAgg {
*/
private static void d_uakp( DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
if(a instanceof DenseBlockFP64DEDUP)
- uakpDedup(a, c, n, kbuff, kplus, rl, ru);
+ uakpDedup((DenseBlockFP64DEDUP) a, c, n, kbuff, kplus,
rl, ru);
else {
final int bil = a.index(rl);
final int biu = a.index(ru - 1);
@@ -1929,7 +1929,7 @@ public class LibMatrixAgg {
private static void d_uarkp( DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru )
{
if(a instanceof DenseBlockFP64DEDUP)
- uarkpDedup(a, c, n, kbuff, kplus, rl, ru);
+ uarkpDedup((DenseBlockFP64DEDUP) a, c, n, kbuff, kplus,
rl, ru);
else {
for (int i = rl; i < ru; i++) {
kbuff.set(0, 0); //reset buffer
@@ -1963,7 +1963,7 @@ public class LibMatrixAgg {
*/
private static void d_uackp( DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
if(a instanceof DenseBlockFP64DEDUP)
- uackpDedup(a, c, n, kbuff, kplus, rl, ru);
+ uackpDedup((DenseBlockFP64DEDUP) a, c, n, kbuff, kplus,
rl, ru);
else {
for( int i=rl; i<ru; i++ )
sumAgg( a.values(i), c, a.pos(i), n, kbuff,
kplus );
@@ -3628,10 +3628,15 @@ public class LibMatrixAgg {
/////////////////////////////////////////////////////
- private static void uakpDedup (DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru) {
+ private static void uakpDedup (DenseBlockFP64DEDUP a, DenseBlock c, int
n, KahanObject kbuff, KahanPlus kplus, int rl, int ru) {
HashMap<double[], Integer> counts = new HashMap<>();
+ if(a.getNrEmbsPerRow() != 1){
+ //TODO: currently impossible case, since Dedup reshape
is not supported yet, once it is, this method needs
+ // to be implemented
+ throw new NotImplementedException("Check TODO");
+ }
for(int i = rl; i < ru; i++) {
- double[] row = a.values(i);
+ double[] row = a.getDedupDirectly(i);
Integer count = counts.getOrDefault(row, 0);
count += 1;
counts.put(row, count);
@@ -3643,14 +3648,18 @@ public class LibMatrixAgg {
});
}
- private static void uarkpDedup( DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
+ private static void uarkpDedup( DenseBlockFP64DEDUP a, DenseBlock c,
int n, KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
HashMap<double[], double[]> cache = new HashMap<>();
+ if(a.getNrEmbsPerRow() != 1){
+ //TODO: currently impossible case, since Dedup reshape
is not supported yet, once it is, this method needs
+ // to be implemented
+ throw new NotImplementedException("Check TODO");
+ }
for(int i = rl; i < ru; i++) {
- double[] row = a.values(i);
- int finalI = i;
+ double[] row = a.getDedupDirectly(i);
double[] kbuff_array = cache.computeIfAbsent(row,
lambda_row -> {
kbuff.set(0, 0);
- sum(lambda_row, a.pos(finalI), n, kbuff, kplus);
+ sum(lambda_row, 0, n, kbuff, kplus);
return new double[] {kbuff._sum,
kbuff._correction};
});
cache.putIfAbsent(row, kbuff_array);
@@ -3659,10 +3668,15 @@ public class LibMatrixAgg {
}
}
- private static void uackpDedup( DenseBlock a, DenseBlock c, int n,
KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
+ private static void uackpDedup( DenseBlockFP64DEDUP a, DenseBlock c,
int n, KahanObject kbuff, KahanPlus kplus, int rl, int ru ) {
HashMap<double[], Integer> counts = new HashMap<>();
+ if(a.getNrEmbsPerRow() != 1){
+ //TODO: currently impossible case, since Dedup reshape
is not supported yet, once it is, this method needs
+ // to be implemented
+ throw new NotImplementedException("Check TODO");
+ }
for(int i = rl; i < ru; i++) {
- double[] row = a.values(i);
+ double[] row = a.getDedupDirectly(i);
Integer count = counts.getOrDefault(row, 0);
count += 1;
counts.put(row, count);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
index 26a00425a0..5bf133e93c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNN.java
@@ -651,6 +651,37 @@ public class LibMatrixDNN {
params.output.setNonZeros(0);
}
}
+
+ public static void lstm(DnnParameters params){
+ long nnz;
+ if(LibMatrixDNNLSTM.checkLSTMInputForOptimisation(params)){
+ params.output.allocateDenseBlock();
+ params.output2.allocateDenseBlock();
+ params.output3.allocateDenseBlock();
+ params.output4.allocateDenseBlock();
+ params.output5.allocateDenseBlock();
+ nnz = execute(LibMatrixDNNLSTM.getLSTMWorkers(params),
params);
+ }
+ else
+ nnz = LibMatrixDNNLSTM.lstmGeneric(params);
+ //post-processing: maintain nnz
+ params.output.setNonZeros(nnz);
+ params.output.examSparsity();
+ }
+
+ public static void lstmBackward(DnnParameters params) {
+ long nnz;
+
if(LibMatrixDNNLSTM.checkLSTMBackwardInputForOptimisation(params)){
+ //out.allocateDenseBlock();
+ //cout.allocateDenseBlock();
+ nnz = execute(LibMatrixDNNLSTM.getLSTMWorkers(params),
params);
+ }
+ else
+ nnz = LibMatrixDNNLSTM.lstmBackwardGeneric(params);
+ //post-processing: maintain nnz
+ params.output.setNonZeros(nnz);
+ params.output.examSparsity();
+ }
/**
* Executes the tasks in parallel using java's ExecutorService.
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNLSTM.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNLSTM.java
new file mode 100644
index 0000000000..04e1ec445d
--- /dev/null
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixDNNLSTM.java
@@ -0,0 +1,497 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.runtime.matrix.data;
+
+import org.apache.commons.math3.util.FastMath;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.functionobjects.KahanPlus;
+import org.apache.sysds.runtime.functionobjects.SwapIndex;
+import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.KahanObject;
+import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.AggregateUnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
+import org.apache.sysds.runtime.matrix.operators.UnaryOperator;
+import org.apache.sysds.runtime.matrix.operators.ScalarOperator;
+import org.apache.sysds.runtime.util.UtilFunctions;
+
+import java.util.ArrayList;
+import java.util.concurrent.Callable;
+
+import static
org.apache.sysds.runtime.functionobjects.KahanPlus.getKahanPlusFnObject;
+import static org.apache.sysds.runtime.instructions.InstructionUtils.*;
+
+public class LibMatrixDNNLSTM {
+ private static final int row_tile_size = 4;
+ private static final boolean kahan = false;
+ private static final boolean optimized = true;
+ public static ArrayList<Callable<Long>> getLSTMWorkers(DnnParameters
params) {
+ ArrayList<Callable<Long>> ret = new ArrayList<>();
+ int k =
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+ int taskSize = (int) (Math.ceil((double) params.N / k));
+
+ //very small input => use less threads
+ if(taskSize < row_tile_size && (params.D+params.M)*params.T <
256*25)
+ taskSize = row_tile_size;
+ for(int i = 0; i*taskSize < params.N; i++)
+ ret.add(new LSTMExecutor(i*taskSize,
Math.min((i+1)*taskSize, params.N),params));
+ return ret;
+ }
+
+ public static void lstmTile(int n, int d, int T, int m, int start, int
end, MatrixBlock x, MatrixBlock w,
+ MatrixBlock bias,
MatrixBlock out0, MatrixBlock c0, boolean return_sequences,
+ MatrixBlock out,
MatrixBlock cout, MatrixBlock cache_out, MatrixBlock cache_c, MatrixBlock
cache_ifog){
+
+ //inputs arrays
+ double[] c_0_values = c0.getDenseBlockValues();
+ double[] bias_values = bias.getDenseBlockValues();
+ double[] out0_values = out0.getDenseBlockValues();
+ double[] w_values = w.getDenseBlockValues();
+ double[] x_values = x.getDenseBlockValues();
+
+ double[] out_values = out.getDenseBlockValues();
+ double[] cout_values = cout.getDenseBlockValues();
+ double[] cache_out_values = cache_out.getDenseBlockValues();
+ double[] cache_c_values = cache_c.getDenseBlockValues();
+ double[] cache_ifog_values = cache_ifog.getDenseBlockValues();
+
+ int c_prev_pointer;
+
+ //constants
+ final boolean biasAllocated = bias.isAllocated();
+ final boolean xAllocated = x.isAllocated();
+ final boolean wAllocated = w.isAllocated();
+ final int tile_size_i = row_tile_size;
+ final int tile_size_j = 32;
+ final int tile_size_k = 1024;
+ final int m_4 = 4*m;
+ final int m_T = T*m;
+
+ int[] pos_in_x = new int[tile_size_i];
+ int pos_in_sequence;
+ double[] ifog = new double[tile_size_i*4*m];
+
+ KahanObject kbuff[] = kahan ? new KahanObject[tile_size_i*4*m]
: null;
+ if(kahan)
+ for (int i = 0; i < tile_size_i*4*m; i++)
+ kbuff[i] = new KahanObject(0,0);
+ KahanPlus kplus = kahan ? getKahanPlusFnObject() : null;
+
+ double[] out_prev_values = null;
+ double[] c_prev_values = null;
+
+ for( int bi = start; bi < end; bi+=tile_size_i ) {
+ int bimin = Math.min(end, bi + tile_size_i);
+
+ //init out_prev
+ if (out0_values != null) {
+ if (out_prev_values == null)
+ out_prev_values = new double[m *
tile_size_i];
+ for (int i = bi, i_internal = 0; i < bimin;
i++, i_internal++) {
+ c_prev_pointer = i * m;
+ for (int j = 0; j < m; j++)
+ out_prev_values[j + i_internal
* m] = out0_values[c_prev_pointer + j];
+ }
+ } else
+ out_prev_values = new double[m * tile_size_i];
+
+ //init c_prev
+ if (c_0_values != null) {
+ if (c_prev_values == null)
+ c_prev_values = new double[m *
tile_size_i];
+ for (int i = bi, i_internal = 0; i < bimin;
i++, i_internal++) {
+ c_prev_pointer = i * m;
+ for (int j = 0; j < m; j++)
+ c_prev_values[j + i_internal *
m] = c_0_values[c_prev_pointer + j];
+ }
+ } else
+ c_prev_values = new double[m * tile_size_i];
+
+ //calculate position of input token sequence for all
rows in tile
+ for (int i = bi, i_internal = 0; i < bimin; i++,
i_internal++) {
+ pos_in_x[i_internal] = i * x.getNumColumns();
+ }
+ //iterate timesteps
+ for (int t = 0; t < T; t++) {
+ pos_in_sequence = t * d;
+ int offset_t_internal = t*m;
+ int offset_t = offset_t_internal*n;
+ int offset_t2 = offset_t*4;
+ //init ifog with bias values
+ for (int j = 0; j < 4 * m; j++) {
+ //for all rows in the row tile
+ for (int i = bi, i_internal = 0; i <
bimin; i++, i_internal++) {
+ if(kahan)
+ kbuff[j + i_internal *
m_4].set(biasAllocated ? bias_values[j] : 0.0, 0.0);
+ else
+ ifog[j + i_internal *
m_4] = biasAllocated ? bias_values[j] : 0.0;
+ }
+ }
+
+ //iterate input token tiles
+ if(xAllocated)
+ for (int bj = 0; bj < d; bj +=
tile_size_j)
+ //iterate weight tiles
+ if(wAllocated)
+ for (int bk = 0, bjmin
= Math.min(d, bj + tile_size_j); bk < m_4; bk += tile_size_k) {
+ int bkmin =
Math.min(m_4, bk + tile_size_k);
+
+ //core loop:
adds the input token to the ifog-gates
+ for (int i =
bi, i_internal = 0; i < bimin; i++, i_internal++) {
+ int
pos_internal_ifog_i = i_internal * m_4;
+ int pos
= pos_in_x[i_internal] + pos_in_sequence;
+ for
(int j = bj; j < bjmin; j++) {
+
int offset_w = j * 4 * m;
+
int offset_x = pos + j;
+
for (int k = bk; k < bkmin; k++) {
+
if (kahan)
+
kplus.execute2(kbuff[pos_internal_ifog_i + k],
x_values[offset_x] * w_values[k + offset_w]);
+
else
+
ifog[pos_internal_ifog_i + k] += x_values[offset_x] *
w_values[k + offset_w];
+
}
+ }
+ }
+ }
+ //iterate hidden state tiles
+ for (int bj = 0; bj < m; bj += tile_size_j)
+ //iterate weight tiles
+ if(wAllocated)
+ for (int bk = 0, bjmin =
Math.min(m, bj + tile_size_j); bk < 4 * m; bk += tile_size_k) {
+ int bkmin = Math.min(4
* m, bk + tile_size_k);
+
+ //core loop: adds the
hidden state to the ifog-gates
+ for (int i = bi,
i_internal = 0; i < bimin; i++, i_internal++) {
+ int
offset_out_prev = i_internal * m;
+ int
offset_internal = offset_out_prev*4;
+ for (int j =
bj; j < bjmin; j++){
+ int
offset_tmp = (j + d) * m_4;
+ for
(int k = bk; k < bkmin; k++){
+
int offset_w = k + offset_tmp;
+
if(kahan)
+
kplus.execute2(kbuff[offset_internal + k],
out_prev_values[offset_out_prev + j] * w_values[offset_w]);
+
else
+
ifog[offset_internal + k] += out_prev_values[offset_out_prev + j] *
w_values[offset_w];
+ }
+ }
+ }
+ }
+
+ //calculate new hidden state for the current
tile
+ for (int i = bi, i_internal = 0; i < bimin;
i++, i_internal++) {
+ //from now on only elementwise
operations
+
+ //calculate index offset for array
operations
+ int offset_internal_i = i_internal * 4
* m;
+ int offset_internal_f =
offset_internal_i + m;
+ int offset_internal_o =
offset_internal_f + m;
+ int offset_internal_g =
offset_internal_o + m;
+ int offset_c_internal = i_internal * m;
+ int offset_out = i*m_T +
offset_t_internal;
+
+ int offset_i = i*m;
+ int offset_cache = offset_t + offset_i;
+ int offset_cache_i = offset_t2 +
offset_i*4;
+ int offset_cache_f = offset_cache_i + m;
+ int offset_cache_o = offset_cache_f + m;
+ int offset_cache_g = offset_cache_o + m;
+
+ for (int j = 0; j < m; j++) {
+ double ig, fg, og,gg;
+ if(kahan){
+ ig = 1.0 /
(FastMath.exp(-kbuff[offset_internal_i + j]._sum) + 1.0);
+ fg = 1.0 /
(FastMath.exp(-kbuff[offset_internal_f + j]._sum) + 1.0);
+ og = 1.0 /
(FastMath.exp(-kbuff[offset_internal_o + j]._sum) + 1.0);
+ gg =
FastMath.tanh(kbuff[offset_internal_g + j]._sum);
+ } else{
+ ig = 1.0 /
(FastMath.exp(-ifog[offset_internal_i + j]) + 1.0);
+ fg = 1.0 /
(FastMath.exp(-ifog[offset_internal_f + j]) + 1.0);
+ og = 1.0 /
(FastMath.exp(-ifog[offset_internal_o + j]) + 1.0);
+ gg =
FastMath.tanh(ifog[offset_internal_g + j]);
+ }
+ //c_prev_values.shape = (N,M)
+ double c =
c_prev_values[offset_c_internal + j] * fg + ig * gg;
+ double o = FastMath.tanh(c) *
og;
+
+ //out.shape = (N,T*M)
+ if (return_sequences)
+ out_values[offset_out +
j] = o;
+ //out.setValue(i, t * m
+ j, o);
+
+ //set caches
+ cache_out_values[offset_cache +
j] = o;
+ cache_c_values[offset_cache +
j] = c;
+
cache_ifog_values[offset_cache_i + j] = ig;
+
cache_ifog_values[offset_cache_f + j] = fg;
+
cache_ifog_values[offset_cache_o + j] = og;
+
cache_ifog_values[offset_cache_g + j] = gg;
+
+ c_prev_values[offset_c_internal
+ j] = c;
+
out_prev_values[offset_c_internal + j] = o;
+
+ }
+ }
+ }
+ for (int i = bi, i_internal = 0; i < bimin; i++,
i_internal++) {
+ int offset_i = i*m;
+ for (int j = 0; j < m; j++) {
+ cout_values[offset_i + j] =
c_prev_values[i_internal * m + j];
+ if (!return_sequences)
+ out_values[offset_i + j] =
out_prev_values[i_internal * m + j];
+ }
+ }
+ }
+ }
+
+
+ public static long lstmGeneric(DnnParameters params) {
+ //applies the LSTM operation on the input matrices using the
generic matrix block operations
+
+ MatrixBlock x = params.input1, w = params.input2, bias =
params.bias;
+ MatrixBlock out = params.input3, c = params.input4;
+ MatrixBlock cache_out = params.output3, cache_c =
params.output4, cache_ifog = params.output5;
+
+ int k =
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+ int M = params.M;
+
+ //init Operators
+ BinaryOperator plus =
InstructionUtils.parseBinaryOperator("+",k);
+ BinaryOperator emult =
InstructionUtils.parseBinaryOperator("*",k);
+ UnaryOperator tanh =
InstructionUtils.parseUnaryOperator("tanh",k);
+ UnaryOperator sigmoid =
InstructionUtils.parseUnaryOperator("sigmoid",k);
+ AggregateBinaryOperator mmult =
InstructionUtils.getMatMultOperator(k);
+
+ //iterate time steps
+ for (int t = 0; t < params.T; t++) {
+ //Extract the current input vector
+ MatrixBlock x_t = x.slice(0, x.rlen - 1, t*params.D ,
(t+1)*params.D - 1);
+
+ // Compute input, forget, output, and g gates
+ // ifog = input %*% W + b
+ MatrixBlock ifog = x_t.append(out, true);
+ ifog = ifog.aggregateBinaryOperations(ifog, w, mmult);
+ ifog = ifog.binaryOperations(plus, bias);
+
+ // Apply sigmoid to i, f, o gates and tanh to g gate
+ MatrixBlock ifo = ifog.slice(0, ifog.rlen - 1, 0, 3*M -
1).unaryOperations(sigmoid);
+ MatrixBlock i = ifo.slice(0, ifog.rlen - 1, 0, M - 1);
+ MatrixBlock f = ifo.slice(0, ifog.rlen - 1, M, 2*M - 1);
+ MatrixBlock o = ifo.slice(0, ifog.rlen - 1, 2*M, 3*M -
1);
+ MatrixBlock g = ifog.slice(0, ifog.rlen - 1, 3*M, 4*M -
1).unaryOperations(tanh);
+
+ // Update cell state
+ // c = ifog[,M+1:2*M]*c_prev +
ifog[,1:M]*ifog[,3*M+1:4*M] # shape (N, M)
+ MatrixBlock tmp = i.binaryOperations(emult, g);
+ c = f.binaryOperations(emult, c).binaryOperations(plus,
tmp, t == params.T-1 ? params.output2 : null);
+
+ // Compute output
+ // out_t = ifog[,2*M+1:3*M] * tanh::forward(c) # shape
(N, M)
+ tmp = c.unaryOperations(tanh);
+ if(params.return_sequences){
+ out = o.binaryOperations(emult, tmp);
+ params.output.leftIndexingOperations(out, 0,
out.rlen - 1, t*M,(t + 1)*M - 1,
+ null,
MatrixObject.UpdateType.INPLACE );
+ }
+ else
+ out = o.binaryOperations(emult, tmp, t ==
params.T-1 ? params.output : null);
+
+ //store caches
+ ifog = ifo.append(g, true);
+ MatrixBlock cache_out_t = LibMatrixReorg.reshape(out,
new MatrixBlock(), 1, cache_out.clen, true);
+ cache_out.leftIndexingOperations(cache_out_t, t, t,0,
cache_out.clen - 1, null, MatrixObject.UpdateType.INPLACE );
+
+ MatrixBlock cache_c_t = LibMatrixReorg.reshape(c, new
MatrixBlock(), 1, cache_c.clen, true);
+ cache_c.leftIndexingOperations(cache_c_t, t, t,0,
cache_c.clen - 1, null, MatrixObject.UpdateType.INPLACE );
+
+ MatrixBlock cache_ifog_t = LibMatrixReorg.reshape(ifog,
new MatrixBlock(), 1, cache_ifog.clen, true);
+ cache_ifog.leftIndexingOperations(cache_ifog_t, t,
t,0,cache_ifog.clen - 1, null, MatrixObject.UpdateType.INPLACE );
+ }
+ return params.output.recomputeNonZeros();
+ }
+
+ @SuppressWarnings("unused")
+ public static long lstmBackwardGeneric(DnnParameters params) {
+ //TODO elias: currently we apply operator each on the whole
batch,
+ // -> slice the batch into small parts -> each thread processes
one part through all timesteps (maybe even slice
+ // the batch into smaller section) -> this should help keep the
data local -> cache friendly
+
+ //inputs
+ MatrixBlock x = params.input1, w = params.input2, bias =
params.bias;
+ MatrixBlock out0 = params.input3, c0 = params.input4, dout =
params.input5, dc = params.input6;
+ MatrixBlock cache_out = params.input7, cache_c = params.input8,
cache_ifog = params.input9;
+
+ //outputs
+ MatrixBlock dX = params.output, dW = null, db = null;
+
+ int k =
OptimizerUtils.getConstrainedNumThreads(params.numThreads);
+ int M = params.M;
+
+ //init Operators
+ BinaryOperator plus = parseBinaryOperator("+",k);
+ BinaryOperator emult = parseBinaryOperator("*",k);
+ ScalarOperator exp2 = parseScalarBinaryOperator("^2",false,
0.0, k);
+ ScalarOperator minus = parseScalarBinaryOperator("-",true, 1.0,
k);
+ UnaryOperator tanh = parseUnaryOperator("tanh", k);
+ UnaryOperator sprop = parseUnaryOperator("sprop", k);
+ AggregateUnaryOperator colsum =
parseBasicAggregateUnaryOperator("uack+",k);
+ ReorgOperator transpose = new
ReorgOperator(SwapIndex.getSwapIndexFnObject(), k);
+ AggregateBinaryOperator mmult =
InstructionUtils.getMatMultOperator(k);
+
+ //if(!params.return_sequences): get the predecessing partial
derivative
+ //else: load the
predecessing partial derivative for timestep t in the for loop
+ MatrixBlock dout_prev = params.return_sequences ? null : dout;
+
+ //precompute t(W)
+ //Note elias: optionally calculated it multiple times in for
loop
+ w = w.reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+
+ //iterate time steps reversely (backpropagation)
+ for(int t = params.T - 1; t >= 0; t--){
+ //get the predecessing partial derivative
+ if(params.return_sequences)
+ if(t == params.T-1)
+ dout_prev = dout.slice(0, dout.rlen-1,
t*M, (t+1)*M - 1);
+ else
+ dout_prev = dout.slice(0, dout.rlen-1,
t*M, (t+1)*M - 1).binaryOperations(plus, dout_prev);
+
+ //load and reuse cached results from forward pass for
the current time step
+ MatrixBlock c_t =
LibMatrixReorg.reshape(cache_c.slice(t, t, 0, cache_c.clen - 1), new
MatrixBlock(), params.N, M, true);
+ MatrixBlock c_prev = t==0 ? c0 :
LibMatrixReorg.reshape(cache_c.slice(t - 1, t - 1, 0, cache_c.clen - 1), new
MatrixBlock(), params.N, M, true);
+ MatrixBlock ifog =
LibMatrixReorg.reshape(cache_ifog.slice(t, t,0, cache_ifog.clen - 1), new
MatrixBlock(), params.N, 4*M, true);
+ MatrixBlock i = ifog.slice(0, ifog.rlen - 1, 0, M -1);
+ MatrixBlock f = ifog.slice(0, ifog.rlen - 1, M, 2*M -1);
+ MatrixBlock o = ifog.slice(0, ifog.rlen - 1, 2*M, 3*M
-1);
+ MatrixBlock g = ifog.slice(0, ifog.rlen - 1, 3*M,
ifog.clen -1);
+
+ //dct = dct + o*tanh::backward(dout_t, ct) # shape (N,
M)
+ MatrixBlock tanh_forward = c_t.unaryOperations(tanh);
+ MatrixBlock tanh_back =
tanh_forward.scalarOperations(exp2, new MatrixBlock())
+ .scalarOperations(minus, new
MatrixBlock());
+ tanh_back = tanh_back.binaryOperations(emult,
dout_prev);
+ MatrixBlock tmp = o.binaryOperations(emult, tanh_back);
+ dc = dc.binaryOperations(plus, tmp);
+
+ //do = tanh::forward(ct) * dout_t # output gate, shape
(N, M)
+ MatrixBlock d_o = tanh_forward.binaryOperations(emult,
dout_prev);
+
+ //df = c_prev * dct # forget gate, shape (N, M)
+ MatrixBlock d_f = c_prev.binaryOperations(emult, dc);
+
+ //di = g * dct # input gate, shape (N, M)
+ MatrixBlock d_i = g.binaryOperations(emult, dc);
+
+ //dg = i * dct # g gate, shape (N, M)
+ MatrixBlock d_g = i.binaryOperations(emult, dc);
+
+ //di_raw = i * (1-i) * di
+ //df_raw = f * (1-f) * df
+ //do_raw = o * (1-o) * do
+ //dg_raw = (1-g^2) * dg
+ //difog_raw = cbind(di_raw, df_raw, do_raw, dg_raw) #
shape (N, 4M)
+ MatrixBlock difog_raw = new MatrixBlock(params.N, 4*M,
false);
+ MatrixBlock di_raw = i.unaryOperations(sprop, new
MatrixBlock()).binaryOperations(emult, d_i);
+ difog_raw.leftIndexingOperations(di_raw,0,
difog_raw.rlen - 1, 0, M-1, null,
+ MatrixObject.UpdateType.INPLACE);
+ MatrixBlock df_raw = f.unaryOperations(sprop, new
MatrixBlock()).binaryOperations(emult, d_f);
+ difog_raw.leftIndexingOperations(df_raw,0,
difog_raw.rlen - 1, M, 2*M-1, null,
+ MatrixObject.UpdateType.INPLACE);
+ MatrixBlock do_raw = o.unaryOperations(sprop, new
MatrixBlock()).binaryOperations(emult, d_o);
+ difog_raw.leftIndexingOperations(do_raw,0,
difog_raw.rlen - 1, 2*M, 3*M-1, null,
+ MatrixObject.UpdateType.INPLACE);
+ MatrixBlock dg_raw = g.scalarOperations(exp2, new
MatrixBlock()).scalarOperations(minus, new
MatrixBlock()).binaryOperations(emult, d_g);
+ difog_raw.leftIndexingOperations(dg_raw,0,
difog_raw.rlen - 1, 3*M, 4*M-1, null,
+ MatrixObject.UpdateType.INPLACE);
+
+ //load the current input vector and in the cached
previous hidden state
+ MatrixBlock x_t = x.slice(0, x.rlen - 1, t*params.D ,
(t+1)*params.D - 1);
+ MatrixBlock out_prev = t==0 ? out0 :
LibMatrixReorg.reshape(cache_out.slice(t - 1, t - 1, 0, cache_out.clen - 1),
new MatrixBlock(), params.N, M, true);
+
+ //merge mm for dx and dout_prev: input = cbind(X_t,
out_prev) # shape (N, D+M)
+ MatrixBlock in_t = x_t.append(out_prev,
true).reorgOperations(transpose, new MatrixBlock(), 0, 0, 0);
+
+ //dW = dW + t(input) %*% difog_raw # shape (D+M, 4M)
+ tmp = in_t.aggregateBinaryOperations(in_t, difog_raw,
params.T == 1 ? params.output2 : null,mmult);
+ dW = (t==params.T-1) ? tmp : dW.binaryOperations(plus,
tmp, (t == 0) ? params.output2 : null);
+
+ //db = db + colSums(difog_raw) # shape (1, 4M)
+ tmp = difog_raw.aggregateUnaryOperations(colsum,
params.T == 1 ? params.output3 : null, difog_raw.rlen, new MatrixIndexes(1,1),
true);
+ db = (t==params.T-1) ? tmp : db.binaryOperations(plus,
tmp,(t == 0) ? params.output3 : null);
+
+ //dinput = difog_raw %*% t(W) # shape (N, D+M)
+ MatrixBlock dinput =
difog_raw.aggregateBinaryOperations(difog_raw, w, mmult);
+
+ //dX[,(t-1)*D+1:t*D] = dinput[,1:D]
+ dX.leftIndexingOperations(dinput.slice(0, dinput.rlen -
1, 0, params.D-1),0, dX.rlen - 1, t*params.D, (t+1)*params.D - 1, null,
MatrixObject.UpdateType.INPLACE);
+
+ //dout_prev = dinput[,D+1:D+M] # shape (N, M)
+ //if(t == 0) -> dout0 = dout_prev
+ dout_prev = dinput.slice(0, dinput.rlen - 1, params.D,
dinput.clen - 1, (t == 0) ? params.output4 : null);
+
+ //dc_prev = f * dct # shape (N, M)
+ //if(t == 0) -> dc0 = dc_prev
+ dc = f.binaryOperations(emult, dc, (t == 0) ?
params.output5 : null);
+ }
+
+ return params.output.recomputeNonZeros();
+ }
+
+ public static boolean checkLSTMInputForOptimisation(DnnParameters
params) {
+ //optimised just for FP64 single block or Empty:
+// System.out.println(!params.input1.isAllocated() + " | " +
!params.input1.sparse + " | " + (params.input1.denseBlock.numBlocks() == 1));
+// System.out.println(!params.input2.isAllocated() + " | " +
!params.input2.sparse + " | " + (params.input2.denseBlock.numBlocks() == 1));
+// System.out.println(!params.bias.isAllocated() + " | " +
!params.bias.sparse + " | " + (params.bias.denseBlock.numBlocks() == 1));
+// System.out.println(!params.input4.isAllocated() + " | " +
!params.input4.sparse + " | " + (params.input4.denseBlock.numBlocks() == 1));
+// System.out.println(!params.input3.isAllocated() + " | " +
!params.input3.sparse + " | " + (params.input3.denseBlock.numBlocks() == 1));
+// System.out.println(optimized);
+
+ //largest output size if cache_ifog (T, N*M)
+ boolean fits_FP64 = (UtilFunctions.prod(new
int[]{params.T,params.N,params.M}) < Integer.MAX_VALUE);
+
+ return (!params.input1.isAllocated() || (!params.input1.sparse
&& params.input1.denseBlock.numBlocks() == 1))
+ && (!params.input2.isAllocated() ||
(!params.input2.sparse && params.input2.denseBlock.numBlocks() == 1))
+ && (!params.bias.isAllocated() ||
(!params.bias.sparse && params.bias.denseBlock.numBlocks() == 1))
+ && (!params.input4.isAllocated() ||
(!params.input4.sparse && params.input4.denseBlock.numBlocks() == 1))
+ && (!params.input3.isAllocated() ||
(!params.input3.sparse && params.input3.denseBlock.numBlocks() == 1))
+ && fits_FP64
+ && optimized;
+ }
+
+ public static boolean
checkLSTMBackwardInputForOptimisation(DnnParameters params) {
+ return false;
+ }
+
+ private static class LSTMExecutor implements Callable<Long> {
+ protected final int _rl, _ru;
+ protected final DnnParameters _params;
+
+ public LSTMExecutor(int rl, int ru, DnnParameters params) {
+ _rl = rl;
+ _ru = ru;
+ _params = params;
+ }
+
+ @Override
+ public Long call() throws Exception {
+ lstmTile(_params.N, _params.D, _params.T, _params.M,
_rl, _ru, _params.input1, _params.input2, _params.bias, _params.input3,
_params.input4, _params.return_sequences, _params.output, _params.output2,
_params.output3, _params.output4, _params.output5);
+ //multithreaded nnz maintenance of current working set
+ return _params.output.recomputeNonZeros(_rl, _ru - 1);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
index 96c75cecdc..780afdad67 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/LibMatrixMult.java
@@ -29,6 +29,7 @@ import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.stream.IntStream;
+import org.apache.commons.lang3.NotImplementedException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.commons.math3.util.FastMath;
@@ -208,8 +209,17 @@ public class LibMatrixMult
((SparseBlockMCSR)
ret.getSparseBlock()).setNnzEstimatePerRow(m2.clen, m2.clen);
}
- if(m1.denseBlock instanceof DenseBlockFP64DEDUP)
+ if(m1.denseBlock instanceof DenseBlockFP64DEDUP){
+ DenseBlockFP64DEDUP tmp = (DenseBlockFP64DEDUP)
m1.denseBlock;
+ if(tmp.getNrEmbsPerRow() != 1){
+ //TODO: currently impossible case, since Dedup
reshape is not supported yet, once it is, this method needs
+ // to be implemented
+ throw new NotImplementedException("Check TODO");
+ }
ret.allocateDenseBlock(true, true);
+ tmp = (DenseBlockFP64DEDUP) ret.denseBlock;
+ tmp.setEmbeddingSize(ret.clen);
+ }
else
ret.allocateBlock();
@@ -1164,11 +1174,16 @@ public class LibMatrixMult
}
}
- public static void matrixMultDenseDenseMMDedup(DenseBlock a, DenseBlock
b, DenseBlock c, int n, int cd, int rl, int ru, ConcurrentHashMap<double[],
double[]> cache) {
+ public static void matrixMultDenseDenseMMDedup(DenseBlockFP64DEDUP a,
DenseBlock b, DenseBlockFP64DEDUP c, int n, int cd, int rl, int ru,
ConcurrentHashMap<double[], double[]> cache) {
//n = m2.clen;
//cd = m1.clen;
+ if(a.getNrEmbsPerRow() != 1){
+ //TODO: currently impossible case, since Dedup reshape
is not supported yet, once it is, this method needs
+ // to be implemented
+ throw new NotImplementedException("Check TODO");
+ }
for (int i = rl; i < ru; i++) {
- double[] a_row = a.values(i);
+ double[] a_row = a.getDedupDirectly(i);
double[] c_row = cache.getOrDefault(a_row, null);
if (c_row == null) {
c_row = new double[n];
@@ -1180,10 +1195,9 @@ public class LibMatrixMult
c_row[j] += a_row[k] *
b_column[b.pos(k, j)];
}
}
- //the following requires
cache.put(a_row, c_row);
}
- c.set(i, c_row);
+ c.setDedupDirectly(i, c_row);
}
}
@@ -4688,7 +4702,7 @@ public class LibMatrixMult
matrixMultUltraSparse(_m1, _m2, _ret, _m1Perm,
rl, ru);
else if(!_m1.sparse && !_m2.sparse)
if(_m1.denseBlock instanceof
DenseBlockFP64DEDUP && _m2.denseBlock.isContiguous(0,_m1.clen) && cl == 0 && cu
== _m2.clen)
-
matrixMultDenseDenseMMDedup(_m1.denseBlock, _m2.denseBlock, _ret.denseBlock,
_m2.clen, _m1.clen, rl, ru, _cache);
+
matrixMultDenseDenseMMDedup((DenseBlockFP64DEDUP) _m1.denseBlock,
_m2.denseBlock, (DenseBlockFP64DEDUP) _ret.denseBlock, _m2.clen, _m1.clen, rl,
ru, _cache);
else
matrixMultDenseDense(_m1, _m2, _ret,
_tm2, _pm2r, rl, ru, cl, cu);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 84ff9b7c52..86ee70bc18 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -440,7 +440,23 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
denseBlock = DenseBlockFactory.createDenseBlock(rlen,
clen, containsDuplicates);
return true;
}
- else if( containsDuplicates && !(denseBlock instanceof
DenseBlockFP64DEDUP)) {
+ else if(denseBlock instanceof DenseBlockFP64DEDUP){
+ if( containsDuplicates ){
+ //capacity() of DedupDenseBlock returns size of
internal pointer array
+ //therefore: allocation of DedupDenseBlock
makes just sense if each row contains a single deduplicated embedding
+ //otherwise info about the nr of embeddings
need to be known upfront
+ //then the cond becomes: if(
denseBlock.capacity() < rlen*nr_of_embeddings_per_row )
+ if( denseBlock.capacity() < rlen )
+ denseBlock.reset(rlen, clen);
+ else
+ return false;
+ } else
+ denseBlock =
DenseBlockFactory.createDenseBlock(rlen, clen, false);
+ return true;
+ }
+ else if( containsDuplicates ) {
+ //info: currently dedup allocation assumes, that each
row contains a single embedding
+ //therefore clen == embedding_size
denseBlock = DenseBlockFactory.createDenseBlock(rlen,
clen, true);
return true;
}
@@ -589,7 +605,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
/**
* Get if this MatrixBlock is an empty block. The call can potentially
tricker a recomputation of non zeros if the
* non-zero count is unknown.
- *
+ *
* @param safe True if we want to ensure the count non zeros if the nnz
is unknown.
* @return If the block is empty.
*/
@@ -725,7 +741,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
throw new NotImplementedException();
else{
//allocate and init dense block (w/o overwriting nnz)
- allocateDenseBlock(false);
+ allocateDenseBlock(false,denseBlock instanceof
DenseBlockFP64DEDUP);
nonZeros += UtilFunctions.computeNnz(values, 0,
values.length) - denseBlock.countNonZeros(r);
denseBlock.set(r, values);
}
@@ -764,14 +780,14 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
public List<Integer> containsVector(MatrixBlock pattern, boolean
earlyAbort) {
- //note: in contract to containsValue, we return the row index
where a match
+ //note: in contract to containsValue, we return the row index
where a match
//was found in order to reuse these block operations for Spark
ops as well
-
+
//basic error handling
if( clen != pattern.clen || pattern.rlen != 1 )
throw new DMLRuntimeException("contains only supports
pattern row vectors of matching "
+ "number of columns: " +
getDataCharacteristics()+" vs "+pattern.getDataCharacteristics());
-
+
//make a pass over the data to determine if it includes the
//pattern, with early abort as soon as the pattern is found
double[] dpattern =
DataConverter.convertToDoubleVector(pattern, false, false);
@@ -779,7 +795,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
getSparseBlock().contains(dpattern, earlyAbort) :
getDenseBlock().contains(dpattern, earlyAbort);
}
-
+
/**
* <p>Append value is only used when values are appended at the end of
each row for the sparse representation</p>
*
@@ -1231,16 +1247,16 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
public final void examSparsity() {
examSparsity(true, 1);
}
-
+
/**
* Evaluates if this matrix block should be in sparse format in
* memory. Depending on the current representation, the state of the
- * matrix block is changed to the right representation if necessary.
- * Note that this consumes for the time of execution memory for both
+ * matrix block is changed to the right representation if necessary.
+ * Note that this consumes for the time of execution memory for both
* representations.
- *
+ *
* Allowing CSR format is default for this operation.
- *
+ *
* @param k parallelization degree
*/
public final void examSparsity(int k ) {
@@ -1263,10 +1279,10 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
/**
* Evaluates if this matrix block should be in sparse format in
* memory. Depending on the current representation, the state of the
- * matrix block is changed to the right representation if necessary.
- * Note that this consumes for the time of execution memory for both
+ * matrix block is changed to the right representation if necessary.
+ * Note that this consumes for the time of execution memory for both
* representations.
- *
+ *
* @param allowCSR allow CSR format on dense to sparse conversion
* @param k parallelization degree
*/
@@ -1380,7 +1396,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
nonZeros = denseBlock.countNonZeros();
else // both blocks not allocated.
nonZeros = 0;
-
+
return nonZeros;
}
@@ -1403,7 +1419,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
int bz = (int) Math.ceil(((double)
rlen) / k*2);
for(int i = 0; i < rlen; i += bz) {
final int j = i;
- f.add(pool.submit(() ->
+ f.add(pool.submit(() ->
denseBlock.countNonZeros(j, Math.min(j + bz, rlen) -1, 0, clen -1)));
}
}
@@ -1434,7 +1450,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
else{
nonZeros = 0;
- }
+ }
return nonZeros;
}
@@ -2129,18 +2145,21 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
private void readDedupDenseBlock(DataInput in) throws IOException,
DMLRuntimeException {
allocateDenseBlock(true,true);
- DenseBlock a = getDenseBlock();
+ DenseBlockFP64DEDUP a = (DenseBlockFP64DEDUP) getDenseBlock();
+ int embPerRow = in.readInt();
+ int embSize= in.readInt();
+ a.setEmbeddingSize(embSize);
if(a.getDim(0) != rlen || a.getDim(1) != clen)
- a.resetNoFill(rlen, clen); // reset the dimensions of a
if incorrect.
+ a.resetNoFillDedup(rlen,embPerRow); // reset the
dimensions of a if incorrect.
HashMap<Integer, double[]> mapping = new HashMap<>();
- for( int i=0; i<rlen; i++ ) {
+ for( int i=0; i<rlen*embPerRow; i++ ) {
Integer pos = in.readInt();
double[] row = mapping.get(pos);
if( row == null){
- row = new double[clen];
+ row = new double[embSize];
mapping.put(pos, row);
}
- a.set(i, row);
+ a.setDedupDirectly(i, row);
}
for (int i = 0; i < mapping.size(); i++) {
double[] row = mapping.get(i);
@@ -2333,19 +2352,20 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
out.writeByte( BlockType.DEDUP_BLOCK.ordinal() );
DenseBlockFP64DEDUP a = (DenseBlockFP64DEDUP) getDenseBlock();
- if (rlen > a.numBlocks())
- throw new DMLRuntimeException("Serialize
DedupDenseblock: block does not contain enough rows ["+a.numBlocks() +" < " +
rlen + "]");
-
+ //if (rlen > a.numBlocks())
+ // throw new DMLRuntimeException("Serialize
DedupDenseblock: block does not contain enough rows ["+a.numBlocks() +" < " +
rlen + "]");
+ out.writeInt(a.getNrEmbsPerRow());
+ out.writeInt(a.getEmbSize());
HashMap<double[], Integer> mapping = new HashMap<>((int)
(a.getNrDistinctRows()*1.1));
ArrayList<double[]> unique_rows = new ArrayList<>((int)
(a.getNrDistinctRows()*1.1));
-
- for(int i=0; i<rlen; i++) {
- double[] avals = a.values(i); //equals 1 row
- Integer pos = mapping.get(avals);
+ int embsPerRow = a.getNrEmbsPerRow();
+ for(int i=0; i<rlen*embsPerRow; i++) {
+ double[] vals = a.getDedupDirectly(i); //equals 1 row
+ Integer pos = mapping.get(vals);
if (pos == null) {
pos = mapping.size();
- unique_rows.add(avals);
- mapping.put(avals, pos);
+ unique_rows.add(vals);
+ mapping.put(vals, pos);
}
out.writeInt(pos);
}
@@ -4678,7 +4698,7 @@ public class MatrixBlock extends MatrixValue implements
CacheBlock<MatrixBlock>,
}
return (MatrixBlock)result;
}
-
+
@Override
public MatrixBlock aggregateUnaryOperations(AggregateUnaryOperator op,
MatrixValue result,
int blen, MatrixIndexes indexesIn, boolean inCP) {
diff --git
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
index 59c5f2c973..11c3781c04 100644
---
a/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
+++
b/src/main/java/org/apache/sysds/runtime/transform/encode/MultiColumnEncoder.java
@@ -356,15 +356,17 @@ public class MultiColumnEncoder implements Encoder {
boolean hasDC = false;
boolean hasWE = false;
int distinctWE = 0;
+ int sizeWE = 0;
for(ColumnEncoderComposite columnEncoder : _columnEncoders) {
hasDC |=
columnEncoder.hasEncoder(ColumnEncoderDummycode.class);
for (ColumnEncoder enc : columnEncoder.getEncoders())
if(enc instanceof ColumnEncoderWordEmbedding){
hasWE = true;
distinctWE =
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+ sizeWE = ((ColumnEncoderWordEmbedding)
enc).getDomainSize();
}
}
- outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE);
+ outputMatrixPreProcessing(out, in, hasDC, hasWE, distinctWE,
sizeWE);
if(k > 1) {
if(!_partitionDone) //happens if this method is
directly called
deriveNumRowPartitions(in, k);
@@ -553,7 +555,7 @@ public class MultiColumnEncoder implements Encoder {
return totMemOverhead;
}
- private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE, int distinctWE) {
+ private static void outputMatrixPreProcessing(MatrixBlock output,
CacheBlock<?> input, boolean hasDC, boolean hasWE, int distinctWE, int sizeWE) {
long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0;
if(output.isInSparseFormat()) {
if (MatrixBlock.DEFAULT_SPARSEBLOCK !=
SparseBlock.Type.CSR
@@ -601,8 +603,11 @@ public class MultiColumnEncoder implements Encoder {
else {
// Allocate dense block and set nnz to total #entries
output.allocateDenseBlock(true, hasWE);
- if( hasWE)
- ((DenseBlockFP64DEDUP)
output.getDenseBlock()).setDistinct(distinctWE);
+ if( hasWE){
+ DenseBlockFP64DEDUP dedup =
((DenseBlockFP64DEDUP) output.getDenseBlock());
+ dedup.setDistinct(distinctWE);
+ dedup.setEmbeddingSize(sizeWE);
+ }
//output.setAllNonZeros();
}
@@ -1159,17 +1164,19 @@ public class MultiColumnEncoder implements Encoder {
boolean hasUDF =
_encoder.getColumnEncoders().stream().anyMatch(e ->
e.hasEncoder(ColumnEncoderUDF.class));
boolean hasWE = false;
int distinctWE = 0;
+ int sizeWE = 0;
for (ColumnEncoder enc : _encoder.getEncoders())
if(enc instanceof ColumnEncoderWordEmbedding){
hasWE = true;
distinctWE =
((ColumnEncoderWordEmbedding) enc).getNrDistinctEmbeddings();
+ sizeWE = ((ColumnEncoderWordEmbedding)
enc).getDomainSize();
}
int numCols = _encoder.getNumOutCols();
boolean hasDC =
_encoder.getColumnEncoders(ColumnEncoderDummycode.class).size() > 0;
long estNNz = (long) _input.getNumRows() * (hasUDF ?
numCols : _input.getNumColumns());
boolean sparse =
MatrixBlock.evalSparseFormatInMemory(_input.getNumRows(), numCols, estNNz) &&
!hasUDF;
_output.reset(_input.getNumRows(), numCols, sparse,
estNNz);
- outputMatrixPreProcessing(_output, _input, hasDC,
hasWE, distinctWE);
+ outputMatrixPreProcessing(_output, _input, hasDC,
hasWE, distinctWE,sizeWE);
return null;
}
diff --git
a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java
b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java
index 3c18cf049b..9ec4aee526 100644
--- a/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java
+++ b/src/test/java/org/apache/sysds/test/component/compress/io/IOTest.java
@@ -32,9 +32,11 @@ import org.apache.sysds.runtime.compress.io.WriterCompressed;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
import org.junit.AfterClass;
+import org.junit.Ignore;
import org.junit.Test;
@net.jcip.annotations.NotThreadSafe
+@Ignore //see corrupted tests TODO move out of component tests
public class IOTest {
protected static final Log LOG =
LogFactory.getLog(IOTest.class.getName());
diff --git a/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java
b/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java
new file mode 100644
index 0000000000..bf19abd8db
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/functions/dnn/LSTMTest.java
@@ -0,0 +1,172 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sysds.test.functions.dnn;
+
+import org.apache.sysds.common.Types;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+public class LSTMTest extends AutomatedTestBase {
+ String TEST_NAME1 = "LSTMForwardTest";
+ String TEST_NAME2 = "LSTMBackwardTest";
+ private final static String TEST_DIR = "functions/tensor/";
+
+ @Override
+ public void setUp() {
+ TestUtils.clearAssertionInformation();
+ addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_DIR, TEST_NAME1));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_DIR, TEST_NAME2));
+ }
+
+ @Test
+ public void testLSTMForwardLocalSingleSample1(){
+ runLSTMTest(1, 32, 1,1, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocalSingleSample2(){
+ runLSTMTest(1, 1, 64,1, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocalSingleSample3(){
+ runLSTMTest(1, 1, 1,2048, TEST_NAME1);
+ }
+
+ //note elias: for large hidden sizes there is discrepancy between
built-in and the dml script
+ @Test
+ public void testLSTMForwardLocalSingleSample4(){
+ runLSTMTest(1, 32, 32,1025, 0,0, 1e-2, TEST_NAME1,false);
+ }
+
+ @Test
+ public void testLSTMForwardLocal1(){
+ runLSTMTest(64, 2, 2,2, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocal2(){
+ runLSTMTest(32, 8, 1,1, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocal3(){
+ runLSTMTest(32, 1, 64,1, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocal4(){
+ runLSTMTest(32, 8, 36,1025, TEST_NAME1);
+ }
+
+ @Test
+ public void testLSTMForwardLocal5(){
+ runLSTMTest(32, 75, 128,256, 0, 1, 1e-3, TEST_NAME1, false);
+ }
+
+ @Test
+ public void testLSTMBackwardLocalSingleSample1(){
+ runLSTMTest(1, 2, 3,4,0,1,1e-5, TEST_NAME2, true);
+ }
+
+ @Test
+ public void testLSTMBackwardLocal1(){
+ runLSTMTest(64, 32, 16,32,0,0,1e-5, TEST_NAME2, true);
+ }
+
+ @Test
+ public void testLSTMBackwardLocal2(){
+ runLSTMTest(64, 32, 16,32,0,1,1e-5, TEST_NAME2, true);
+ }
+
+ @Test
+ @Ignore
+ public void testLSTMForwardLocalLarge(){
+ runLSTMTest(100, 32, 128,64, 0, 1, 1e-5, TEST_NAME1, false);
+ }
+
+ @Test
+ @Ignore
+ public void testLSTMBackwardLocalLarge(){
+ runLSTMTest(128, 128, 128,64, 0, 0, 1e-5, TEST_NAME2, true);
+ }
+
+ private void runLSTMTest(double batch_size, double seq_length, double
num_features, double hidden_size, String testname){
+ runLSTMTest(batch_size, seq_length, num_features,
hidden_size,0, testname);
+ }
+
+ private void runLSTMTest(double batch_size, double seq_length, double
num_features, double hidden_size, int debug, String testname){
+ runLSTMTest(batch_size, seq_length, num_features,
hidden_size,debug, 0, 1e-5, testname, false);
+ }
+
+ private void runLSTMTest(double batch_size, double seq_length, double
num_features, double hidden_size, int debug, int seq, double precision, String
testname, boolean backward)
+ {
+ //set runtime platform
+ Types.ExecMode rtold = setExecMode(Types.ExecMode.SINGLE_NODE);
+ try
+ {
+ getAndLoadTestConfiguration(testname);
+ fullDMLScriptName = getScript();
+
+ //run script
+ //"-explain", "runtime",
+ programArgs = new String[]{"-stats","-args",
String.valueOf(batch_size), String.valueOf(seq_length),
+ String.valueOf(num_features),
String.valueOf(hidden_size), String.valueOf(debug), String.valueOf(seq),
+ output("1A"),output("1B"),output("2A"),
output("2B"),output("3A"),output("3B"),"","","",""};
+ int offset = 0;
+ if(backward){
+ programArgs[14 + offset] = output("4A");
+ programArgs[15 + offset] = output("4B");
+ programArgs[16 + offset] = output("5A");
+ programArgs[17 + offset] = output("5B");
+ }
+ //output("4A"), output("4B"),output("5A"),output("5B")
+ runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+
+ // Compare results
+ extracted(precision,"1");
+ extracted(precision,"2");
+ extracted(precision,"3");
+ if(backward){
+ extracted(precision,"4");
+ extracted(precision,"5");
+ }
+ }
+ catch(Exception ex) {
+ throw new RuntimeException(ex);
+ }
+ finally {
+ resetExecMode(rtold);
+ }
+ }
+
+ private void extracted(double precision, String output) {
+ HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir(output+"A");
+ double[][] resultActualDouble =
TestUtils.convertHashMapToDoubleArray(res_actual);
+ HashMap<MatrixValue.CellIndex, Double> res_expected =
readDMLMatrixFromOutputDir(output+"B");
+ double[][] resultExpectedDouble =
TestUtils.convertHashMapToDoubleArray(res_expected);
+ TestUtils.compareMatrices(resultExpectedDouble,
resultActualDouble, precision);
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
index c189bb47aa..a54703f41f 100644
--- a/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
+++ b/src/test/java/org/apache/sysds/test/functions/io/binary/SerializeTest.java
@@ -21,6 +21,7 @@ package org.apache.sysds.test.functions.io.binary;
import com.google.crypto.tink.subtle.Random;
import org.apache.sysds.lops.Lop;
+import org.apache.sysds.runtime.data.DenseBlockFP64DEDUP;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.transform.encode.EncoderFactory;
import org.apache.sysds.runtime.transform.encode.MultiColumnEncoder;
@@ -196,6 +197,8 @@ public class SerializeTest extends AutomatedTestBase
double[][] X_duplicated = new double[rows*10][];
MatrixBlock mb = new MatrixBlock(rows*10, cols, false,
0, true);
mb.allocateDenseBlock(true, true);
+ DenseBlockFP64DEDUP dedup = (DenseBlockFP64DEDUP)
mb.getDenseBlock();
+ dedup.setEmbeddingSize(cols);
HashMap<double[], Integer > seen = new HashMap<>();
for (int i = 0; i < rows*10; i++) {
int row = Random.randInt(rows);
@@ -205,7 +208,7 @@ public class SerializeTest extends AutomatedTestBase
seen.put(X[row], tmpPos);
}
X_duplicated[i] = X[row];
- mb.quickSetRow(i, X[row]);
+ dedup.setDedupDirectly(i, X[row]);
}
String fname = SCRIPT_DIR + TEST_DIR +
"dedupSerializedBlock.out";
@@ -222,15 +225,16 @@ public class SerializeTest extends AutomatedTestBase
}
//compare matrices - values
+ DenseBlockFP64DEDUP dedup2 = (DenseBlockFP64DEDUP)
mb2.getDenseBlock();
HashMap<double[], Integer > seen2 = new HashMap<>();
- for( int i=0; i<mb.getNumRows(); i++ ){
- double[] row = mb2.getDenseBlock().values(i);
+ for( int i=0;
i<mb.getNumRows()*dedup2.getNrEmbsPerRow(); i++ ){
+ double[] row = dedup2.getDedupDirectly(i);
Integer tmpPos = seen2.get(row);
if(tmpPos == null) {
tmpPos = seen2.size();
seen2.put(row, tmpPos);
}
- Integer posMb1 =
seen.get(mb.getDenseBlock().values(i));
+ Integer posMb1 =
seen.get(dedup.getDedupDirectly(i));
Assert.assertEquals( (long) tmpPos, (long)
posMb1);
}
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
index 4375dcda3d..b8da9a1083 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding1Test.java
@@ -37,9 +37,13 @@ import java.util.List;
import java.util.Map;
import java.util.Random;
+import static
org.apache.sysds.test.functions.transform.TransformFrameEncodeWordEmbedding2Test.manuallyDeriveWordEmbeddings;
+import static
org.apache.sysds.test.functions.transform.TransformFrameEncodeWordEmbedding2Test.manuallyDeriveWordEmbeddingsReshape;
+
public class TransformFrameEncodeWordEmbedding1Test extends AutomatedTestBase
{
private final static String TEST_NAME1 =
"TransformFrameEncodeWordEmbeddings";
+ private final static String TEST_NAME2 =
"TransformFrameEncodeWordEmbeddings1Reshape";
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbedding1Test.class.getSimpleName() + "/";
@@ -47,6 +51,7 @@ public class TransformFrameEncodeWordEmbedding1Test extends
AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
}
@Test
@@ -59,6 +64,12 @@ public class TransformFrameEncodeWordEmbedding1Test extends
AutomatedTestBase
runTransformTest(TEST_NAME1, ExecMode.SPARK);
}
+ @Test
+ public void testTransformToWordEmbeddingsWithReshape() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE);
+ }
+
+
private void runTransformTest(String testname, ExecMode rt)
{
//set runtime platform
@@ -84,11 +95,7 @@ public class TransformFrameEncodeWordEmbedding1Test extends
AutomatedTestBase
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
// Manually derive the expected result
- double[][] res_expected = new
double[stringsColumn.size()][cols];
- for (int i = 0; i < stringsColumn.size(); i++) {
- int rowMapped = map.get(stringsColumn.get(i));
- System.arraycopy(a[rowMapped], 0,
res_expected[i], 0, cols);
- }
+ double[][] res_expected = testname.equals(TEST_NAME2) ?
manuallyDeriveWordEmbeddingsReshape(cols, a, map, stringsColumn, 10) :
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
// Compare results
HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
@@ -100,7 +107,7 @@ public class TransformFrameEncodeWordEmbedding1Test extends
AutomatedTestBase
finally {
resetExecMode(rtold);
}
-}
+ }
public static List<String> shuffleAndMultiplyStrings(List<String>
strings, int multiply){
List<String> out = new ArrayList<>();
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
index b994ae83c3..e4ec5180c5 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbedding2Test.java
@@ -40,8 +40,7 @@ import java.util.Random;
public class TransformFrameEncodeWordEmbedding2Test extends AutomatedTestBase
{
private final static String TEST_NAME1 =
"TransformFrameEncodeWordEmbeddings2";
- private final static String TEST_NAME2a =
"TransformFrameEncodeWordEmbeddings2MultiCols1";
- private final static String TEST_NAME2b =
"TransformFrameEncodeWordEmbeddings2MultiCols2";
+ private final static String TEST_NAME2 =
"TransformFrameEncodeWordEmbeddings2Reshape";
private final static String TEST_DIR = "functions/transform/";
private final static String TEST_CLASS_DIR = TEST_DIR +
TransformFrameEncodeWordEmbedding2Test.class.getSimpleName() + "/";
@@ -50,8 +49,7 @@ public class TransformFrameEncodeWordEmbedding2Test extends
AutomatedTestBase
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME1, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1));
- addTestConfiguration(TEST_NAME2a, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2a));
- addTestConfiguration(TEST_NAME2b, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2b));
+ addTestConfiguration(TEST_NAME2, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2));
}
@Test
@@ -69,7 +67,16 @@ public class TransformFrameEncodeWordEmbedding2Test extends
AutomatedTestBase
runTransformTest(TEST_NAME1, ExecMode.HYBRID);
}
- private void runTransformTest(String testname, ExecMode rt)
+ @Test
+ public void testTransformToWordEmbeddingsWithReshape() {
+ runTransformTest(TEST_NAME2, ExecMode.SINGLE_NODE,10);
+ }
+
+ private void runTransformTest(String testname, ExecMode rt){
+ runTransformTest(testname, rt, 1);
+ }
+
+ private void runTransformTest(String testname, ExecMode rt, int reshape)
{
//set runtime platform
ExecMode rtold = setExecMode(rt);
@@ -89,8 +96,9 @@ public class TransformFrameEncodeWordEmbedding2Test extends
AutomatedTestBase
// Generate the dictionary by assigning unique ID to
each distinct token
Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
+ int multiplier = 320/32;
// Create the dataset by repeating and shuffling the
distinct tokens
- List<String> stringsColumn =
shuffleAndMultiplyStrings(strings, 320);
+ List<String> stringsColumn =
shuffleAndMultiplyStrings(strings, multiplier);
writeStringsToCsvFile(stringsColumn, baseDirectory +
INPUT_DIR + "data");
//run script
@@ -98,11 +106,12 @@ public class TransformFrameEncodeWordEmbedding2Test
extends AutomatedTestBase
runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
// Manually derive the expected result
- double[][] res_expected =
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
+ //double[][] res_expected =
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
+ double[][] res_expected = testname.equals(TEST_NAME2) ?
manuallyDeriveWordEmbeddingsReshape(cols, a, map, stringsColumn, reshape) :
manuallyDeriveWordEmbeddings(cols, a, map, stringsColumn);
// Compare results
HashMap<MatrixValue.CellIndex, Double> res_actual =
readDMLMatrixFromOutputDir("result");
- double[][] resultActualDouble =
TestUtils.convertHashMapToDoubleArray(res_actual, rows*320, cols);
+ double[][] resultActualDouble =
testname.equals(TEST_NAME2) ? TestUtils.convertHashMapToDoubleArray(res_actual,
rows*multiplier / reshape, cols*reshape) :
TestUtils.convertHashMapToDoubleArray(res_actual, rows*multiplier, cols);
TestUtils.compareMatrices(res_expected,
resultActualDouble, 1e-6);
}
catch(Exception ex) {
@@ -123,6 +132,16 @@ public class TransformFrameEncodeWordEmbedding2Test
extends AutomatedTestBase
return res_expected;
}
+ public static double[][] manuallyDeriveWordEmbeddingsReshape(int cols,
double[][] a, Map<String, Integer> map, List<String> stringsColumn, int factor){
+ double[][] res_expected = new double[stringsColumn.size() /
factor][cols*factor];
+ for (int i = 0; i < stringsColumn.size()/ factor; i++)
+ for (int j = 0; j < factor; j++) {
+ int rowMapped =
map.get(stringsColumn.get(i*factor + j));
+ System.arraycopy(a[rowMapped], 0,
res_expected[i], j*cols, cols);
+ }
+ return res_expected;
+ }
+
public static List<String> shuffleAndMultiplyStrings(List<String>
strings, int multiply){
List<String> out = new ArrayList<>();
Random random = new Random();
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
index 79c7fdf388..966c7c465b 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingMMTest.java
@@ -59,8 +59,8 @@ public class TransformFrameEncodeWordEmbeddingMMTest extends
AutomatedTestBase {
Types.ExecMode rtold = setExecMode(rt);
try
{
- int rows = 100;
- int cols = 300;
+ int rows = 10;
+ int cols = 30;
getAndLoadTestConfiguration(testname);
fullDMLScriptName = getScript();
@@ -75,7 +75,7 @@ public class TransformFrameEncodeWordEmbeddingMMTest extends
AutomatedTestBase {
Map<String,Integer> map = writeDictToCsvFile(strings,
baseDirectory + INPUT_DIR + "dict");
// Create the dataset by repeating and shuffling the
distinct tokens
- int factor = 320;
+ int factor = 32;
rows *= factor;
List<String> stringsColumn =
shuffleAndMultiplyStrings(strings, factor);
writeStringsToCsvFile(stringsColumn, baseDirectory +
INPUT_DIR + "data");
diff --git
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
index 1d07469484..d8926fc190 100644
---
a/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/transform/TransformFrameEncodeWordEmbeddingRowSumTest.java
@@ -59,6 +59,11 @@ public class TransformFrameEncodeWordEmbeddingRowSumTest
extends AutomatedTestBa
runDedupRowSumTest(TEST_NAME1, Types.ExecMode.SINGLE_NODE);
}
+ @Test
+ public void testDedupRowSumsSpark() {
+ runDedupRowSumTest(TEST_NAME1, Types.ExecMode.SPARK);
+ }
+
@Test
public void testDedupColSums() {
runDedupColSumTest(TEST_NAME2, Types.ExecMode.SINGLE_NODE);
diff --git a/src/test/scripts/functions/tensor/LSTMBackwardTest.dml
b/src/test/scripts/functions/tensor/LSTMBackwardTest.dml
new file mode 100644
index 0000000000..3645fad53a
--- /dev/null
+++ b/src/test/scripts/functions/tensor/LSTMBackwardTest.dml
@@ -0,0 +1,86 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("scripts/nn/layers/lstm.dml") as lstm
+
+batch_size = as.integer($1)
+seq_length = as.integer($2)
+num_features = as.integer($3)
+hidden_size = as.integer($4)
+debug = as.logical(as.integer($5))
+seq = as.logical(as.integer($6))
+
+[W,b,out0, c0] = lstm::init(batch_size, num_features, hidden_size)
+lstmIn = rand(rows=batch_size, cols=seq_length*num_features, min=-2, max=2,
pdf="uniform")
+W = rand(rows=num_features + hidden_size, cols=hidden_size*4, min=-1, max=1,
pdf="uniform")
+b = rand(rows=1, cols=4*hidden_size, min=-1, max=1, pdf="uniform")
+out0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+c0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+dout = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+if(seq){
+ dout = rand(rows=batch_size, cols=hidden_size*seq_length, min=-1, max=1,
pdf="uniform")
+}
+dc = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+
+#print(toString(b[1,1]))
+#print(toString(W[1,1]))
+#print(toString(lstmIn[1,1]))
+#print(toString(out0[1,1]))
+#print(toString(c0[1,1]))
+
+[out, c, cache_out, cache_c, cache_ifog] = lstm(lstmIn, W, b, out0, c0, seq)
+[out2, c2, cache_out2, cache_c2, cache_ifog2] = lstm::forward(lstmIn,
W,b,seq_length,num_features,seq,out0, c0)
+
+t0 = time()
+[dx, dw, db, dout0, dc0] = lstm_backward(lstmIn, W, b, out0, c0, seq, dout,
dc, cache_out, cache_c, cache_ifog)
+t1 = time()
+[dx2, dw2, db2, dout02, dc02] = lstm::backward(dout, dc, lstmIn,
W,b,seq_length,num_features,seq,out0, c0,cache_out2, cache_c2, cache_ifog2)
+t2 = time()
+
+if(debug){
+ print(toString(out))
+ print(toString(out2))
+}
+
+print(toString(dw[1,1]))
+print(toString(dw2[1,1]))
+#print(toString(dx[1,1]))
+#print(toString(dx2[1,1]))
+#print(toString(db[1,1]))
+#print(toString(db2[1,1]))
+#print(toString(dout0[1,1]))
+#print(toString(dout02[1,1]))
+#print(toString(dc0[1,1]))
+#print(toString(dc02[1,1]))
+
+write(dx, $7, format="text");
+write(dx2, $8, format="text");
+write(dw, $9, format="text");
+write(dw2, $10, format="text");
+write(db, $11, format="text");
+write(db2, $12, format="text");
+write(dout0, $13, format="text");
+write(dout02, $14, format="text");
+write(dc0, $15, format="text");
+write(dc02, $16, format="text");
+
+T = 1000000
+print("built-in took: " + (t1 - t0)/T + " ms")
+print("dml-script took: " + (t2 - t1)/T + " ms")
\ No newline at end of file
diff --git a/src/test/scripts/functions/tensor/LSTMForwardTest.dml
b/src/test/scripts/functions/tensor/LSTMForwardTest.dml
new file mode 100644
index 0000000000..e6e819b997
--- /dev/null
+++ b/src/test/scripts/functions/tensor/LSTMForwardTest.dml
@@ -0,0 +1,66 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+source("scripts/nn/layers/lstm.dml") as lstm
+
+batch_size = as.integer($1)
+seq_length = as.integer($2)
+num_features = as.integer($3)
+hidden_size = as.integer($4)
+debug = as.logical(as.integer($5))
+seq = as.logical(as.integer($6))
+
+[W,b,out0, c0] = lstm::init(batch_size, num_features, hidden_size)
+lstmIn = rand(rows=batch_size, cols=seq_length*num_features, min=-2, max=2,
pdf="uniform")
+W = rand(rows=num_features + hidden_size, cols=hidden_size*4, min=-1, max=1,
pdf="uniform")
+b = rand(rows=1, cols=4*hidden_size, min=-1, max=1, pdf="uniform")
+out0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+c0 = rand(rows=batch_size, cols=hidden_size, min=-1, max=1, pdf="uniform")
+
+#print(toString(b[1,1]))
+#print(toString(W[1,1]))
+#print(toString(lstmIn[1,1]))
+#print(toString(out0[1,1]))
+#print(toString(c0[1,1]))
+
+
+
+t0 = time()
+[out, c, cache_out, cache_c, cache_ifog] = lstm(lstmIn, W, b, out0, c0, seq)
+t1 = time()
+if(debug){
+ print(toString(out))
+}
+[out2, c2, cache_out2, cache_c2, cache_ifog2] = lstm::forward(lstmIn,
W,b,seq_length,num_features,seq,out0, c0)
+t2 = time()
+if(debug){
+ print(toString(out2))
+}
+
+write(cache_out, $7, format="text");
+write(cache_out2, $8, format="text");
+write(cache_c, $9, format="text");
+write(cache_c2, $10, format="text");
+write(cache_ifog, $11, format="text");
+write(cache_ifog2, $12, format="text");
+
+T = 1000000
+print("built-in took: " + (t1 - t0)/T + " ms")
+print("dml-script took: " + (t2 - t1)/T + " ms")
\ No newline at end of file
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings1Reshape.dml
similarity index 79%
copy from
src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
copy to
src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings1Reshape.dml
index c439ef50d7..6d67a2a7d8 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings1Reshape.dml
@@ -21,17 +21,22 @@
# Read the pre-trained word embeddings
E = read($1, rows=100, cols=300, format="text");
+
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($2, data_type="frame", format="csv");
+
# Read the recode map for the distinct tokens
Meta = read($3, data_type="frame", format="csv");
-#Read the matrix that is used for multiplication after transform
-MM = read($4, rows=300, cols=300, format="text");
-jspec = "{ids: true, word_embedding: [1]}";
-Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
-Product = Data_enc %*% MM
-write(Product, $5, format="text");
+jspec = "{ids: true, dummycode: [1]}";
+Data_enc = transformapply(target=Data, spec=jspec, meta=Meta);
+
+# Apply the embeddings on all tokens (1K x 100)
+Data_enc = Data_enc %*% E;
+
+seq_len = 10
+Data_enc = matrix(Data_enc, rows=nrow(Data_enc) / seq_len,
cols=ncol(Data_enc)*seq_len)
+write(Data_enc, $4, format="text");
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2Reshape.dml
similarity index 80%
rename from
src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
rename to
src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2Reshape.dml
index 00484697d6..2559af769a 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2MultiCols1.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddings2Reshape.dml
@@ -20,24 +20,22 @@
#-------------------------------------------------------------
# Read the pre-trained word embeddings
-E = read($1, rows=100, cols=100, format="text");
+E = read($1, rows=100, cols=300, format="text");
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($2, data_type="frame", format="csv");
# Read the recode map for the distinct tokens
Meta = read($3, data_type="frame", format="csv");
-DataExtension = as.frame(matrix(1, rows=length(Data), cols=1))
-Data = cbind(Data, DataExtension)
-Data = cbind(DataExtension, Data)
-Meta = cbind(Meta, Meta)
-
-jspec = "{ids: true, word_embedding: [2]}";
-#jspec = "{ids: true, dummycode: [2]}";
+jspec = "{ids: true, word_embedding: [1]}";
Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);
-Data_enc = Data_enc[,2:101]
+seq_len = 10
+N = nrow(Data_enc) / seq_len
+Data_enc = matrix(Data_enc, rows=N, cols=ncol(Data_enc)*seq_len)
+
write(Data_enc, $4, format="text");
+
diff --git
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
index c439ef50d7..36d299e80f 100644
---
a/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
+++
b/src/test/scripts/functions/transform/TransformFrameEncodeWordEmbeddingsMM.dml
@@ -20,13 +20,13 @@
#-------------------------------------------------------------
# Read the pre-trained word embeddings
-E = read($1, rows=100, cols=300, format="text");
+E = read($1, rows=10, cols=30, format="text");
# Read the token sequence (1K) w/ 100 distinct tokens
Data = read($2, data_type="frame", format="csv");
# Read the recode map for the distinct tokens
Meta = read($3, data_type="frame", format="csv");
#Read the matrix that is used for multiplication after transform
-MM = read($4, rows=300, cols=300, format="text");
+MM = read($4, rows=30, cols=30, format="text");
jspec = "{ids: true, word_embedding: [1]}";
Data_enc = transformapply(target=Data, spec=jspec, meta=Meta, embedding=E);