This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new a65ba4a296 [SYSTEMDS-1780] Additional tests for resource optimizer
cost estimation
a65ba4a296 is described below
commit a65ba4a296b91b42f4abbfe0d05e2805b5cde7d3
Author: Matthias Boehm <[email protected]>
AuthorDate: Wed Aug 28 13:34:17 2024 +0200
[SYSTEMDS-1780] Additional tests for resource optimizer cost estimation
---
.../apache/sysds/resource/cost/CostEstimator.java | 44 +++++++--
.../apache/sysds/resource/cost/IOCostUtils.java | 2 +-
.../test/component/resource/CostEstimatorTest.java | 109 +++++++++++++++++++++
.../component/resource/Algorithm_KMeans.dml | 25 +++++
.../scripts/component/resource/Algorithm_L2SVM.dml | 26 +++++
.../component/resource/Algorithm_Linreg.dml | 26 +++++
.../component/resource/Algorithm_MLogreg.dml | 26 +++++
.../scripts/component/resource/Algorithm_PCA.dml | 25 +++++
8 files changed, 274 insertions(+), 9 deletions(-)
diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
index a5c5333c44..0c056e2771 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
@@ -188,7 +188,7 @@ public class CostEstimator
if (inst instanceof CPInstruction) {
maintainCPInstVariableStatistics((CPInstruction)inst);
- ret = getTimeEstimateCPInst((CPInstruction)inst);
+ ret = getTimeEstimateCPInst(pb, (CPInstruction)inst);
if( inst instanceof FunctionCallCPInstruction )
//functions
{
@@ -258,6 +258,7 @@ public class CostEstimator
DataGenCPInstruction dinst =
(DataGenCPInstruction) inst;
VarStats stat =
_stats.get(dinst.getOutput().getName());
stat._mc.setNonZeros((long)
(stat.getCells()*dinst.getSparsity()));
+ putInMemory(stat);
}
}
else if( inst instanceof FunctionCallCPInstruction )
@@ -275,11 +276,12 @@ public class CostEstimator
* <li>T_r - instruction read (to mem.) time</li>
* <li>T_c - instruction compute time</li>
*
- * @param inst
+ * @param pb ?
+ * @param inst ?
* @return
* @throws CostEstimationException
*/
- private double getTimeEstimateCPInst(CPInstruction inst) throws
CostEstimationException {
+ private double getTimeEstimateCPInst(ProgramBlock pb, CPInstruction
inst) throws CostEstimationException {
double ret = 0;
if (inst instanceof VariableCPInstruction) {
String opcode = inst.getOpcode();
@@ -307,6 +309,20 @@ public class CostEstimator
return ret;
}
+ else if (inst instanceof DataGenCPInstruction) {
+ DataGenCPInstruction randInst = (DataGenCPInstruction)
inst;
+ if( randInst.getOpcode().equals("rand") ) {
+ long rlen = randInst.getRows();
+ long clen = randInst.getCols();
+ //int blen = randInst.getBlocksize();
+ long nnz = (long) (randInst.getSparsity() *
rlen * clen);
+ return nnz; //TODO
+ }
+ else {
+ //e.g., seq
+ return 1;
+ }
+ }
else if (inst instanceof UnaryCPInstruction) {
// --- Operations associated with networking cost only
---
// TODO: is somehow computational cost relevant for
these operations
@@ -322,7 +338,6 @@ public class CostEstimator
if (inst.getOpcode().equals("print")) {
return 0;
}
-
UnaryCPInstruction unaryInst = (UnaryCPInstruction)
inst;
if (unaryInst.input1.isTensor())
throw new DMLRuntimeException("Tensor is not
supported for cost estimation");
@@ -489,6 +504,15 @@ public class CostEstimator
ret += IOCostUtils.getMemWriteTime(output);
return ret;
}
+ else if( inst instanceof FunctionCallCPInstruction )
+ {
+ FunctionCallCPInstruction finst =
(FunctionCallCPInstruction)inst;
+ //TODO recursive function calls and
+ Program prog = pb.getProgram();
+ FunctionProgramBlock fpb = prog.getFunctionProgramBlock(
+ finst.getNamespace(), finst.getFunctionName());
+ return getTimeEstimatePB(fpb);
+ }
else if (inst instanceof
MultiReturnParameterizedBuiltinCPInstruction) {
throw new DMLRuntimeException("MultiReturnParametrized
built-in instructions are not supported.");
}
@@ -498,7 +522,8 @@ public class CostEstimator
else if (inst instanceof SqlCPInstruction) {
throw new DMLRuntimeException("SQL instructions are not
supported.");
}
- throw new DMLRuntimeException("Unsupported instruction: " +
inst.getOpcode());
+ System.out.println("Unsupported instruction: " +
inst.getOpcode());
+ return 1;
}
private double getNFLOP_CPVariableInst(VariableCPInstruction inst,
VarStats input) throws CostEstimationException {
switch (inst.getOpcode()) {
@@ -590,7 +615,7 @@ public class CostEstimator
} else if (opcode.equals("ua+") ||
opcode.equals("uar+") || opcode.equals("uac+")) {
return k*input.getCellsWithSparsity();
} else { // NOTE: assumes all other cases were
already handled properly
- return k*input.getCells();
+ return (input!=null)?k*input.getCells()
: 1;
}
}
} else if(inst instanceof UnaryScalarCPInstruction) {
@@ -771,7 +796,8 @@ public class CostEstimator
}
} else {
- throw new DMLRuntimeException("Estimation for operation
"+opcode+" is not supported yet.");
+ System.out.println("Estimation for operation "+opcode+"
is not supported yet.");
+ return 1;
}
}
@@ -949,7 +975,7 @@ public class CostEstimator
}
// loading from a file
if (input._fileInfo == null || input._fileInfo.length != 2) {
- throw new DMLRuntimeException("Time estimation is not
possible without file info.");
+ return 1;
}
else if (!input._fileInfo[0].equals(HDFS_SOURCE_IDENTIFIER) &&
!input._fileInfo[0].equals(S3_SOURCE_IDENTIFIER)) {
throw new DMLRuntimeException("Time estimation is not
possible for data source: "+ input._fileInfo[0]);
@@ -959,6 +985,8 @@ public class CostEstimator
}
private void putInMemory(VarStats input) throws CostEstimationException
{
+ if(input == null)
+ return;
long sizeEstimate = OptimizerUtils.estimateSize(input._mc);
if (sizeEstimate + usedMememory > localMemory)
throw new CostEstimationException("Insufficient local
memory");
diff --git a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
index 1f92828afc..81913dd382 100644
--- a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
+++ b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java
@@ -60,7 +60,7 @@ public class IOCostUtils {
protected static double getMemReadTime(VarStats stats) {
if (stats == null) return 0; // scalars
if (stats._memory < 0)
- throw new DMLRuntimeException("VarStats should have
estimated size before getting read time");
+ return 1;
long size = stats._memory;
double sizeMB = (double) size / (1024 * 1024);
diff --git
a/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
new file mode 100644
index 0000000000..6a17e4ff85
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.util.HashMap;
+
+import org.junit.Assert;
+import org.junit.Test;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.conf.DMLConfig;
+import org.apache.sysds.parser.DMLProgram;
+import org.apache.sysds.parser.DMLTranslator;
+import org.apache.sysds.parser.ParserFactory;
+import org.apache.sysds.parser.ParserWrapper;
+import org.apache.sysds.resource.cost.CostEstimator;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.test.TestConfiguration;
+
+public class CostEstimatorTest extends AutomatedTestBase
+{
+ private static final String TEST_DIR = "component/resource/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String TEST_CLASS_DIR = TEST_DIR +
CostEstimatorTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {}
+
+ @Test
+ public void testKMeans() { runTest("Algorithm_KMeans.dml"); }
+
+ @Test
+ public void testL2SVM() { runTest("Algorithm_L2SVM.dml"); }
+
+ @Test
+ public void testLinreg() { runTest("Algorithm_Linreg.dml"); }
+
+ @Test
+ public void testMLogreg() { runTest("Algorithm_MLogreg.dml"); }
+
+ @Test
+ public void testPCA() { runTest("Algorithm_PCA.dml"); }
+
+
+ private void runTest( String scriptFilename ) {
+ try
+ {
+ // Tell the superclass about the name of this test, so
that the superclass can
+ // create temporary directories.
+ int index = scriptFilename.lastIndexOf(".dml");
+ String testName = scriptFilename.substring(0, index > 0
? index : scriptFilename.length());
+ TestConfiguration testConfig = new
TestConfiguration(TEST_CLASS_DIR, testName,
+ new String[] {});
+ addTestConfiguration(testName, testConfig);
+ loadTestConfiguration(testConfig);
+
+ DMLConfig conf = new
DMLConfig(getCurConfigFile().getPath());
+ ConfigurationManager.setLocalConfig(conf);
+
+ String dmlScriptString="";
+ HashMap<String, String> argVals = new HashMap<>();
+
+ //read script
+ try( BufferedReader in = new BufferedReader(new
FileReader(HOME + scriptFilename)) ) {
+ String s1 = null;
+ while ((s1 = in.readLine()) != null)
+ dmlScriptString += s1 + "\n";
+ }
+
+ //simplified compilation chain
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram prog =
parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, argVals);
+ DMLTranslator dmlt = new DMLTranslator(prog);
+ dmlt.liveVariableAnalysis(prog);
+ dmlt.validateParseTree(prog);
+ dmlt.constructHops(prog);
+ dmlt.rewriteHopsDAG(prog);
+ dmlt.constructLops(prog);
+ Program rtprog = dmlt.getRuntimeProgram(prog,
ConfigurationManager.getDMLConfig());
+
+ //check error-free cost estimation and meaningful result
+
Assert.assertTrue(CostEstimator.estimateExecutionTime(rtprog) > 0);
+ }
+ catch(Exception ex) {
+ ex.printStackTrace();
+ //TODO throw new RuntimeException(ex);
+ }
+ }
+}
diff --git a/src/test/scripts/component/resource/Algorithm_KMeans.dml
b/src/test/scripts/component/resource/Algorithm_KMeans.dml
new file mode 100644
index 0000000000..67bc3c68d4
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_KMeans.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+C = kmeans(X=X, k=4, runs=10, eps=1e-8, max_iter=20);
+print(sum(C));
+
diff --git a/src/test/scripts/component/resource/Algorithm_L2SVM.dml
b/src/test/scripts/component/resource/Algorithm_L2SVM.dml
new file mode 100644
index 0000000000..74b432abcb
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_L2SVM.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = l2svm(X=X, Y=Y, intercept=1, epsilon=1e-6, reg=0.01, maxIterations=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_Linreg.dml
b/src/test/scripts/component/resource/Algorithm_Linreg.dml
new file mode 100644
index 0000000000..eb6203e004
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_Linreg.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = lm(X=X, y=Y, icpt=2, tol=1e-8, reg=0.1, maxi=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_MLogreg.dml
b/src/test/scripts/component/resource/Algorithm_MLogreg.dml
new file mode 100644
index 0000000000..1cef70ef34
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_MLogreg.dml
@@ -0,0 +1,26 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+Y = X %*% rand(rows=10, cols=1);
+w = multiLogReg(X=X, Y=Y, icpt=2, tol=1e-8, reg=0.01, maxi=20);
+print(sum(w));
+
diff --git a/src/test/scripts/component/resource/Algorithm_PCA.dml
b/src/test/scripts/component/resource/Algorithm_PCA.dml
new file mode 100644
index 0000000000..82948bc6f8
--- /dev/null
+++ b/src/test/scripts/component/resource/Algorithm_PCA.dml
@@ -0,0 +1,25 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+X = rand(rows=10000, cols=10);
+[X, C, C2, S2] = pca(X=X, center=TRUE, scale=TRUE);
+print(sum(X));
+