This is an automated email from the ASF dual-hosted git repository.

arnabp20 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 444df94  [SYSTEMDS-2893] Lineage tracing GPU instructions
444df94 is described below

commit 444df94ed8e990322e2a3a7a6fc7e3a0545b46da
Author: arnabp <[email protected]>
AuthorDate: Thu Mar 11 19:12:33 2021 +0100

    [SYSTEMDS-2893] Lineage tracing GPU instructions
    
    This patch extends lineage tracing for GPU instructions.
    This is the initial version and only a few operations
    are supported at this moment. Furthermore, this patch adds
    a test which tunes hyper-parameters for LM in GPU, recomputes
    the result from lineage in CPU and finally matches the results.
---
 .../gpu/AggregateBinaryGPUInstruction.java         | 12 ++-
 .../gpu/ArithmeticBinaryGPUInstruction.java        | 13 +++-
 .../gpu/BuiltinBinaryGPUInstruction.java           | 71 ++++++++++-------
 .../gpu/MatrixMatrixBuiltinGPUInstruction.java     | 39 +++++-----
 .../instructions/gpu/ReorgGPUInstruction.java      | 12 ++-
 .../runtime/lineage/LineageRecomputeUtils.java     | 21 ++++-
 src/test/java/org/apache/sysds/test/TestUtils.java |  8 ++
 .../test/functions/builtin/BuiltinCsplineTest.java |  6 +-
 .../functions/lineage/LineageTraceGPUTest.java     | 91 ++++++++++++++++++++++
 .../scripts/functions/lineage/LineageTraceGPU1.dml | 49 ++++++++++++
 10 files changed, 264 insertions(+), 58 deletions(-)

diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
index 93cda1a..4d8f70c 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/AggregateBinaryGPUInstruction.java
@@ -18,6 +18,7 @@
  */
 package org.apache.sysds.runtime.instructions.gpu;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
@@ -26,6 +27,9 @@ import org.apache.sysds.runtime.functionobjects.Plus;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCuMatMult;
 import org.apache.sysds.runtime.matrix.data.MatrixBlock;
@@ -34,7 +38,7 @@ import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.utils.GPUStatistics;
 
-public class AggregateBinaryGPUInstruction extends GPUInstruction {
+public class AggregateBinaryGPUInstruction extends GPUInstruction implements 
LineageTraceable {
        private CPOperand _input1 = null;
        private CPOperand _input2 = null;
        private CPOperand _output = null;
@@ -97,4 +101,10 @@ public class AggregateBinaryGPUInstruction extends 
GPUInstruction {
                MatrixObject mo = ec.getMatrixObject(var);
                return LibMatrixCUDA.isInSparseFormat(ec.getGPUContext(0), mo);
        }
+
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               return Pair.of(_output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, _input1, _input2)));
+       }
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/ArithmeticBinaryGPUInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/ArithmeticBinaryGPUInstruction.java
index 43d4e12..f451910 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/ArithmeticBinaryGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/ArithmeticBinaryGPUInstruction.java
@@ -19,13 +19,18 @@
 
 package org.apache.sysds.runtime.instructions.gpu;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
-public abstract class ArithmeticBinaryGPUInstruction extends GPUInstruction {
+public abstract class ArithmeticBinaryGPUInstruction extends GPUInstruction 
implements LineageTraceable {
        protected CPOperand _input1;
        protected CPOperand _input2;
        protected CPOperand _output;
@@ -65,4 +70,10 @@ public abstract class ArithmeticBinaryGPUInstruction extends 
GPUInstruction {
                else
                        throw new DMLRuntimeException("Unsupported GPU 
ArithmeticInstruction.");
        }
+
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               return Pair.of(_output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, _input1, _input2)));
+       }
 }
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
index 937c692..82d3222 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/BuiltinBinaryGPUInstruction.java
@@ -19,17 +19,22 @@
 
 package org.apache.sysds.runtime.instructions.gpu;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.ValueType;
 import org.apache.sysds.runtime.DMLRuntimeException;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.Builtin;
 import org.apache.sysds.runtime.functionobjects.ValueFunction;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.matrix.operators.BinaryOperator;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 
-public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction {
+public abstract class BuiltinBinaryGPUInstruction extends GPUInstruction 
implements LineageTraceable {
        @SuppressWarnings("unused")
        private int _arity;
 
@@ -45,42 +50,48 @@ public abstract class BuiltinBinaryGPUInstruction extends 
GPUInstruction {
                this.input2 = input2;
        }
 
-  public static BuiltinBinaryGPUInstruction parseInstruction(String str) {
-    CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
-    CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
-    CPOperand out = new CPOperand("", ValueType.UNKNOWN, DataType.UNKNOWN);
+       public static BuiltinBinaryGPUInstruction parseInstruction(String str) {
+               CPOperand in1 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
+               CPOperand in2 = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
+               CPOperand out = new CPOperand("", ValueType.UNKNOWN, 
DataType.UNKNOWN);
 
-    String[] parts = InstructionUtils.getInstructionPartsWithValueType(str);
-    InstructionUtils.checkNumFields ( parts, 3 );
+               String[] parts = 
InstructionUtils.getInstructionPartsWithValueType(str);
+               InstructionUtils.checkNumFields(parts, 3);
 
-    String opcode = parts[0];
-    in1.split(parts[1]);
-    in2.split(parts[2]);
-    out.split(parts[3]);
+               String opcode = parts[0];
+               in1.split(parts[1]);
+               in2.split(parts[2]);
+               out.split(parts[3]);
 
-    // check for valid data type of output
-    if((in1.getDataType() == DataType.MATRIX || in2.getDataType() == 
DataType.MATRIX) && out.getDataType() != DataType.MATRIX)
-      throw new DMLRuntimeException("Element-wise matrix operations between 
variables " + in1.getName() +
-              " and " + in2.getName() + " must produce a matrix, which " + 
out.getName() + " is not");
+               // check for valid data type of output
+               if ((in1.getDataType() == DataType.MATRIX || in2.getDataType() 
== DataType.MATRIX) &&
+                               out.getDataType() != DataType.MATRIX)
+                       throw new DMLRuntimeException("Element-wise matrix 
operations between variables " + in1.getName() + " and "
+                               + in2.getName() + " must produce a matrix, 
which " + out.getName() + " is not");
 
-    // Determine appropriate Function Object based on opcode
-    ValueFunction func = Builtin.getBuiltinFnObject(opcode);
-    
-    boolean isMatrixMatrix = in1.getDataType() == DataType.MATRIX && 
in2.getDataType() == DataType.MATRIX;
-    boolean isMatrixScalar = (in1.getDataType() == DataType.MATRIX && 
in2.getDataType() == DataType.SCALAR) || 
-                                                       (in1.getDataType() == 
DataType.SCALAR && in2.getDataType() == DataType.MATRIX);
+               // Determine appropriate Function Object based on opcode
+               ValueFunction func = Builtin.getBuiltinFnObject(opcode);
 
-    if ( in1.getDataType() == DataType.SCALAR && in2.getDataType() == 
DataType.SCALAR )
-      throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations 
on 2 scalars");
-    else if ( isMatrixMatrix && opcode.equals("solve") )
-      return new MatrixMatrixBuiltinGPUInstruction(new BinaryOperator(func), 
in1, in2, out, opcode, str, 2);
-    else if ( isMatrixScalar && (opcode.equals("min") || opcode.equals("max")) 
)
-        return new ScalarMatrixBuiltinGPUInstruction(new BinaryOperator(func), 
in1, in2, out, opcode, str, 2);
+               boolean isMatrixMatrix = in1.getDataType() == DataType.MATRIX 
&& in2.getDataType() == DataType.MATRIX;
+               boolean isMatrixScalar = (in1.getDataType() == DataType.MATRIX 
&& in2.getDataType() == DataType.SCALAR) ||
+                               (in1.getDataType() == DataType.SCALAR && 
in2.getDataType() == DataType.MATRIX);
 
-    else
-      throw new DMLRuntimeException("GPU : Unsupported GPU builtin operations 
on a matrix and a scalar:" + opcode);
+               if (in1.getDataType() == DataType.SCALAR && in2.getDataType() 
== DataType.SCALAR)
+                       throw new DMLRuntimeException("GPU : Unsupported GPU 
builtin operations on 2 scalars");
+               else if (isMatrixMatrix && opcode.equals("solve"))
+                       return new MatrixMatrixBuiltinGPUInstruction(new 
BinaryOperator(func), in1, in2, out, opcode, str, 2);
+               else if (isMatrixScalar && (opcode.equals("min") || 
opcode.equals("max")))
+                       return new ScalarMatrixBuiltinGPUInstruction(new 
BinaryOperator(func), in1, in2, out, opcode, str, 2);
 
+               else
+                       throw new DMLRuntimeException(
+                               "GPU : Unsupported GPU builtin operations on a 
matrix and a scalar:" + opcode);
+       }
 
-  }
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               return Pair.of(output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, input1, input2)));
+       }
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
index b31eb72..d1c6a9b 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/MatrixMatrixBuiltinGPUInstruction.java
@@ -35,24 +35,25 @@ public class MatrixMatrixBuiltinGPUInstruction extends 
BuiltinBinaryGPUInstructi
                _gputype = GPUINSTRUCTION_TYPE.BuiltinUnary;
        }
 
-  @Override
-  public void processInstruction(ExecutionContext ec) {
-    GPUStatistics.incrementNoOfExecutedGPUInst();
-
-    String opcode = getOpcode();
-    MatrixObject mat1 = getMatrixInputForGPUInstruction(ec, input1.getName());
-    MatrixObject mat2 = getMatrixInputForGPUInstruction(ec, input2.getName());
-
-    if(opcode.equals("solve")) {
-      ec.setMetaData(output.getName(), mat1.getNumColumns(), 1);
-      LibMatrixCUDA.solve(ec, ec.getGPUContext(0), getExtendedOpcode(), mat1, 
mat2, output.getName());
-
-    } else {
-      throw new DMLRuntimeException("Unsupported GPU operator:" + opcode);
-    }
-    ec.releaseMatrixInputForGPUInstruction(input1.getName());
-    ec.releaseMatrixInputForGPUInstruction(input2.getName());
-    ec.releaseMatrixOutputForGPUInstruction(output.getName());
-  }
+       @Override
+       public void processInstruction(ExecutionContext ec) {
+               GPUStatistics.incrementNoOfExecutedGPUInst();
+
+               String opcode = getOpcode();
+               MatrixObject mat1 = getMatrixInputForGPUInstruction(ec, 
input1.getName());
+               MatrixObject mat2 = getMatrixInputForGPUInstruction(ec, 
input2.getName());
+
+               if (opcode.equals("solve")) {
+                       ec.setMetaData(output.getName(), mat1.getNumColumns(), 
1);
+                       LibMatrixCUDA.solve(ec, ec.getGPUContext(0), 
getExtendedOpcode(), mat1, mat2, output.getName());
+
+               }
+               else {
+                       throw new DMLRuntimeException("Unsupported GPU 
operator:" + opcode);
+               }
+               ec.releaseMatrixInputForGPUInstruction(input1.getName());
+               ec.releaseMatrixInputForGPUInstruction(input2.getName());
+               ec.releaseMatrixOutputForGPUInstruction(output.getName());
+       }
 
 }
diff --git 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/ReorgGPUInstruction.java
 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/ReorgGPUInstruction.java
index fbc579c..a3e36d0 100644
--- 
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/ReorgGPUInstruction.java
+++ 
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/ReorgGPUInstruction.java
@@ -19,18 +19,22 @@
 
 package org.apache.sysds.runtime.instructions.gpu;
 
+import org.apache.commons.lang3.tuple.Pair;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.functionobjects.SwapIndex;
 import org.apache.sysds.runtime.instructions.InstructionUtils;
 import org.apache.sysds.runtime.instructions.cp.CPOperand;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
+import org.apache.sysds.runtime.lineage.LineageTraceable;
 import org.apache.sysds.runtime.matrix.data.LibMatrixCUDA;
 import org.apache.sysds.runtime.matrix.operators.Operator;
 import org.apache.sysds.runtime.matrix.operators.ReorgOperator;
 import org.apache.sysds.utils.GPUStatistics;
 
-public class ReorgGPUInstruction extends GPUInstruction {
+public class ReorgGPUInstruction extends GPUInstruction implements 
LineageTraceable {
        private CPOperand _input;
        private CPOperand _output;
 
@@ -80,4 +84,10 @@ public class ReorgGPUInstruction extends GPUInstruction {
                ec.releaseMatrixInputForGPUInstruction(_input.getName());
                ec.releaseMatrixOutputForGPUInstruction(_output.getName());
        }
+
+       @Override
+       public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
+               return Pair.of(_output.getName(), new LineageItem(getOpcode(),
+                       LineageItemUtils.getLineage(ec, _input)));
+       }
 }
\ No newline at end of file
diff --git 
a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java 
b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
index 66e1932..3bf36a1 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageRecomputeUtils.java
@@ -89,7 +89,21 @@ public class LineageRecomputeUtils {
                LineageItem root = LineageParser.parseLineageTrace(mainTrace);
                if (dedupPatches != null)
                        LineageParser.parseLineageTraceDedup(dedupPatches);
+
+               // Disable GPU execution. TODO: Support GPU
+               boolean GPUenabled = false;
+               if (DMLScript.USE_ACCELERATOR) {
+                       GPUenabled = true;
+                       DMLScript.USE_ACCELERATOR = false;
+               }
+               // Reset statistics
+               if (DMLScript.STATISTICS)
+                       Statistics.reset();
+
                Data ret = computeByLineage(root);
+
+               if (GPUenabled)
+                       DMLScript.USE_ACCELERATOR = true;
                // Cleanup the statics
                loopPatchMap.clear();
                return ret;
@@ -115,9 +129,9 @@ public class LineageRecomputeUtils {
                partDagRoots.put(varname, out);
                constructBasicBlock(partDagRoots, varname, prog);
                
-               // Reset cache due to cleaned data objects
+               // Reset cache to avoid erroneous reuse
                LineageCache.resetCache();
-               //execute instructions and get result
+               // Execute instructions and get result
                if (DEBUG) {
                        DMLScript.STATISTICS = true;
                        ExplainCounts counts = 
Explain.countDistributedOperations(prog);
@@ -125,10 +139,11 @@ public class LineageRecomputeUtils {
                }
                ec.setProgram(prog);
                prog.execute(ec);
-               if (DEBUG) {
+               if (DEBUG || DMLScript.STATISTICS) {
                        Statistics.stopRunTimer();
                        
System.out.println(Statistics.display(DMLScript.STATISTICS_COUNT));
                }
+
                return ec.getVariable(varname);
        }
        
diff --git a/src/test/java/org/apache/sysds/test/TestUtils.java 
b/src/test/java/org/apache/sysds/test/TestUtils.java
index b8c6304..6244833 100644
--- a/src/test/java/org/apache/sysds/test/TestUtils.java
+++ b/src/test/java/org/apache/sysds/test/TestUtils.java
@@ -78,6 +78,8 @@ import org.apache.sysds.runtime.util.DataConverter;
 import org.apache.sysds.runtime.util.UtilFunctions;
 import org.junit.Assert;
 
+import jcuda.runtime.JCuda;
+
 
 /**
  * <p>
@@ -3063,4 +3065,10 @@ public class TestUtils
                                return true;
                return false;
        }
+       
+       public static int isGPUAvailable() {
+               // returns cudaSuccess if at least one gpu is available
+               final int[] deviceCount = new int[1];
+               return JCuda.cudaGetDeviceCount(deviceCount);
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinCsplineTest.java 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinCsplineTest.java
index 7846fa8..41ec221 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinCsplineTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinCsplineTest.java
@@ -17,7 +17,7 @@
  * under the License.
  */
  
-package org.apache.sysds.test.applications;
+package org.apache.sysds.test.functions.builtin;
 
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -43,8 +43,8 @@ public class BuiltinCsplineTest extends AutomatedTestBase {
        
        protected int numRecords;
        private final static int numDim = 1;
-    
-    public BuiltinCsplineTest(int rows, int cols) {
+
+       public BuiltinCsplineTest(int rows, int cols) {
                numRecords = rows;
        }
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceGPUTest.java
 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceGPUTest.java
new file mode 100644
index 0000000..405ead7
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/lineage/LineageTraceGPUTest.java
@@ -0,0 +1,91 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.lineage;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+
+import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
+import org.apache.sysds.runtime.instructions.cp.Data;
+import org.apache.sysds.runtime.lineage.Lineage;
+import org.apache.sysds.runtime.lineage.LineageRecomputeUtils;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue.CellIndex;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import jcuda.runtime.cudaError;
+
+public class LineageTraceGPUTest extends AutomatedTestBase{
+       
+       protected static final String TEST_DIR = "functions/lineage/";
+       protected static final String TEST_NAME1 = "LineageTraceGPU1"; 
+       protected String TEST_CLASS_DIR = TEST_DIR + 
LineageTraceGPUTest.class.getSimpleName() + "/";
+       
+       protected static final int numRecords = 10;
+       protected static final int numFeatures = 5;
+       
+       
+       @Override
+       public void setUp() {
+               TestUtils.clearAssertionInformation();
+               addTestConfiguration( TEST_NAME1, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1, new String[] {"R"}) );
+       }
+       
+       @Test
+       public void simpleHLM_gpu() {              //hyper-parameter tuning 
over LM (simple)
+               testLineageTraceExec(TEST_NAME1);
+       }
+       
+       private void testLineageTraceExec(String testname) {
+               System.out.println("------------ BEGIN " + testname + 
"------------");
+               
+               int gpuStatus = TestUtils.isGPUAvailable(); 
+               getAndLoadTestConfiguration(testname);
+               List<String> proArgs = new ArrayList<>();
+               
+               proArgs.add("-stats");
+               if (gpuStatus == cudaError.cudaSuccess)
+                       proArgs.add("-gpu");
+               proArgs.add("-lineage");
+               proArgs.add("-args");
+               proArgs.add(output("R"));
+               proArgs.add(String.valueOf(numRecords));
+               proArgs.add(String.valueOf(numFeatures));
+               programArgs = proArgs.toArray(new String[proArgs.size()]);
+               fullDMLScriptName = getScript();
+               
+               Lineage.resetInternalState();
+               //run the test
+               runTest(true, EXCEPTION_NOT_EXPECTED, null, -1);
+               
+               //get lineage and generate program
+               String Rtrace = readDMLLineageFromHDFS("R");
+               //NOTE: the generated program is CP-only.
+               Data ret = 
LineageRecomputeUtils.parseNComputeLineageTrace(Rtrace, null);
+               
+               HashMap<CellIndex, Double> dmlfile = 
readDMLMatrixFromOutputDir("R");
+               MatrixBlock tmp = ((MatrixObject)ret).acquireReadAndRelease();
+               TestUtils.compareMatrices(dmlfile, tmp, 1e-6);
+       }
+}
diff --git a/src/test/scripts/functions/lineage/LineageTraceGPU1.dml 
b/src/test/scripts/functions/lineage/LineageTraceGPU1.dml
new file mode 100644
index 0000000..dd223ac
--- /dev/null
+++ b/src/test/scripts/functions/lineage/LineageTraceGPU1.dml
@@ -0,0 +1,49 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+SimlinRegDS = function(Matrix[Double] X, Matrix[Double] y, Double lamda, 
Integer N)
+ return (Matrix[double] beta)
+{
+  A = (t(X) %*% X) + diag(matrix(lamda, rows=N, cols=1));
+  b = t(X) %*% y;
+  beta = solve(A, b);
+}
+
+no_lamda = 10;
+
+stp = (0.1 - 0.0001)/no_lamda;
+lamda = 0.0001;
+lim = 0.1;
+
+X = rand(rows=1000, cols=100, seed=42);
+y = rand(rows=1000, cols=1, seed=42);
+N = ncol(X);
+R = matrix(0, rows=N, cols=no_lamda+2);
+i = 1;
+
+while (lamda < lim)
+{
+  beta = SimlinRegDS(X, y, lamda, N);
+  R[,i] = beta;
+  lamda = lamda + stp;
+  i = i + 1;
+}
+write(R, $1);
+

Reply via email to