This is an automated email from the ASF dual-hosted git repository. niketanpansare pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/systemml.git
The following commit(s) were added to refs/heads/master by this push: new 0cabde0 [SYSTEMML-540] Improve the performance of lstm builtin function 0cabde0 is described below commit 0cabde0ca26c99a55c62f7e7ffac67b450dea850 Author: Niketan Pansare <npan...@us.ibm.com> AuthorDate: Wed Feb 27 21:03:15 2019 -0800 [SYSTEMML-540] Improve the performance of lstm builtin function - Allow FunctionOp to be multi-threaded. - Currently, only lstm builtin function will have number of threads > 1. - Added more tests. --- .../java/org/apache/sysml/hops/FunctionOp.java | 10 +++- .../java/org/apache/sysml/lops/FunctionCallCP.java | 14 +++++- .../runtime/instructions/cp/DnnCPInstruction.java | 13 ++--- .../sysml/runtime/matrix/data/LibMatrixDNN.java | 55 ++++++++++++++++------ .../org/apache/sysml/test/gpu/LstmCPUTest.java | 50 ++++++++++++++++++++ 5 files changed, 118 insertions(+), 24 deletions(-) diff --git a/src/main/java/org/apache/sysml/hops/FunctionOp.java b/src/main/java/org/apache/sysml/hops/FunctionOp.java index 66ce478..5fdc8e7 100644 --- a/src/main/java/org/apache/sysml/hops/FunctionOp.java +++ b/src/main/java/org/apache/sysml/hops/FunctionOp.java @@ -39,7 +39,7 @@ import org.apache.sysml.runtime.controlprogram.parfor.opt.CostEstimatorHops; * Note: Currently, we support expressions in function arguments along with function calls * in expressions with single outputs, leaving multiple outputs handling as it is. */ -public class FunctionOp extends Hop +public class FunctionOp extends MultiThreadedHop { public enum FunctionType{ DML, @@ -253,8 +253,14 @@ public class FunctionOp extends Hop tmp.add( in.constructLops() ); //construct function call + int numThreads = 0; + if(getFunctionType() == FunctionType.MULTIRETURN_BUILTIN && isBuiltinFunction() && et == ExecType.CP && + (getFunctionName().equalsIgnoreCase("lstm") || getFunctionName().equalsIgnoreCase("lstm_backward"))) { + numThreads = OptimizerUtils.getConstrainedNumThreads(_maxNumThreads); + } + Lop fcall = _singleOutFun ? new FunctionCallCPSingle( tmp, _fnamespace, _fname, et ) : - new FunctionCallCP(tmp, _fnamespace, _fname, _inputNames, _outputNames, _outputHops, et); + new FunctionCallCP(tmp, _fnamespace, _fname, _inputNames, _outputNames, _outputHops, et, numThreads); setLineNumbers(fcall); setLops(fcall); diff --git a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java index 50d43de..237b806 100644 --- a/src/main/java/org/apache/sysml/lops/FunctionCallCP.java +++ b/src/main/java/org/apache/sysml/lops/FunctionCallCP.java @@ -38,10 +38,12 @@ public class FunctionCallCP extends Lop private String[] _inputNames; private String[] _outputNames; private ArrayList<Lop> _outputLops = null; + private int _numThreads; public FunctionCallCP(ArrayList<Lop> inputs, String fnamespace, String fname, - String[] inputNames, String[] outputNames, ArrayList<Hop> outputHops, ExecType et) { + String[] inputNames, String[] outputNames, ArrayList<Hop> outputHops, ExecType et, int numThreads) { this(inputs, fnamespace, fname, inputNames, outputNames, et); + _numThreads = numThreads; if(outputHops != null) { _outputLops = new ArrayList<>(); setLevel(); @@ -104,6 +106,11 @@ public class FunctionCallCP extends Lop sb.append(_outputNames[i]); } + if(_numThreads > 0) { + sb.append(Lop.OPERAND_DELIMITOR); + sb.append(_numThreads); + } + return sb.toString(); } @@ -145,6 +152,11 @@ public class FunctionCallCP extends Lop inst.append(Lop.OPERAND_DELIMITOR); inst.append(out); } + + if(_numThreads > 0) { + inst.append(Lop.OPERAND_DELIMITOR); + inst.append(_numThreads); + } return inst.toString(); } diff --git a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java index 4043908..93ffd4f 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/cp/DnnCPInstruction.java @@ -103,7 +103,7 @@ public class DnnCPInstruction extends UnaryCPInstruction { public DnnCPInstruction(CPOperand in1, CPOperand in2, CPOperand in3, CPOperand in4, CPOperand in5, CPOperand in6, CPOperand in7, CPOperand in8, CPOperand out, CPOperand out2, CPOperand out3, CPOperand out4, CPOperand out5, String opcode, String istr, - double intermediateMemoryBudget) throws DMLRuntimeException { + double intermediateMemoryBudget, int numThreads) throws DMLRuntimeException { super(CPType.Dnn, null, in1, out, opcode, istr); _in2 = in2; _in3 = in3; @@ -120,7 +120,7 @@ public class DnnCPInstruction extends UnaryCPInstruction { _padding = null; _input_shape = null; _filter_shape = null; - _numThreads = 0; + _numThreads = numThreads; _intermediateMemoryBudget = intermediateMemoryBudget; } @@ -246,7 +246,7 @@ public class DnnCPInstruction extends UnaryCPInstruction { CPOperand out3 = new CPOperand(parts[11]); // retRunningVar CPOperand out4 = new CPOperand(parts[12]); // resultSaveMean CPOperand out5 = new CPOperand(parts[13]); // resultSaveInvVariance - return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0); + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, in7, in8, out, out2, out3, out4, out5, opcode, str, 0, 0); } else if (opcode.equalsIgnoreCase("batch_norm2d_backward")) { InstructionUtils.checkNumFields(parts, 9); @@ -259,10 +259,10 @@ public class DnnCPInstruction extends UnaryCPInstruction { CPOperand out = new CPOperand(parts[7]); // dX CPOperand out2 = new CPOperand(parts[8]); // dScale CPOperand out3 = new CPOperand(parts[9]); // dBias - return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0); + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, out3, null, null, opcode, str, 0, 0); } else if (opcode.equalsIgnoreCase("lstm")) { - InstructionUtils.checkNumFields(parts, 8); + InstructionUtils.checkNumFields(parts, 9); CPOperand in1 = new CPOperand(parts[1]); // X CPOperand in2 = new CPOperand(parts[2]); // W CPOperand in3 = new CPOperand(parts[3]); // b @@ -271,7 +271,8 @@ public class DnnCPInstruction extends UnaryCPInstruction { CPOperand in6 = new CPOperand(parts[6]); // return_seq CPOperand out = new CPOperand(parts[7]); // out CPOperand out2 = new CPOperand(parts[8]); // c - return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0); + int numThreads = Integer.parseInt(parts[9]); + return new DnnCPInstruction(in1, in2, in3, in4, in5, in6, null, null, out, out2, null, null, null, opcode, str, 0, numThreads); } else { throw new DMLRuntimeException("Unknown opcode while parsing a DnnCPInstruction: " + str); diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java index 179b5d3..0f932ba 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixDNN.java @@ -37,16 +37,16 @@ import org.apache.sysml.runtime.functionobjects.Builtin; import org.apache.sysml.runtime.functionobjects.KahanPlus; import org.apache.sysml.runtime.functionobjects.Multiply; import org.apache.sysml.runtime.functionobjects.Plus; -import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.PlusMultiply; import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.instructions.cp.KahanObject; import org.apache.sysml.runtime.matrix.operators.AggregateBinaryOperator; import org.apache.sysml.runtime.matrix.operators.AggregateOperator; import org.apache.sysml.runtime.matrix.operators.BinaryOperator; +import org.apache.sysml.runtime.matrix.operators.TernaryOperator; import org.apache.sysml.runtime.matrix.operators.UnaryOperator; import org.apache.sysml.runtime.util.CommonThreadPool; import org.apache.sysml.runtime.util.DnnUtils; -import org.apache.sysml.runtime.util.IndexRange; /* * This class allows users to invoke deep learning related operations @@ -282,11 +282,26 @@ public class LibMatrixDNN { return ret; } - private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2) { - return (MatrixBlock) matBlock1.binaryOperations(new BinaryOperator(Plus.getPlusFnObject()), matBlock2, new MatrixBlock()); + private static MatrixBlock add(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) { + BinaryOperator bop = new BinaryOperator(Plus.getPlusFnObject()); +// if(inplace) { +// matBlock1.binaryOperationsInPlace(bop, matBlock2); +// return matBlock1; +// } +// else { + return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock()); +// } } - private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2) { - return (MatrixBlock) matBlock1.binaryOperations(new BinaryOperator(Multiply.getMultiplyFnObject()), matBlock2, new MatrixBlock()); + + private static MatrixBlock multiply(MatrixBlock matBlock1, MatrixBlock matBlock2, boolean inplace) { + BinaryOperator bop = new BinaryOperator(Multiply.getMultiplyFnObject()); +// if(inplace) { +// matBlock1.binaryOperationsInPlace(bop, matBlock2); +// return matBlock1; +// } +// else { + return (MatrixBlock) matBlock1.binaryOperations(bop, matBlock2, new MatrixBlock()); +// } } // sigmoid(0)*c_prev + sigmoid(0)*tanh(0); @@ -296,10 +311,16 @@ public class LibMatrixDNN { private static MatrixBlock sigmoid(MatrixBlock in, int numThreads, boolean inPlace) { return (MatrixBlock) in.unaryOperations(new UnaryOperator(sigmoidOp, numThreads, inPlace), new MatrixBlock()); } + private static MatrixBlock tanh(MatrixBlock in, int numThreads, boolean inPlace) { return (MatrixBlock) in.unaryOperations(new UnaryOperator(tanhOp, numThreads, inPlace), new MatrixBlock()); } + private static MatrixBlock plusMultiply(MatrixBlock matBlock1, MatrixBlock matBlock2, MatrixBlock matBlock3) { + return matBlock1.ternaryOperations(new TernaryOperator(PlusMultiply.getFnObject()), + matBlock2, matBlock3, new MatrixBlock()); + } + public static void lstm(MatrixBlock X, MatrixBlock W, MatrixBlock b, MatrixBlock out0, MatrixBlock c0, boolean return_seq, int N, int T, int D, int M, MatrixBlock out, MatrixBlock c, // output @@ -314,19 +335,23 @@ public class LibMatrixDNN { MatrixBlock out_t = null; for(int t = 1; t <= T; t++) { MatrixBlock X_t = X.slice(0, N-1, (t-1)*D, t*D-1, new MatrixBlock()); - MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads)), b); - MatrixBlock i = ifog_raw.slice(0, N-1, 0, M-1, new MatrixBlock()); - MatrixBlock f = ifog_raw.slice(0, N-1, M, 2*M-1, new MatrixBlock()); - MatrixBlock o = ifog_raw.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock()); + MatrixBlock ifog_raw = add(add(matmult(X_t, W1, numThreads), matmult(out_prev, W2, numThreads), true), b, true); + + MatrixBlock ifo = ifog_raw.slice(0, N-1, 0, 3*M-1, new MatrixBlock()); + ifo = sigmoid(ifo, numThreads, true); + MatrixBlock i = ifo.slice(0, N-1, 0, M-1, new MatrixBlock()); + MatrixBlock f = ifo.slice(0, N-1, M, 2*M-1, new MatrixBlock()); + MatrixBlock o = ifo.slice(0, N-1, 2*M, 3*M-1, new MatrixBlock()); + MatrixBlock g = ifog_raw.slice(0, N-1, 3*M, 4*M-1, new MatrixBlock()); - i = sigmoid(i, numThreads, true); - f = sigmoid(f, numThreads, true); - o = sigmoid(o, numThreads, true); g = tanh(g, numThreads, true); + // c_t = f*c_prev + i*g - c_t = add(multiply(f, c_prev) , multiply(i, g)); + c_t = plusMultiply(multiply(f, c_prev, true), i, g); + // out_t = o*tanh(c) - out_t = multiply(o, tanh(c_t, numThreads, false)); + out_t = multiply(o, tanh(c_t, numThreads, false), true); + if(return_seq) { out = out.leftIndexingOperations(out_t, 0, N-1, (t-1)*M, t*M-1, new MatrixBlock(), UpdateType.INPLACE); } diff --git a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java index 3aa37ad..785c890 100644 --- a/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java +++ b/src/test/java/org/apache/sysml/test/gpu/LstmCPUTest.java @@ -48,6 +48,46 @@ public class LstmCPUTest extends GPUTests { } @Test + public void testLstmForward1() { + testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.2); + } + + @Test + public void testLstmForward2() { + testLstmCuDNNWithNNLayer(1, 1, 1, 1, "FALSE", 0.1); + } + + @Test + public void testLstmForward3() { + testLstmCuDNNWithNNLayer(20, 13, 50, 10, "TRUE", 0.15); + } + + @Test + public void testLstmForward4() { + testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.1); + } + + @Test + public void testLstmForward5() { + testLstmCuDNNWithNNLayer(20, 13, 1, 10, "TRUE", 0.5); + } + + @Test + public void testLstmForward6() { + testLstmCuDNNWithNNLayer(20, 13, 1, 10, "FALSE", 0.3); + } + + @Test + public void testLstmForward7() { + testLstmCuDNNWithNNLayer(20, 13, 4, 1, "TRUE", 0.8); + } + + @Test + public void testLstmForward8() { + testLstmCuDNNWithNNLayer(20, 13, 4, 1, "FALSE", 0.9); + } + + @Test public void testLstmForward9() { testLstmCuDNNWithNNLayer(1, 1, 1, 1, "TRUE", 0.9); } @@ -67,6 +107,16 @@ public class LstmCPUTest extends GPUTests { testLstmCuDNNWithNNLayer(20, 13, 50, 10, "FALSE", 0.9); } + @Test + public void testLstmForward13() { + testLstmCuDNNWithNNLayer(20, 1, 4, 10, "TRUE", 0.8); + } + + @Test + public void testLstmForward14() { + testLstmCuDNNWithNNLayer(20, 1, 4, 10, "FALSE", 0.9); + } + public void testLstmCuDNNWithNNLayer(int N, int T, int D, int M, String returnSequences, double sparsity) { String scriptStr1 = "source(" + builtinDML + ") as lstm;\n " + "[output, c] = lstm::forward(x, w, b, " + returnSequences + ", out0, c0)";