This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new a65ba4a296 [SYSTEMDS-1780] Additional tests for resource optimizer cost estimation a65ba4a296 is described below commit a65ba4a296b91b42f4abbfe0d05e2805b5cde7d3 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Wed Aug 28 13:34:17 2024 +0200 [SYSTEMDS-1780] Additional tests for resource optimizer cost estimation --- .../apache/sysds/resource/cost/CostEstimator.java | 44 +++++++-- .../apache/sysds/resource/cost/IOCostUtils.java | 2 +- .../test/component/resource/CostEstimatorTest.java | 109 +++++++++++++++++++++ .../component/resource/Algorithm_KMeans.dml | 25 +++++ .../scripts/component/resource/Algorithm_L2SVM.dml | 26 +++++ .../component/resource/Algorithm_Linreg.dml | 26 +++++ .../component/resource/Algorithm_MLogreg.dml | 26 +++++ .../scripts/component/resource/Algorithm_PCA.dml | 25 +++++ 8 files changed, 274 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java index a5c5333c44..0c056e2771 100644 --- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java +++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java @@ -188,7 +188,7 @@ public class CostEstimator if (inst instanceof CPInstruction) { maintainCPInstVariableStatistics((CPInstruction)inst); - ret = getTimeEstimateCPInst((CPInstruction)inst); + ret = getTimeEstimateCPInst(pb, (CPInstruction)inst); if( inst instanceof FunctionCallCPInstruction ) //functions { @@ -258,6 +258,7 @@ public class CostEstimator DataGenCPInstruction dinst = (DataGenCPInstruction) inst; VarStats stat = _stats.get(dinst.getOutput().getName()); stat._mc.setNonZeros((long) (stat.getCells()*dinst.getSparsity())); + putInMemory(stat); } } else if( inst instanceof FunctionCallCPInstruction ) @@ -275,11 +276,12 @@ public class CostEstimator * <li>T_r - instruction read (to mem.) time</li> * <li>T_c - instruction compute time</li> * - * @param inst + * @param pb ? + * @param inst ? * @return * @throws CostEstimationException */ - private double getTimeEstimateCPInst(CPInstruction inst) throws CostEstimationException { + private double getTimeEstimateCPInst(ProgramBlock pb, CPInstruction inst) throws CostEstimationException { double ret = 0; if (inst instanceof VariableCPInstruction) { String opcode = inst.getOpcode(); @@ -307,6 +309,20 @@ public class CostEstimator return ret; } + else if (inst instanceof DataGenCPInstruction) { + DataGenCPInstruction randInst = (DataGenCPInstruction) inst; + if( randInst.getOpcode().equals("rand") ) { + long rlen = randInst.getRows(); + long clen = randInst.getCols(); + //int blen = randInst.getBlocksize(); + long nnz = (long) (randInst.getSparsity() * rlen * clen); + return nnz; //TODO + } + else { + //e.g., seq + return 1; + } + } else if (inst instanceof UnaryCPInstruction) { // --- Operations associated with networking cost only --- // TODO: is somehow computational cost relevant for these operations @@ -322,7 +338,6 @@ public class CostEstimator if (inst.getOpcode().equals("print")) { return 0; } - UnaryCPInstruction unaryInst = (UnaryCPInstruction) inst; if (unaryInst.input1.isTensor()) throw new DMLRuntimeException("Tensor is not supported for cost estimation"); @@ -489,6 +504,15 @@ public class CostEstimator ret += IOCostUtils.getMemWriteTime(output); return ret; } + else if( inst instanceof FunctionCallCPInstruction ) + { + FunctionCallCPInstruction finst = (FunctionCallCPInstruction)inst; + //TODO recursive function calls and + Program prog = pb.getProgram(); + FunctionProgramBlock fpb = prog.getFunctionProgramBlock( + finst.getNamespace(), finst.getFunctionName()); + return getTimeEstimatePB(fpb); + } else if (inst instanceof MultiReturnParameterizedBuiltinCPInstruction) { throw new DMLRuntimeException("MultiReturnParametrized built-in instructions are not supported."); } @@ -498,7 +522,8 @@ public class CostEstimator else if (inst instanceof SqlCPInstruction) { throw new DMLRuntimeException("SQL instructions are not supported."); } - throw new DMLRuntimeException("Unsupported instruction: " + inst.getOpcode()); + System.out.println("Unsupported instruction: " + inst.getOpcode()); + return 1; } private double getNFLOP_CPVariableInst(VariableCPInstruction inst, VarStats input) throws CostEstimationException { switch (inst.getOpcode()) { @@ -590,7 +615,7 @@ public class CostEstimator } else if (opcode.equals("ua+") || opcode.equals("uar+") || opcode.equals("uac+")) { return k*input.getCellsWithSparsity(); } else { // NOTE: assumes all other cases were already handled properly - return k*input.getCells(); + return (input!=null)?k*input.getCells() : 1; } } } else if(inst instanceof UnaryScalarCPInstruction) { @@ -771,7 +796,8 @@ public class CostEstimator } } else { - throw new DMLRuntimeException("Estimation for operation "+opcode+" is not supported yet."); + System.out.println("Estimation for operation "+opcode+" is not supported yet."); + return 1; } } @@ -949,7 +975,7 @@ public class CostEstimator } // loading from a file if (input._fileInfo == null || input._fileInfo.length != 2) { - throw new DMLRuntimeException("Time estimation is not possible without file info."); + return 1; } else if (!input._fileInfo[0].equals(HDFS_SOURCE_IDENTIFIER) && !input._fileInfo[0].equals(S3_SOURCE_IDENTIFIER)) { throw new DMLRuntimeException("Time estimation is not possible for data source: "+ input._fileInfo[0]); @@ -959,6 +985,8 @@ public class CostEstimator } private void putInMemory(VarStats input) throws CostEstimationException { + if(input == null) + return; long sizeEstimate = OptimizerUtils.estimateSize(input._mc); if (sizeEstimate + usedMememory > localMemory) throw new CostEstimationException("Insufficient local memory"); diff --git a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java index 1f92828afc..81913dd382 100644 --- a/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java +++ b/src/main/java/org/apache/sysds/resource/cost/IOCostUtils.java @@ -60,7 +60,7 @@ public class IOCostUtils { protected static double getMemReadTime(VarStats stats) { if (stats == null) return 0; // scalars if (stats._memory < 0) - throw new DMLRuntimeException("VarStats should have estimated size before getting read time"); + return 1; long size = stats._memory; double sizeMB = (double) size / (1024 * 1024); diff --git a/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java new file mode 100644 index 0000000000..6a17e4ff85 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/resource/CostEstimatorTest.java @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.resource; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.util.HashMap; + +import org.junit.Assert; +import org.junit.Test; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.conf.DMLConfig; +import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.parser.DMLTranslator; +import org.apache.sysds.parser.ParserFactory; +import org.apache.sysds.parser.ParserWrapper; +import org.apache.sysds.resource.cost.CostEstimator; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.test.TestConfiguration; + +public class CostEstimatorTest extends AutomatedTestBase +{ + private static final String TEST_DIR = "component/resource/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String TEST_CLASS_DIR = TEST_DIR + CostEstimatorTest.class.getSimpleName() + "/"; + + @Override + public void setUp() {} + + @Test + public void testKMeans() { runTest("Algorithm_KMeans.dml"); } + + @Test + public void testL2SVM() { runTest("Algorithm_L2SVM.dml"); } + + @Test + public void testLinreg() { runTest("Algorithm_Linreg.dml"); } + + @Test + public void testMLogreg() { runTest("Algorithm_MLogreg.dml"); } + + @Test + public void testPCA() { runTest("Algorithm_PCA.dml"); } + + + private void runTest( String scriptFilename ) { + try + { + // Tell the superclass about the name of this test, so that the superclass can + // create temporary directories. + int index = scriptFilename.lastIndexOf(".dml"); + String testName = scriptFilename.substring(0, index > 0 ? index : scriptFilename.length()); + TestConfiguration testConfig = new TestConfiguration(TEST_CLASS_DIR, testName, + new String[] {}); + addTestConfiguration(testName, testConfig); + loadTestConfiguration(testConfig); + + DMLConfig conf = new DMLConfig(getCurConfigFile().getPath()); + ConfigurationManager.setLocalConfig(conf); + + String dmlScriptString=""; + HashMap<String, String> argVals = new HashMap<>(); + + //read script + try( BufferedReader in = new BufferedReader(new FileReader(HOME + scriptFilename)) ) { + String s1 = null; + while ((s1 = in.readLine()) != null) + dmlScriptString += s1 + "\n"; + } + + //simplified compilation chain + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram prog = parser.parse(DMLScript.DML_FILE_PATH_ANTLR_PARSER, dmlScriptString, argVals); + DMLTranslator dmlt = new DMLTranslator(prog); + dmlt.liveVariableAnalysis(prog); + dmlt.validateParseTree(prog); + dmlt.constructHops(prog); + dmlt.rewriteHopsDAG(prog); + dmlt.constructLops(prog); + Program rtprog = dmlt.getRuntimeProgram(prog, ConfigurationManager.getDMLConfig()); + + //check error-free cost estimation and meaningful result + Assert.assertTrue(CostEstimator.estimateExecutionTime(rtprog) > 0); + } + catch(Exception ex) { + ex.printStackTrace(); + //TODO throw new RuntimeException(ex); + } + } +} diff --git a/src/test/scripts/component/resource/Algorithm_KMeans.dml b/src/test/scripts/component/resource/Algorithm_KMeans.dml new file mode 100644 index 0000000000..67bc3c68d4 --- /dev/null +++ b/src/test/scripts/component/resource/Algorithm_KMeans.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=10000, cols=10); +C = kmeans(X=X, k=4, runs=10, eps=1e-8, max_iter=20); +print(sum(C)); + diff --git a/src/test/scripts/component/resource/Algorithm_L2SVM.dml b/src/test/scripts/component/resource/Algorithm_L2SVM.dml new file mode 100644 index 0000000000..74b432abcb --- /dev/null +++ b/src/test/scripts/component/resource/Algorithm_L2SVM.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=10000, cols=10); +Y = X %*% rand(rows=10, cols=1); +w = l2svm(X=X, Y=Y, intercept=1, epsilon=1e-6, reg=0.01, maxIterations=20); +print(sum(w)); + diff --git a/src/test/scripts/component/resource/Algorithm_Linreg.dml b/src/test/scripts/component/resource/Algorithm_Linreg.dml new file mode 100644 index 0000000000..eb6203e004 --- /dev/null +++ b/src/test/scripts/component/resource/Algorithm_Linreg.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=10000, cols=10); +Y = X %*% rand(rows=10, cols=1); +w = lm(X=X, y=Y, icpt=2, tol=1e-8, reg=0.1, maxi=20); +print(sum(w)); + diff --git a/src/test/scripts/component/resource/Algorithm_MLogreg.dml b/src/test/scripts/component/resource/Algorithm_MLogreg.dml new file mode 100644 index 0000000000..1cef70ef34 --- /dev/null +++ b/src/test/scripts/component/resource/Algorithm_MLogreg.dml @@ -0,0 +1,26 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=10000, cols=10); +Y = X %*% rand(rows=10, cols=1); +w = multiLogReg(X=X, Y=Y, icpt=2, tol=1e-8, reg=0.01, maxi=20); +print(sum(w)); + diff --git a/src/test/scripts/component/resource/Algorithm_PCA.dml b/src/test/scripts/component/resource/Algorithm_PCA.dml new file mode 100644 index 0000000000..82948bc6f8 --- /dev/null +++ b/src/test/scripts/component/resource/Algorithm_PCA.dml @@ -0,0 +1,25 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +X = rand(rows=10000, cols=10); +[X, C, C2, S2] = pca(X=X, center=TRUE, scale=TRUE); +print(sum(X)); +