Repository: systemml Updated Branches: refs/heads/master 0529350a3 -> b9b273d87
[SYSTEMML-540] Added support for GPU relu, scalar min and max operations Project: http://git-wip-us.apache.org/repos/asf/systemml/repo Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/b9b273d8 Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/b9b273d8 Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/b9b273d8 Branch: refs/heads/master Commit: b9b273d87acd2643962307692c156783f7ce7543 Parents: 0529350 Author: Niketan Pansare <npan...@us.ibm.com> Authored: Fri Jan 19 16:14:51 2018 -0800 Committer: Niketan Pansare <npan...@us.ibm.com> Committed: Fri Jan 19 16:14:51 2018 -0800 ---------------------------------------------------------------------- .../java/org/apache/sysml/hops/BinaryOp.java | 2 +- .../instructions/GPUInstructionParser.java | 2 + .../gpu/BuiltinBinaryGPUInstruction.java | 12 +++- .../gpu/ScalarMatrixBuiltinGPUInstruction.java | 72 ++++++++++++++++++++ .../runtime/matrix/data/LibMatrixCUDA.java | 5 +- 5 files changed, 88 insertions(+), 5 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/systemml/blob/b9b273d8/src/main/java/org/apache/sysml/hops/BinaryOp.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/hops/BinaryOp.java b/src/main/java/org/apache/sysml/hops/BinaryOp.java index 1553b7a..d3e36a3 100644 --- a/src/main/java/org/apache/sysml/hops/BinaryOp.java +++ b/src/main/java/org/apache/sysml/hops/BinaryOp.java @@ -161,7 +161,7 @@ public class BinaryOp extends Hop OpOp2 [] supportedOps = { OpOp2.MULT, OpOp2.PLUS, OpOp2.MINUS, OpOp2.DIV, OpOp2.POW, OpOp2.MINUS1_MULT, OpOp2.MODULUS, OpOp2.INTDIV, OpOp2.LESS, OpOp2.LESSEQUAL, OpOp2.EQUAL, OpOp2.NOTEQUAL, OpOp2.GREATER, OpOp2.GREATEREQUAL}; - if(isMatrixScalar && op == OpOp2.MINUS_NZ) { + if(isMatrixScalar && (op == OpOp2.MINUS_NZ || op == OpOp2.MIN || op == OpOp2.MAX)) { // Only supported for matrix scalar: return true; } http://git-wip-us.apache.org/repos/asf/systemml/blob/b9b273d8/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java index 138b4f5..3c19b1a 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/GPUInstructionParser.java @@ -103,6 +103,8 @@ public class GPUInstructionParser extends InstructionParser // Binary Builtin functions String2GPUInstructionType.put( "solve", GPUINSTRUCTION_TYPE.BuiltinBinary); + String2GPUInstructionType.put( "min", GPUINSTRUCTION_TYPE.BuiltinBinary); + String2GPUInstructionType.put( "max", GPUINSTRUCTION_TYPE.BuiltinBinary); // Aggregate Unary String2GPUInstructionType.put( "ua+" , GPUINSTRUCTION_TYPE.AggregateUnary); // Sum http://git-wip-us.apache.org/repos/asf/systemml/blob/b9b273d8/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java index 36016e7..b145c09 100644 --- a/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java @@ -64,14 +64,20 @@ public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction { // Determine appropriate Function Object based on opcode ValueFunction func = Builtin.getBuiltinFnObject(opcode); + + boolean isMatrixMatrix = in1.getDataType() == Expression.DataType.MATRIX && in2.getDataType() == Expression.DataType.MATRIX; + boolean isMatrixScalar = (in1.getDataType() == Expression.DataType.MATRIX && in2.getDataType() == Expression.DataType.SCALAR) || + (in1.getDataType() == Expression.DataType.SCALAR && in2.getDataType() == Expression.DataType.MATRIX); - // Only for "solve" if ( in1.getDataType() == Expression.DataType.SCALAR && in2.getDataType() == Expression.DataType.SCALAR ) throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on 2 scalars"); - else if ( in1.getDataType() == Expression.DataType.MATRIX && in2.getDataType() == Expression.DataType.MATRIX ) + else if ( isMatrixMatrix && opcode.equals("solve") ) return new MatrixMatrixBuiltinGPUInstruction(new BinaryOperator(func), in1, in2, out, opcode, str, 2); + else if ( isMatrixScalar && (opcode.equals("min") || opcode.equals("max")) ) + return new ScalarMatrixBuiltinGPUInstruction(new BinaryOperator(func), in1, in2, out, opcode, str, 2); + else - throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on a matrix and a scalar"); + throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations on a matrix and a scalar:" + opcode); } http://git-wip-us.apache.org/repos/asf/systemml/blob/b9b273d8/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixBuiltinGPUInstruction.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixBuiltinGPUInstruction.java b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixBuiltinGPUInstruction.java new file mode 100644 index 0000000..f8e024f --- /dev/null +++ b/src/main/java/org/apache/sysml/runtime/instructions/gpu/ScalarMatrixBuiltinGPUInstruction.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysml.runtime.instructions.gpu; + +import org.apache.sysml.parser.Expression.DataType; +import org.apache.sysml.runtime.DMLRuntimeException; +import org.apache.sysml.runtime.controlprogram.caching.MatrixObject; +import org.apache.sysml.runtime.controlprogram.context.ExecutionContext; +import org.apache.sysml.runtime.instructions.InstructionUtils; +import org.apache.sysml.runtime.instructions.cp.CPOperand; +import org.apache.sysml.runtime.instructions.cp.ScalarObject; +import org.apache.sysml.runtime.matrix.data.LibMatrixCUDA; +import org.apache.sysml.runtime.matrix.data.LibMatrixCuDNN; +import org.apache.sysml.runtime.matrix.operators.Operator; +import org.apache.sysml.utils.GPUStatistics; + +public class ScalarMatrixBuiltinGPUInstruction extends BuiltinBinaryGPUInstruction { + + protected ScalarMatrixBuiltinGPUInstruction(Operator op, CPOperand input1, CPOperand input2, CPOperand output, + String opcode, String istr, int _arity) { + super(op, input1, input2, output, opcode, istr, _arity); + _gputype = GPUINSTRUCTION_TYPE.BuiltinUnary; + } + + @Override + public void processInstruction(ExecutionContext ec) throws DMLRuntimeException { + GPUStatistics.incrementNoOfExecutedGPUInst(); + + String opcode = getOpcode(); + CPOperand mat = ( input1.getDataType() == DataType.MATRIX ) ? input1 : input2; + CPOperand scalar = ( input1.getDataType() == DataType.MATRIX ) ? input2 : input1; + MatrixObject in1 = getMatrixInputForGPUInstruction(ec, mat.getName()); + ScalarObject constant = (ScalarObject) ec.getScalarInput(scalar.getName(), scalar.getValueType(), scalar.isLiteral()); + + if(opcode.equals("max")) { + ec.setMetaData(output.getName(), in1.getNumRows(), in1.getNumColumns()); + double constVal = constant.getDoubleValue(); + if(constVal == 0) + LibMatrixCuDNN.relu(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, output.getName()); + else + LibMatrixCUDA.matrixScalarOp(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, output.getName(), false, + InstructionUtils.parseScalarBinaryOperator(opcode, false, constVal)); + } else if(opcode.equals("min")) { + ec.setMetaData(output.getName(), in1.getNumRows(), in1.getNumColumns()); + double constVal = constant.getDoubleValue(); + LibMatrixCUDA.matrixScalarOp(ec, ec.getGPUContext(0), getExtendedOpcode(), in1, output.getName(), false, + InstructionUtils.parseScalarBinaryOperator(opcode, false, constVal)); + } else { + throw new DMLRuntimeException("Unsupported GPU operator:" + opcode); + } + ec.releaseMatrixInputForGPUInstruction(mat.getName()); + ec.releaseMatrixOutputForGPUInstruction(output.getName()); + } + +} http://git-wip-us.apache.org/repos/asf/systemml/blob/b9b273d8/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java ---------------------------------------------------------------------- diff --git a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java index 4f2de29..63c57c9 100644 --- a/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java +++ b/src/main/java/org/apache/sysml/runtime/matrix/data/LibMatrixCUDA.java @@ -61,6 +61,7 @@ import org.apache.sysml.runtime.functionobjects.ReduceCol; import org.apache.sysml.runtime.functionobjects.ReduceDiag; import org.apache.sysml.runtime.functionobjects.ReduceRow; import org.apache.sysml.runtime.functionobjects.ValueFunction; +import org.apache.sysml.runtime.functionobjects.Builtin.BuiltinCode; import org.apache.sysml.runtime.instructions.cp.DoubleObject; import org.apache.sysml.runtime.instructions.gpu.GPUInstruction; import org.apache.sysml.runtime.instructions.gpu.context.CSRPointer; @@ -1310,7 +1311,7 @@ public class LibMatrixCUDA { * @param op operator * @throws DMLRuntimeException if DMLRuntimeException occurs */ - private static void matrixScalarOp(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, boolean isInputTransposed, + public static void matrixScalarOp(ExecutionContext ec, GPUContext gCtx, String instName, MatrixObject in, String outputName, boolean isInputTransposed, ScalarOperator op) throws DMLRuntimeException { if (ec.getGPUContext(0) != gCtx) throw new DMLRuntimeException("GPU : Invalid internal state, the GPUContext set with the ExecutionContext is not the same used to run this LibMatrixCUDA function"); @@ -1604,6 +1605,8 @@ public class LibMatrixCUDA { else if(fn instanceof MinusNz) return 16; else if(fn instanceof Modulus) return 17; else if(fn instanceof IntegerDivide) return 18; + else if(fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()==BuiltinCode.MIN) return 11; + else if(fn instanceof Builtin && ((Builtin)fn).getBuiltinCode()==BuiltinCode.MAX) return 12; throw new DMLRuntimeException("The given value function is not supported:" + fn.getClass().getName()); }