This is an automated email from the ASF dual-hosted git repository.

janardhan pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new f66da42  [SYSTEMDS-2597] Verify PageRank script works with MLContext
f66da42 is described below

commit f66da4272afa28c3dafc25575c46510508d8de9c
Author: Janardhan Pulivarthi <[email protected]>
AuthorDate: Thu Aug 6 12:15:17 2020 +0530

    [SYSTEMDS-2597] Verify PageRank script works with MLContext
    
      * Tests PageRank script with MLContext against an R scripts
      * keeps consistency of fullRScriptName throughout `AutomatedTestBase`
    
    Closes #1005.
---
 .../org/apache/sysds/test/AutomatedTestBase.java   |  4 +-
 .../functions/mlcontext/MLContextPageRankTest.java | 90 ++++++++++++++++++++++
 .../functions/mlcontext/MLContextTestBase.java     |  2 +-
 3 files changed, 94 insertions(+), 2 deletions(-)

diff --git a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java 
b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
index 56845c5..8c1973b 100644
--- a/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
+++ b/src/test/java/org/apache/sysds/test/AutomatedTestBase.java
@@ -905,7 +905,9 @@ public abstract class AutomatedTestBase {
         */
        protected void runRScript(boolean newWay) {
 
-               String executionFile = sourceDirectory + selectedTest + ".R";
+               String executionFile = sourceDirectory + selectedTest + ".R";;
+               if(fullRScriptName != null)
+                       executionFile = fullRScriptName;
 
                // *** HACK ALERT *** HACK ALERT *** HACK ALERT ***
                // Some of the R scripts will fail if the "expected" directory 
doesn't exist.
diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextPageRankTest.java
 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextPageRankTest.java
new file mode 100644
index 0000000..880b286
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextPageRankTest.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.functions.mlcontext;
+
+import org.apache.log4j.Logger;
+import org.apache.sysds.api.mlcontext.Script;
+import org.apache.sysds.runtime.matrix.data.MatrixBlock;
+import org.apache.sysds.runtime.matrix.data.MatrixValue;
+import org.apache.sysds.test.TestUtils;
+import org.junit.Test;
+
+import java.util.HashMap;
+
+import static org.apache.sysds.api.mlcontext.ScriptFactory.dmlFromFile;
+
+public class MLContextPageRankTest extends MLContextTestBase {
+       protected static Logger log = 
Logger.getLogger(MLContextPageRankTest.class);
+
+       protected final static String TEST_SCRIPT_PAGERANK = 
"scripts/staging/PageRank.dml";
+       private final static double sparsity1 = 0.41; // dense
+       private final static double sparsity2 = 0.05; // sparse
+
+       private final static double eps = 0.1;
+
+       private final static int rows = 1468;
+       private final static int cols = 1468;
+       private final static double alpha = 0.85;
+       private final static double maxiter = 10;
+
+       @Test
+       public void testPageRankSparse() {
+               runPageRankTestMLC(true);
+       }
+
+       @Test
+       public void testPageRankDense() {
+               runPageRankTestMLC(false);
+       }
+
+
+       private void runPageRankTestMLC(boolean sparse) {
+
+               //generate actual datasets
+               double[][] G = getRandomMatrix(rows, cols, 1, 1, 
sparse?sparsity2:sparsity1, 234);
+               double[][] p = getRandomMatrix(cols, 1, 0, 1e-14, 1, 71);
+               double[][] e = getRandomMatrix(rows, 1, 0, 1e-14, 1, 72);
+               double[][] u = getRandomMatrix(1, cols, 0, 1e-14, 1, 73);
+               writeInputMatrixWithMTD("G", G, true);
+               writeInputMatrixWithMTD("p", p, true);
+               writeInputMatrixWithMTD("e", e, true);
+               writeInputMatrixWithMTD("u", u, true);
+
+
+               fullRScriptName = 
"src/test/scripts/functions/codegenalg/Algorithm_PageRank.R";
+
+               rCmd = getRCmd(inputDir(), String.valueOf(alpha),
+                               String.valueOf(maxiter), expectedDir());
+               runRScript(true);
+
+               MatrixBlock outmat = new MatrixBlock();
+
+               Script pr = dmlFromFile(TEST_SCRIPT_PAGERANK);
+               pr.in("G", G).in("p", p).in("e", e).in("u", u)
+                               .in("$5", alpha).in("$6", maxiter)
+                               .out("p");
+               outmat = ml.execute(pr).getMatrix("p").toMatrixBlock();
+
+
+               //compare matrices
+               HashMap<MatrixValue.CellIndex, Double> rfile = 
readRMatrixFromFS("p");
+               TestUtils.compareMatrices(rfile, outmat, eps);
+       }
+}
diff --git 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
index ce1abf3..1bf764c 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/mlcontext/MLContextTestBase.java
@@ -67,7 +67,7 @@ public abstract class MLContextTestBase extends 
AutomatedTestBase {
        @Override
        public void setUp() {
                Class<? extends MLContextTestBase> clazz = this.getClass();
-               String dir = (testDir == null) ? "functions/mlcontext" : 
testDir;
+               String dir = (testDir == null) ? "functions/mlcontext/" : 
testDir;
                String name = (testName == null) ? clazz.getSimpleName() : 
testName;
 
                addTestConfiguration(dir, name);

Reply via email to