This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new a65ba4a296 [SYSTEMDS-1780] Additional tests for resource optimizer 
cost estimation
a65ba4a296 is described below

commit a65ba4a296b91b42f4abbfe0d05e2805b5cde7d3
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Wed Aug 28 13:34:17 2024 +0200

    [SYSTEMDS-1780] Additional tests for resource optimizer cost estimation
---
 .../apache/sysds/resource/cost/CostEstimator.java  |  44 +++++++--
 .../apache/sysds/resource/cost/IOCostUtils.java    |   2 +-
 .../test/component/resource/CostEstimatorTest.java | 109 +++++++++++++++++++++
 .../component/resource/Algorithm_KMeans.dml        |  25 +++++
 .../scripts/component/resource/Algorithm_L2SVM.dml |  26 +++++
 .../component/resource/Algorithm_Linreg.dml        |  26 +++++
 .../component/resource/Algorithm_MLogreg.dml       |  26 +++++
 .../scripts/component/resource/Algorithm_PCA.dml   |  25 +++++
 8 files changed, 274 insertions(+), 9 deletions(-)

diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java 
b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
index a5c5333c44..0c056e2771 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
@@ -188,7 +188,7 @@ public class CostEstimator
                if (inst instanceof CPInstruction) {
                        maintainCPInstVariableStatistics((CPInstruction)inst);
 
-                       ret = getTimeEstimateCPInst((CPInstruction)inst);
+                       ret = getTimeEstimateCPInst(pb, (CPInstruction)inst);
 
                        if( inst instanceof FunctionCallCPInstruction ) 
//functions
                        {
@@ -258,6 +258,7 @@ public class CostEstimator
                                DataGenCPInstruction dinst = 
(DataGenCPInstruction) inst;
                                VarStats stat = 
_stats.get(dinst.getOutput().getName());
                                stat._mc.setNonZeros((long) 
(stat.getCells()*dinst.getSparsity()));
+                               putInMemory(stat);
                        }
                }
                else if( inst instanceof FunctionCallCPInstruction )
@@ -275,11 +276,12 @@ public class CostEstimator
         * <li>T_r - instruction read (to mem.) time</li>
         * <li>T_c - instruction compute time</li>
         *
-        * @param inst
+        * @param pb ?
+        * @param inst ?
         * @return
         * @throws CostEstimationException
         */
-       private double getTimeEstimateCPInst(CPInstruction inst) throws 
CostEstimationException {
+       private double getTimeEstimateCPInst(ProgramBlock pb, CPInstruction 
inst) throws CostEstimationException {
                double ret = 0;
                if (inst instanceof VariableCPInstruction) {
                        String opcode = inst.getOpcode();
@@ -307,6 +309,20 @@ public class CostEstimator
 
                        return ret;
                }
+               else if (inst instanceof DataGenCPInstruction) {
+                       DataGenCPInstruction randInst = (DataGenCPInstruction) 
inst;
+                       if( randInst.getOpcode().equals("rand") ) {
+                               long rlen = randInst.getRows();
+                               long clen = randInst.getCols();
+                               //int blen = randInst.getBlocksize();
+                               long nnz = (long) (randInst.getSparsity() * 
rlen * clen);
+                               return nnz; //TODO
+                       }
+                       else {
+                               //e.g., seq
+                               return 1;
+                       }
+               }
                else if (inst instanceof UnaryCPInstruction) {
                        // --- Operations associated with networking cost only 
---
                        // TODO: is somehow computational cost relevant for 
these operations
@@ -322,7 +338,6 @@ public class CostEstimator
                        if (inst.getOpcode().equals("print")) {
                                return 0;
                        }
-
                        UnaryCPInstruction unaryInst = (UnaryCPInstruction) 
inst;
                        if (unaryInst.input1.isTensor())
                                throw new DMLRuntimeException("Tensor is not 
supported for cost estimation");
@@ -489,6 +504,15 @@ public class CostEstimator
                        ret += IOCostUtils.getMemWriteTime(output);
                        return ret;
                }
+               else if( inst instanceof FunctionCallCPInstruction )
+               {
+                       FunctionCallCPInstruction finst = 
(FunctionCallCPInstruction)inst;
+                       //TODO recursive function calls and 
+                       Program prog = pb.getProgram();
+                       FunctionProgramBlock fpb = prog.getFunctionProgramBlock(
+                               finst.getNamespace(), finst.getFunctionName());
+                       return getTimeEstimatePB(fpb);
+               }
                else if (inst instanceof 
MultiReturnParameterizedBuiltinCPInstruction) {
                        throw new DMLRuntimeException("MultiReturnParametrized 
built-in instructions are not supported.");
                }
@@ -498,7 +522,8 @@ public class CostEstimator
                else if (inst instanceof SqlCPInstruction) {
                        throw new DMLRuntimeException("SQL instructions are not 
supported.");
                }
-               throw new DMLRuntimeException("Unsupported instruction: " + 
inst.getOpcode());
+               System.out.println("Unsupported instruction: " + 
inst.getOpcode());
+               return 1;
        }
        private double getNFLOP_CPVariableInst(VariableCPInstruction inst, 
VarStats input) throws CostEstimationException {
                switch (inst.getOpcode()) {
@@ -590,7 +615,7 @@ public class CostEstimator
                                } else if (opcode.equals("ua+") || 
opcode.equals("uar+") || opcode.equals("uac+")) {
                                        return k*input.getCellsWithSparsity();
                                } else { // NOTE: assumes all other cases were 
already handled properly
-                                       return k*input.getCells();
+                                       return (input!=null)?k*input.getCells() 
: 1;
                                }
                        }
                } else if(inst instanceof UnaryScalarCPInstruction) {
@@ -771,7 +796,8 @@ public class CostEstimator
                        }
 
                } else {
-                       throw new DMLRuntimeException("Estimation for operation 
"+opcode+" is not supported yet.");
+                       System.out.println("Estimation for operation "+opcode+" 
is not supported yet.");
+                       return 1;
                }
        }
 
@@ -949,7 +975,7 @@ public class CostEstimator
                }
                // loading from a file
                if (input._fileInfo == null || input._fileInfo.length != 2) {
-                       throw new DMLRuntimeException("Time estimation is not 
possible without file info.");
+                       return 1;
                }
                else if (!input._fileInfo[0].equals(HDFS_SOURCE_IDENTIFIER) && 
!input._fileInfo[0].equals(S3_SOURCE_IDENTIFIER)) {
                        throw new DMLRuntimeException("Time estimation is not 
possible for data source: "+ input._fileInfo[0]);
@@ -959,6 +985,8 @@ public class CostEstimator
        }
 
        private void putInMemory(VarStats input) throws CostEstimationException 
{
+               if(input == null)
+                       return;
                long sizeEstimate = OptimizerUtils.estimateSize(input._mc);
                if (sizeEstimate + usedMememory > localMemory)
                        throw new CostEstimationException("Insufficient local 
memory");
diff --git a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java 
b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
index 1f92828afc..81913dd382 100644
--- a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
@@ -60,7 +60,7 @@ public class IOCostUtils {
        protected static double getMemReadTime(VarStats stats) {
                if (stats == null) return 0; // scalars
                if (stats._memory < 0)
-                       throw new DMLRuntimeException("VarStats should have 
estimated size before getting read time");
+                       return 1;
                long size = stats._memory;
                double sizeMB = (double) size / (1024 * 1024);
 
diff --git 
a/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java 
b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
new file mode 100644
index 0000000000..6a17e4ff85
--- /dev/null
+++ 
b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ * 
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ * 
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.resource.cost.CostEstimator;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class CostEstimatorTest extends AutomatedTestBase
+{
+       private static final String TEST_DIR = "component/resource/";
+       private static final String HOME = SCRIPT_DIR + TEST_DIR;
+       private static final String TEST_CLASS_DIR = TEST_DIR + 
CostEstimatorTest.class.getSimpleName() + "/";
+       
+       @Override
+       public void setUp() {}
+       
+       @Test
+       public void testKMeans() { runTest("Algorithm_KMeans.dml"); }
+
+       @Test
+       public void testL2SVM() { runTest("Algorithm_L2SVM.dml"); }
+
+       @Test
+       public void testLinreg() { runTest("Algorithm_Linreg.dml"); }
+
+       @Test
+       public void testMLogreg() { runTest("Algorithm_MLogreg.dml"); }
+
+       @Test
+       public void testPCA() { runTest("Algorithm_PCA.dml"); }
+
+       
+       private void runTest( String scriptFilename ) {
+               try
+               {
+                       // Tell the superclass about the name of this test, so 
that the superclass can
+                       // create temporary directories.
+                       int index = scriptFilename.lastIndexOf(".dml");
+                       String testName = scriptFilename.substring(0, index > 0 
? index : scriptFilename.length());
+                       TestConfiguration testConfig = new 
TestConfiguration(TEST_CLASS_DIR, testName, 
+                                       new String[] {});
+                       addTestConfiguration(testName, testConfig);
+                       loadTestConfiguration(testConfig);
+                       
+                       DMLConfig conf = new 
DMLConfig(getCurConfigFile().getPath());
+                       ConfigurationManager.setLocalConfig(conf);
+                       
+                       String dmlScriptString="";
+                       HashMap<String, String> argVals = new HashMap<>();
+                       
+                       //read script
+                       try( BufferedReader in = new BufferedReader(new 
FileReader(HOME + scriptFilename)) ) {
+                               String s1 = null;
+                               while ((s1 = in.readLine()) != null)
+                                       dmlScriptString += s1 + "\n";
+                       }
+                       
+                       //simplified compilation chain
+                       ParserWrapper parser = ParserFactory.createParser();
+                       DMLProgram prog = 
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, argVals);
+                       DMLTranslator dmlt = new DMLTranslator(prog);
+                       dmlt.liveVariableAnalysis(prog);
+                       dmlt.validateParseTree(prog);
+                       dmlt.constructHops(prog);
+                       dmlt.rewriteHopsDAG(prog);
+                       dmlt.constructLops(prog);
+                       Program rtprog = dmlt.getRuntimeProgram(prog, 
ConfigurationManager.getDMLConfig());
+                       
+                       //check error-free cost estimation and meaningful result
+                       
Assert.assertTrue(CostEstimator.estimateExecutionTime(rtprog) > 0);
+               }
+               catch(Exception ex) {
+                       ex.printStackTrace();
+                       //TODO throw new RuntimeException(ex);
+               }
+       }
+}
diff --git a/src/test/scripts/component/resource/Algorithm_KMeans.dml 
b/src/test/scripts/component/resource/Algorithm_KMeans.dml
new file mode 100644
index 0000000000..67bc3c68d4
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_KMeans.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+C = kmeans(X=X, k=4, runs=10, eps=1e-8, max_iter=20);
+print(sum(C));
+
diff --git a/src/test/scripts/component/resource/Algorithm_L2SVM.dml 
b/src/test/scripts/component/resource/Algorithm_L2SVM.dml
new file mode 100644
index 0000000000..74b432abcb
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_L2SVM.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = l2svm(X=X, Y=Y, intercept=1, epsilon=1e-6, reg=0.01, maxIterations=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_Linreg.dml 
b/src/test/scripts/component/resource/Algorithm_Linreg.dml
new file mode 100644
index 0000000000..eb6203e004
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_Linreg.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = lm(X=X, y=Y, icpt=2, tol=1e-8, reg=0.1, maxi=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_MLogreg.dml 
b/src/test/scripts/component/resource/Algorithm_MLogreg.dml
new file mode 100644
index 0000000000..1cef70ef34
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_MLogreg.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = multiLogReg(X=X, Y=Y, icpt=2, tol=1e-8, reg=0.01, maxi=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_PCA.dml 
b/src/test/scripts/component/resource/Algorithm_PCA.dml
new file mode 100644
index 0000000000..82948bc6f8
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_PCA.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+[X, C, C2, S2] = pca(X=X, center=TRUE, scale=TRUE);
+print(sum(X));
+

Reply via email to