This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/master by this push:
     new 28742bc  [MINOR] Fix robustness and cleanup lmPredict, more gridSearch 
tests
28742bc is described below

commit 28742bc397887c5ab9c9a8f3193c311ebd4b9e39
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Jun 2 22:19:48 2021 +0200

    [MINOR] Fix robustness and cleanup lmPredict, more gridSearch tests
---
 scripts/builtin/lmPredict.dml                      |  6 +--
 .../runtime/compress/cocode/CoCodeCostTSMM.java    |  2 +-
 .../runtime/compress/lib/BitmapLossyEncoder.java   |  3 ++
 .../runtime/compress/lib/CLALibRightMultBy.java    |  2 +-
 .../functions/builtin/BuiltinGridSearchTest.java   | 35 ++++++++++++++---
 .../scripts/functions/builtin/GridSearchLM2.dml    | 44 ++++++++++++----------
 6 files changed, 62 insertions(+), 30 deletions(-)

diff --git a/scripts/builtin/lmPredict.dml b/scripts/builtin/lmPredict.dml
index 3a7ead7..f53e326 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/scripts/builtin/lmPredict.dml
@@ -23,11 +23,11 @@ m_lmPredict = function(Matrix[Double] X, Matrix[Double] B,
   Matrix[Double] ytest, Integer icpt = 0, Boolean verbose = FALSE) 
   return (Matrix[Double] yhat)
 {
-  intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
-  yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
+  intercept = ifelse(icpt>0 | ncol(X)+1==nrow(B), as.scalar(B[nrow(B),]), 0);
+  yhat = X %*% B[1:ncol(X),] + intercept;
 
   if( verbose ) {
-    y_residual =  ytest - yhat;
+    y_residual = ytest - yhat;
     avg_res = sum(y_residual) / nrow(ytest);
     ss_res = sum(y_residual^2);
     ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeCostTSMM.java 
b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeCostTSMM.java
index f31c53f..7635061 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeCostTSMM.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/cocode/CoCodeCostTSMM.java
@@ -149,7 +149,7 @@ public class CoCodeCostTSMM extends AColumnCoCoder {
                return cost;
        }
 
-       private double getCostOfSelfTSMM(CompressedSizeInfoColGroup g) {
+       private static double getCostOfSelfTSMM(CompressedSizeInfoColGroup g) {
                double cost = 0;
                final int nCol = g.getColumns().length;
                cost += g.getNumVals() * (nCol * (nCol + 1)) / 2;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/BitmapLossyEncoder.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/BitmapLossyEncoder.java
index 88711fe..55fde03 100644
--- 
a/src/main/java/org/apache/sysds/runtime/compress/lib/BitmapLossyEncoder.java
+++ 
b/src/main/java/org/apache/sysds/runtime/compress/lib/BitmapLossyEncoder.java
@@ -89,6 +89,7 @@ public class BitmapLossyEncoder {
         * @param numRows The number of Rows.
         * @return a lossy bitmap.
         */
+       @SuppressWarnings("unused")
        private static BitmapLossy make8BitLossy(Bitmap ubm, Stats stats, int 
numRows) {
                final double[] fp = ubm.getValues();
                int numCols = ubm.getNumColumns();
@@ -284,6 +285,7 @@ public class BitmapLossyEncoder {
                }
        }
 
+       @SuppressWarnings("unused")
        private static double[] getMemLocalDoubleArray(int length, boolean 
clean) {
                double[] ar = memPoolDoubleArray.get();
                if(ar != null && ar.length >= length) {
@@ -312,6 +314,7 @@ public class BitmapLossyEncoder {
                protected double maxDelta;
                protected boolean sameDelta;
 
+               @SuppressWarnings("unused")
                public Stats(double[] fp) {
                        max = Double.NEGATIVE_INFINITY;
                        min = Double.POSITIVE_INFINITY;
diff --git 
a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java 
b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
index 61e275c..d486fa8 100644
--- a/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
+++ b/src/main/java/org/apache/sysds/runtime/compress/lib/CLALibRightMultBy.java
@@ -149,7 +149,7 @@ public class CLALibRightMultBy {
        }
 
        private static ColGroupEmpty 
findEmptyColumnsAndMakeEmptyColGroup(List<AColGroup> colGroups, int nCols, int 
nRows) {
-               Set<Integer> emptyColumns = new HashSet<Integer>(nCols);
+               Set<Integer> emptyColumns = new HashSet<>(nCols);
                for(int i = 0; i < nCols; i++)
                        emptyColumns.add(i);
 
diff --git 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
index 7d4449b..34504c9 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/builtin/BuiltinGridSearchTest.java
@@ -23,16 +23,16 @@ import org.junit.Assert;
 import org.junit.Test;
 
 import org.apache.sysds.common.Types.ExecMode;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.test.AutomatedTestBase;
 import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.test.TestUtils;
-
+import org.apache.sysds.utils.Statistics;
 
 public class BuiltinGridSearchTest extends AutomatedTestBase
 {
        private final static String TEST_NAME1 = "GridSearchLM";
        private final static String TEST_NAME2 = "GridSearchMLogreg";
+       private final static String TEST_NAME3 = "GridSearchLM2";
        private final static String TEST_DIR = "functions/builtin/";
        private final static String TEST_CLASS_DIR = TEST_DIR + 
BuiltinGridSearchTest.class.getSimpleName() + "/";
        
@@ -43,24 +43,45 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
        public void setUp() {
                addTestConfiguration(TEST_NAME1,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME1,new String[]{"R"}));
                addTestConfiguration(TEST_NAME2,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME2,new String[]{"R"}));
+               addTestConfiguration(TEST_NAME3,new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME3,new String[]{"R"}));
        }
        
        @Test
        public void testGridSearchLmCP() {
-               runGridSearch(TEST_NAME1, ExecType.CP);
+               runGridSearch(TEST_NAME1, ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void testGridSearchLmHybrid() {
+               runGridSearch(TEST_NAME1, ExecMode.HYBRID);
        }
        
        @Test
        public void testGridSearchLmSpark() {
-               runGridSearch(TEST_NAME1, ExecType.SPARK);
+               runGridSearch(TEST_NAME1, ExecMode.SPARK);
        }
        
        @Test
        public void testGridSearchMLogregCP() {
-               runGridSearch(TEST_NAME2, ExecType.CP);
+               runGridSearch(TEST_NAME2, ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void testGridSearchMLogregHybrid() {
+               runGridSearch(TEST_NAME2, ExecMode.HYBRID);
+       }
+       
+       @Test
+       public void testGridSearchLm2CP() {
+               runGridSearch(TEST_NAME3, ExecMode.SINGLE_NODE);
+       }
+       
+       @Test
+       public void testGridSearchLm2Hybrid() {
+               runGridSearch(TEST_NAME3, ExecMode.HYBRID);
        }
        
-       private void runGridSearch(String testname, ExecType et)
+       private void runGridSearch(String testname, ExecMode et)
        {
                ExecMode modeOld = setExecMode(et);
                try {
@@ -78,6 +99,8 @@ public class BuiltinGridSearchTest extends AutomatedTestBase
                        
                        //expected loss smaller than default invocation
                        
Assert.assertTrue(TestUtils.readDMLBoolean(output("R")));
+                       if( et != ExecMode.SPARK )
+                               Assert.assertEquals(0, 
Statistics.getNoOfExecutedSPInst());
                }
                finally {
                        resetExecMode(modeOld);
diff --git a/scripts/builtin/lmPredict.dml 
b/src/test/scripts/functions/builtin/GridSearchLM2.dml
similarity index 54%
copy from scripts/builtin/lmPredict.dml
copy to src/test/scripts/functions/builtin/GridSearchLM2.dml
index 3a7ead7..278d94c 100644
--- a/scripts/builtin/lmPredict.dml
+++ b/src/test/scripts/functions/builtin/GridSearchLM2.dml
@@ -19,24 +19,30 @@
 #
 #-------------------------------------------------------------
 
-m_lmPredict = function(Matrix[Double] X, Matrix[Double] B, 
-  Matrix[Double] ytest, Integer icpt = 0, Boolean verbose = FALSE) 
-  return (Matrix[Double] yhat)
+l2norm = function(Matrix[Double] X, Matrix[Double] y, Matrix[Double] B) 
+  return (Matrix[Double] loss)
 {
-  intercept = ifelse(icpt==0, matrix(0,1,ncol(B)), B[nrow(B),]);
-  yhat = X %*% B[1:ncol(X)] + matrix(1,nrow(X),1) %*% intercept;
-
-  if( verbose ) {
-    y_residual =  ytest - yhat;
-    avg_res = sum(y_residual) / nrow(ytest);
-    ss_res = sum(y_residual^2);
-    ss_avg_res = ss_res - nrow(ytest) * avg_res^2;
-    R2 = 1 - ss_res / (sum(ytest^2) - nrow(ytest) * 
(sum(ytest)/nrow(ytest))^2);
-    print("\nAccuracy:" +
-          "\n--sum(ytest) = " + sum(ytest) +
-          "\n--sum(yhat) = " + sum(yhat) +
-          "\n--AVG_RES_Y: " + avg_res +
-          "\n--SS_AVG_RES_Y: " + ss_avg_res +
-          "\n--R2: " + R2 );
-  }
+  yhat = lmPredict(X=X, B=B, ytest=y)
+  loss = as.matrix(sum((y - yhat)^2));
 }
+
+X = read($1);
+y = read($2);
+
+N = 200;
+Xtrain = X[1:N,];
+ytrain = y[1:N,];
+Xtest = X[(N+1):nrow(X),];
+ytest = y[(N+1):nrow(X),];
+
+params = list("icpt","reg", "tol", "maxi");
+paramRanges = list(seq(0,1,2),10^seq(0,-4), 10^seq(-6,-12), 10^seq(1,3));
+[B1, opt] = gridSearch(X=Xtrain, y=ytrain, train="lm", predict="l2norm",
+  numB=ncol(X)+1, params=params, paramValues=paramRanges);
+B2 = lm(X=Xtrain, y=ytrain, verbose=FALSE);
+
+l1 = l2norm(Xtest, ytest, B1);
+l2 = l2norm(Xtest, ytest, B2);
+R = as.scalar(l1 < l2);
+
+write(R, $3)

Reply via email to