This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 321174ed8e [SYSTEMDS-1780] Resource elasticity: basic enumerators and recompiler 321174ed8e is described below commit 321174ed8eff4fd7c0ae53f83258b88057d8ad80 Author: lachezar-n <lachezar.nikolo...@gmail.com> AuthorDate: Wed Aug 21 19:45:59 2024 +0200 [SYSTEMDS-1780] Resource elasticity: basic enumerators and recompiler Closes #2067. --- .../java/org/apache/sysds/conf/CompilerConfig.java | 5 +- .../java/org/apache/sysds/hops/AggBinaryOp.java | 4 +- src/main/java/org/apache/sysds/hops/Hop.java | 4 + .../java/org/apache/sysds/hops/OptimizerUtils.java | 17 +- .../java/org/apache/sysds/lops/compile/Dag.java | 6 + .../java/org/apache/sysds/resource/AWSUtils.java | 66 +++ .../org/apache/sysds/resource/CloudInstance.java | 99 ++++ .../java/org/apache/sysds/resource/CloudUtils.java | 132 +++++ .../apache/sysds/resource/ResourceCompiler.java | 248 ++++++++++ .../apache/sysds/resource/cost/CostEstimator.java | 5 +- .../resource/enumeration/EnumerationUtils.java | 113 +++++ .../sysds/resource/enumeration/Enumerator.java | 522 ++++++++++++++++++++ .../resource/enumeration/GridBasedEnumerator.java | 89 ++++ .../enumeration/InterestBasedEnumerator.java | 315 ++++++++++++ .../context/SparkExecutionContext.java | 22 +- .../parfor/stat/InfrastructureAnalyzer.java | 8 +- .../test/component/resource/CloudUtilsTests.java | 118 +++++ .../test/component/resource/EnumeratorTests.java | 529 +++++++++++++++++++++ .../test/component/resource/RecompilationTest.java | 258 ++++++++++ .../test/component/resource/TestingUtils.java | 71 +++ src/test/scripts/component/resource/data/A.csv | 0 src/test/scripts/component/resource/data/A.csv.mtd | 10 + src/test/scripts/component/resource/data/B.csv | 0 src/test/scripts/component/resource/data/B.csv.mtd | 10 + src/test/scripts/component/resource/data/C.csv | 0 src/test/scripts/component/resource/data/C.csv.mtd | 10 + src/test/scripts/component/resource/data/D.csv | 0 src/test/scripts/component/resource/data/D.csv.mtd | 10 + src/test/scripts/component/resource/mm_test.dml | 34 ++ .../component/resource/mm_transpose_test.dml | 32 ++ 30 files changed, 2701 insertions(+), 36 deletions(-) diff --git a/src/main/java/org/apache/sysds/conf/CompilerConfig.java b/src/main/java/org/apache/sysds/conf/CompilerConfig.java index 9728bd2769..800fe575c9 100644 --- a/src/main/java/org/apache/sysds/conf/CompilerConfig.java +++ b/src/main/java/org/apache/sysds/conf/CompilerConfig.java @@ -78,7 +78,9 @@ public class CompilerConfig CODEGEN_ENABLED, //federated runtime conversion - FEDERATED_RUNTIME; + FEDERATED_RUNTIME, + // resource optimization mode + RESOURCE_OPTIMIZATION; } //default flags (exposed for testing purposes only) @@ -107,6 +109,7 @@ public class CompilerConfig _bmap.put(ConfigType.REJECT_READ_WRITE_UNKNOWNS, true); _bmap.put(ConfigType.MLCONTEXT, false); _bmap.put(ConfigType.CODEGEN_ENABLED, false); + _bmap.put(ConfigType.RESOURCE_OPTIMIZATION, false); _imap = new HashMap<>(); _imap.put(ConfigType.BLOCK_SIZE, OptimizerUtils.DEFAULT_BLOCKSIZE); diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java index 3b09984179..640ababbad 100644 --- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java +++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java @@ -1229,7 +1229,7 @@ public class AggBinaryOp extends MultiThreadedHop { double m2_size = m2_rows * m2_cols; double result_size = m1_rows * m2_cols; - int numReducersRMM = OptimizerUtils.getNumReducers(true); + int numReducersRMM = OptimizerUtils.getNumTasks(); // Estimate the cost of RMM // RMM phase 1 @@ -1256,7 +1256,7 @@ public class AggBinaryOp extends MultiThreadedHop { double m2_size = m2_rows * m2_cols; double result_size = m1_rows * m2_cols; - int numReducersCPMM = OptimizerUtils.getNumReducers(false); + int numReducersCPMM = OptimizerUtils.getNumTasks(); // Estimate the cost of CPMM // CPMM phase 1 diff --git a/src/main/java/org/apache/sysds/hops/Hop.java b/src/main/java/org/apache/sysds/hops/Hop.java index 93501efa0d..44930604b3 100644 --- a/src/main/java/org/apache/sysds/hops/Hop.java +++ b/src/main/java/org/apache/sysds/hops/Hop.java @@ -37,6 +37,7 @@ import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.hops.cost.ComputeCost; import org.apache.sysds.hops.recompile.Recompiler; @@ -281,6 +282,9 @@ public abstract class Hop implements ParseInfo { } else if ( DMLScript.getGlobalExecMode() == ExecMode.SPARK ) _etypeForced = ExecType.SPARK; // enabled with -exec spark option + else if ( DMLScript.getGlobalExecMode() == ExecMode.HYBRID + && ConfigurationManager.getCompilerConfigFlag(ConfigType.RESOURCE_OPTIMIZATION)) + _etypeForced = null; } public void checkAndSetInvalidCPDimsAndSize() diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java index 0a37570ee8..9e787de8c0 100644 --- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java +++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java @@ -658,20 +658,13 @@ public class OptimizerUtils } /** - * Returns the number of reducers that potentially run in parallel. + * Returns the number of tasks that potentially run in parallel. * This is either just the configured value (SystemDS config) or - * the minimum of configured value and available reduce slots. - * - * @param configOnly true if configured value - * @return number of reducers + * the minimum of configured value and available task slots. + * + * @return number of tasks */ - public static int getNumReducers( boolean configOnly ) { - if( isSparkExecutionMode() ) - return SparkExecutionContext.getDefaultParallelism(false); - return InfrastructureAnalyzer.getLocalParallelism(); - } - - public static int getNumMappers() { + public static int getNumTasks() { if( isSparkExecutionMode() ) return SparkExecutionContext.getDefaultParallelism(false); return InfrastructureAnalyzer.getLocalParallelism(); diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index b26c539e9a..f67cb74cd3 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -147,6 +147,12 @@ public class Dag<N extends Lop> dt.isFrame() ? Lop.FRAME_VAR_NAME_PREFIX : Lop.SCALAR_VAR_NAME_PREFIX) + var_index.getNextID(); } + + // to be used only resource optimization + public static void resetUniqueMembers() { + job_id.reset(-1); + var_index.reset(-1); + } /////// // Dag modifications diff --git a/src/main/java/org/apache/sysds/resource/AWSUtils.java b/src/main/java/org/apache/sysds/resource/AWSUtils.java new file mode 100644 index 0000000000..7d3ab6409a --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/AWSUtils.java @@ -0,0 +1,66 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource; + +import org.apache.sysds.resource.enumeration.EnumerationUtils; + +public class AWSUtils extends CloudUtils { + public static final String EC2_REGEX = "^([a-z]+)([0-9])(a|g|i?)([bdnez]*)\\.([a-z0-9]+)$"; + @Override + public boolean validateInstanceName(String input) { + String instanceName = input.toLowerCase(); + if (!instanceName.toLowerCase().matches(EC2_REGEX)) return false; + try { + getInstanceType(instanceName); + getInstanceSize(instanceName); + } catch (IllegalArgumentException e) { + return false; + } + return true; + } + + @Override + public InstanceType getInstanceType(String instanceName) { + String typeAsString = instanceName.split("\\.")[0]; + // throws exception if string value is not valid + return InstanceType.customValueOf(typeAsString); + } + + @Override + public InstanceSize getInstanceSize(String instanceName) { + String sizeAsString = instanceName.split("\\.")[1]; + // throws exception if string value is not valid + return InstanceSize.customValueOf(sizeAsString); + } + + @Override + public double calculateClusterPrice(EnumerationUtils.ConfigurationPoint config, double time) { + double pricePerSeconds = getClusterCostPerHour(config) / 3600; + return time * pricePerSeconds; + } + + private double getClusterCostPerHour(EnumerationUtils.ConfigurationPoint config) { + if (config.numberExecutors == 0) { + return config.driverInstance.getPrice(); + } + return config.driverInstance.getPrice() + + config.executorInstance.getPrice()*config.numberExecutors; + } +} diff --git a/src/main/java/org/apache/sysds/resource/CloudInstance.java b/src/main/java/org/apache/sysds/resource/CloudInstance.java new file mode 100644 index 0000000000..d740e35b80 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/CloudInstance.java @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource; + +/** + * This class describes the configurations of a single VM instance. + * The idea is to use this class to represent instances of different + * cloud hypervisors - currently supporting only EC2 instances by AWS. + */ +public class CloudInstance { + private final String instanceName; + private final long memory; + private final int vCPUCores; + private final double pricePerHour; + private final double gFlops; + private final double memorySpeed; + private final double diskSpeed; + private final double networkSpeed; + public CloudInstance(String instanceName, long memory, int vCPUCores, double gFlops, double memorySpeed, double diskSpeed, double networkSpeed, double pricePerHour) { + this.instanceName = instanceName; + this.memory = memory; + this.vCPUCores = vCPUCores; + this.gFlops = gFlops; + this.memorySpeed = memorySpeed; + this.diskSpeed = diskSpeed; + this.networkSpeed = networkSpeed; + this.pricePerHour = pricePerHour; + } + + public String getInstanceName() { + return instanceName; + } + + /** + * @return memory of the instance in B + */ + public long getMemory() { + return memory; + } + + /** + * @return number of virtual CPU cores of the instance + */ + public int getVCPUs() { + return vCPUCores; + } + + /** + * @return price per hour of the instance + */ + public double getPrice() { + return pricePerHour; + } + + /** + * @return number of FLOPS of the instance + */ + public long getFLOPS() { + return (long) (gFlops*1024)*1024*1024; + } + + /** + * @return memory speed/bandwidth of the instance in MB/s + */ + public double getMemorySpeed() { + return memorySpeed; + } + + /** + * @return isk speed/bandwidth of the instance in MB/s + */ + public double getDiskSpeed() { + return diskSpeed; + } + + /** + * @return network speed/bandwidth of the instance in MB/s + */ + public double getNetworkSpeed() { + return networkSpeed; + } +} diff --git a/src/main/java/org/apache/sysds/resource/CloudUtils.java b/src/main/java/org/apache/sysds/resource/CloudUtils.java new file mode 100644 index 0000000000..3a1e38f422 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/CloudUtils.java @@ -0,0 +1,132 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource; + +import org.apache.sysds.resource.enumeration.EnumerationUtils; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.HashMap; + +public abstract class CloudUtils { + public enum CloudProvider { + AWS // potentially AZURE, GOOGLE + } + public enum InstanceType { + // AWS EC2 instance + M5, M5A, M6I, M6A, M6G, M7I, M7A, M7G, // general purpose - vCores:mem~=1:4 + C5, C5A, C6I, C6A, C6G, C7I, C7A, C7G, // compute optimized - vCores:mem~=1:2 + R5, R5A, R6I, R6A, R6G, R7I, R7A, R7G; // memory optimized - vCores:mem~=1:8 + // Potentially VM instance types for different Cloud providers + + public static InstanceType customValueOf(String name) { + return InstanceType.valueOf(name.toUpperCase()); + } + } + + public enum InstanceSize { + _XLARGE, _2XLARGE, _4XLARGE, _8XLARGE, _12XLARGE, _16XLARGE, _24XLARGE, _32XLARGE, _48XLARGE; + // Potentially VM instance sizes for different Cloud providers + + public static InstanceSize customValueOf(String name) { + return InstanceSize.valueOf("_"+name.toUpperCase()); + } + } + + public static final double MINIMAL_EXECUTION_TIME = 120; // seconds; NOTE: set always equal or higher than DEFAULT_CLUSTER_LAUNCH_TIME + + public static final double DEFAULT_CLUSTER_LAUNCH_TIME = 120; // seconds; NOTE: set always to at least 60 seconds + + public static long GBtoBytes(double gb) { + return (long) (gb * 1024 * 1024 * 1024); + } + public abstract boolean validateInstanceName(String instanceName); + public abstract InstanceType getInstanceType(String instanceName); + public abstract InstanceSize getInstanceSize(String instanceName); + + /** + * This method calculates the cluster price based on the + * estimated execution time and the cluster configuration. + * @param config the cluster configuration for the calculation + * @param time estimated execution time in seconds + * @return price for the given time + */ + public abstract double calculateClusterPrice(EnumerationUtils.ConfigurationPoint config, double time); + + /** + * Performs read of csv file filled with VM instance characteristics. + * Each record in the csv should carry the following information (including header): + * <li>API_Name - naming for VM instance used by the provider</li> + * <li>Memory - floating number for the instance memory in GBs</li> + * <li>vCPUs - number of physical threads</li> + * <li>gFlops - FLOPS capability of the CPU in GFLOPS (Giga)</li> + * <li>ramSpeed - memory bandwidth in MB/s</li> + * <li>diskSpeed - memory bandwidth in MB/s</li> + * <li>networkSpeed - memory bandwidth in MB/s</li> + * <li>Price - price for instance per hour</li> + * @param instanceTablePath csv file + * @return map with filtered instances + * @throws IOException in case problem at reading the csv file + */ + public HashMap<String, CloudInstance> loadInstanceInfoTable(String instanceTablePath) throws IOException { + HashMap<String, CloudInstance> result = new HashMap<>(); + int lineCount = 1; + // try to open the file + try(BufferedReader br = new BufferedReader(new FileReader(instanceTablePath))){ + String parsedLine; + // validate the file header + parsedLine = br.readLine(); + if (!parsedLine.equals("API_Name,Memory,vCPUs,gFlops,ramSpeed,diskSpeed,networkSpeed,Price")) + throw new IOException("Invalid CSV header inside: " + instanceTablePath); + + + while ((parsedLine = br.readLine()) != null) { + String[] values = parsedLine.split(","); + if (values.length != 8 || !validateInstanceName(values[0])) + throw new IOException(String.format("Invalid CSV line(%d) inside: %s", lineCount, instanceTablePath)); + + String API_Name = values[0]; + long Memory = (long) (Double.parseDouble(values[1])*1024)*1024*1024; + int vCPUs = Integer.parseInt(values[2]); + double gFlops = Double.parseDouble(values[3]); + double ramSpeed = Double.parseDouble(values[4]); + double diskSpeed = Double.parseDouble(values[5]); + double networkSpeed = Double.parseDouble(values[6]); + double Price = Double.parseDouble(values[7]); + + CloudInstance parsedInstance = new CloudInstance( + API_Name, + Memory, + vCPUs, + gFlops, + ramSpeed, + diskSpeed, + networkSpeed, + Price + ); + result.put(API_Name, parsedInstance); + lineCount++; + } + } + + return result; + } +} diff --git a/src/main/java/org/apache/sysds/resource/ResourceCompiler.java b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java new file mode 100644 index 0000000000..ab5c452b56 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource; + +import org.apache.spark.SparkConf; +import org.apache.sysds.api.DMLOptions; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.common.Types; +import org.apache.sysds.conf.CompilerConfig; +import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.recompile.Recompiler; +import org.apache.sysds.lops.Lop; +import org.apache.sysds.lops.compile.Dag; +import org.apache.sysds.lops.rewrite.LopRewriter; +import org.apache.sysds.parser.*; +import org.apache.sysds.runtime.controlprogram.*; +import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.instructions.Instruction; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.apache.sysds.api.DMLScript.*; + +/** + * This class does full or partial program recompilation + * based on given runtime program. It uses the methods provided + * by {@code hops.recompile.Recompiler}). + * It keeps a state of the current recompilation phase in order + * to decide when to do full recompilation and when not. + */ +public class ResourceCompiler { + public static final long DEFAULT_DRIVER_MEMORY = 512*1024*1024; // 0.5GB + public static final int DEFAULT_DRIVER_THREADS = 1; // 0.5GB + public static final long DEFAULT_EXECUTOR_MEMORY = 512*1024*1024; // 0.5GB + public static final int DEFAULT_EXECUTOR_THREADS = 1; // 0.5GB + public static final int DEFAULT_NUMBER_EXECUTORS = 1; // 0.5GB + static { + // TODO: consider moving to the executable of the resource optimizer once implemented + USE_LOCAL_SPARK_CONFIG = true; + ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.ALLOW_DYN_RECOMPILATION, false); + ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.RESOURCE_OPTIMIZATION, true); + } + private static final LopRewriter _lopRewriter = new LopRewriter(); + + public static Program compile(String filePath, Map<String, String> args) throws IOException { + // setting the dynamic recompilation flags during resource optimization is obsolete + DMLOptions dmlOptions =DMLOptions.defaultOptions; + dmlOptions.argVals = args; + + String dmlScriptStr = readDMLScript(true, filePath); + Map<String, String> argVals = dmlOptions.argVals; + + Dag.resetUniqueMembers(); + // NOTE: skip configuring code generation + // NOTE: expects setting up the initial cluster configs before calling + ParserWrapper parser = ParserFactory.createParser(); + DMLProgram dmlProgram = parser.parse(null, dmlScriptStr, argVals); + DMLTranslator dmlTranslator = new DMLTranslator(dmlProgram); + dmlTranslator.liveVariableAnalysis(dmlProgram); + dmlTranslator.validateParseTree(dmlProgram); + dmlTranslator.constructHops(dmlProgram); + dmlTranslator.rewriteHopsDAG(dmlProgram); + dmlTranslator.constructLops(dmlProgram); + dmlTranslator.rewriteLopDAG(dmlProgram); + return dmlTranslator.getRuntimeProgram(dmlProgram, ConfigurationManager.getDMLConfig()); + } + + private static ArrayList<Instruction> recompile(StatementBlock sb, ArrayList<Hop> hops) { + // construct new lops + ArrayList<Lop> lops = new ArrayList<>(hops.size()); + Hop.resetVisitStatus(hops); + for( Hop hop : hops ){ + Recompiler.rClearLops(hop); + lops.add(hop.constructLops()); + } + // apply hop-lop rewrites to cover the case of changed lop operators + _lopRewriter.rewriteLopDAG(sb, lops); + + Dag<Lop> dag = new Dag<>(); + for (Lop l : lops) { + l.addToDag(dag); + } + + return dag.getJobs(sb, ConfigurationManager.getDMLConfig()); + } + + /** + * Recompiling a given program for resource optimization for single node execution + * @param program program to be recompiled + * @param driverMemory target driver memory + * @param driverCores target driver threads/cores + * @return the recompiled program as a new {@code Program} instance + */ + public static Program doFullRecompilation(Program program, long driverMemory, int driverCores) { + setDriverConfigurations(driverMemory, driverCores); + setSingleNodeExecution(); + return doFullRecompilation(program); + } + + /** + * Recompiling a given program for resource optimization for Spark execution + * @param program program to be recompiled + * @param driverMemory target driver memory + * @param driverCores target driver threads/cores + * @param numberExecutors target number of executor nodes + * @param executorMemory target executor memory + * @param executorCores target executor threads/cores + * @return the recompiled program as a new {@code Program} instance + */ + public static Program doFullRecompilation(Program program, long driverMemory, int driverCores, int numberExecutors, long executorMemory, int executorCores) { + setDriverConfigurations(driverMemory, driverCores); + setExecutorConfigurations(numberExecutors, executorMemory, executorCores); + return doFullRecompilation(program); + } + + private static Program doFullRecompilation(Program program) { + Dag.resetUniqueMembers(); + Program newProgram = new Program(); + ArrayList<ProgramBlock> B = Stream.concat( + program.getProgramBlocks().stream(), + program.getFunctionProgramBlocks().values().stream()) + .collect(Collectors.toCollection(ArrayList::new)); + doRecompilation(B, newProgram); + return newProgram; + } + + private static void doRecompilation(ArrayList<ProgramBlock> origin, Program target) { + for (ProgramBlock originBlock : origin) { + doRecompilation(originBlock, target); + } + } + + private static void doRecompilation(ProgramBlock originBlock, Program target) { + if (originBlock instanceof FunctionProgramBlock) + { + FunctionProgramBlock fpb = (FunctionProgramBlock)originBlock; + doRecompilation(fpb.getChildBlocks(), target); + } + else if (originBlock instanceof WhileProgramBlock) + { + WhileProgramBlock wpb = (WhileProgramBlock)originBlock; + WhileStatementBlock sb = (WhileStatementBlock) originBlock.getStatementBlock(); + if(sb!=null && sb.getPredicateHops()!=null ){ + ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), null, null, true, true, 0); + wpb.setPredicate(inst); + target.addProgramBlock(wpb); + } + doRecompilation(wpb.getChildBlocks(), target); + } + else if (originBlock instanceof IfProgramBlock) + { + IfProgramBlock ipb = (IfProgramBlock)originBlock; + IfStatementBlock sb = (IfStatementBlock) ipb.getStatementBlock(); + if(sb!=null && sb.getPredicateHops()!=null ){ + ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getPredicateHops(), null, null, true, true, 0); + ipb.setPredicate(inst); + target.addProgramBlock(ipb); + } + doRecompilation(ipb.getChildBlocksIfBody(), target); + doRecompilation(ipb.getChildBlocksElseBody(), target); + } + else if (originBlock instanceof ForProgramBlock) //incl parfor + { + ForProgramBlock fpb = (ForProgramBlock)originBlock; + ForStatementBlock sb = (ForStatementBlock) fpb.getStatementBlock(); + if(sb!=null){ + if( sb.getFromHops()!=null ){ + ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getFromHops(), null, null, true, true, 0); + fpb.setFromInstructions( inst ); + } + if(sb.getToHops()!=null){ + ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getToHops(), null, null, true, true, 0); + fpb.setToInstructions( inst ); + } + if(sb.getIncrementHops()!=null){ + ArrayList<Instruction> inst = Recompiler.recompileHopsDag(sb.getIncrementHops(), null, null, true, true, 0); + fpb.setIncrementInstructions(inst); + } + target.addProgramBlock(fpb); + + } + doRecompilation(fpb.getChildBlocks(), target); + } + else + { + BasicProgramBlock bpb = (BasicProgramBlock)originBlock; + StatementBlock sb = bpb.getStatementBlock(); + ArrayList<Instruction> inst = recompile(sb, sb.getHops()); + bpb.setInstructions(inst); + target.addProgramBlock(bpb); + } + } + + public static void setDriverConfigurations(long nodeMemory, int nodeNumCores) { + // TODO: think of reasonable factor for the JVM heap as prt of the node's memory + InfrastructureAnalyzer.setLocalMaxMemory(nodeMemory); + InfrastructureAnalyzer.setLocalPar(nodeNumCores); + } + + public static void setExecutorConfigurations(int numExecutors, long nodeMemory, int nodeNumCores) { + // TODO: think of reasonable factor for the JVM heap as prt of the node's memory + if (numExecutors > 0) { + DMLScript.setGlobalExecMode(Types.ExecMode.HYBRID); + SparkConf sparkConf = SparkExecutionContext.createSystemDSSparkConf(); + // ------------------ Static Configurations ------------------- + // TODO: think how to avoid setting them every time + sparkConf.set("spark.master", "local[*]"); + sparkConf.set("spark.app.name", "SystemDS"); + sparkConf.set("spark.memory.useLegacyMode", "false"); + // ------------------ Static Configurations ------------------- + // ------------------ Dynamic Configurations ------------------- + sparkConf.set("spark.executor.memory", (nodeMemory/(1024*1024))+"m"); + sparkConf.set("spark.executor.instances", Integer.toString(numExecutors)); + sparkConf.set("spark.executor.cores", Integer.toString(nodeNumCores)); + // ------------------ Dynamic Configurations ------------------- + SparkExecutionContext.initLocalSparkContext(sparkConf); + } else { + throw new RuntimeException("The given number of executors was 0"); + } + } + + public static void setSingleNodeExecution() { + DMLScript.setGlobalExecMode(Types.ExecMode.SINGLE_NODE); + } +} diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java index f2787c50d3..a5c5333c44 100644 --- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java +++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java @@ -161,7 +161,8 @@ public class CostEstimator ForProgramBlock tmp = (ForProgramBlock)pb; for( ProgramBlock pb2 : tmp.getChildBlocks() ) ret += getTimeEstimatePB(pb2); - + // NOTE: currently ParFor blocks are handled as regular for block + // what could lead to very inaccurate estimation in case of complex ParFor blocks ret *= OptimizerUtils.getNumIterations(tmp, DEFAULT_NUMITER); } else if ( pb instanceof FunctionProgramBlock ) { @@ -960,7 +961,7 @@ public class CostEstimator private void putInMemory(VarStats input) throws CostEstimationException { long sizeEstimate = OptimizerUtils.estimateSize(input._mc); if (sizeEstimate + usedMememory > localMemory) - throw new CostEstimationException("Insufficient local memory for "); + throw new CostEstimationException("Insufficient local memory"); usedMememory += sizeEstimate; input._memory = sizeEstimate; } diff --git a/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java new file mode 100644 index 0000000000..fa075dd9d9 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java @@ -0,0 +1,113 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource.enumeration; + +import org.apache.sysds.resource.CloudInstance; + +import java.util.HashMap; +import java.util.LinkedList; +import java.util.TreeMap; + +public class EnumerationUtils { + /** + * Data structure representing a projected search space for + * VM instances as node's memory mapped to further maps with + * the node's numbers of cores for the given memory + * mapped to a list of unique object of type {@code CloudInstance} + * which have this corresponding characteristics (memory and cores). + * The higher layer keep the memory since is more significant + * for the program compilation. The lower map level contains + * the different options for number of core for the memory that + * this map data structure is being mapped to. The last layer + * of LinkedLists represents the unique VM instances in lists + * since the memory - cores combinations is often not unique. + * The {@code CloudInstance} objects are unique over the whole + * set of lists within this lowest level of the search space. + * <br></br> + * This representation allows compact storing of VM instance + * characteristics relevant for program compilation while + * still keeping a reference to the object carrying the + * whole instance information, relevant for cost estimation. + * <br></br> + * {@code TreeMap} data structures are used as building blocks for + * the complex search space structure to ensure ascending order + * of the instance characteristics - memory and number of cores. + */ + public static class InstanceSearchSpace extends TreeMap<Long, TreeMap<Integer, LinkedList<CloudInstance>>> { + private static final long serialVersionUID = -8855424955793322839L; + + public void initSpace(HashMap<String, CloudInstance> instances) { + for (CloudInstance instance: instances.values()) { + long currentMemory = instance.getMemory(); + + this.putIfAbsent(currentMemory, new TreeMap<>()); + TreeMap<Integer, LinkedList<CloudInstance>> currentSubTree = this.get(currentMemory); + + currentSubTree.putIfAbsent(instance.getVCPUs(), new LinkedList<>()); + LinkedList<CloudInstance> currentList = currentSubTree.get(instance.getVCPUs()); + + currentList.add(instance); + } + } + } + + /** + * Simple data structure to hold cluster configurations + */ + public static class ConfigurationPoint { + public CloudInstance driverInstance; + public CloudInstance executorInstance; + public int numberExecutors; + + public ConfigurationPoint(CloudInstance driverInstance) { + this.driverInstance = driverInstance; + this.executorInstance = null; + this.numberExecutors = 0; + } + + public ConfigurationPoint(CloudInstance driverInstance, CloudInstance executorInstance, int numberExecutors) { + this.driverInstance = driverInstance; + this.executorInstance = executorInstance; + this.numberExecutors = numberExecutors; + } + } + + /** + * Data structure to hold all data related to cost estimation + */ + public static class SolutionPoint extends ConfigurationPoint { + double timeCost; + double monetaryCost; + + public SolutionPoint(ConfigurationPoint inputPoint, double timeCost, double monetaryCost) { + super(inputPoint.driverInstance, inputPoint.executorInstance, inputPoint.numberExecutors); + this.timeCost = timeCost; + this.monetaryCost = monetaryCost; + } + + public void update(ConfigurationPoint point, double timeCost, double monetaryCost) { + this.driverInstance = point.driverInstance; + this.executorInstance = point.executorInstance; + this.numberExecutors = point.numberExecutors; + this.timeCost = timeCost; + this.monetaryCost = monetaryCost; + } + } +} diff --git a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java new file mode 100644 index 0000000000..2147dfc368 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java @@ -0,0 +1,522 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + + +package org.apache.sysds.resource.enumeration; + +import org.apache.sysds.resource.AWSUtils; +import org.apache.sysds.resource.CloudInstance; +import org.apache.sysds.resource.CloudUtils; +import org.apache.sysds.resource.ResourceCompiler; +import org.apache.sysds.resource.cost.CostEstimationException; +import org.apache.sysds.resource.cost.CostEstimator; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace; +import org.apache.sysds.resource.enumeration.EnumerationUtils.ConfigurationPoint; +import org.apache.sysds.resource.enumeration.EnumerationUtils.SolutionPoint; + +import java.io.IOException; +import java.util.*; + +public abstract class Enumerator { + + public enum EnumerationStrategy { + GridBased, // considering all combination within a given range of configuration + InterestBased, // considering only combinations of configurations with memory budge close to memory estimates + } + + public enum OptimizationStrategy { + MinTime, // always prioritize execution time minimization + MinPrice, // always prioritize operation price minimization + } + + // Static variables ------------------------------------------------------------------------------------------------ + + public static final int DEFAULT_MIN_EXECUTORS = 0; // Single Node execution allowed + /** + * A reasonable upper bound for the possible number of executors + * is required to set limits for the search space and to avoid + * evaluating cluster configurations that most probably would + * have too high distribution overhead + */ + public static final int DEFAULT_MAX_EXECUTORS = 200; + + // limit for the ratio number of executor and number + // of executor per executor + public static final int MAX_LEVEL_PARALLELISM = 1000; + + /** Time/Monetary delta for considering optimal solutions as fraction */ + public static final double COST_DELTA_FRACTION = 0.02; + + // Instance variables ---------------------------------------------------------------------------------------------- + HashMap<String, CloudInstance> instances = null; + Program program; + CloudUtils utils; + EnumerationStrategy enumStrategy; + OptimizationStrategy optStrategy; + private final double maxTime; + private final double maxPrice; + protected final int minExecutors; + protected final int maxExecutors; + protected final Set<CloudUtils.InstanceType> instanceTypesRange; + protected final Set<CloudUtils.InstanceSize> instanceSizeRange; + + protected final InstanceSearchSpace driverSpace = new InstanceSearchSpace(); + protected final InstanceSearchSpace executorSpace = new InstanceSearchSpace(); + protected ArrayList<SolutionPoint> solutionPool = new ArrayList<>(); + + // Initialization functionality ------------------------------------------------------------------------------------ + + public Enumerator(Builder builder) { + if (builder.provider.equals(CloudUtils.CloudProvider.AWS)) { + utils = new AWSUtils(); + } // as of now no other provider is supported + this.program = builder.program; + this.enumStrategy = builder.enumStrategy; + this.optStrategy = builder.optStrategy; + this.maxTime = builder.maxTime; + this.maxPrice = builder.maxPrice; + this.minExecutors = builder.minExecutors; + this.maxExecutors = builder.maxExecutors; + this.instanceTypesRange = builder.instanceTypesRange; + this.instanceSizeRange = builder.instanceSizeRange; + } + + /** + * Meant to be used for testing purposes + */ + public HashMap<String, CloudInstance> getInstances() { + return instances; + } + + /** + * Meant to be used for testing purposes + */ + public InstanceSearchSpace getDriverSpace() { + return driverSpace; + } + + /** + * Meant to be used for testing purposes + */ + public void setDriverSpace(InstanceSearchSpace inputSpace) { + driverSpace.putAll(inputSpace); + } + + /** + * Meant to be used for testing purposes + */ + public InstanceSearchSpace getExecutorSpace() { + return executorSpace; + } + + /** + * Meant to be used for testing purposes + */ + public void setExecutorSpace(InstanceSearchSpace inputSpace) { + executorSpace.putAll(inputSpace); + } + + /** + * Meant to be used for testing purposes + */ + public ArrayList<SolutionPoint> getSolutionPool() { + return solutionPool; + } + + /** + * Meant to be used for testing purposes + */ + public void setSolutionPool(ArrayList<SolutionPoint> solutionPool) { + this.solutionPool = solutionPool; + } + + /** + * Setting the available VM instances manually. + * Meant to be used for testing purposes. + * @param inputInstances initialized map of instances + */ + public void setInstanceTable(HashMap<String, CloudInstance> inputInstances) { + instances = new HashMap<>(); + for (String key: inputInstances.keySet()) { + if (instanceTypesRange.contains(utils.getInstanceType(key)) + && instanceSizeRange.contains(utils.getInstanceSize(key))) { + instances.put(key, inputInstances.get(key)); + } + } + } + + /** + * Loads the info table for the available VM instances + * and filters out the instances that are not contained + * in the set of allowed instance types and sizes. + * + * @param path csv file with instances' info + * @throws IOException in case the loading part fails at reading the csv file + */ + public void loadInstanceTableFile(String path) throws IOException { + HashMap<String, CloudInstance> allInstances = utils.loadInstanceInfoTable(path); + instances = new HashMap<>(); + for (String key: allInstances.keySet()) { + if (instanceTypesRange.contains(utils.getInstanceType(key)) + && instanceSizeRange.contains(utils.getInstanceSize(key))) { + instances.put(key, allInstances.get(key)); + } + } + } + + // Main functionality ---------------------------------------------------------------------------------------------- + + /** + * Called once to enumerate the search space for + * VM instances for driver or executor nodes. + * These instances are being represented as + */ + public abstract void preprocessing(); + + /** + * Called once after preprocessing to fill the + * pool with optimal solutions by parsing + * the enumerated search space. + * Within its execution the number of potential + * executor nodes is being estimated (enumerated) + * dynamically for each parsed executor instance. + */ + public void processing() { + ConfigurationPoint configurationPoint; + SolutionPoint optSolutionPoint = new SolutionPoint( + new ConfigurationPoint(null, null, -1), + Double.MAX_VALUE, + Double.MAX_VALUE + ); + for (Map.Entry<Long, TreeMap<Integer, LinkedList<CloudInstance>>> dMemoryEntry: driverSpace.entrySet()) { + // loop over the search space to enumerate the driver configurations + for (Map.Entry<Integer, LinkedList<CloudInstance>> dCoresEntry: dMemoryEntry.getValue().entrySet()) { + // single node execution mode + if (evaluateSingleNodeExecution(dMemoryEntry.getKey())) { + program = ResourceCompiler.doFullRecompilation( + program, + dMemoryEntry.getKey(), + dCoresEntry.getKey() + ); + for (CloudInstance dInstance: dCoresEntry.getValue()) { + configurationPoint = new ConfigurationPoint(dInstance); + updateOptimalSolution(optSolutionPoint, configurationPoint); + } + } + // enumeration for distributed execution + for (Map.Entry<Long, TreeMap<Integer, LinkedList<CloudInstance>>> eMemoryEntry: executorSpace.entrySet()) { + // loop over the search space to enumerate the executor configurations + for (Map.Entry<Integer, LinkedList<CloudInstance>> eCoresEntry: eMemoryEntry.getValue().entrySet()) { + List<Integer> numberExecutorsSet = estimateRangeExecutors(eMemoryEntry.getKey(), eCoresEntry.getKey()); + // Spark execution mode + for (int numberExecutors: numberExecutorsSet) { + // TODO: avoid full recompilation when the driver memory is not changed + program = ResourceCompiler.doFullRecompilation( + program, + dMemoryEntry.getKey(), + dCoresEntry.getKey(), + numberExecutors, + eMemoryEntry.getKey(), + eCoresEntry.getKey() + ); + // TODO: avoid full program cost estimation when the driver instance is not changed + for (CloudInstance dInstance: dCoresEntry.getValue()) { + for (CloudInstance eInstance: eCoresEntry.getValue()) { + configurationPoint = new ConfigurationPoint(dInstance, eInstance, numberExecutors); + updateOptimalSolution(optSolutionPoint, configurationPoint); + } + } + } + } + } + } + } + } + + /** + * Deciding in the overall best solution out + * of the filled pool of potential solutions + * after processing. + * @return - single optimal cluster configuration + */ + public SolutionPoint postprocessing() { + if (solutionPool.isEmpty()) { + throw new RuntimeException("Calling postprocessing() should follow calling processing()"); + } + SolutionPoint optSolution = solutionPool.get(0); + double bestCost = Double.MAX_VALUE; + for (SolutionPoint solution: solutionPool) { + double combinedCost = solution.monetaryCost * solution.timeCost; + if (combinedCost < bestCost) { + optSolution = solution; + bestCost = combinedCost; + } else if (combinedCost == bestCost) { + // the ascending order of the searching spaces for driver and executor + // instances ensures that in case of equally good optimal solutions + // the first one has at least resource characteristics. + // This, however, is not valid for the number of executors + if (solution.numberExecutors < optSolution.numberExecutors) { + optSolution = solution; + bestCost = combinedCost; + } + } + } + return optSolution; + } + + // Helper methods -------------------------------------------------------------------------------------------------- + + public abstract boolean evaluateSingleNodeExecution(long driverMemory); + + /** + * Estimates the minimum and maximum number of + * executors based on given VM instance characteristics + * and on the enumeration strategy + * + * @param executorMemory memory of currently considered executor instance + * @param executorCores CPU of cores of currently considered executor instance + * @return - [min, max] + */ + public abstract ArrayList<Integer> estimateRangeExecutors(long executorMemory, int executorCores); + + /** + * Estimates the time cost for the current program based on the + * given cluster configurations and following this estimation + * it calculates the corresponding monetary cost. + * @param point - cluster configuration used for (re)compiling the current program + * @return - [time cost, monetary cost] + */ + private double[] getCostEstimate(ConfigurationPoint point) { + // get the estimated time cost + double timeCost; + try { + // estimate execution time of the current program + // TODO: pass further relevant cluster configurations to cost estimator after extending it + // like for example: FLOPS, I/O and networking speed + timeCost = CostEstimator.estimateExecutionTime(program) + CloudUtils.DEFAULT_CLUSTER_LAUNCH_TIME; + } catch (CostEstimationException e) { + throw new RuntimeException(e.getMessage()); + } + // calculate monetary cost + double monetaryCost = utils.calculateClusterPrice(point, timeCost); + return new double[] {timeCost, monetaryCost}; // time cost, monetary cost + } + + /** + * Invokes the estimation of the time and monetary cost + * based on the compiled program and the given cluster configurations. + * Following the optimization strategy, the given current optimal solution + * and the new cost estimation, it decides if the given cluster configuration + * can be potential optimal solution having lower cost or such a cost + * that is negligibly higher than the current lowest one. + * @param currentOptimal solution point with the lowest cost + * @param newPoint new cluster configuration for estimation + */ + private void updateOptimalSolution(SolutionPoint currentOptimal, ConfigurationPoint newPoint) { + // TODO: clarify if setting max time and max price simultaneously makes really sense + SolutionPoint newPotentialSolution; + boolean replaceCurrentOptimal = false; + double[] newCost = getCostEstimate(newPoint); + if (optStrategy == OptimizationStrategy.MinTime) { + if (newCost[1] > maxPrice || newCost[0] >= currentOptimal.timeCost * (1 + COST_DELTA_FRACTION)) { + return; + } + if (newCost[0] < currentOptimal.timeCost) replaceCurrentOptimal = true; + } else if (optStrategy == OptimizationStrategy.MinPrice) { + if (newCost[0] > maxTime || newCost[1] >= currentOptimal.monetaryCost * (1 + COST_DELTA_FRACTION)) { + return; + } + if (newCost[1] < currentOptimal.monetaryCost) replaceCurrentOptimal = true; + } + newPotentialSolution = new SolutionPoint(newPoint, newCost[0], newCost[1]); + solutionPool.add(newPotentialSolution); + if (replaceCurrentOptimal) { + currentOptimal.update(newPoint, newCost[0], newCost[1]); + } + } + + // Class builder --------------------------------------------------------------------------------------------------- + + public static class Builder { + private final CloudUtils.CloudProvider provider = CloudUtils.CloudProvider.AWS; // currently default and only choice + private Program program; + private EnumerationStrategy enumStrategy = null; + private OptimizationStrategy optStrategy = null; + private double maxTime = -1d; + private double maxPrice = -1d; + private int minExecutors = DEFAULT_MIN_EXECUTORS; + private int maxExecutors = DEFAULT_MAX_EXECUTORS; + private Set<CloudUtils.InstanceType> instanceTypesRange = null; + private Set<CloudUtils.InstanceSize> instanceSizeRange = null; + + // GridBased specific ------------------------------------------------------------------------------------------ + private int stepSizeExecutors = 1; + private int expBaseExecutors = -1; // flag for exp. increasing number of executors if -1 + // InterestBased specific -------------------------------------------------------------------------------------- + private boolean fitDriverMemory = true; + private boolean fitBroadcastMemory = true; + private boolean checkSingleNodeExecution = false; + private boolean fitCheckpointMemory = false; + public Builder() {} + + public Builder withRuntimeProgram(Program program) { + this.program = program; + return this; + } + + public Builder withEnumerationStrategy(EnumerationStrategy strategy) { + this.enumStrategy = strategy; + return this; + } + + public Builder withOptimizationStrategy(OptimizationStrategy strategy) { + this.optStrategy = strategy; + return this; + } + + public Builder withTimeLimit(double time) { + if (time < CloudUtils.MINIMAL_EXECUTION_TIME) { + throw new IllegalArgumentException(CloudUtils.MINIMAL_EXECUTION_TIME + + "s is the minimum target execution time."); + } + this.maxTime = time; + return this; + } + + public Builder withBudget(double price) { + if (price <= 0) { + throw new IllegalArgumentException("The given budget (target price) should be positive"); + } + this.maxPrice = price; + return this; + } + + public Builder withNumberExecutorsRange(int min, int max) { + this.minExecutors = min; + this.maxExecutors = max; + return this; + } + + public Builder withInstanceTypeRange(String[] instanceTypes) { + this.instanceTypesRange = typeRangeFromStrings(instanceTypes); + return this; + } + + public Builder withInstanceSizeRange(String[] instanceSizes) { + this.instanceSizeRange = sizeRangeFromStrings(instanceSizes); + return this; + } + + public Builder withStepSizeExecutor(int stepSize) { + this.stepSizeExecutors = stepSize; + return this; + } + + + public Builder withFitDriverMemory(boolean fitDriverMemory) { + this.fitDriverMemory = fitDriverMemory; + return this; + } + + public Builder withFitBroadcastMemory(boolean fitBroadcastMemory) { + this.fitBroadcastMemory = fitBroadcastMemory; + return this; + } + + public Builder withCheckSingleNodeExecution(boolean checkSingleNodeExecution) { + this.checkSingleNodeExecution = checkSingleNodeExecution; + return this; + } + + public Builder withFitCheckpointMemory(boolean fitCheckpointMemory) { + this.fitCheckpointMemory = fitCheckpointMemory; + return this; + } + + public Builder withExpBaseExecutors(int expBaseExecutors) { + if (expBaseExecutors != -1 && expBaseExecutors < 2) { + throw new IllegalArgumentException("Given exponent base for number of executors should be -1 or bigger than 1."); + } + this.expBaseExecutors = expBaseExecutors; + return this; + } + + public Enumerator build() { + if (this.program == null) { + throw new IllegalArgumentException("Providing runtime program is required"); + } + + if (instanceTypesRange == null) { + instanceTypesRange = EnumSet.allOf(CloudUtils.InstanceType.class); + } + + if (instanceSizeRange == null) { + instanceSizeRange = EnumSet.allOf(CloudUtils.InstanceSize.class); + } + + switch (optStrategy) { + case MinTime: + if (this.maxPrice < 0) { + throw new IllegalArgumentException("Budget not specified but required " + + "for the chosen optimization strategy: " + optStrategy); + } + break; + case MinPrice: + if (this.maxTime < 0) { + throw new IllegalArgumentException("Time limit not specified but required " + + "for the chosen optimization strategy: " + optStrategy); + } + break; + default: // in case optimization strategy was not configured + throw new IllegalArgumentException("Setting an optimization strategy is required."); + } + + switch (enumStrategy) { + case GridBased: + return new GridBasedEnumerator(this, stepSizeExecutors, expBaseExecutors); + case InterestBased: + if (fitCheckpointMemory && expBaseExecutors != -1) { + throw new IllegalArgumentException("Number of executors cannot be fitted on the checkpoint estimates and increased exponentially simultaneously."); + } + return new InterestBasedEnumerator(this, fitDriverMemory, fitBroadcastMemory, checkSingleNodeExecution, fitCheckpointMemory); + default: + throw new IllegalArgumentException("Setting an enumeration strategy is required."); + } + } + + protected static Set<CloudUtils.InstanceType> typeRangeFromStrings(String[] types) { + Set<CloudUtils.InstanceType> result = EnumSet.noneOf(CloudUtils.InstanceType.class); + for (String typeAsString: types) { + CloudUtils.InstanceType type = CloudUtils.InstanceType.customValueOf(typeAsString); // can throw IllegalArgumentException + result.add(type); + } + return result; + } + + protected static Set<CloudUtils.InstanceSize> sizeRangeFromStrings(String[] sizes) { + Set<CloudUtils.InstanceSize> result = EnumSet.noneOf(CloudUtils.InstanceSize.class); + for (String sizeAsString: sizes) { + CloudUtils.InstanceSize size = CloudUtils.InstanceSize.customValueOf(sizeAsString); // can throw IllegalArgumentException + result.add(size); + } + return result; + } + } +} diff --git a/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java new file mode 100644 index 0000000000..aa71aba139 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource.enumeration; + +import java.util.*; + +public class GridBasedEnumerator extends Enumerator { + // marks if the number of executors should + // be increased by a given step + private final int stepSizeExecutors; + // marks if the number of executors should + // be increased exponentially + // (single node execution mode is not excluded) + // -1 marks no exp. increasing + private final int expBaseExecutors; + public GridBasedEnumerator(Builder builder, int stepSizeExecutors, int expBaseExecutors) { + super(builder); + this.stepSizeExecutors = stepSizeExecutors; + this.expBaseExecutors = expBaseExecutors; + } + + /** + * Initializes the pool for driver and executor + * instances parsed at processing with all the + * available instances + */ + @Override + public void preprocessing() { + driverSpace.initSpace(instances); + executorSpace.initSpace(instances); + } + + @Override + public boolean evaluateSingleNodeExecution(long driverMemory) { + return minExecutors == 0; + } + + @Override + public ArrayList<Integer> estimateRangeExecutors(long executorMemory, int executorCores) { + // consider the maximum level of parallelism and + // based on the initiated flags decides for the following methods + // for enumeration of the number of executors: + // 1. Increasing the number of executor with given step size (default 1) + // 2. Exponentially increasing number of executors based on + // a given exponent base - with additional option for 0 executors + int currentMax = Math.min(maxExecutors, MAX_LEVEL_PARALLELISM / executorCores); + ArrayList<Integer> result; + if (expBaseExecutors > 1) { + int maxCapacity = (int) Math.floor(Math.log(currentMax) / Math.log(2)); + result = new ArrayList<>(maxCapacity); + int exponent = 0; + int numExecutors; + while ((numExecutors = (int) Math.pow(expBaseExecutors, exponent)) <= currentMax) { + if (numExecutors >= minExecutors) { + result.add(numExecutors); + } + exponent++; + } + } else { + int capacity = (int) Math.floor((double) (currentMax - minExecutors + 1) / stepSizeExecutors); + result = new ArrayList<>(capacity); + // exclude the 0 from the iteration while keeping it as starting point to ensure predictable steps + int numExecutors = minExecutors == 0? minExecutors + stepSizeExecutors : minExecutors; + while (numExecutors <= currentMax) { + result.add(numExecutors); + numExecutors += stepSizeExecutors; + } + } + + return result; + } +} diff --git a/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java new file mode 100644 index 0000000000..ea6971a967 --- /dev/null +++ b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java @@ -0,0 +1,315 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.resource.enumeration; + +import org.apache.sysds.hops.Hop; +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.parser.StatementBlock; +import org.apache.sysds.runtime.controlprogram.*; +import org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace; + +import java.util.*; +import java.util.stream.Collectors; + +public class InterestBasedEnumerator extends Enumerator { + public final static long MINIMUM_RELEVANT_MEM_ESTIMATE = 2L * 1024 * 1024 * 1024; // 2GB + // different instance families can have slightly different memory characteristics (e.g. EC2 Graviton (arm) instances) + // and using memory delta allows not ignoring such instances + // TODO: enable usage of memory delta when FLOPS and bandwidth characteristics added to cost estimator + public final static boolean USE_MEMORY_DELTA = false; + public final static double MEMORY_DELTA_FRACTION = 0.05; // 5% + public final static double BROADCAST_MEMORY_FACTOR = 0.21; // fraction of the minimum memory fraction for storage + // marks if memory estimates should be used at deciding + // for the search space of the instance for the driver nodes + private final boolean fitDriverMemory; + // marks if memory estimates should be used at deciding + // for the search space of the instance for the executor nodes + private final boolean fitBroadcastMemory; + // marks if the estimation of the range of number of executors + // for consideration should exclude single node execution mode + // if any of the estimates cannot fit in the driver's memory + private final boolean checkSingleNodeExecution; + // marks if the estimated output size should be + // considered as interesting point at deciding the + // number of executors - checkpoint storage level + private final boolean fitCheckpointMemory; + // largest full memory estimate (scaled) + private long largestMemoryEstimateCP; + // ordered set ot output memory estimates (scaled) + private TreeSet<Long> memoryEstimatesSpark; + public InterestBasedEnumerator( + Builder builder, + boolean fitDriverMemory, + boolean fitBroadcastMemory, + boolean checkSingleNodeExecution, + boolean fitCheckpointMemory + ) { + super(builder); + this.fitDriverMemory = fitDriverMemory; + this.fitBroadcastMemory = fitBroadcastMemory; + this.checkSingleNodeExecution = checkSingleNodeExecution; + this.fitCheckpointMemory = fitCheckpointMemory; + } + + @Override + public void preprocessing() { + InstanceSearchSpace fullSearchSpace = new InstanceSearchSpace(); + fullSearchSpace.initSpace(instances); + + if (fitDriverMemory) { + // get full memory estimates and scale according ot the driver memory factor + TreeSet<Long> memoryEstimatesForDriver = getMemoryEstimates(program, false, OptimizerUtils.MEM_UTIL_FACTOR); + setInstanceSpace(fullSearchSpace, driverSpace, memoryEstimatesForDriver); + if (checkSingleNodeExecution) { + largestMemoryEstimateCP = !memoryEstimatesForDriver.isEmpty()? memoryEstimatesForDriver.last() : -1; + } + } + + if (fitBroadcastMemory) { + // get output memory estimates and scaled according the broadcast memory factor + // for executors' memory search space and driver memory factor for driver's memory search space + TreeSet<Long> memoryEstimatesOutputSpark = getMemoryEstimates(program, true, BROADCAST_MEMORY_FACTOR); + // avoid calling getMemoryEstimates with different factor but rescale: output should fit twice in the CP memory + TreeSet<Long> memoryEstimatesOutputCP = memoryEstimatesOutputSpark.stream() + .map(mem -> 2 * (long) (mem * BROADCAST_MEMORY_FACTOR / OptimizerUtils.MEM_UTIL_FACTOR)) + .collect(Collectors.toCollection(TreeSet::new)); + setInstanceSpace(fullSearchSpace, driverSpace, memoryEstimatesOutputCP); + setInstanceSpace(fullSearchSpace, executorSpace, memoryEstimatesOutputSpark); + if (checkSingleNodeExecution) { + largestMemoryEstimateCP = !memoryEstimatesOutputCP.isEmpty()? memoryEstimatesOutputCP.last() : -1; + } + if (fitCheckpointMemory) { + memoryEstimatesSpark = memoryEstimatesOutputSpark; + } + } else { + executorSpace.putAll(fullSearchSpace); + if (fitCheckpointMemory) { + memoryEstimatesSpark = getMemoryEstimates(program, true, BROADCAST_MEMORY_FACTOR); + } + } + + if (!fitDriverMemory && !fitBroadcastMemory) { + driverSpace.putAll(fullSearchSpace); + if (checkSingleNodeExecution) { + TreeSet<Long> memoryEstimatesForDriver = getMemoryEstimates(program, false, OptimizerUtils.MEM_UTIL_FACTOR); + largestMemoryEstimateCP = !memoryEstimatesForDriver.isEmpty()? memoryEstimatesForDriver.last() : -1; + } + } + } + + @Override + public boolean evaluateSingleNodeExecution(long driverMemory) { + // Checking if single node execution should be excluded is optional. + if (checkSingleNodeExecution && minExecutors == 0 && largestMemoryEstimateCP > 0) { + return largestMemoryEstimateCP <= driverMemory; + } + return minExecutors == 0; + } + + @Override + public ArrayList<Integer> estimateRangeExecutors(long executorMemory, int executorCores) { + // consider the maximum level of parallelism and + // based on the initiated flags decides on the following methods + // for enumeration of the number of executors: + // 1. Such a number that leads to combined distributed memory + // close to the output size of the HOPs + // 3. Enumerating all options with the established range + int min = Math.max(1, minExecutors); + int max = Math.min(maxExecutors, (MAX_LEVEL_PARALLELISM / executorCores)); + + ArrayList<Integer> result; + if (fitCheckpointMemory) { + result = new ArrayList<>(memoryEstimatesSpark.size() + 1); + int previousNumber = -1; + for (long estimate: memoryEstimatesSpark) { + // the ratio is just an intermediate for the new enumerated number of executors + double ratio = (double) estimate / executorMemory; + int currentNumber = (int) Math.max(1, Math.floor(ratio)); + if (currentNumber < min || currentNumber == previousNumber) { + + continue; + } + if (currentNumber <= max) { + result.add(currentNumber); + previousNumber = currentNumber; + } else { + break; + } + } + // add a number that allow also the largest checkpoint to be done in memory + if (previousNumber < 0) { + // always append at least one value to allow evaluating Spark execution + result.add(min); + } else if (previousNumber < max) { + result.add(previousNumber + 1); + } + } else { // enumerate all options within the min-max range + result = new ArrayList<>((max - min) + 1); + for (int n = min; n <= max; n++) { + result.add(n); + } + } + return result; + } + + // Static helper methods ------------------------------------------------------------------------------------------- + private static void setInstanceSpace(InstanceSearchSpace inputSpace, InstanceSearchSpace outputSpace, TreeSet<Long> memoryEstimates) { + TreeSet<Long> memoryPoints = getMemoryPoints(memoryEstimates, inputSpace.keySet()); + for (long memory: memoryPoints) { + outputSpace.put(memory, inputSpace.get(memory)); + } + // in case no large enough memory estimates exist set the instances with minimal memory + if (outputSpace.isEmpty()) { + long minMemory = inputSpace.firstKey(); + outputSpace.put(minMemory, inputSpace.get(minMemory)); + } + } + + /** + * @param availableMemory should be always a sorted set; + * this is always the case for the result of {@code keySet()} called on {@code TreeMap} + */ + private static TreeSet<Long> getMemoryPoints(TreeSet<Long> estimates, Set<Long> availableMemory) { + // use tree set to avoid adding duplicates and ensure ascending order + TreeSet<Long> result = new TreeSet<>(); + // assumed ascending order + List<Long> relevantPoints = new ArrayList<>(availableMemory); + for (long estimate: estimates) { + if (availableMemory.isEmpty()) { + break; + } + // divide list on larger and smaller by partitioning - partitioning preserve the order + Map<Boolean, List<Long>> divided = relevantPoints.stream() + .collect(Collectors.partitioningBy(n -> n < estimate)); + // get the points smaller than the current memory estimate + List<Long> smallerPoints = divided.get(true); + long largestOfTheSmaller = smallerPoints.isEmpty() ? -1 : smallerPoints.get(smallerPoints.size() - 1); + // reduce the list of relevant points - equal or larger than the estimate + relevantPoints = divided.get(false); + // get points greater or equal than the current memory estimate + long smallestOfTheLarger = relevantPoints.isEmpty()? -1 : relevantPoints.get(0); + + if (USE_MEMORY_DELTA) { + // Delta memory of 5% of the node's memory allows not ignoring + // memory points with potentially equivalent values but not exactly the same values. + // This is the case for example in AWS for instances of the same type but with + // different additional capabilities: m5.xlarge (16GB) vs m5n.xlarge (15.25GB). + // Get points smaller than the current memory estimate within the memory delta + long memoryDelta = Math.round(estimate * MEMORY_DELTA_FRACTION); + for (long point : smallerPoints) { + if (point >= (largestOfTheSmaller - memoryDelta)) { + result.add(point); + } + } + for (long point : relevantPoints) { + if (point <= (smallestOfTheLarger + memoryDelta)) { + result.add(point); + } else { + break; + } + } + } else { + if (largestOfTheSmaller > 0) { + result.add(largestOfTheSmaller); + } + if (smallestOfTheLarger > 0) { + result.add(smallestOfTheLarger); + } + } + } + return result; + } + + /** + * Extracts the memory estimates which original size is larger than {@code MINIMUM_RELEVANT_MEM_ESTIMATE} + * + * @param currentProgram program for extracting the memory estimates from + * @param outputOnly {@code true} - output estimate only; + * {@code false} - sum of input, intermediate and output estimates + * @param memoryFactor factor for reverse scaling the estimates to avoid + * scaling the search space parameters representing the nodes' memory budget + * @return memory estimates in ascending order ensured by the {@code TreeSet} data structure + */ + public static TreeSet<Long> getMemoryEstimates(Program currentProgram, boolean outputOnly, double memoryFactor) { + TreeSet<Long> estimates = new TreeSet<>(); + getMemoryEstimates(currentProgram.getProgramBlocks(), estimates, outputOnly); + return estimates.stream() + .filter(mem -> mem > MINIMUM_RELEVANT_MEM_ESTIMATE) + .map(mem -> (long) (mem / memoryFactor)) + .collect(Collectors.toCollection(TreeSet::new)); + } + + private static void getMemoryEstimates(ArrayList<ProgramBlock> pbs, TreeSet<Long> mem, boolean outputOnly) { + for( ProgramBlock pb : pbs ) { + getMemoryEstimates(pb, mem, outputOnly); + } + } + + private static void getMemoryEstimates(ProgramBlock pb, TreeSet<Long> mem, boolean outputOnly) { + if (pb instanceof FunctionProgramBlock) + { + FunctionProgramBlock fpb = (FunctionProgramBlock)pb; + getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly); + } + else if (pb instanceof WhileProgramBlock) + { + WhileProgramBlock fpb = (WhileProgramBlock)pb; + getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly); + } + else if (pb instanceof IfProgramBlock) + { + IfProgramBlock fpb = (IfProgramBlock)pb; + getMemoryEstimates(fpb.getChildBlocksIfBody(), mem, outputOnly); + getMemoryEstimates(fpb.getChildBlocksElseBody(), mem, outputOnly); + } + else if (pb instanceof ForProgramBlock) // including parfor + { + ForProgramBlock fpb = (ForProgramBlock)pb; + getMemoryEstimates(fpb.getChildBlocks(), mem, outputOnly); + } + else + { + StatementBlock sb = pb.getStatementBlock(); + if( sb != null && sb.getHops() != null ){ + Hop.resetVisitStatus(sb.getHops()); + for( Hop hop : sb.getHops() ) + getMemoryEstimates(hop, mem, outputOnly); + } + } + } + + private static void getMemoryEstimates(Hop hop, TreeSet<Long> mem, boolean outputOnly) + { + if( hop.isVisited() ) + return; + //process children + for(Hop hi : hop.getInput()) + getMemoryEstimates(hi, mem, outputOnly); + + if (outputOnly) { + long estimate = (long) hop.getOutputMemEstimate(0); + if (estimate > 0) + mem.add(estimate); + } else { + mem.add((long) hop.getMemEstimate()); + } + hop.setVisited(); + } +} diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java index c0803cbcc7..76e664245c 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java @@ -159,22 +159,10 @@ public class SparkExecutionContext extends ExecutionContext return _spctx; } - public static void initVirtualSparkContext(SparkConf sparkConf) { - if (_spctx != null) { - for (Tuple2<String, String> pair : sparkConf.getAll()) { - _spctx.sc().getConf().set(pair._1, pair._2); - } - } else { - handleIllegalReflectiveAccessSpark(); - try { - _spctx = new JavaSparkContext(sparkConf); - // assumes NON-legacy spark version - _sconf = new SparkClusterConfig(); - } catch (Exception e) { - throw new RuntimeException(e); - } + public static void initLocalSparkContext(SparkConf sparkConf) { + if (_sconf == null) { + _sconf = new SparkClusterConfig(); } - _sconf.analyzeSparkConfiguation(sparkConf); } @@ -1858,7 +1846,7 @@ public class SparkExecutionContext extends ExecutionContext private static final double BROADCAST_DATA_FRACTION_LEGACY = 0.35; //forward private config from Spark's UnifiedMemoryManager.scala (>1.6) - private static final long RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024; + public static final long RESERVED_SYSTEM_MEMORY_BYTES = 300 * 1024 * 1024; //meta configurations private boolean _legacyVersion = false; //spark version <1.6 @@ -1985,7 +1973,7 @@ public class SparkExecutionContext extends ExecutionContext _confOnly &= true; } else if( DMLScript.USE_LOCAL_SPARK_CONFIG ) { - //avoid unnecessary spark context creation in local mode (e.g., tests) + //avoid unnecessary spark context creation in local mode (e.g., tests, resource opt.) _numExecutors = 1; _defaultPar = 2; _confOnly &= true; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java index 98c6848839..16f3cf7616 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java @@ -56,7 +56,7 @@ public class InfrastructureAnalyzer private static int _remoteParReduce = -1; private static boolean _localJT = false; private static long _blocksize = -1; - + //static initialization, called for each JVM (on each node) static { //analyze local node properties @@ -136,7 +136,11 @@ public class InfrastructureAnalyzer public static void setLocalMaxMemory( long localMem ) { _localJVMMaxMem = localMem; } - + + public static void setLocalPar(int localPar) { + _localPar = localPar; + } + public static double getLocalMaxMemoryFraction() { //since parfor modifies _localJVMMaxMem, some internal primitives //need access to the current fraction of total local memory diff --git a/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java b/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java new file mode 100644 index 0000000000..b7e08ae35d --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.resource; + +import org.apache.sysds.resource.AWSUtils; +import org.apache.sysds.resource.CloudInstance; +import org.apache.sysds.resource.CloudUtils.InstanceType; +import org.apache.sysds.resource.CloudUtils.InstanceSize; +import org.junit.Test; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.HashMap; + +import static org.apache.sysds.test.component.resource.TestingUtils.assertEqualsCloudInstances; +import static org.apache.sysds.test.component.resource.TestingUtils.getSimpleCloudInstanceMap; +import static org.junit.Assert.*; + +@net.jcip.annotations.NotThreadSafe +public class CloudUtilsTests { + + @Test + public void getInstanceTypeAWSTest() { + AWSUtils utils = new AWSUtils(); + + InstanceType expectedValue = InstanceType.M5; + InstanceType actualValue; + + actualValue = utils.getInstanceType("m5.xlarge"); + assertEquals(expectedValue, actualValue); + + actualValue = utils.getInstanceType("M5.XLARGE"); + assertEquals(expectedValue, actualValue); + + try { + utils.getInstanceType("NON-M5.xlarge"); + fail("Throwing IllegalArgumentException was expected"); + } catch (IllegalArgumentException e) { + // this block ensures correct execution of the test + } + } + + @Test + public void getInstanceSizeAWSTest() { + AWSUtils utils = new AWSUtils(); + + InstanceSize expectedValue = InstanceSize._XLARGE; + InstanceSize actualValue; + + actualValue = utils.getInstanceSize("m5.xlarge"); + assertEquals(expectedValue, actualValue); + + actualValue = utils.getInstanceSize("M5.XLARGE"); + assertEquals(expectedValue, actualValue); + + try { + utils.getInstanceSize("m5.nonxlarge"); + fail("Throwing IllegalArgumentException was expected"); + } catch (IllegalArgumentException e) { + // this block ensures correct execution of the test + } + } + + @Test + public void validateInstanceNameAWSTest() { + AWSUtils utils = new AWSUtils(); + + // basic intel instance (old) + assertTrue(utils.validateInstanceName("m5.2xlarge")); + assertTrue(utils.validateInstanceName("M5.2XLARGE")); + // basic intel instance (new) + assertTrue(utils.validateInstanceName("m6i.xlarge")); + // basic amd instance + assertTrue(utils.validateInstanceName("m6a.xlarge")); + // basic graviton instance + assertTrue(utils.validateInstanceName("m6g.xlarge")); + // invalid values + assertFalse(utils.validateInstanceName("v5.xlarge")); + assertFalse(utils.validateInstanceName("m5.notlarge")); + assertFalse(utils.validateInstanceName("m5xlarge")); + assertFalse(utils.validateInstanceName(".xlarge")); + assertFalse(utils.validateInstanceName("m5.")); + } + + @Test + public void loadCSVFileAWSTest() throws IOException { + AWSUtils utils = new AWSUtils(); + + File tmpFile = TestingUtils.generateTmpInstanceInfoTableFile(); + + HashMap<String, CloudInstance> actual = utils.loadInstanceInfoTable(tmpFile.getPath()); + HashMap<String, CloudInstance> expected = getSimpleCloudInstanceMap(); + + for (String instanceName: expected.keySet()) { + assertEqualsCloudInstances(expected.get(instanceName), actual.get(instanceName)); + } + + Files.deleteIfExists(tmpFile.toPath()); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java new file mode 100644 index 0000000000..437770ce28 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java @@ -0,0 +1,529 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.resource; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.resource.CloudInstance; +import org.apache.sysds.resource.enumeration.Enumerator; +import org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace; +import org.apache.sysds.resource.enumeration.EnumerationUtils.ConfigurationPoint; +import org.apache.sysds.resource.enumeration.EnumerationUtils.SolutionPoint; +import org.apache.sysds.resource.enumeration.InterestBasedEnumerator; +import org.apache.sysds.runtime.controlprogram.Program; +import org.junit.Assert; +import org.junit.Test; +import org.mockito.MockedStatic; +import org.mockito.Mockito; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.*; + +import static org.apache.sysds.resource.CloudUtils.GBtoBytes; +import static org.junit.Assert.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; + +@net.jcip.annotations.NotThreadSafe +public class EnumeratorTests { + + @Test + public void loadInstanceTableTest() throws IOException { + // loading the table is entirely implemented by the abstract class + // use any enumerator + Enumerator anyEnumerator = getGridBasedEnumeratorPrebuild() + .withInstanceTypeRange(new String[]{"m5"}) + .withInstanceSizeRange(new String[]{"xlarge"}) + .build(); + + File tmpFile = TestingUtils.generateTmpInstanceInfoTableFile(); + anyEnumerator.loadInstanceTableFile(tmpFile.toString()); + + HashMap<String, CloudInstance> actualInstances = anyEnumerator.getInstances(); + + Assert.assertEquals(1, actualInstances.size()); + Assert.assertNotNull(actualInstances.get("m5.xlarge")); + + Files.deleteIfExists(tmpFile.toPath()); + } + + @Test + public void preprocessingGridBasedTest() { + Enumerator gridBasedEnumerator = getGridBasedEnumeratorPrebuild().build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + gridBasedEnumerator.setInstanceTable(instances); + + gridBasedEnumerator.preprocessing(); + // assertions for driver space + InstanceSearchSpace driverSpace = gridBasedEnumerator.getDriverSpace(); + assertEquals(3, driverSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0); + assertInstanceInSearchSpace("m5.xlarge", driverSpace, 16, 4, 0); + assertInstanceInSearchSpace("c5.2xlarge", driverSpace, 16, 8, 0); + assertInstanceInSearchSpace("m5.2xlarge", driverSpace, 32, 8, 0); + // assertions for executor space + InstanceSearchSpace executorSpace = gridBasedEnumerator.getDriverSpace(); + assertEquals(3, executorSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4, 0); + assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4, 0); + assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8, 0); + assertInstanceInSearchSpace("m5.2xlarge", executorSpace, 32, 8, 0); + } + + @Test + public void preprocessingInterestBasedDriverMemoryTest() { + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withFitDriverMemory(true) + .withFitBroadcastMemory(false) + .build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + interestBasedEnumerator.setInstanceTable(instances); + + // use 10GB (scaled) memory estimate to be between the available 8GB and 16GB driver node's memory + TreeSet<Long> mockingMemoryEstimates = new TreeSet<>(Set.of(GBtoBytes(10))); + try (MockedStatic<InterestBasedEnumerator> mockedEnumerator = + Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) { + mockedEnumerator + .when(() -> InterestBasedEnumerator.getMemoryEstimates( + any(Program.class), + eq(false), + eq(OptimizerUtils.MEM_UTIL_FACTOR))) + .thenReturn(mockingMemoryEstimates); + interestBasedEnumerator.preprocessing(); + } + + // assertions for driver space + InstanceSearchSpace driverSpace = interestBasedEnumerator.getDriverSpace(); + assertEquals(2, driverSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0); + assertInstanceInSearchSpace("m5.xlarge", driverSpace, 16, 4, 0); + assertInstanceInSearchSpace("c5.2xlarge", driverSpace, 16, 8, 0); + Assert.assertNull(driverSpace.get(GBtoBytes(32))); + // assertions for executor space + InstanceSearchSpace executorSpace = interestBasedEnumerator.getExecutorSpace(); + assertEquals(3, executorSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4, 0); + assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4, 0); + assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8, 0); + assertInstanceInSearchSpace("m5.2xlarge", executorSpace, 32, 8, 0); + } + + @Test + public void preprocessingInterestBasedBroadcastMemoryTest() { + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withFitDriverMemory(false) + .withFitBroadcastMemory(true) + .build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + interestBasedEnumerator.setInstanceTable(instances); + + double outputEstimate = 2.5; + double scaledOutputEstimateBroadcast = outputEstimate / InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR; // ~=12 + // scaledOutputEstimateCP = 2 * outputEstimate / OptimizerUtils.MEM_UTIL_FACTOR ~= 7 + TreeSet<Long> mockingMemoryEstimates = new TreeSet<>(Set.of(GBtoBytes(scaledOutputEstimateBroadcast))); + try (MockedStatic<InterestBasedEnumerator> mockedEnumerator = + Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) { + mockedEnumerator + .when(() -> InterestBasedEnumerator.getMemoryEstimates( + any(Program.class), + eq(true), + eq(InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR))) + .thenReturn(mockingMemoryEstimates); + interestBasedEnumerator.preprocessing(); + } + + // assertions for driver space + InstanceSearchSpace driverSpace = interestBasedEnumerator.getDriverSpace(); + assertEquals(1, driverSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0); + Assert.assertNull(driverSpace.get(GBtoBytes(16))); + Assert.assertNull(driverSpace.get(GBtoBytes(32))); + // assertions for executor space + InstanceSearchSpace executorSpace = interestBasedEnumerator.getExecutorSpace(); + assertEquals(2, executorSpace.size()); + assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4, 0); + assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4, 0); + assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8, 0); + Assert.assertNull(executorSpace.get(GBtoBytes(32))); + } + + @Test + public void evaluateSingleNodeExecutionGridBasedTest() { + Enumerator gridBasedEnumerator; + boolean result; + + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0,1) + .build(); + + // memory not relevant for grid-based enumerator + result = gridBasedEnumerator.evaluateSingleNodeExecution(-1); + Assert.assertTrue(result); + + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(1,2) + .build(); + + // memory not relevant for grid-based enumerator + result = gridBasedEnumerator.evaluateSingleNodeExecution(-1); + Assert.assertFalse(result); + } + + @Test + public void estimateRangeExecutorsGridBasedStepSizeTest() { + Enumerator gridBasedEnumerator; + ArrayList<Integer> expectedResult; + ArrayList<Integer> actualResult; + + // num. executors range starting from zero and step size = 2 + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 10) + .withStepSizeExecutor(2) + .build(); + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(2, 4, 6, 8, 10)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(2, 4)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 200); + Assert.assertEquals(expectedResult, actualResult); + + // num. executors range not starting from zero and without step size given + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(3, 8) + .build(); + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(3, 4, 5, 6, 7, 8)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(3, 4, 5)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 200); + Assert.assertEquals(expectedResult, actualResult); + } + + @Test + public void estimateRangeExecutorsGridBasedExpBaseTest() { + Enumerator gridBasedEnumerator; + ArrayList<Integer> expectedResult; + ArrayList<Integer> actualResult; + + // num. executors range starting from zero and exponential base = 2 + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 10) + .withExpBaseExecutors(2) + .build(); + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2, 4, 8)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2, 4)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 200); + Assert.assertEquals(expectedResult, actualResult); + + // num. executors range not starting from zero and with exponential base = 3 + gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(3, 30) + .withExpBaseExecutors(3) + .build(); + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(3,9, 27)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(3,9)); + actualResult = gridBasedEnumerator.estimateRangeExecutors(-1, 100); + Assert.assertEquals(expectedResult, actualResult); + } + + @Test + public void evaluateSingleNodeExecutionInterestBasedTest() { + boolean result; + + // no fitting the memory estimates for checkpointing + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 5) + .withFitDriverMemory(false) + .withFitBroadcastMemory(false) + .withCheckSingleNodeExecution(true) + .build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + interestBasedEnumerator.setInstanceTable(instances); + + TreeSet<Long> mockingMemoryEstimates = new TreeSet<>(Set.of(GBtoBytes(6), GBtoBytes(12))); + try (MockedStatic<InterestBasedEnumerator> mockedEnumerator = + Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) { + mockedEnumerator + .when(() -> InterestBasedEnumerator.getMemoryEstimates( + any(Program.class), + eq(false), + eq(OptimizerUtils.MEM_UTIL_FACTOR))) + .thenReturn(mockingMemoryEstimates); + // initiate memoryEstimatesSpark + interestBasedEnumerator.preprocessing(); + } + + result = interestBasedEnumerator.evaluateSingleNodeExecution(GBtoBytes(8)); + Assert.assertFalse(result); + } + + @Test + public void estimateRangeExecutorsInterestBasedGeneralTest() { + ArrayList<Integer> expectedResult; + ArrayList<Integer>actualResult; + + // no fitting the memory estimates for checkpointing + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 5) + .build(); + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2, 3, 4, 5)); + actualResult = interestBasedEnumerator.estimateRangeExecutors(-1, 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2, 3)); + actualResult = interestBasedEnumerator.estimateRangeExecutors(-1, 256); + Assert.assertEquals(expectedResult, actualResult); + } + + @Test + public void estimateRangeExecutorsInterestBasedCheckpointMemoryTest() { + ArrayList<Integer> expectedResult; + ArrayList<Integer>actualResult; + + // fitting the memory estimates for checkpointing + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 5) + .withFitCheckpointMemory(true) + .withFitDriverMemory(false) + .withFitBroadcastMemory(false) + .build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + interestBasedEnumerator.setInstanceTable(instances); + + TreeSet<Long> mockingMemoryEstimates = new TreeSet<>(Set.of(GBtoBytes(20), GBtoBytes(40))); + try (MockedStatic<InterestBasedEnumerator> mockedEnumerator = + Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) { + mockedEnumerator + .when(() -> InterestBasedEnumerator.getMemoryEstimates( + any(Program.class), + eq(true), + eq(InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR))) + .thenReturn(mockingMemoryEstimates); + // initiate memoryEstimatesSpark + interestBasedEnumerator.preprocessing(); + } + + // test the general case when the max level of parallelism is not reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2, 3)); + actualResult = interestBasedEnumerator.estimateRangeExecutors(GBtoBytes(16), 4); + Assert.assertEquals(expectedResult, actualResult); + // test the case when the max level of parallelism (1000) is reached (0 is never part of the result) + expectedResult = new ArrayList<>(List.of(1, 2)); + actualResult = interestBasedEnumerator.estimateRangeExecutors(GBtoBytes(16), 500); + Assert.assertEquals(expectedResult, actualResult); + } + + @Test + public void processingTest() { + // all implemented enumerators should enumerate the same solution pool in this basic case - empty program + Enumerator gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withTimeLimit(Double.MAX_VALUE) + .withNumberExecutorsRange(0, 2) + .build(); + + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 2) + .build(); + + HashMap<String, CloudInstance> instances = TestingUtils.getSimpleCloudInstanceMap(); + InstanceSearchSpace space = new InstanceSearchSpace(); + space.initSpace(instances); + + // run processing for the grid based enumerator + gridBasedEnumerator.setDriverSpace(space); + gridBasedEnumerator.setExecutorSpace(space); + gridBasedEnumerator.processing(); + ArrayList<SolutionPoint> actualSolutionPoolGB = gridBasedEnumerator.getSolutionPool(); + // run processing for the interest based enumerator + interestBasedEnumerator.setDriverSpace(space); + interestBasedEnumerator.setExecutorSpace(space); + interestBasedEnumerator.processing(); + ArrayList<SolutionPoint> actualSolutionPoolIB = gridBasedEnumerator.getSolutionPool(); + + + ArrayList<CloudInstance> expectedInstances = new ArrayList<>(Arrays.asList( + instances.get("c5.xlarge"), + instances.get("m5.xlarge") + )); + // expected solution pool with 0 executors (number executors = 0, executors and executorInstance being null) + // each solution having one of the available instances as driver node + Assert.assertEquals(expectedInstances.size(), actualSolutionPoolGB.size()); + Assert.assertEquals(expectedInstances.size(), actualSolutionPoolIB.size()); + for (int i = 0; i < expectedInstances.size(); i++) { + SolutionPoint pointGB = actualSolutionPoolGB.get(i); + Assert.assertEquals(0, pointGB.numberExecutors); + Assert.assertEquals(expectedInstances.get(i), pointGB.driverInstance); + Assert.assertNull(pointGB.executorInstance); + SolutionPoint pointIB = actualSolutionPoolGB.get(i); + Assert.assertEquals(0, pointIB.numberExecutors); + Assert.assertEquals(expectedInstances.get(i), pointIB.driverInstance); + Assert.assertNull(pointIB.executorInstance); + } + } + + @Test + public void postprocessingTest() { + // postprocessing equivalent for all types of enumerators + Enumerator enumerator = getGridBasedEnumeratorPrebuild().build(); + // construct solution pool + // first dummy configuration point since not relevant for postprocessing + ConfigurationPoint dummyPoint = new ConfigurationPoint(null); + SolutionPoint solution1 = new SolutionPoint(dummyPoint, 1000, 1000); + SolutionPoint solution2 = new SolutionPoint(dummyPoint, 900, 1000); // optimal point + SolutionPoint solution3 = new SolutionPoint(dummyPoint, 800, 10000); + SolutionPoint solution4 = new SolutionPoint(dummyPoint, 1000, 10000); + SolutionPoint solution5 = new SolutionPoint(dummyPoint, 900, 10000); + ArrayList<SolutionPoint> mockListSolutions = new ArrayList<>(List.of(solution1, solution2, solution3, solution4, solution5)); + enumerator.setSolutionPool(mockListSolutions); + + SolutionPoint optimalSolution = enumerator.postprocessing(); + assertEquals(solution2, optimalSolution); + } + + @Test + public void GridBasedEnumerationMinPriceTest() { + Enumerator gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 2) + .build(); + + gridBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap()); + + gridBasedEnumerator.preprocessing(); + gridBasedEnumerator.processing(); + SolutionPoint solution = gridBasedEnumerator.postprocessing(); + + // expected m5.xlarge since it is the cheaper + Assert.assertEquals("m5.xlarge", solution.driverInstance.getInstanceName()); + // expected no executor nodes since tested for a 'zero' program + Assert.assertEquals(0, solution.numberExecutors); + } + + @Test + public void InterestBasedEnumerationMinPriceTest() { + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withNumberExecutorsRange(0, 2) + .build(); + + interestBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap()); + + interestBasedEnumerator.preprocessing(); + interestBasedEnumerator.processing(); + SolutionPoint solution = interestBasedEnumerator.postprocessing(); + + // expected c5.xlarge since is the instance with at least memory + Assert.assertEquals("c5.xlarge", solution.driverInstance.getInstanceName()); + // expected no executor nodes since tested for a 'zero' program + Assert.assertEquals(0, solution.numberExecutors); + } + + @Test + public void GridBasedEnumerationMinTimeTest() { + Enumerator gridBasedEnumerator = getGridBasedEnumeratorPrebuild() + .withOptimizationStrategy(Enumerator.OptimizationStrategy.MinTime) + .withBudget(Double.MAX_VALUE) + .withNumberExecutorsRange(0, 2) + .build(); + + gridBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap()); + + gridBasedEnumerator.preprocessing(); + gridBasedEnumerator.processing(); + SolutionPoint solution = gridBasedEnumerator.postprocessing(); + + // expected m5.xlarge since it is the cheaper + Assert.assertEquals("m5.xlarge", solution.driverInstance.getInstanceName()); + // expected no executor nodes since tested for a 'zero' program + Assert.assertEquals(0, solution.numberExecutors); + } + + @Test + public void InterestBasedEnumerationMinTimeTest() { + Enumerator interestBasedEnumerator = getInterestBasedEnumeratorPrebuild() + .withOptimizationStrategy(Enumerator.OptimizationStrategy.MinTime) + .withBudget(Double.MAX_VALUE) + .withNumberExecutorsRange(0, 2) + .build(); + + interestBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap()); + + interestBasedEnumerator.preprocessing(); + interestBasedEnumerator.processing(); + SolutionPoint solution = interestBasedEnumerator.postprocessing(); + + // expected c5.xlarge since is the instance with at least memory + Assert.assertEquals("c5.xlarge", solution.driverInstance.getInstanceName()); + // expected no executor nodes since tested for a 'zero' program + Assert.assertEquals(0, solution.numberExecutors); + } + + // Helpers + private static Enumerator.Builder getGridBasedEnumeratorPrebuild() { + Program emptyProgram = new Program(); + return (new Enumerator.Builder()) + .withRuntimeProgram(emptyProgram) + .withEnumerationStrategy(Enumerator.EnumerationStrategy.GridBased) + .withOptimizationStrategy(Enumerator.OptimizationStrategy.MinPrice) + .withTimeLimit(Double.MAX_VALUE); + } + + private static Enumerator.Builder getInterestBasedEnumeratorPrebuild() { + Program emptyProgram = new Program(); + return (new Enumerator.Builder()) + .withRuntimeProgram(emptyProgram) + .withEnumerationStrategy(Enumerator.EnumerationStrategy.InterestBased) + .withOptimizationStrategy(Enumerator.OptimizationStrategy.MinPrice) + .withTimeLimit(Double.MAX_VALUE); + } + + private static void assertInstanceInSearchSpace( + String expectedName, + InstanceSearchSpace searchSpace, + int memory, /* in GB */ + int cores, + int index + ) { + Assert.assertNotNull(searchSpace.get(GBtoBytes(memory))); + try { + String actualName = searchSpace.get(GBtoBytes(memory)).get(cores).get(index).getInstanceName(); + Assert.assertEquals(expectedName, actualName); + } catch (NullPointerException e) { + fail(expectedName+" instances not properly passed to "+searchSpace.getClass().getName()); + } + } +} diff --git a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java new file mode 100644 index 0000000000..b2631684c2 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java @@ -0,0 +1,258 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.resource; + +import org.apache.sysds.resource.ResourceCompiler; +import org.apache.sysds.runtime.controlprogram.BasicProgramBlock; +import org.apache.sysds.runtime.controlprogram.Program; +import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; +import org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer; +import org.apache.sysds.runtime.instructions.Instruction; +import org.apache.sysds.runtime.instructions.spark.SPInstruction; +import org.apache.sysds.test.AutomatedTestBase; +import org.apache.sysds.utils.Explain; + +import org.junit.Assert; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +import static org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext.SparkClusterConfig.RESERVED_SYSTEM_MEMORY_BYTES; + +public class RecompilationTest extends AutomatedTestBase { + private static final boolean DEBUG_MODE = true; + private static final String TEST_DIR = "component/resource/"; + private static final String TEST_DATA_DIR = "component/resource/data/"; + private static final String HOME = SCRIPT_DIR + TEST_DIR; + private static final String HOME_DATA = SCRIPT_DIR + TEST_DATA_DIR; + // Static Configuration values ------------------------------------------------------------------------------------- + private static final int driverThreads = 4; + private static final int executorThreads = 2; + + @Override + public void setUp() {} + + // Tests for setting cluster configurations ------------------------------------------------------------------------ + + @Test + public void testSetDriverConfigurations() { + long expectedMemory = 1024*1024*1024; // 1GB + int expectedThreads = 4; + + ResourceCompiler.setDriverConfigurations(expectedMemory, expectedThreads); + + Assert.assertEquals(expectedMemory, InfrastructureAnalyzer.getLocalMaxMemory()); + Assert.assertEquals(expectedThreads, InfrastructureAnalyzer.getLocalParallelism()); + } + + @Test + public void testSetExecutorConfigurations() { + int numberExecutors = 10; + long executorMemory = 1024*1024*1024; // 1GB + long expectedMemoryBudget = (long) (numberExecutors*(executorMemory-RESERVED_SYSTEM_MEMORY_BYTES)*0.6); + int executorThreads = 4; + int expectedParallelism = numberExecutors*executorThreads; + + ResourceCompiler.setExecutorConfigurations(numberExecutors, executorMemory, executorThreads); + + Assert.assertEquals(numberExecutors, SparkExecutionContext.getNumExecutors()); + Assert.assertEquals(expectedMemoryBudget, (long) SparkExecutionContext.getDataMemoryBudget(false, false)); + Assert.assertEquals(expectedParallelism, SparkExecutionContext.getDefaultParallelism(false)); + } + + // Tests for regular matrix multiplication (X%*%Y) ----------------------------------------------------------------- + + @Test + public void test_CP_MM_Enforced() throws IOException { + // Single node cluster with 8GB driver memory -> ba+* operator + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG + // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB + // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB + runTestMM("A.csv", "B.csv", 8L*1024*1024*1024, 0, -1, "ba+*"); + } + + @Test + public void test_CP_MM_Preferred() throws IOException { + // Distributed cluster with 16GB driver memory (large enough to host the computation) and any executors + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG + // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB + // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB + runTestMM("A.csv", "B.csv", 16L*1024*1024*1024, 2, 1024*1024*1024, "ba+*"); + } + + @Test + public void test_SP_MAPMM() throws IOException { + // Distributed cluster with 4GB driver memory and 4GB executors -> mapmm operator + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG + // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB + // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB + runTestMM("A.csv", "B.csv", 4L*1024*1024*1024, 2, 4L*1024*1024*1024, "mapmm"); + } + + @Test + public void test_SP_RMM() throws IOException { + // Distributed cluster with 1GB driver memory and 500MB executors -> rmm operator + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG + // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB + // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB + runTestMM("A.csv", "B.csv", 1024*1024*1024, 2, (long) (0.5*1024*1024*1024), "rmm"); + } + + @Test + public void test_SP_CPMM() throws IOException { + // Distributed cluster with 8GB driver memory and 4GB executors -> cpmm operator + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG + // Y = C.csv: (10^4)x(10^4) = 10^8 ~ 800MB + // X %*% Y -> (10^5)x(10^4) = 10^9 ~ 8GB + runTestMM("A.csv", "C.csv", 8L*1024*1024*1024, 2, 4L*1024*1024*1024, "cpmm"); + } + + // Tests for transposed self matrix multiplication (t(X)%*%X) ------------------------------------------------------ + + @Test + public void test_CP_TSMM() throws IOException { + // Single node cluster with 8GB driver memory -> tsmm operator in CP + // X = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB + // t(X) %*% X -> (10^3)x(10^3) = 10^6 ~ 8MB (single block) + runTestTSMM("B.csv", 8L*1024*1024*1024, 0, -1, "tsmm", false); + } + + @Test + public void test_SP_TSMM() throws IOException { + // Distributed cluster with 1GB driver memory and 8GB executor memory -> tsmm operator in Spark + // X = D.csv: (10^5)x(10^3) = 10^8 ~ 800MB + // t(X) %*% X -> (10^3)x(10^3) = 10^6 ~ 8MB (single block) + runTestTSMM("D.csv", 1024*1024*1024, 2, 8L*1024*1024*1024, "tsmm", true); + } + + @Test + public void test_SP_TSMM_as_CPMM() throws IOException { + // Distributed cluster with 8GB driver memory and 8GB executor memory -> cpmm operator in Spark + // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8GB + // t(X) %*% X -> (10^4)x(10^4) = 10^8 ~ 800MB + runTestTSMM("A.csv", 8L*1024*1024*1024, 2, 8L*1024*1024*1024, "cpmm", true); + } + + @Test + public void test_MM_RecompilationSequence() throws IOException { + Map<String, String> nvargs = new HashMap<>(); + nvargs.put("$X", HOME_DATA+"A.csv"); + nvargs.put("$Y", HOME_DATA+"B.csv"); + + // pre-compiled program using default values to be used as source for the recompilation + Program precompiledProgram = generateInitialProgram(HOME+"mm_test.dml", nvargs); + // original compilation used for comparison + Program expectedProgram; + + ResourceCompiler.setDriverConfigurations(8L*1024*1024*1024, driverThreads); + ResourceCompiler.setSingleNodeExecution(); + expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, 8L*1024*1024*1024, 0, -1, "ba+*", false); + + ResourceCompiler.setDriverConfigurations(16L*1024*1024*1024, driverThreads); + ResourceCompiler.setExecutorConfigurations(2, 1024*1024*1024, executorThreads); + expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, 16L*1024*1024*1024, 2, 1024*1024*1024, "ba+*", false); + + ResourceCompiler.setDriverConfigurations(4L*1024*1024*1024, driverThreads); + ResourceCompiler.setExecutorConfigurations(2, 4L*1024*1024*1024, executorThreads); + expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, 4L*1024*1024*1024, 2, 4L*1024*1024*1024, "mapmm", true); + + ResourceCompiler.setDriverConfigurations(1024*1024*1024, driverThreads); + ResourceCompiler.setExecutorConfigurations(2, (long) (0.5*1024*1024*1024), executorThreads); + expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, 1024*1024*1024, 2, (long) (0.5*1024*1024*1024), "rmm", true); + + ResourceCompiler.setDriverConfigurations(8L*1024*1024*1024, driverThreads); + ResourceCompiler.setSingleNodeExecution(); + expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, 8L*1024*1024*1024, 0, -1, "ba+*", false); + } + + // Helper functions ------------------------------------------------------------------------------------------------ + private Program generateInitialProgram(String filePath, Map<String, String> args) throws IOException { + ResourceCompiler.setDriverConfigurations(ResourceCompiler.DEFAULT_DRIVER_MEMORY, ResourceCompiler.DEFAULT_DRIVER_THREADS); + ResourceCompiler.setExecutorConfigurations(ResourceCompiler.DEFAULT_NUMBER_EXECUTORS, ResourceCompiler.DEFAULT_EXECUTOR_MEMORY, ResourceCompiler.DEFAULT_EXECUTOR_THREADS); + return ResourceCompiler.compile(filePath, args); + } + + private void runTestMM(String fileX, String fileY, long driverMemory, int numberExecutors, long executorMemory, String expectedOpcode) throws IOException { + boolean expectedSparkExecType = !Objects.equals(expectedOpcode,"ba+*"); + Map<String, String> nvargs = new HashMap<>(); + nvargs.put("$X", HOME_DATA+fileX); + nvargs.put("$Y", HOME_DATA+fileY); + + // pre-compiled program using default values to be used as source for the recompilation + Program precompiledProgram = generateInitialProgram(HOME+"mm_test.dml", nvargs); + + ResourceCompiler.setDriverConfigurations(driverMemory, driverThreads); + if (numberExecutors > 0) { + ResourceCompiler.setExecutorConfigurations(numberExecutors, executorMemory, executorThreads); + } else { + ResourceCompiler.setSingleNodeExecution(); + } + + // original compilation used for comparison + Program expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, expectedOpcode, expectedSparkExecType); + } + + private void runTestTSMM(String fileX, long driverMemory, int numberExecutors, long executorMemory, String expectedOpcode, boolean expectedSparkExecType) throws IOException { + Map<String, String> nvargs = new HashMap<>(); + nvargs.put("$X", HOME_DATA+fileX); + + // pre-compiled program using default values to be used as source for the recompilation + Program precompiledProgram = generateInitialProgram(HOME+"mm_transpose_test.dml", nvargs); + + ResourceCompiler.setDriverConfigurations(driverMemory, driverThreads); + if (numberExecutors > 0) { + ResourceCompiler.setExecutorConfigurations(numberExecutors, executorMemory, executorThreads); + } else { + ResourceCompiler.setSingleNodeExecution(); + } + // original compilation used for comparison + Program expectedProgram = ResourceCompiler.compile(HOME+"mm_transpose_test.dml", nvargs); + runTest(precompiledProgram, expectedProgram, driverMemory, numberExecutors, executorMemory, expectedOpcode, expectedSparkExecType); + } + + private void runTest(Program precompiledProgram, Program expectedProgram, long driverMemory, int numberExecutors, long executorMemory, String expectedOpcode, boolean expectedSparkExecType) { + String expectedProgramExplained = Explain.explain(expectedProgram); + + Program recompiledProgram; + if (numberExecutors == 0) { + recompiledProgram = ResourceCompiler.doFullRecompilation(precompiledProgram, driverMemory, driverThreads); + } else { + recompiledProgram = ResourceCompiler.doFullRecompilation(precompiledProgram, driverMemory, driverThreads, numberExecutors, executorMemory, executorThreads); + } + String actualProgramExplained = Explain.explain(recompiledProgram); + + if (DEBUG_MODE) System.out.println(actualProgramExplained); + Assert.assertEquals(expectedProgramExplained, actualProgramExplained); + Optional<Instruction> mmInstruction = ((BasicProgramBlock) recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream() + .filter(inst -> (Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) && Objects.equals(inst.getOpcode(), expectedOpcode))) + .findFirst(); + Assert.assertTrue(mmInstruction.isPresent()); + } +} diff --git a/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java new file mode 100644 index 0000000000..035fac6ab1 --- /dev/null +++ b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java @@ -0,0 +1,71 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.test.component.resource; + +import org.apache.sysds.resource.CloudInstance; +import org.junit.Assert; + +import java.io.File; +import java.io.IOException; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; + +import static org.apache.sysds.resource.CloudUtils.GBtoBytes; + +public class TestingUtils { + public static void assertEqualsCloudInstances(CloudInstance expected, CloudInstance actual) { + Assert.assertEquals(expected.getInstanceName(), actual.getInstanceName()); + Assert.assertEquals(expected.getMemory(), actual.getMemory()); + Assert.assertEquals(expected.getVCPUs(), actual.getVCPUs()); + Assert.assertEquals(expected.getFLOPS(), actual.getFLOPS()); + Assert.assertEquals(expected.getMemorySpeed(), actual.getMemorySpeed(), 0.0); + Assert.assertEquals(expected.getDiskSpeed(), actual.getDiskSpeed(), 0.0); + Assert.assertEquals(expected.getNetworkSpeed(), actual.getNetworkSpeed(), 0.0); + Assert.assertEquals(expected.getPrice(), actual.getPrice(), 0.0); + + } + + public static HashMap<String, CloudInstance> getSimpleCloudInstanceMap() { + HashMap<String, CloudInstance> instanceMap = new HashMap<>(); + // fill the map wsearchStrategyh enough cloud instances to allow testing all search space dimension searchStrategyerations + instanceMap.put("m5.xlarge", new CloudInstance("m5.xlarge", GBtoBytes(16), 4, 0.5, 0.0, 143.75, 160, 1.5)); + instanceMap.put("m5.2xlarge", new CloudInstance("m5.2xlarge", GBtoBytes(32), 8, 1.0, 0.0, 0.0, 0.0, 1.9)); + instanceMap.put("c5.xlarge", new CloudInstance("c5.xlarge", GBtoBytes(8), 4, 0.5, 0.0, 0.0, 0.0, 1.7)); + instanceMap.put("c5.2xlarge", new CloudInstance("c5.2xlarge", GBtoBytes(16), 8, 1.0, 0.0, 0.0, 0.0, 2.1)); + + return instanceMap; + } + + public static File generateTmpInstanceInfoTableFile() throws IOException { + File tmpFile = File.createTempFile("systemds_tmp", ".csv"); + + List<String> csvLines = Arrays.asList( + "API_Name,Memory,vCPUs,gFlops,ramSpeed,diskSpeed,networkSpeed,Price", + "m5.xlarge,16.0,4,0.5,0,143.75,160,1.5", + "m5.2xlarge,32.0,8,1.0,0,0,0,1.9", + "c5.xlarge,8.0,4,0.5,0,0,0,1.7", + "c5.2xlarge,16.0,8,1.0,0,0,0,2.1" + ); + Files.write(tmpFile.toPath(), csvLines); + return tmpFile; + } +} diff --git a/src/test/scripts/component/resource/data/A.csv b/src/test/scripts/component/resource/data/A.csv new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/test/scripts/component/resource/data/A.csv.mtd b/src/test/scripts/component/resource/data/A.csv.mtd new file mode 100644 index 0000000000..9e43aec450 --- /dev/null +++ b/src/test/scripts/component/resource/data/A.csv.mtd @@ -0,0 +1,10 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 100000, + "cols": 10000, + "nnz": 1000000000, + "format": "csv", + "header": false, + "sep": "," +} \ No newline at end of file diff --git a/src/test/scripts/component/resource/data/B.csv b/src/test/scripts/component/resource/data/B.csv new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/test/scripts/component/resource/data/B.csv.mtd b/src/test/scripts/component/resource/data/B.csv.mtd new file mode 100644 index 0000000000..db7bd64569 --- /dev/null +++ b/src/test/scripts/component/resource/data/B.csv.mtd @@ -0,0 +1,10 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 10000, + "cols": 1000, + "nnz": 10000000, + "format": "csv", + "header": false, + "sep": "," +} \ No newline at end of file diff --git a/src/test/scripts/component/resource/data/C.csv b/src/test/scripts/component/resource/data/C.csv new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/test/scripts/component/resource/data/C.csv.mtd b/src/test/scripts/component/resource/data/C.csv.mtd new file mode 100644 index 0000000000..6b218bdeca --- /dev/null +++ b/src/test/scripts/component/resource/data/C.csv.mtd @@ -0,0 +1,10 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 10000, + "cols": 10000, + "nnz": 100000000, + "format": "csv", + "header": false, + "sep": "," +} \ No newline at end of file diff --git a/src/test/scripts/component/resource/data/D.csv b/src/test/scripts/component/resource/data/D.csv new file mode 100644 index 0000000000..e69de29bb2 diff --git a/src/test/scripts/component/resource/data/D.csv.mtd b/src/test/scripts/component/resource/data/D.csv.mtd new file mode 100644 index 0000000000..8552d1cdc0 --- /dev/null +++ b/src/test/scripts/component/resource/data/D.csv.mtd @@ -0,0 +1,10 @@ +{ + "data_type": "matrix", + "value_type": "double", + "rows": 100000, + "cols": 1000, + "nnz": 100000000, + "format": "csv", + "header": false, + "sep": "," +} \ No newline at end of file diff --git a/src/test/scripts/component/resource/mm_test.dml b/src/test/scripts/component/resource/mm_test.dml new file mode 100644 index 0000000000..88bcfc2f2a --- /dev/null +++ b/src/test/scripts/component/resource/mm_test.dml @@ -0,0 +1,34 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +fileX = $X; +fileY = $Y; +# R - virtual result +fileR = "R.csv"; +fmtR = "csv"; + +X = read(fileX); +Y = read(fileY); + +R = X%*%Y; + +# trigger full calculation +write(R, fileR, fmtR); \ No newline at end of file diff --git a/src/test/scripts/component/resource/mm_transpose_test.dml b/src/test/scripts/component/resource/mm_transpose_test.dml new file mode 100644 index 0000000000..5cdf451b6f --- /dev/null +++ b/src/test/scripts/component/resource/mm_transpose_test.dml @@ -0,0 +1,32 @@ +#------------------------------------------------------------- +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# +#------------------------------------------------------------- + +fileX = $X; +# R - virtual result +fileR = "R.csv"; +fmtR = "csv"; + +X = read(fileX); + +R = t(X)%*%X; + +# trigger full calculation +write(R, fileR, fmtR); \ No newline at end of file