This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 321174ed8e [SYSTEMDS-1780] Resource elasticity: basic enumerators and
recompiler
321174ed8e is described below
commit 321174ed8eff4fd7c0ae53f83258b88057d8ad80
Author: lachezar-n <[email protected]>
AuthorDate: Wed Aug 21 19:45:59 2024 +0200
[SYSTEMDS-1780] Resource elasticity: basic enumerators and recompiler
Closes #2067.
---
.../java/org/apache/sysds/conf/CompilerConfig.java | 5 +-
.../java/org/apache/sysds/hops/AggBinaryOp.java | 4 +-
src/main/java/org/apache/sysds/hops/Hop.java | 4 +
.../java/org/apache/sysds/hops/OptimizerUtils.java | 17 +-
.../java/org/apache/sysds/lops/compile/Dag.java | 6 +
.../java/org/apache/sysds/resource/AWSUtils.java | 66 +++
.../org/apache/sysds/resource/CloudInstance.java | 99 ++++
.../java/org/apache/sysds/resource/CloudUtils.java | 132 +++++
.../apache/sysds/resource/ResourceCompiler.java | 248 ++++++++++
.../apache/sysds/resource/cost/CostEstimator.java | 5 +-
.../resource/enumeration/EnumerationUtils.java | 113 +++++
.../sysds/resource/enumeration/Enumerator.java | 522 ++++++++++++++++++++
.../resource/enumeration/GridBasedEnumerator.java | 89 ++++
.../enumeration/InterestBasedEnumerator.java | 315 ++++++++++++
.../context/SparkExecutionContext.java | 22 +-
.../parfor/stat/InfrastructureAnalyzer.java | 8 +-
.../test/component/resource/CloudUtilsTests.java | 118 +++++
.../test/component/resource/EnumeratorTests.java | 529 +++++++++++++++++++++
.../test/component/resource/RecompilationTest.java | 258 ++++++++++
.../test/component/resource/TestingUtils.java | 71 +++
src/test/scripts/component/resource/data/A.csv | 0
src/test/scripts/component/resource/data/A.csv.mtd | 10 +
src/test/scripts/component/resource/data/B.csv | 0
src/test/scripts/component/resource/data/B.csv.mtd | 10 +
src/test/scripts/component/resource/data/C.csv | 0
src/test/scripts/component/resource/data/C.csv.mtd | 10 +
src/test/scripts/component/resource/data/D.csv | 0
src/test/scripts/component/resource/data/D.csv.mtd | 10 +
src/test/scripts/component/resource/mm_test.dml | 34 ++
.../component/resource/mm_transpose_test.dml | 32 ++
30 files changed, 2701 insertions(+), 36 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/CompilerConfig.java
b/src/main/java/org/apache/sysds/conf/CompilerConfig.java
index 9728bd2769..800fe575c9 100644
--- a/src/main/java/org/apache/sysds/conf/CompilerConfig.java
+++ b/src/main/java/org/apache/sysds/conf/CompilerConfig.java
@@ -78,7 +78,9 @@ public class CompilerConfig
CODEGEN_ENABLED,
//federated runtime conversion
- FEDERATED_RUNTIME;
+ FEDERATED_RUNTIME,
+ // resource optimization mode
+ RESOURCE_OPTIMIZATION;
}
//default flags (exposed for testing purposes only)
@@ -107,6 +109,7 @@ public class CompilerConfig
_bmap.put(ConfigType.REJECT_READ_WRITE_UNKNOWNS, true);
_bmap.put(ConfigType.MLCONTEXT, false);
_bmap.put(ConfigType.CODEGEN_ENABLED, false);
+ _bmap.put(ConfigType.RESOURCE_OPTIMIZATION, false);
_imap = new HashMap<>();
_imap.put(ConfigType.BLOCK_SIZE,
OptimizerUtils.DEFAULT_BLOCKSIZE);
diff --git a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
index 3b09984179..640ababbad 100644
--- a/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
+++ b/src/main/java/org/apache/sysds/hops/AggBinaryOp.java
@@ -1229,7 +1229,7 @@ public class AggBinaryOp extends MultiThreadedHop {
double m2_size = m2_rows * m2_cols;
double result_size = m1_rows * m2_cols;
- int numReducersRMM = OptimizerUtils.getNumReducers(true);
+ int numReducersRMM = OptimizerUtils.getNumTasks();
// Estimate the cost of RMM
// RMM phase 1
@@ -1256,7 +1256,7 @@ public class AggBinaryOp extends MultiThreadedHop {
double m2_size = m2_rows * m2_cols;
double result_size = m1_rows * m2_cols;
- int numReducersCPMM = OptimizerUtils.getNumReducers(false);
+ int numReducersCPMM = OptimizerUtils.getNumTasks();
// Estimate the cost of CPMM
// CPMM phase 1
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 93501efa0d..44930604b3 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -37,6 +37,7 @@ import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.OpOpData;
import org.apache.sysds.common.Types.ValueType;
+import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.cost.ComputeCost;
import org.apache.sysds.hops.recompile.Recompiler;
@@ -281,6 +282,9 @@ public abstract class Hop implements ParseInfo {
}
else if ( DMLScript.getGlobalExecMode() == ExecMode.SPARK )
_etypeForced = ExecType.SPARK; // enabled with -exec
spark option
+ else if ( DMLScript.getGlobalExecMode() == ExecMode.HYBRID
+ &&
ConfigurationManager.getCompilerConfigFlag(ConfigType.RESOURCE_OPTIMIZATION))
+ _etypeForced = null;
}
public void checkAndSetInvalidCPDimsAndSize()
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 0a37570ee8..9e787de8c0 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -658,20 +658,13 @@ public class OptimizerUtils
}
/**
- * Returns the number of reducers that potentially run in parallel.
+ * Returns the number of tasks that potentially run in parallel.
* This is either just the configured value (SystemDS config) or
- * the minimum of configured value and available reduce slots.
- *
- * @param configOnly true if configured value
- * @return number of reducers
+ * the minimum of configured value and available task slots.
+ *
+ * @return number of tasks
*/
- public static int getNumReducers( boolean configOnly ) {
- if( isSparkExecutionMode() )
- return
SparkExecutionContext.getDefaultParallelism(false);
- return InfrastructureAnalyzer.getLocalParallelism();
- }
-
- public static int getNumMappers() {
+ public static int getNumTasks() {
if( isSparkExecutionMode() )
return
SparkExecutionContext.getDefaultParallelism(false);
return InfrastructureAnalyzer.getLocalParallelism();
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index b26c539e9a..f67cb74cd3 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -147,6 +147,12 @@ public class Dag<N extends Lop>
dt.isFrame() ? Lop.FRAME_VAR_NAME_PREFIX :
Lop.SCALAR_VAR_NAME_PREFIX) + var_index.getNextID();
}
+
+ // to be used only resource optimization
+ public static void resetUniqueMembers() {
+ job_id.reset(-1);
+ var_index.reset(-1);
+ }
///////
// Dag modifications
diff --git a/src/main/java/org/apache/sysds/resource/AWSUtils.java
b/src/main/java/org/apache/sysds/resource/AWSUtils.java
new file mode 100644
index 0000000000..7d3ab6409a
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/AWSUtils.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource;
+
+import org.apache.sysds.resource.enumeration.EnumerationUtils;
+
+public class AWSUtils extends CloudUtils {
+ public static final String EC2_REGEX =
"^([a-z]+)([0-9])(a|g|i?)([bdnez]*)\\.([a-z0-9]+)$";
+ @Override
+ public boolean validateInstanceName(String input) {
+ String instanceName = input.toLowerCase();
+ if (!instanceName.toLowerCase().matches(EC2_REGEX)) return
false;
+ try {
+ getInstanceType(instanceName);
+ getInstanceSize(instanceName);
+ } catch (IllegalArgumentException e) {
+ return false;
+ }
+ return true;
+ }
+
+ @Override
+ public InstanceType getInstanceType(String instanceName) {
+ String typeAsString = instanceName.split("\\.")[0];
+ // throws exception if string value is not valid
+ return InstanceType.customValueOf(typeAsString);
+ }
+
+ @Override
+ public InstanceSize getInstanceSize(String instanceName) {
+ String sizeAsString = instanceName.split("\\.")[1];
+ // throws exception if string value is not valid
+ return InstanceSize.customValueOf(sizeAsString);
+ }
+
+ @Override
+ public double calculateClusterPrice(EnumerationUtils.ConfigurationPoint
config, double time) {
+ double pricePerSeconds = getClusterCostPerHour(config) / 3600;
+ return time * pricePerSeconds;
+ }
+
+ private double
getClusterCostPerHour(EnumerationUtils.ConfigurationPoint config) {
+ if (config.numberExecutors == 0) {
+ return config.driverInstance.getPrice();
+ }
+ return config.driverInstance.getPrice() +
+
config.executorInstance.getPrice()*config.numberExecutors;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/resource/CloudInstance.java
b/src/main/java/org/apache/sysds/resource/CloudInstance.java
new file mode 100644
index 0000000000..d740e35b80
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/CloudInstance.java
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource;
+
+/**
+ * This class describes the configurations of a single VM instance.
+ * The idea is to use this class to represent instances of different
+ * cloud hypervisors - currently supporting only EC2 instances by AWS.
+ */
+public class CloudInstance {
+ private final String instanceName;
+ private final long memory;
+ private final int vCPUCores;
+ private final double pricePerHour;
+ private final double gFlops;
+ private final double memorySpeed;
+ private final double diskSpeed;
+ private final double networkSpeed;
+ public CloudInstance(String instanceName, long memory, int vCPUCores,
double gFlops, double memorySpeed, double diskSpeed, double networkSpeed,
double pricePerHour) {
+ this.instanceName = instanceName;
+ this.memory = memory;
+ this.vCPUCores = vCPUCores;
+ this.gFlops = gFlops;
+ this.memorySpeed = memorySpeed;
+ this.diskSpeed = diskSpeed;
+ this.networkSpeed = networkSpeed;
+ this.pricePerHour = pricePerHour;
+ }
+
+ public String getInstanceName() {
+ return instanceName;
+ }
+
+ /**
+ * @return memory of the instance in B
+ */
+ public long getMemory() {
+ return memory;
+ }
+
+ /**
+ * @return number of virtual CPU cores of the instance
+ */
+ public int getVCPUs() {
+ return vCPUCores;
+ }
+
+ /**
+ * @return price per hour of the instance
+ */
+ public double getPrice() {
+ return pricePerHour;
+ }
+
+ /**
+ * @return number of FLOPS of the instance
+ */
+ public long getFLOPS() {
+ return (long) (gFlops*1024)*1024*1024;
+ }
+
+ /**
+ * @return memory speed/bandwidth of the instance in MB/s
+ */
+ public double getMemorySpeed() {
+ return memorySpeed;
+ }
+
+ /**
+ * @return isk speed/bandwidth of the instance in MB/s
+ */
+ public double getDiskSpeed() {
+ return diskSpeed;
+ }
+
+ /**
+ * @return network speed/bandwidth of the instance in MB/s
+ */
+ public double getNetworkSpeed() {
+ return networkSpeed;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/resource/CloudUtils.java
b/src/main/java/org/apache/sysds/resource/CloudUtils.java
new file mode 100644
index 0000000000..3a1e38f422
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/CloudUtils.java
@@ -0,0 +1,132 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource;
+
+import org.apache.sysds.resource.enumeration.EnumerationUtils;
+
+import java.io.BufferedReader;
+import java.io.FileReader;
+import java.io.IOException;
+import java.util.HashMap;
+
+public abstract class CloudUtils {
+ public enum CloudProvider {
+ AWS // potentially AZURE, GOOGLE
+ }
+ public enum InstanceType {
+ // AWS EC2 instance
+ M5, M5A, M6I, M6A, M6G, M7I, M7A, M7G, // general purpose -
vCores:mem~=1:4
+ C5, C5A, C6I, C6A, C6G, C7I, C7A, C7G, // compute optimized -
vCores:mem~=1:2
+ R5, R5A, R6I, R6A, R6G, R7I, R7A, R7G; // memory optimized -
vCores:mem~=1:8
+ // Potentially VM instance types for different Cloud providers
+
+ public static InstanceType customValueOf(String name) {
+ return InstanceType.valueOf(name.toUpperCase());
+ }
+ }
+
+ public enum InstanceSize {
+ _XLARGE, _2XLARGE, _4XLARGE, _8XLARGE, _12XLARGE, _16XLARGE,
_24XLARGE, _32XLARGE, _48XLARGE;
+ // Potentially VM instance sizes for different Cloud providers
+
+ public static InstanceSize customValueOf(String name) {
+ return InstanceSize.valueOf("_"+name.toUpperCase());
+ }
+ }
+
+ public static final double MINIMAL_EXECUTION_TIME = 120; // seconds;
NOTE: set always equal or higher than DEFAULT_CLUSTER_LAUNCH_TIME
+
+ public static final double DEFAULT_CLUSTER_LAUNCH_TIME = 120; //
seconds; NOTE: set always to at least 60 seconds
+
+ public static long GBtoBytes(double gb) {
+ return (long) (gb * 1024 * 1024 * 1024);
+ }
+ public abstract boolean validateInstanceName(String instanceName);
+ public abstract InstanceType getInstanceType(String instanceName);
+ public abstract InstanceSize getInstanceSize(String instanceName);
+
+ /**
+ * This method calculates the cluster price based on the
+ * estimated execution time and the cluster configuration.
+ * @param config the cluster configuration for the calculation
+ * @param time estimated execution time in seconds
+ * @return price for the given time
+ */
+ public abstract double
calculateClusterPrice(EnumerationUtils.ConfigurationPoint config, double time);
+
+ /**
+ * Performs read of csv file filled with VM instance characteristics.
+ * Each record in the csv should carry the following information
(including header):
+ * <li>API_Name - naming for VM instance used by the provider</li>
+ * <li>Memory - floating number for the instance memory in GBs</li>
+ * <li>vCPUs - number of physical threads</li>
+ * <li>gFlops - FLOPS capability of the CPU in GFLOPS (Giga)</li>
+ * <li>ramSpeed - memory bandwidth in MB/s</li>
+ * <li>diskSpeed - memory bandwidth in MB/s</li>
+ * <li>networkSpeed - memory bandwidth in MB/s</li>
+ * <li>Price - price for instance per hour</li>
+ * @param instanceTablePath csv file
+ * @return map with filtered instances
+ * @throws IOException in case problem at reading the csv file
+ */
+ public HashMap<String, CloudInstance> loadInstanceInfoTable(String
instanceTablePath) throws IOException {
+ HashMap<String, CloudInstance> result = new HashMap<>();
+ int lineCount = 1;
+ // try to open the file
+ try(BufferedReader br = new BufferedReader(new
FileReader(instanceTablePath))){
+ String parsedLine;
+ // validate the file header
+ parsedLine = br.readLine();
+ if
(!parsedLine.equals("API_Name,Memory,vCPUs,gFlops,ramSpeed,diskSpeed,networkSpeed,Price"))
+ throw new IOException("Invalid CSV header
inside: " + instanceTablePath);
+
+
+ while ((parsedLine = br.readLine()) != null) {
+ String[] values = parsedLine.split(",");
+ if (values.length != 8 ||
!validateInstanceName(values[0]))
+ throw new
IOException(String.format("Invalid CSV line(%d) inside: %s", lineCount,
instanceTablePath));
+
+ String API_Name = values[0];
+ long Memory = (long)
(Double.parseDouble(values[1])*1024)*1024*1024;
+ int vCPUs = Integer.parseInt(values[2]);
+ double gFlops = Double.parseDouble(values[3]);
+ double ramSpeed = Double.parseDouble(values[4]);
+ double diskSpeed =
Double.parseDouble(values[5]);
+ double networkSpeed =
Double.parseDouble(values[6]);
+ double Price = Double.parseDouble(values[7]);
+
+ CloudInstance parsedInstance = new
CloudInstance(
+ API_Name,
+ Memory,
+ vCPUs,
+ gFlops,
+ ramSpeed,
+ diskSpeed,
+ networkSpeed,
+ Price
+ );
+ result.put(API_Name, parsedInstance);
+ lineCount++;
+ }
+ }
+
+ return result;
+ }
+}
diff --git a/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
new file mode 100644
index 0000000000..ab5c452b56
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/ResourceCompiler.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource;
+
+import org.apache.spark.SparkConf;
+import org.apache.sysds.api.DMLOptions;
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.common.Types;
+import org.apache.sysds.conf.CompilerConfig;
+import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.recompile.Recompiler;
+import org.apache.sysds.lops.Lop;
+import org.apache.sysds.lops.compile.Dag;
+import org.apache.sysds.lops.rewrite.LopRewriter;
+import org.apache.sysds.parser.*;
+import org.apache.sysds.runtime.controlprogram.*;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.instructions.Instruction;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+import static org.apache.sysds.api.DMLScript.*;
+
+/**
+ * This class does full or partial program recompilation
+ * based on given runtime program. It uses the methods provided
+ * by {@code hops.recompile.Recompiler}).
+ * It keeps a state of the current recompilation phase in order
+ * to decide when to do full recompilation and when not.
+ */
+public class ResourceCompiler {
+ public static final long DEFAULT_DRIVER_MEMORY = 512*1024*1024; // 0.5GB
+ public static final int DEFAULT_DRIVER_THREADS = 1; // 0.5GB
+ public static final long DEFAULT_EXECUTOR_MEMORY = 512*1024*1024; //
0.5GB
+ public static final int DEFAULT_EXECUTOR_THREADS = 1; // 0.5GB
+ public static final int DEFAULT_NUMBER_EXECUTORS = 1; // 0.5GB
+ static {
+ // TODO: consider moving to the executable of the resource
optimizer once implemented
+ USE_LOCAL_SPARK_CONFIG = true;
+
ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.ALLOW_DYN_RECOMPILATION,
false);
+
ConfigurationManager.getCompilerConfig().set(CompilerConfig.ConfigType.RESOURCE_OPTIMIZATION,
true);
+ }
+ private static final LopRewriter _lopRewriter = new LopRewriter();
+
+ public static Program compile(String filePath, Map<String, String>
args) throws IOException {
+ // setting the dynamic recompilation flags during resource
optimization is obsolete
+ DMLOptions dmlOptions =DMLOptions.defaultOptions;
+ dmlOptions.argVals = args;
+
+ String dmlScriptStr = readDMLScript(true, filePath);
+ Map<String, String> argVals = dmlOptions.argVals;
+
+ Dag.resetUniqueMembers();
+ // NOTE: skip configuring code generation
+ // NOTE: expects setting up the initial cluster configs before
calling
+ ParserWrapper parser = ParserFactory.createParser();
+ DMLProgram dmlProgram = parser.parse(null, dmlScriptStr,
argVals);
+ DMLTranslator dmlTranslator = new DMLTranslator(dmlProgram);
+ dmlTranslator.liveVariableAnalysis(dmlProgram);
+ dmlTranslator.validateParseTree(dmlProgram);
+ dmlTranslator.constructHops(dmlProgram);
+ dmlTranslator.rewriteHopsDAG(dmlProgram);
+ dmlTranslator.constructLops(dmlProgram);
+ dmlTranslator.rewriteLopDAG(dmlProgram);
+ return dmlTranslator.getRuntimeProgram(dmlProgram,
ConfigurationManager.getDMLConfig());
+ }
+
+ private static ArrayList<Instruction> recompile(StatementBlock sb,
ArrayList<Hop> hops) {
+ // construct new lops
+ ArrayList<Lop> lops = new ArrayList<>(hops.size());
+ Hop.resetVisitStatus(hops);
+ for( Hop hop : hops ){
+ Recompiler.rClearLops(hop);
+ lops.add(hop.constructLops());
+ }
+ // apply hop-lop rewrites to cover the case of changed lop
operators
+ _lopRewriter.rewriteLopDAG(sb, lops);
+
+ Dag<Lop> dag = new Dag<>();
+ for (Lop l : lops) {
+ l.addToDag(dag);
+ }
+
+ return dag.getJobs(sb, ConfigurationManager.getDMLConfig());
+ }
+
+ /**
+ * Recompiling a given program for resource optimization for single
node execution
+ * @param program program to be recompiled
+ * @param driverMemory target driver memory
+ * @param driverCores target driver threads/cores
+ * @return the recompiled program as a new {@code Program} instance
+ */
+ public static Program doFullRecompilation(Program program, long
driverMemory, int driverCores) {
+ setDriverConfigurations(driverMemory, driverCores);
+ setSingleNodeExecution();
+ return doFullRecompilation(program);
+ }
+
+ /**
+ * Recompiling a given program for resource optimization for Spark
execution
+ * @param program program to be recompiled
+ * @param driverMemory target driver memory
+ * @param driverCores target driver threads/cores
+ * @param numberExecutors target number of executor nodes
+ * @param executorMemory target executor memory
+ * @param executorCores target executor threads/cores
+ * @return the recompiled program as a new {@code Program} instance
+ */
+ public static Program doFullRecompilation(Program program, long
driverMemory, int driverCores, int numberExecutors, long executorMemory, int
executorCores) {
+ setDriverConfigurations(driverMemory, driverCores);
+ setExecutorConfigurations(numberExecutors, executorMemory,
executorCores);
+ return doFullRecompilation(program);
+ }
+
+ private static Program doFullRecompilation(Program program) {
+ Dag.resetUniqueMembers();
+ Program newProgram = new Program();
+ ArrayList<ProgramBlock> B = Stream.concat(
+
program.getProgramBlocks().stream(),
+
program.getFunctionProgramBlocks().values().stream())
+
.collect(Collectors.toCollection(ArrayList::new));
+ doRecompilation(B, newProgram);
+ return newProgram;
+ }
+
+ private static void doRecompilation(ArrayList<ProgramBlock> origin,
Program target) {
+ for (ProgramBlock originBlock : origin) {
+ doRecompilation(originBlock, target);
+ }
+ }
+
+ private static void doRecompilation(ProgramBlock originBlock, Program
target) {
+ if (originBlock instanceof FunctionProgramBlock)
+ {
+ FunctionProgramBlock fpb =
(FunctionProgramBlock)originBlock;
+ doRecompilation(fpb.getChildBlocks(), target);
+ }
+ else if (originBlock instanceof WhileProgramBlock)
+ {
+ WhileProgramBlock wpb = (WhileProgramBlock)originBlock;
+ WhileStatementBlock sb = (WhileStatementBlock)
originBlock.getStatementBlock();
+ if(sb!=null && sb.getPredicateHops()!=null ){
+ ArrayList<Instruction> inst =
Recompiler.recompileHopsDag(sb.getPredicateHops(), null, null, true, true, 0);
+ wpb.setPredicate(inst);
+ target.addProgramBlock(wpb);
+ }
+ doRecompilation(wpb.getChildBlocks(), target);
+ }
+ else if (originBlock instanceof IfProgramBlock)
+ {
+ IfProgramBlock ipb = (IfProgramBlock)originBlock;
+ IfStatementBlock sb = (IfStatementBlock)
ipb.getStatementBlock();
+ if(sb!=null && sb.getPredicateHops()!=null ){
+ ArrayList<Instruction> inst =
Recompiler.recompileHopsDag(sb.getPredicateHops(), null, null, true, true, 0);
+ ipb.setPredicate(inst);
+ target.addProgramBlock(ipb);
+ }
+ doRecompilation(ipb.getChildBlocksIfBody(), target);
+ doRecompilation(ipb.getChildBlocksElseBody(), target);
+ }
+ else if (originBlock instanceof ForProgramBlock) //incl parfor
+ {
+ ForProgramBlock fpb = (ForProgramBlock)originBlock;
+ ForStatementBlock sb = (ForStatementBlock)
fpb.getStatementBlock();
+ if(sb!=null){
+ if( sb.getFromHops()!=null ){
+ ArrayList<Instruction> inst =
Recompiler.recompileHopsDag(sb.getFromHops(), null, null, true, true, 0);
+ fpb.setFromInstructions( inst );
+ }
+ if(sb.getToHops()!=null){
+ ArrayList<Instruction> inst =
Recompiler.recompileHopsDag(sb.getToHops(), null, null, true, true, 0);
+ fpb.setToInstructions( inst );
+ }
+ if(sb.getIncrementHops()!=null){
+ ArrayList<Instruction> inst =
Recompiler.recompileHopsDag(sb.getIncrementHops(), null, null, true, true, 0);
+ fpb.setIncrementInstructions(inst);
+ }
+ target.addProgramBlock(fpb);
+
+ }
+ doRecompilation(fpb.getChildBlocks(), target);
+ }
+ else
+ {
+ BasicProgramBlock bpb = (BasicProgramBlock)originBlock;
+ StatementBlock sb = bpb.getStatementBlock();
+ ArrayList<Instruction> inst = recompile(sb,
sb.getHops());
+ bpb.setInstructions(inst);
+ target.addProgramBlock(bpb);
+ }
+ }
+
+ public static void setDriverConfigurations(long nodeMemory, int
nodeNumCores) {
+ // TODO: think of reasonable factor for the JVM heap as prt of
the node's memory
+ InfrastructureAnalyzer.setLocalMaxMemory(nodeMemory);
+ InfrastructureAnalyzer.setLocalPar(nodeNumCores);
+ }
+
+ public static void setExecutorConfigurations(int numExecutors, long
nodeMemory, int nodeNumCores) {
+ // TODO: think of reasonable factor for the JVM heap as prt of
the node's memory
+ if (numExecutors > 0) {
+ DMLScript.setGlobalExecMode(Types.ExecMode.HYBRID);
+ SparkConf sparkConf =
SparkExecutionContext.createSystemDSSparkConf();
+ // ------------------ Static Configurations
-------------------
+ // TODO: think how to avoid setting them every time
+ sparkConf.set("spark.master", "local[*]");
+ sparkConf.set("spark.app.name", "SystemDS");
+ sparkConf.set("spark.memory.useLegacyMode", "false");
+ // ------------------ Static Configurations
-------------------
+ // ------------------ Dynamic Configurations
-------------------
+ sparkConf.set("spark.executor.memory",
(nodeMemory/(1024*1024))+"m");
+ sparkConf.set("spark.executor.instances",
Integer.toString(numExecutors));
+ sparkConf.set("spark.executor.cores",
Integer.toString(nodeNumCores));
+ // ------------------ Dynamic Configurations
-------------------
+ SparkExecutionContext.initLocalSparkContext(sparkConf);
+ } else {
+ throw new RuntimeException("The given number of
executors was 0");
+ }
+ }
+
+ public static void setSingleNodeExecution() {
+ DMLScript.setGlobalExecMode(Types.ExecMode.SINGLE_NODE);
+ }
+}
diff --git a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
index f2787c50d3..a5c5333c44 100644
--- a/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
+++ b/src/main/java/org/apache/sysds/resource/cost/CostEstimator.java
@@ -161,7 +161,8 @@ public class CostEstimator
ForProgramBlock tmp = (ForProgramBlock)pb;
for( ProgramBlock pb2 : tmp.getChildBlocks() )
ret += getTimeEstimatePB(pb2);
-
+ // NOTE: currently ParFor blocks are handled as regular
for block
+ // what could lead to very inaccurate estimation in
case of complex ParFor blocks
ret *= OptimizerUtils.getNumIterations(tmp,
DEFAULT_NUMITER);
}
else if ( pb instanceof FunctionProgramBlock ) {
@@ -960,7 +961,7 @@ public class CostEstimator
private void putInMemory(VarStats input) throws CostEstimationException
{
long sizeEstimate = OptimizerUtils.estimateSize(input._mc);
if (sizeEstimate + usedMememory > localMemory)
- throw new CostEstimationException("Insufficient local
memory for ");
+ throw new CostEstimationException("Insufficient local
memory");
usedMememory += sizeEstimate;
input._memory = sizeEstimate;
}
diff --git
a/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
new file mode 100644
index 0000000000..fa075dd9d9
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/enumeration/EnumerationUtils.java
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource.enumeration;
+
+import org.apache.sysds.resource.CloudInstance;
+
+import java.util.HashMap;
+import java.util.LinkedList;
+import java.util.TreeMap;
+
+public class EnumerationUtils {
+ /**
+ * Data structure representing a projected search space for
+ * VM instances as node's memory mapped to further maps with
+ * the node's numbers of cores for the given memory
+ * mapped to a list of unique object of type {@code CloudInstance}
+ * which have this corresponding characteristics (memory and cores).
+ * The higher layer keep the memory since is more significant
+ * for the program compilation. The lower map level contains
+ * the different options for number of core for the memory that
+ * this map data structure is being mapped to. The last layer
+ * of LinkedLists represents the unique VM instances in lists
+ * since the memory - cores combinations is often not unique.
+ * The {@code CloudInstance} objects are unique over the whole
+ * set of lists within this lowest level of the search space.
+ * <br></br>
+ * This representation allows compact storing of VM instance
+ * characteristics relevant for program compilation while
+ * still keeping a reference to the object carrying the
+ * whole instance information, relevant for cost estimation.
+ * <br></br>
+ * {@code TreeMap} data structures are used as building blocks for
+ * the complex search space structure to ensure ascending order
+ * of the instance characteristics - memory and number of cores.
+ */
+ public static class InstanceSearchSpace extends TreeMap<Long,
TreeMap<Integer, LinkedList<CloudInstance>>> {
+ private static final long serialVersionUID =
-8855424955793322839L;
+
+ public void initSpace(HashMap<String, CloudInstance> instances)
{
+ for (CloudInstance instance: instances.values()) {
+ long currentMemory = instance.getMemory();
+
+ this.putIfAbsent(currentMemory, new
TreeMap<>());
+ TreeMap<Integer, LinkedList<CloudInstance>>
currentSubTree = this.get(currentMemory);
+
+ currentSubTree.putIfAbsent(instance.getVCPUs(),
new LinkedList<>());
+ LinkedList<CloudInstance> currentList =
currentSubTree.get(instance.getVCPUs());
+
+ currentList.add(instance);
+ }
+ }
+ }
+
+ /**
+ * Simple data structure to hold cluster configurations
+ */
+ public static class ConfigurationPoint {
+ public CloudInstance driverInstance;
+ public CloudInstance executorInstance;
+ public int numberExecutors;
+
+ public ConfigurationPoint(CloudInstance driverInstance) {
+ this.driverInstance = driverInstance;
+ this.executorInstance = null;
+ this.numberExecutors = 0;
+ }
+
+ public ConfigurationPoint(CloudInstance driverInstance,
CloudInstance executorInstance, int numberExecutors) {
+ this.driverInstance = driverInstance;
+ this.executorInstance = executorInstance;
+ this.numberExecutors = numberExecutors;
+ }
+ }
+
+ /**
+ * Data structure to hold all data related to cost estimation
+ */
+ public static class SolutionPoint extends ConfigurationPoint {
+ double timeCost;
+ double monetaryCost;
+
+ public SolutionPoint(ConfigurationPoint inputPoint, double
timeCost, double monetaryCost) {
+ super(inputPoint.driverInstance,
inputPoint.executorInstance, inputPoint.numberExecutors);
+ this.timeCost = timeCost;
+ this.monetaryCost = monetaryCost;
+ }
+
+ public void update(ConfigurationPoint point, double timeCost,
double monetaryCost) {
+ this.driverInstance = point.driverInstance;
+ this.executorInstance = point.executorInstance;
+ this.numberExecutors = point.numberExecutors;
+ this.timeCost = timeCost;
+ this.monetaryCost = monetaryCost;
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
new file mode 100644
index 0000000000..2147dfc368
--- /dev/null
+++ b/src/main/java/org/apache/sysds/resource/enumeration/Enumerator.java
@@ -0,0 +1,522 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+
+package org.apache.sysds.resource.enumeration;
+
+import org.apache.sysds.resource.AWSUtils;
+import org.apache.sysds.resource.CloudInstance;
+import org.apache.sysds.resource.CloudUtils;
+import org.apache.sysds.resource.ResourceCompiler;
+import org.apache.sysds.resource.cost.CostEstimationException;
+import org.apache.sysds.resource.cost.CostEstimator;
+import org.apache.sysds.runtime.controlprogram.Program;
+import
org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace;
+import
org.apache.sysds.resource.enumeration.EnumerationUtils.ConfigurationPoint;
+import org.apache.sysds.resource.enumeration.EnumerationUtils.SolutionPoint;
+
+import java.io.IOException;
+import java.util.*;
+
+public abstract class Enumerator {
+
+ public enum EnumerationStrategy {
+ GridBased, // considering all combination within a given range
of configuration
+ InterestBased, // considering only combinations of
configurations with memory budge close to memory estimates
+ }
+
+ public enum OptimizationStrategy {
+ MinTime, // always prioritize execution time minimization
+ MinPrice, // always prioritize operation price minimization
+ }
+
+ // Static variables
------------------------------------------------------------------------------------------------
+
+ public static final int DEFAULT_MIN_EXECUTORS = 0; // Single Node
execution allowed
+ /**
+ * A reasonable upper bound for the possible number of executors
+ * is required to set limits for the search space and to avoid
+ * evaluating cluster configurations that most probably would
+ * have too high distribution overhead
+ */
+ public static final int DEFAULT_MAX_EXECUTORS = 200;
+
+ // limit for the ratio number of executor and number
+ // of executor per executor
+ public static final int MAX_LEVEL_PARALLELISM = 1000;
+
+ /** Time/Monetary delta for considering optimal solutions as fraction */
+ public static final double COST_DELTA_FRACTION = 0.02;
+
+ // Instance variables
----------------------------------------------------------------------------------------------
+ HashMap<String, CloudInstance> instances = null;
+ Program program;
+ CloudUtils utils;
+ EnumerationStrategy enumStrategy;
+ OptimizationStrategy optStrategy;
+ private final double maxTime;
+ private final double maxPrice;
+ protected final int minExecutors;
+ protected final int maxExecutors;
+ protected final Set<CloudUtils.InstanceType> instanceTypesRange;
+ protected final Set<CloudUtils.InstanceSize> instanceSizeRange;
+
+ protected final InstanceSearchSpace driverSpace = new
InstanceSearchSpace();
+ protected final InstanceSearchSpace executorSpace = new
InstanceSearchSpace();
+ protected ArrayList<SolutionPoint> solutionPool = new ArrayList<>();
+
+ // Initialization functionality
------------------------------------------------------------------------------------
+
+ public Enumerator(Builder builder) {
+ if (builder.provider.equals(CloudUtils.CloudProvider.AWS)) {
+ utils = new AWSUtils();
+ } // as of now no other provider is supported
+ this.program = builder.program;
+ this.enumStrategy = builder.enumStrategy;
+ this.optStrategy = builder.optStrategy;
+ this.maxTime = builder.maxTime;
+ this.maxPrice = builder.maxPrice;
+ this.minExecutors = builder.minExecutors;
+ this.maxExecutors = builder.maxExecutors;
+ this.instanceTypesRange = builder.instanceTypesRange;
+ this.instanceSizeRange = builder.instanceSizeRange;
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public HashMap<String, CloudInstance> getInstances() {
+ return instances;
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public InstanceSearchSpace getDriverSpace() {
+ return driverSpace;
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public void setDriverSpace(InstanceSearchSpace inputSpace) {
+ driverSpace.putAll(inputSpace);
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public InstanceSearchSpace getExecutorSpace() {
+ return executorSpace;
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public void setExecutorSpace(InstanceSearchSpace inputSpace) {
+ executorSpace.putAll(inputSpace);
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public ArrayList<SolutionPoint> getSolutionPool() {
+ return solutionPool;
+ }
+
+ /**
+ * Meant to be used for testing purposes
+ */
+ public void setSolutionPool(ArrayList<SolutionPoint> solutionPool) {
+ this.solutionPool = solutionPool;
+ }
+
+ /**
+ * Setting the available VM instances manually.
+ * Meant to be used for testing purposes.
+ * @param inputInstances initialized map of instances
+ */
+ public void setInstanceTable(HashMap<String, CloudInstance>
inputInstances) {
+ instances = new HashMap<>();
+ for (String key: inputInstances.keySet()) {
+ if
(instanceTypesRange.contains(utils.getInstanceType(key))
+ &&
instanceSizeRange.contains(utils.getInstanceSize(key))) {
+ instances.put(key, inputInstances.get(key));
+ }
+ }
+ }
+
+ /**
+ * Loads the info table for the available VM instances
+ * and filters out the instances that are not contained
+ * in the set of allowed instance types and sizes.
+ *
+ * @param path csv file with instances' info
+ * @throws IOException in case the loading part fails at reading the
csv file
+ */
+ public void loadInstanceTableFile(String path) throws IOException {
+ HashMap<String, CloudInstance> allInstances =
utils.loadInstanceInfoTable(path);
+ instances = new HashMap<>();
+ for (String key: allInstances.keySet()) {
+ if
(instanceTypesRange.contains(utils.getInstanceType(key))
+ &&
instanceSizeRange.contains(utils.getInstanceSize(key))) {
+ instances.put(key, allInstances.get(key));
+ }
+ }
+ }
+
+ // Main functionality
----------------------------------------------------------------------------------------------
+
+ /**
+ * Called once to enumerate the search space for
+ * VM instances for driver or executor nodes.
+ * These instances are being represented as
+ */
+ public abstract void preprocessing();
+
+ /**
+ * Called once after preprocessing to fill the
+ * pool with optimal solutions by parsing
+ * the enumerated search space.
+ * Within its execution the number of potential
+ * executor nodes is being estimated (enumerated)
+ * dynamically for each parsed executor instance.
+ */
+ public void processing() {
+ ConfigurationPoint configurationPoint;
+ SolutionPoint optSolutionPoint = new SolutionPoint(
+ new ConfigurationPoint(null, null, -1),
+ Double.MAX_VALUE,
+ Double.MAX_VALUE
+ );
+ for (Map.Entry<Long, TreeMap<Integer,
LinkedList<CloudInstance>>> dMemoryEntry: driverSpace.entrySet()) {
+ // loop over the search space to enumerate the driver
configurations
+ for (Map.Entry<Integer, LinkedList<CloudInstance>>
dCoresEntry: dMemoryEntry.getValue().entrySet()) {
+ // single node execution mode
+ if
(evaluateSingleNodeExecution(dMemoryEntry.getKey())) {
+ program =
ResourceCompiler.doFullRecompilation(
+ program,
+ dMemoryEntry.getKey(),
+ dCoresEntry.getKey()
+ );
+ for (CloudInstance dInstance:
dCoresEntry.getValue()) {
+ configurationPoint = new
ConfigurationPoint(dInstance);
+
updateOptimalSolution(optSolutionPoint, configurationPoint);
+ }
+ }
+ // enumeration for distributed execution
+ for (Map.Entry<Long, TreeMap<Integer,
LinkedList<CloudInstance>>> eMemoryEntry: executorSpace.entrySet()) {
+ // loop over the search space to
enumerate the executor configurations
+ for (Map.Entry<Integer,
LinkedList<CloudInstance>> eCoresEntry: eMemoryEntry.getValue().entrySet()) {
+ List<Integer>
numberExecutorsSet = estimateRangeExecutors(eMemoryEntry.getKey(),
eCoresEntry.getKey());
+ // Spark execution mode
+ for (int numberExecutors:
numberExecutorsSet) {
+ // TODO: avoid full
recompilation when the driver memory is not changed
+ program =
ResourceCompiler.doFullRecompilation(
+ program,
+
dMemoryEntry.getKey(),
+
dCoresEntry.getKey(),
+
numberExecutors,
+
eMemoryEntry.getKey(),
+
eCoresEntry.getKey()
+ );
+ // TODO: avoid full
program cost estimation when the driver instance is not changed
+ for (CloudInstance
dInstance: dCoresEntry.getValue()) {
+ for
(CloudInstance eInstance: eCoresEntry.getValue()) {
+
configurationPoint = new ConfigurationPoint(dInstance, eInstance,
numberExecutors);
+
updateOptimalSolution(optSolutionPoint, configurationPoint);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Deciding in the overall best solution out
+ * of the filled pool of potential solutions
+ * after processing.
+ * @return - single optimal cluster configuration
+ */
+ public SolutionPoint postprocessing() {
+ if (solutionPool.isEmpty()) {
+ throw new RuntimeException("Calling postprocessing()
should follow calling processing()");
+ }
+ SolutionPoint optSolution = solutionPool.get(0);
+ double bestCost = Double.MAX_VALUE;
+ for (SolutionPoint solution: solutionPool) {
+ double combinedCost = solution.monetaryCost *
solution.timeCost;
+ if (combinedCost < bestCost) {
+ optSolution = solution;
+ bestCost = combinedCost;
+ } else if (combinedCost == bestCost) {
+ // the ascending order of the searching spaces
for driver and executor
+ // instances ensures that in case of equally
good optimal solutions
+ // the first one has at least resource
characteristics.
+ // This, however, is not valid for the number
of executors
+ if (solution.numberExecutors <
optSolution.numberExecutors) {
+ optSolution = solution;
+ bestCost = combinedCost;
+ }
+ }
+ }
+ return optSolution;
+ }
+
+ // Helper methods
--------------------------------------------------------------------------------------------------
+
+ public abstract boolean evaluateSingleNodeExecution(long driverMemory);
+
+ /**
+ * Estimates the minimum and maximum number of
+ * executors based on given VM instance characteristics
+ * and on the enumeration strategy
+ *
+ * @param executorMemory memory of currently considered executor
instance
+ * @param executorCores CPU of cores of currently considered executor
instance
+ * @return - [min, max]
+ */
+ public abstract ArrayList<Integer> estimateRangeExecutors(long
executorMemory, int executorCores);
+
+ /**
+ * Estimates the time cost for the current program based on the
+ * given cluster configurations and following this estimation
+ * it calculates the corresponding monetary cost.
+ * @param point - cluster configuration used for (re)compiling the
current program
+ * @return - [time cost, monetary cost]
+ */
+ private double[] getCostEstimate(ConfigurationPoint point) {
+ // get the estimated time cost
+ double timeCost;
+ try {
+ // estimate execution time of the current program
+ // TODO: pass further relevant cluster configurations
to cost estimator after extending it
+ // like for example: FLOPS, I/O and networking speed
+ timeCost = CostEstimator.estimateExecutionTime(program)
+ CloudUtils.DEFAULT_CLUSTER_LAUNCH_TIME;
+ } catch (CostEstimationException e) {
+ throw new RuntimeException(e.getMessage());
+ }
+ // calculate monetary cost
+ double monetaryCost = utils.calculateClusterPrice(point,
timeCost);
+ return new double[] {timeCost, monetaryCost}; // time cost,
monetary cost
+ }
+
+ /**
+ * Invokes the estimation of the time and monetary cost
+ * based on the compiled program and the given cluster configurations.
+ * Following the optimization strategy, the given current optimal
solution
+ * and the new cost estimation, it decides if the given cluster
configuration
+ * can be potential optimal solution having lower cost or such a cost
+ * that is negligibly higher than the current lowest one.
+ * @param currentOptimal solution point with the lowest cost
+ * @param newPoint new cluster configuration for estimation
+ */
+ private void updateOptimalSolution(SolutionPoint currentOptimal,
ConfigurationPoint newPoint) {
+ // TODO: clarify if setting max time and max price
simultaneously makes really sense
+ SolutionPoint newPotentialSolution;
+ boolean replaceCurrentOptimal = false;
+ double[] newCost = getCostEstimate(newPoint);
+ if (optStrategy == OptimizationStrategy.MinTime) {
+ if (newCost[1] > maxPrice || newCost[0] >=
currentOptimal.timeCost * (1 + COST_DELTA_FRACTION)) {
+ return;
+ }
+ if (newCost[0] < currentOptimal.timeCost)
replaceCurrentOptimal = true;
+ } else if (optStrategy == OptimizationStrategy.MinPrice) {
+ if (newCost[0] > maxTime || newCost[1] >=
currentOptimal.monetaryCost * (1 + COST_DELTA_FRACTION)) {
+ return;
+ }
+ if (newCost[1] < currentOptimal.monetaryCost)
replaceCurrentOptimal = true;
+ }
+ newPotentialSolution = new SolutionPoint(newPoint, newCost[0],
newCost[1]);
+ solutionPool.add(newPotentialSolution);
+ if (replaceCurrentOptimal) {
+ currentOptimal.update(newPoint, newCost[0], newCost[1]);
+ }
+ }
+
+ // Class builder
---------------------------------------------------------------------------------------------------
+
+ public static class Builder {
+ private final CloudUtils.CloudProvider provider =
CloudUtils.CloudProvider.AWS; // currently default and only choice
+ private Program program;
+ private EnumerationStrategy enumStrategy = null;
+ private OptimizationStrategy optStrategy = null;
+ private double maxTime = -1d;
+ private double maxPrice = -1d;
+ private int minExecutors = DEFAULT_MIN_EXECUTORS;
+ private int maxExecutors = DEFAULT_MAX_EXECUTORS;
+ private Set<CloudUtils.InstanceType> instanceTypesRange = null;
+ private Set<CloudUtils.InstanceSize> instanceSizeRange = null;
+
+ // GridBased specific
------------------------------------------------------------------------------------------
+ private int stepSizeExecutors = 1;
+ private int expBaseExecutors = -1; // flag for exp. increasing
number of executors if -1
+ // InterestBased specific
--------------------------------------------------------------------------------------
+ private boolean fitDriverMemory = true;
+ private boolean fitBroadcastMemory = true;
+ private boolean checkSingleNodeExecution = false;
+ private boolean fitCheckpointMemory = false;
+ public Builder() {}
+
+ public Builder withRuntimeProgram(Program program) {
+ this.program = program;
+ return this;
+ }
+
+ public Builder withEnumerationStrategy(EnumerationStrategy
strategy) {
+ this.enumStrategy = strategy;
+ return this;
+ }
+
+ public Builder withOptimizationStrategy(OptimizationStrategy
strategy) {
+ this.optStrategy = strategy;
+ return this;
+ }
+
+ public Builder withTimeLimit(double time) {
+ if (time < CloudUtils.MINIMAL_EXECUTION_TIME) {
+ throw new
IllegalArgumentException(CloudUtils.MINIMAL_EXECUTION_TIME +
+ "s is the minimum target
execution time.");
+ }
+ this.maxTime = time;
+ return this;
+ }
+
+ public Builder withBudget(double price) {
+ if (price <= 0) {
+ throw new IllegalArgumentException("The given
budget (target price) should be positive");
+ }
+ this.maxPrice = price;
+ return this;
+ }
+
+ public Builder withNumberExecutorsRange(int min, int max) {
+ this.minExecutors = min;
+ this.maxExecutors = max;
+ return this;
+ }
+
+ public Builder withInstanceTypeRange(String[] instanceTypes) {
+ this.instanceTypesRange =
typeRangeFromStrings(instanceTypes);
+ return this;
+ }
+
+ public Builder withInstanceSizeRange(String[] instanceSizes) {
+ this.instanceSizeRange =
sizeRangeFromStrings(instanceSizes);
+ return this;
+ }
+
+ public Builder withStepSizeExecutor(int stepSize) {
+ this.stepSizeExecutors = stepSize;
+ return this;
+ }
+
+
+ public Builder withFitDriverMemory(boolean fitDriverMemory) {
+ this.fitDriverMemory = fitDriverMemory;
+ return this;
+ }
+
+ public Builder withFitBroadcastMemory(boolean
fitBroadcastMemory) {
+ this.fitBroadcastMemory = fitBroadcastMemory;
+ return this;
+ }
+
+ public Builder withCheckSingleNodeExecution(boolean
checkSingleNodeExecution) {
+ this.checkSingleNodeExecution =
checkSingleNodeExecution;
+ return this;
+ }
+
+ public Builder withFitCheckpointMemory(boolean
fitCheckpointMemory) {
+ this.fitCheckpointMemory = fitCheckpointMemory;
+ return this;
+ }
+
+ public Builder withExpBaseExecutors(int expBaseExecutors) {
+ if (expBaseExecutors != -1 && expBaseExecutors < 2) {
+ throw new IllegalArgumentException("Given
exponent base for number of executors should be -1 or bigger than 1.");
+ }
+ this.expBaseExecutors = expBaseExecutors;
+ return this;
+ }
+
+ public Enumerator build() {
+ if (this.program == null) {
+ throw new IllegalArgumentException("Providing
runtime program is required");
+ }
+
+ if (instanceTypesRange == null) {
+ instanceTypesRange =
EnumSet.allOf(CloudUtils.InstanceType.class);
+ }
+
+ if (instanceSizeRange == null) {
+ instanceSizeRange =
EnumSet.allOf(CloudUtils.InstanceSize.class);
+ }
+
+ switch (optStrategy) {
+ case MinTime:
+ if (this.maxPrice < 0) {
+ throw new
IllegalArgumentException("Budget not specified but required " +
+ "for the chosen
optimization strategy: " + optStrategy);
+ }
+ break;
+ case MinPrice:
+ if (this.maxTime < 0) {
+ throw new
IllegalArgumentException("Time limit not specified but required " +
+ "for the chosen
optimization strategy: " + optStrategy);
+ }
+ break;
+ default: // in case optimization strategy was
not configured
+ throw new
IllegalArgumentException("Setting an optimization strategy is required.");
+ }
+
+ switch (enumStrategy) {
+ case GridBased:
+ return new GridBasedEnumerator(this,
stepSizeExecutors, expBaseExecutors);
+ case InterestBased:
+ if (fitCheckpointMemory &&
expBaseExecutors != -1) {
+ throw new
IllegalArgumentException("Number of executors cannot be fitted on the
checkpoint estimates and increased exponentially simultaneously.");
+ }
+ return new
InterestBasedEnumerator(this, fitDriverMemory, fitBroadcastMemory,
checkSingleNodeExecution, fitCheckpointMemory);
+ default:
+ throw new
IllegalArgumentException("Setting an enumeration strategy is required.");
+ }
+ }
+
+ protected static Set<CloudUtils.InstanceType>
typeRangeFromStrings(String[] types) {
+ Set<CloudUtils.InstanceType> result =
EnumSet.noneOf(CloudUtils.InstanceType.class);
+ for (String typeAsString: types) {
+ CloudUtils.InstanceType type =
CloudUtils.InstanceType.customValueOf(typeAsString); // can throw
IllegalArgumentException
+ result.add(type);
+ }
+ return result;
+ }
+
+ protected static Set<CloudUtils.InstanceSize>
sizeRangeFromStrings(String[] sizes) {
+ Set<CloudUtils.InstanceSize> result =
EnumSet.noneOf(CloudUtils.InstanceSize.class);
+ for (String sizeAsString: sizes) {
+ CloudUtils.InstanceSize size =
CloudUtils.InstanceSize.customValueOf(sizeAsString); // can throw
IllegalArgumentException
+ result.add(size);
+ }
+ return result;
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
new file mode 100644
index 0000000000..aa71aba139
--- /dev/null
+++
b/src/main/java/org/apache/sysds/resource/enumeration/GridBasedEnumerator.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource.enumeration;
+
+import java.util.*;
+
+public class GridBasedEnumerator extends Enumerator {
+ // marks if the number of executors should
+ // be increased by a given step
+ private final int stepSizeExecutors;
+ // marks if the number of executors should
+ // be increased exponentially
+ // (single node execution mode is not excluded)
+ // -1 marks no exp. increasing
+ private final int expBaseExecutors;
+ public GridBasedEnumerator(Builder builder, int stepSizeExecutors, int
expBaseExecutors) {
+ super(builder);
+ this.stepSizeExecutors = stepSizeExecutors;
+ this.expBaseExecutors = expBaseExecutors;
+ }
+
+ /**
+ * Initializes the pool for driver and executor
+ * instances parsed at processing with all the
+ * available instances
+ */
+ @Override
+ public void preprocessing() {
+ driverSpace.initSpace(instances);
+ executorSpace.initSpace(instances);
+ }
+
+ @Override
+ public boolean evaluateSingleNodeExecution(long driverMemory) {
+ return minExecutors == 0;
+ }
+
+ @Override
+ public ArrayList<Integer> estimateRangeExecutors(long executorMemory,
int executorCores) {
+ // consider the maximum level of parallelism and
+ // based on the initiated flags decides for the following
methods
+ // for enumeration of the number of executors:
+ // 1. Increasing the number of executor with given step size
(default 1)
+ // 2. Exponentially increasing number of executors based on
+ // a given exponent base - with additional option for 0
executors
+ int currentMax = Math.min(maxExecutors, MAX_LEVEL_PARALLELISM /
executorCores);
+ ArrayList<Integer> result;
+ if (expBaseExecutors > 1) {
+ int maxCapacity = (int) Math.floor(Math.log(currentMax)
/ Math.log(2));
+ result = new ArrayList<>(maxCapacity);
+ int exponent = 0;
+ int numExecutors;
+ while ((numExecutors = (int) Math.pow(expBaseExecutors,
exponent)) <= currentMax) {
+ if (numExecutors >= minExecutors) {
+ result.add(numExecutors);
+ }
+ exponent++;
+ }
+ } else {
+ int capacity = (int) Math.floor((double) (currentMax -
minExecutors + 1) / stepSizeExecutors);
+ result = new ArrayList<>(capacity);
+ // exclude the 0 from the iteration while keeping it as
starting point to ensure predictable steps
+ int numExecutors = minExecutors == 0? minExecutors +
stepSizeExecutors : minExecutors;
+ while (numExecutors <= currentMax) {
+ result.add(numExecutors);
+ numExecutors += stepSizeExecutors;
+ }
+ }
+
+ return result;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
new file mode 100644
index 0000000000..ea6971a967
--- /dev/null
+++
b/src/main/java/org/apache/sysds/resource/enumeration/InterestBasedEnumerator.java
@@ -0,0 +1,315 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.resource.enumeration;
+
+import org.apache.sysds.hops.Hop;
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.parser.StatementBlock;
+import org.apache.sysds.runtime.controlprogram.*;
+import
org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace;
+
+import java.util.*;
+import java.util.stream.Collectors;
+
+public class InterestBasedEnumerator extends Enumerator {
+ public final static long MINIMUM_RELEVANT_MEM_ESTIMATE = 2L * 1024 *
1024 * 1024; // 2GB
+ // different instance families can have slightly different memory
characteristics (e.g. EC2 Graviton (arm) instances)
+ // and using memory delta allows not ignoring such instances
+ // TODO: enable usage of memory delta when FLOPS and bandwidth
characteristics added to cost estimator
+ public final static boolean USE_MEMORY_DELTA = false;
+ public final static double MEMORY_DELTA_FRACTION = 0.05; // 5%
+ public final static double BROADCAST_MEMORY_FACTOR = 0.21; // fraction
of the minimum memory fraction for storage
+ // marks if memory estimates should be used at deciding
+ // for the search space of the instance for the driver nodes
+ private final boolean fitDriverMemory;
+ // marks if memory estimates should be used at deciding
+ // for the search space of the instance for the executor nodes
+ private final boolean fitBroadcastMemory;
+ // marks if the estimation of the range of number of executors
+ // for consideration should exclude single node execution mode
+ // if any of the estimates cannot fit in the driver's memory
+ private final boolean checkSingleNodeExecution;
+ // marks if the estimated output size should be
+ // considered as interesting point at deciding the
+ // number of executors - checkpoint storage level
+ private final boolean fitCheckpointMemory;
+ // largest full memory estimate (scaled)
+ private long largestMemoryEstimateCP;
+ // ordered set ot output memory estimates (scaled)
+ private TreeSet<Long> memoryEstimatesSpark;
+ public InterestBasedEnumerator(
+ Builder builder,
+ boolean fitDriverMemory,
+ boolean fitBroadcastMemory,
+ boolean checkSingleNodeExecution,
+ boolean fitCheckpointMemory
+ ) {
+ super(builder);
+ this.fitDriverMemory = fitDriverMemory;
+ this.fitBroadcastMemory = fitBroadcastMemory;
+ this.checkSingleNodeExecution = checkSingleNodeExecution;
+ this.fitCheckpointMemory = fitCheckpointMemory;
+ }
+
+ @Override
+ public void preprocessing() {
+ InstanceSearchSpace fullSearchSpace = new InstanceSearchSpace();
+ fullSearchSpace.initSpace(instances);
+
+ if (fitDriverMemory) {
+ // get full memory estimates and scale according ot the
driver memory factor
+ TreeSet<Long> memoryEstimatesForDriver =
getMemoryEstimates(program, false, OptimizerUtils.MEM_UTIL_FACTOR);
+ setInstanceSpace(fullSearchSpace, driverSpace,
memoryEstimatesForDriver);
+ if (checkSingleNodeExecution) {
+ largestMemoryEstimateCP =
!memoryEstimatesForDriver.isEmpty()? memoryEstimatesForDriver.last() : -1;
+ }
+ }
+
+ if (fitBroadcastMemory) {
+ // get output memory estimates and scaled according the
broadcast memory factor
+ // for executors' memory search space and driver memory
factor for driver's memory search space
+ TreeSet<Long> memoryEstimatesOutputSpark =
getMemoryEstimates(program, true, BROADCAST_MEMORY_FACTOR);
+ // avoid calling getMemoryEstimates with different
factor but rescale: output should fit twice in the CP memory
+ TreeSet<Long> memoryEstimatesOutputCP =
memoryEstimatesOutputSpark.stream()
+ .map(mem -> 2 * (long) (mem *
BROADCAST_MEMORY_FACTOR / OptimizerUtils.MEM_UTIL_FACTOR))
+
.collect(Collectors.toCollection(TreeSet::new));
+ setInstanceSpace(fullSearchSpace, driverSpace,
memoryEstimatesOutputCP);
+ setInstanceSpace(fullSearchSpace, executorSpace,
memoryEstimatesOutputSpark);
+ if (checkSingleNodeExecution) {
+ largestMemoryEstimateCP =
!memoryEstimatesOutputCP.isEmpty()? memoryEstimatesOutputCP.last() : -1;
+ }
+ if (fitCheckpointMemory) {
+ memoryEstimatesSpark =
memoryEstimatesOutputSpark;
+ }
+ } else {
+ executorSpace.putAll(fullSearchSpace);
+ if (fitCheckpointMemory) {
+ memoryEstimatesSpark =
getMemoryEstimates(program, true, BROADCAST_MEMORY_FACTOR);
+ }
+ }
+
+ if (!fitDriverMemory && !fitBroadcastMemory) {
+ driverSpace.putAll(fullSearchSpace);
+ if (checkSingleNodeExecution) {
+ TreeSet<Long> memoryEstimatesForDriver =
getMemoryEstimates(program, false, OptimizerUtils.MEM_UTIL_FACTOR);
+ largestMemoryEstimateCP =
!memoryEstimatesForDriver.isEmpty()? memoryEstimatesForDriver.last() : -1;
+ }
+ }
+ }
+
+ @Override
+ public boolean evaluateSingleNodeExecution(long driverMemory) {
+ // Checking if single node execution should be excluded is
optional.
+ if (checkSingleNodeExecution && minExecutors == 0 &&
largestMemoryEstimateCP > 0) {
+ return largestMemoryEstimateCP <= driverMemory;
+ }
+ return minExecutors == 0;
+ }
+
+ @Override
+ public ArrayList<Integer> estimateRangeExecutors(long executorMemory,
int executorCores) {
+ // consider the maximum level of parallelism and
+ // based on the initiated flags decides on the following methods
+ // for enumeration of the number of executors:
+ // 1. Such a number that leads to combined distributed memory
+ // close to the output size of the HOPs
+ // 3. Enumerating all options with the established range
+ int min = Math.max(1, minExecutors);
+ int max = Math.min(maxExecutors, (MAX_LEVEL_PARALLELISM /
executorCores));
+
+ ArrayList<Integer> result;
+ if (fitCheckpointMemory) {
+ result = new ArrayList<>(memoryEstimatesSpark.size() +
1);
+ int previousNumber = -1;
+ for (long estimate: memoryEstimatesSpark) {
+ // the ratio is just an intermediate for the
new enumerated number of executors
+ double ratio = (double) estimate /
executorMemory;
+ int currentNumber = (int) Math.max(1,
Math.floor(ratio));
+ if (currentNumber < min || currentNumber ==
previousNumber) {
+
+ continue;
+ }
+ if (currentNumber <= max) {
+ result.add(currentNumber);
+ previousNumber = currentNumber;
+ } else {
+ break;
+ }
+ }
+ // add a number that allow also the largest checkpoint
to be done in memory
+ if (previousNumber < 0) {
+ // always append at least one value to allow
evaluating Spark execution
+ result.add(min);
+ } else if (previousNumber < max) {
+ result.add(previousNumber + 1);
+ }
+ } else { // enumerate all options within the min-max range
+ result = new ArrayList<>((max - min) + 1);
+ for (int n = min; n <= max; n++) {
+ result.add(n);
+ }
+ }
+ return result;
+ }
+
+ // Static helper methods
-------------------------------------------------------------------------------------------
+ private static void setInstanceSpace(InstanceSearchSpace inputSpace,
InstanceSearchSpace outputSpace, TreeSet<Long> memoryEstimates) {
+ TreeSet<Long> memoryPoints = getMemoryPoints(memoryEstimates,
inputSpace.keySet());
+ for (long memory: memoryPoints) {
+ outputSpace.put(memory, inputSpace.get(memory));
+ }
+ // in case no large enough memory estimates exist set the
instances with minimal memory
+ if (outputSpace.isEmpty()) {
+ long minMemory = inputSpace.firstKey();
+ outputSpace.put(minMemory, inputSpace.get(minMemory));
+ }
+ }
+
+ /**
+ * @param availableMemory should be always a sorted set;
+ * this is always the case for the result of {@code keySet()} called on
{@code TreeMap}
+ */
+ private static TreeSet<Long> getMemoryPoints(TreeSet<Long> estimates,
Set<Long> availableMemory) {
+ // use tree set to avoid adding duplicates and ensure ascending
order
+ TreeSet<Long> result = new TreeSet<>();
+ // assumed ascending order
+ List<Long> relevantPoints = new ArrayList<>(availableMemory);
+ for (long estimate: estimates) {
+ if (availableMemory.isEmpty()) {
+ break;
+ }
+ // divide list on larger and smaller by partitioning -
partitioning preserve the order
+ Map<Boolean, List<Long>> divided =
relevantPoints.stream()
+ .collect(Collectors.partitioningBy(n ->
n < estimate));
+ // get the points smaller than the current memory
estimate
+ List<Long> smallerPoints = divided.get(true);
+ long largestOfTheSmaller = smallerPoints.isEmpty() ? -1
: smallerPoints.get(smallerPoints.size() - 1);
+ // reduce the list of relevant points - equal or larger
than the estimate
+ relevantPoints = divided.get(false);
+ // get points greater or equal than the current memory
estimate
+ long smallestOfTheLarger = relevantPoints.isEmpty()? -1
: relevantPoints.get(0);
+
+ if (USE_MEMORY_DELTA) {
+ // Delta memory of 5% of the node's memory
allows not ignoring
+ // memory points with potentially equivalent
values but not exactly the same values.
+ // This is the case for example in AWS for
instances of the same type but with
+ // different additional capabilities: m5.xlarge
(16GB) vs m5n.xlarge (15.25GB).
+ // Get points smaller than the current memory
estimate within the memory delta
+ long memoryDelta = Math.round(estimate *
MEMORY_DELTA_FRACTION);
+ for (long point : smallerPoints) {
+ if (point >= (largestOfTheSmaller -
memoryDelta)) {
+ result.add(point);
+ }
+ }
+ for (long point : relevantPoints) {
+ if (point <= (smallestOfTheLarger +
memoryDelta)) {
+ result.add(point);
+ } else {
+ break;
+ }
+ }
+ } else {
+ if (largestOfTheSmaller > 0) {
+ result.add(largestOfTheSmaller);
+ }
+ if (smallestOfTheLarger > 0) {
+ result.add(smallestOfTheLarger);
+ }
+ }
+ }
+ return result;
+ }
+
+ /**
+ * Extracts the memory estimates which original size is larger than
{@code MINIMUM_RELEVANT_MEM_ESTIMATE}
+ *
+ * @param currentProgram program for extracting the memory estimates
from
+ * @param outputOnly {@code true} - output estimate only;
+ * {@code false} - sum of input,
intermediate and output estimates
+ * @param memoryFactor factor for reverse scaling the estimates to avoid
+ * scaling the search space
parameters representing the nodes' memory budget
+ * @return memory estimates in ascending order ensured by the {@code
TreeSet} data structure
+ */
+ public static TreeSet<Long> getMemoryEstimates(Program currentProgram,
boolean outputOnly, double memoryFactor) {
+ TreeSet<Long> estimates = new TreeSet<>();
+ getMemoryEstimates(currentProgram.getProgramBlocks(),
estimates, outputOnly);
+ return estimates.stream()
+ .filter(mem -> mem >
MINIMUM_RELEVANT_MEM_ESTIMATE)
+ .map(mem -> (long) (mem / memoryFactor))
+ .collect(Collectors.toCollection(TreeSet::new));
+ }
+
+ private static void getMemoryEstimates(ArrayList<ProgramBlock> pbs,
TreeSet<Long> mem, boolean outputOnly) {
+ for( ProgramBlock pb : pbs ) {
+ getMemoryEstimates(pb, mem, outputOnly);
+ }
+ }
+
+ private static void getMemoryEstimates(ProgramBlock pb, TreeSet<Long>
mem, boolean outputOnly) {
+ if (pb instanceof FunctionProgramBlock)
+ {
+ FunctionProgramBlock fpb = (FunctionProgramBlock)pb;
+ getMemoryEstimates(fpb.getChildBlocks(), mem,
outputOnly);
+ }
+ else if (pb instanceof WhileProgramBlock)
+ {
+ WhileProgramBlock fpb = (WhileProgramBlock)pb;
+ getMemoryEstimates(fpb.getChildBlocks(), mem,
outputOnly);
+ }
+ else if (pb instanceof IfProgramBlock)
+ {
+ IfProgramBlock fpb = (IfProgramBlock)pb;
+ getMemoryEstimates(fpb.getChildBlocksIfBody(), mem,
outputOnly);
+ getMemoryEstimates(fpb.getChildBlocksElseBody(), mem,
outputOnly);
+ }
+ else if (pb instanceof ForProgramBlock) // including parfor
+ {
+ ForProgramBlock fpb = (ForProgramBlock)pb;
+ getMemoryEstimates(fpb.getChildBlocks(), mem,
outputOnly);
+ }
+ else
+ {
+ StatementBlock sb = pb.getStatementBlock();
+ if( sb != null && sb.getHops() != null ){
+ Hop.resetVisitStatus(sb.getHops());
+ for( Hop hop : sb.getHops() )
+ getMemoryEstimates(hop, mem,
outputOnly);
+ }
+ }
+ }
+
+ private static void getMemoryEstimates(Hop hop, TreeSet<Long> mem,
boolean outputOnly)
+ {
+ if( hop.isVisited() )
+ return;
+ //process children
+ for(Hop hi : hop.getInput())
+ getMemoryEstimates(hi, mem, outputOnly);
+
+ if (outputOnly) {
+ long estimate = (long) hop.getOutputMemEstimate(0);
+ if (estimate > 0)
+ mem.add(estimate);
+ } else {
+ mem.add((long) hop.getMemEstimate());
+ }
+ hop.setVisited();
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
index c0803cbcc7..76e664245c 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/context/SparkExecutionContext.java
@@ -159,22 +159,10 @@ public class SparkExecutionContext extends
ExecutionContext
return _spctx;
}
- public static void initVirtualSparkContext(SparkConf sparkConf) {
- if (_spctx != null) {
- for (Tuple2<String, String> pair : sparkConf.getAll()) {
- _spctx.sc().getConf().set(pair._1, pair._2);
- }
- } else {
- handleIllegalReflectiveAccessSpark();
- try {
- _spctx = new JavaSparkContext(sparkConf);
- // assumes NON-legacy spark version
- _sconf = new SparkClusterConfig();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
+ public static void initLocalSparkContext(SparkConf sparkConf) {
+ if (_sconf == null) {
+ _sconf = new SparkClusterConfig();
}
-
_sconf.analyzeSparkConfiguation(sparkConf);
}
@@ -1858,7 +1846,7 @@ public class SparkExecutionContext extends
ExecutionContext
private static final double BROADCAST_DATA_FRACTION_LEGACY =
0.35;
//forward private config from Spark's
UnifiedMemoryManager.scala (>1.6)
- private static final long RESERVED_SYSTEM_MEMORY_BYTES = 300 *
1024 * 1024;
+ public static final long RESERVED_SYSTEM_MEMORY_BYTES = 300 *
1024 * 1024;
//meta configurations
private boolean _legacyVersion = false; //spark version <1.6
@@ -1985,7 +1973,7 @@ public class SparkExecutionContext extends
ExecutionContext
_confOnly &= true;
}
else if( DMLScript.USE_LOCAL_SPARK_CONFIG ) {
- //avoid unnecessary spark context creation in
local mode (e.g., tests)
+ //avoid unnecessary spark context creation in
local mode (e.g., tests, resource opt.)
_numExecutors = 1;
_defaultPar = 2;
_confOnly &= true;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java
index 98c6848839..16f3cf7616 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/stat/InfrastructureAnalyzer.java
@@ -56,7 +56,7 @@ public class InfrastructureAnalyzer
private static int _remoteParReduce = -1;
private static boolean _localJT = false;
private static long _blocksize = -1;
-
+
//static initialization, called for each JVM (on each node)
static {
//analyze local node properties
@@ -136,7 +136,11 @@ public class InfrastructureAnalyzer
public static void setLocalMaxMemory( long localMem ) {
_localJVMMaxMem = localMem;
}
-
+
+ public static void setLocalPar(int localPar) {
+ _localPar = localPar;
+ }
+
public static double getLocalMaxMemoryFraction() {
//since parfor modifies _localJVMMaxMem, some internal
primitives
//need access to the current fraction of total local memory
diff --git
a/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java
b/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java
new file mode 100644
index 0000000000..b7e08ae35d
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/resource/CloudUtilsTests.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.resource.AWSUtils;
+import org.apache.sysds.resource.CloudInstance;
+import org.apache.sysds.resource.CloudUtils.InstanceType;
+import org.apache.sysds.resource.CloudUtils.InstanceSize;
+import org.junit.Test;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.HashMap;
+
+import static
org.apache.sysds.test.component.resource.TestingUtils.assertEqualsCloudInstances;
+import static
org.apache.sysds.test.component.resource.TestingUtils.getSimpleCloudInstanceMap;
+import static org.junit.Assert.*;
+
[email protected]
+public class CloudUtilsTests {
+
+ @Test
+ public void getInstanceTypeAWSTest() {
+ AWSUtils utils = new AWSUtils();
+
+ InstanceType expectedValue = InstanceType.M5;
+ InstanceType actualValue;
+
+ actualValue = utils.getInstanceType("m5.xlarge");
+ assertEquals(expectedValue, actualValue);
+
+ actualValue = utils.getInstanceType("M5.XLARGE");
+ assertEquals(expectedValue, actualValue);
+
+ try {
+ utils.getInstanceType("NON-M5.xlarge");
+ fail("Throwing IllegalArgumentException was expected");
+ } catch (IllegalArgumentException e) {
+ // this block ensures correct execution of the test
+ }
+ }
+
+ @Test
+ public void getInstanceSizeAWSTest() {
+ AWSUtils utils = new AWSUtils();
+
+ InstanceSize expectedValue = InstanceSize._XLARGE;
+ InstanceSize actualValue;
+
+ actualValue = utils.getInstanceSize("m5.xlarge");
+ assertEquals(expectedValue, actualValue);
+
+ actualValue = utils.getInstanceSize("M5.XLARGE");
+ assertEquals(expectedValue, actualValue);
+
+ try {
+ utils.getInstanceSize("m5.nonxlarge");
+ fail("Throwing IllegalArgumentException was expected");
+ } catch (IllegalArgumentException e) {
+ // this block ensures correct execution of the test
+ }
+ }
+
+ @Test
+ public void validateInstanceNameAWSTest() {
+ AWSUtils utils = new AWSUtils();
+
+ // basic intel instance (old)
+ assertTrue(utils.validateInstanceName("m5.2xlarge"));
+ assertTrue(utils.validateInstanceName("M5.2XLARGE"));
+ // basic intel instance (new)
+ assertTrue(utils.validateInstanceName("m6i.xlarge"));
+ // basic amd instance
+ assertTrue(utils.validateInstanceName("m6a.xlarge"));
+ // basic graviton instance
+ assertTrue(utils.validateInstanceName("m6g.xlarge"));
+ // invalid values
+ assertFalse(utils.validateInstanceName("v5.xlarge"));
+ assertFalse(utils.validateInstanceName("m5.notlarge"));
+ assertFalse(utils.validateInstanceName("m5xlarge"));
+ assertFalse(utils.validateInstanceName(".xlarge"));
+ assertFalse(utils.validateInstanceName("m5."));
+ }
+
+ @Test
+ public void loadCSVFileAWSTest() throws IOException {
+ AWSUtils utils = new AWSUtils();
+
+ File tmpFile = TestingUtils.generateTmpInstanceInfoTableFile();
+
+ HashMap<String, CloudInstance> actual =
utils.loadInstanceInfoTable(tmpFile.getPath());
+ HashMap<String, CloudInstance> expected =
getSimpleCloudInstanceMap();
+
+ for (String instanceName: expected.keySet()) {
+ assertEqualsCloudInstances(expected.get(instanceName),
actual.get(instanceName));
+ }
+
+ Files.deleteIfExists(tmpFile.toPath());
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
new file mode 100644
index 0000000000..437770ce28
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/resource/EnumeratorTests.java
@@ -0,0 +1,529 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.resource.CloudInstance;
+import org.apache.sysds.resource.enumeration.Enumerator;
+import
org.apache.sysds.resource.enumeration.EnumerationUtils.InstanceSearchSpace;
+import
org.apache.sysds.resource.enumeration.EnumerationUtils.ConfigurationPoint;
+import org.apache.sysds.resource.enumeration.EnumerationUtils.SolutionPoint;
+import org.apache.sysds.resource.enumeration.InterestBasedEnumerator;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.junit.Assert;
+import org.junit.Test;
+import org.mockito.MockedStatic;
+import org.mockito.Mockito;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.*;
+
+import static org.apache.sysds.resource.CloudUtils.GBtoBytes;
+import static org.junit.Assert.*;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.ArgumentMatchers.eq;
+
[email protected]
+public class EnumeratorTests {
+
+ @Test
+ public void loadInstanceTableTest() throws IOException {
+ // loading the table is entirely implemented by the abstract
class
+ // use any enumerator
+ Enumerator anyEnumerator = getGridBasedEnumeratorPrebuild()
+ .withInstanceTypeRange(new String[]{"m5"})
+ .withInstanceSizeRange(new String[]{"xlarge"})
+ .build();
+
+ File tmpFile = TestingUtils.generateTmpInstanceInfoTableFile();
+ anyEnumerator.loadInstanceTableFile(tmpFile.toString());
+
+ HashMap<String, CloudInstance> actualInstances =
anyEnumerator.getInstances();
+
+ Assert.assertEquals(1, actualInstances.size());
+ Assert.assertNotNull(actualInstances.get("m5.xlarge"));
+
+ Files.deleteIfExists(tmpFile.toPath());
+ }
+
+ @Test
+ public void preprocessingGridBasedTest() {
+ Enumerator gridBasedEnumerator =
getGridBasedEnumeratorPrebuild().build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ gridBasedEnumerator.setInstanceTable(instances);
+
+ gridBasedEnumerator.preprocessing();
+ // assertions for driver space
+ InstanceSearchSpace driverSpace =
gridBasedEnumerator.getDriverSpace();
+ assertEquals(3, driverSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0);
+ assertInstanceInSearchSpace("m5.xlarge", driverSpace, 16, 4, 0);
+ assertInstanceInSearchSpace("c5.2xlarge", driverSpace, 16, 8,
0);
+ assertInstanceInSearchSpace("m5.2xlarge", driverSpace, 32, 8,
0);
+ // assertions for executor space
+ InstanceSearchSpace executorSpace =
gridBasedEnumerator.getDriverSpace();
+ assertEquals(3, executorSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4,
0);
+ assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4,
0);
+ assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8,
0);
+ assertInstanceInSearchSpace("m5.2xlarge", executorSpace, 32, 8,
0);
+ }
+
+ @Test
+ public void preprocessingInterestBasedDriverMemoryTest() {
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withFitDriverMemory(true)
+ .withFitBroadcastMemory(false)
+ .build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ interestBasedEnumerator.setInstanceTable(instances);
+
+ // use 10GB (scaled) memory estimate to be between the
available 8GB and 16GB driver node's memory
+ TreeSet<Long> mockingMemoryEstimates = new
TreeSet<>(Set.of(GBtoBytes(10)));
+ try (MockedStatic<InterestBasedEnumerator> mockedEnumerator =
+
Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) {
+ mockedEnumerator
+ .when(() ->
InterestBasedEnumerator.getMemoryEstimates(
+ any(Program.class),
+ eq(false),
+
eq(OptimizerUtils.MEM_UTIL_FACTOR)))
+ .thenReturn(mockingMemoryEstimates);
+ interestBasedEnumerator.preprocessing();
+ }
+
+ // assertions for driver space
+ InstanceSearchSpace driverSpace =
interestBasedEnumerator.getDriverSpace();
+ assertEquals(2, driverSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0);
+ assertInstanceInSearchSpace("m5.xlarge", driverSpace, 16, 4, 0);
+ assertInstanceInSearchSpace("c5.2xlarge", driverSpace, 16, 8,
0);
+ Assert.assertNull(driverSpace.get(GBtoBytes(32)));
+ // assertions for executor space
+ InstanceSearchSpace executorSpace =
interestBasedEnumerator.getExecutorSpace();
+ assertEquals(3, executorSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4,
0);
+ assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4,
0);
+ assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8,
0);
+ assertInstanceInSearchSpace("m5.2xlarge", executorSpace, 32, 8,
0);
+ }
+
+ @Test
+ public void preprocessingInterestBasedBroadcastMemoryTest() {
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withFitDriverMemory(false)
+ .withFitBroadcastMemory(true)
+ .build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ interestBasedEnumerator.setInstanceTable(instances);
+
+ double outputEstimate = 2.5;
+ double scaledOutputEstimateBroadcast = outputEstimate /
InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR; // ~=12
+ // scaledOutputEstimateCP = 2 * outputEstimate /
OptimizerUtils.MEM_UTIL_FACTOR ~= 7
+ TreeSet<Long> mockingMemoryEstimates = new
TreeSet<>(Set.of(GBtoBytes(scaledOutputEstimateBroadcast)));
+ try (MockedStatic<InterestBasedEnumerator> mockedEnumerator =
+
Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) {
+ mockedEnumerator
+ .when(() ->
InterestBasedEnumerator.getMemoryEstimates(
+ any(Program.class),
+ eq(true),
+
eq(InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR)))
+ .thenReturn(mockingMemoryEstimates);
+ interestBasedEnumerator.preprocessing();
+ }
+
+ // assertions for driver space
+ InstanceSearchSpace driverSpace =
interestBasedEnumerator.getDriverSpace();
+ assertEquals(1, driverSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", driverSpace, 8, 4, 0);
+ Assert.assertNull(driverSpace.get(GBtoBytes(16)));
+ Assert.assertNull(driverSpace.get(GBtoBytes(32)));
+ // assertions for executor space
+ InstanceSearchSpace executorSpace =
interestBasedEnumerator.getExecutorSpace();
+ assertEquals(2, executorSpace.size());
+ assertInstanceInSearchSpace("c5.xlarge", executorSpace, 8, 4,
0);
+ assertInstanceInSearchSpace("m5.xlarge", executorSpace, 16, 4,
0);
+ assertInstanceInSearchSpace("c5.2xlarge", executorSpace, 16, 8,
0);
+ Assert.assertNull(executorSpace.get(GBtoBytes(32)));
+ }
+
+ @Test
+ public void evaluateSingleNodeExecutionGridBasedTest() {
+ Enumerator gridBasedEnumerator;
+ boolean result;
+
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0,1)
+ .build();
+
+ // memory not relevant for grid-based enumerator
+ result = gridBasedEnumerator.evaluateSingleNodeExecution(-1);
+ Assert.assertTrue(result);
+
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(1,2)
+ .build();
+
+ // memory not relevant for grid-based enumerator
+ result = gridBasedEnumerator.evaluateSingleNodeExecution(-1);
+ Assert.assertFalse(result);
+ }
+
+ @Test
+ public void estimateRangeExecutorsGridBasedStepSizeTest() {
+ Enumerator gridBasedEnumerator;
+ ArrayList<Integer> expectedResult;
+ ArrayList<Integer> actualResult;
+
+ // num. executors range starting from zero and step size = 2
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 10)
+ .withStepSizeExecutor(2)
+ .build();
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(2, 4, 6, 8, 10));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(2, 4));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
200);
+ Assert.assertEquals(expectedResult, actualResult);
+
+ // num. executors range not starting from zero and without step
size given
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(3, 8)
+ .build();
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(3, 4, 5, 6, 7, 8));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(3, 4, 5));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
200);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void estimateRangeExecutorsGridBasedExpBaseTest() {
+ Enumerator gridBasedEnumerator;
+ ArrayList<Integer> expectedResult;
+ ArrayList<Integer> actualResult;
+
+ // num. executors range starting from zero and exponential base
= 2
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 10)
+ .withExpBaseExecutors(2)
+ .build();
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2, 4, 8));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2, 4));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
200);
+ Assert.assertEquals(expectedResult, actualResult);
+
+ // num. executors range not starting from zero and with
exponential base = 3
+ gridBasedEnumerator = getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(3, 30)
+ .withExpBaseExecutors(3)
+ .build();
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(3,9, 27));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(3,9));
+ actualResult = gridBasedEnumerator.estimateRangeExecutors(-1,
100);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void evaluateSingleNodeExecutionInterestBasedTest() {
+ boolean result;
+
+ // no fitting the memory estimates for checkpointing
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 5)
+ .withFitDriverMemory(false)
+ .withFitBroadcastMemory(false)
+ .withCheckSingleNodeExecution(true)
+ .build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ interestBasedEnumerator.setInstanceTable(instances);
+
+ TreeSet<Long> mockingMemoryEstimates = new
TreeSet<>(Set.of(GBtoBytes(6), GBtoBytes(12)));
+ try (MockedStatic<InterestBasedEnumerator> mockedEnumerator =
+
Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) {
+ mockedEnumerator
+ .when(() ->
InterestBasedEnumerator.getMemoryEstimates(
+ any(Program.class),
+ eq(false),
+
eq(OptimizerUtils.MEM_UTIL_FACTOR)))
+ .thenReturn(mockingMemoryEstimates);
+ // initiate memoryEstimatesSpark
+ interestBasedEnumerator.preprocessing();
+ }
+
+ result =
interestBasedEnumerator.evaluateSingleNodeExecution(GBtoBytes(8));
+ Assert.assertFalse(result);
+ }
+
+ @Test
+ public void estimateRangeExecutorsInterestBasedGeneralTest() {
+ ArrayList<Integer> expectedResult;
+ ArrayList<Integer>actualResult;
+
+ // no fitting the memory estimates for checkpointing
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 5)
+ .build();
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2, 3, 4, 5));
+ actualResult =
interestBasedEnumerator.estimateRangeExecutors(-1, 4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2, 3));
+ actualResult =
interestBasedEnumerator.estimateRangeExecutors(-1, 256);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void estimateRangeExecutorsInterestBasedCheckpointMemoryTest() {
+ ArrayList<Integer> expectedResult;
+ ArrayList<Integer>actualResult;
+
+ // fitting the memory estimates for checkpointing
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 5)
+ .withFitCheckpointMemory(true)
+ .withFitDriverMemory(false)
+ .withFitBroadcastMemory(false)
+ .build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ interestBasedEnumerator.setInstanceTable(instances);
+
+ TreeSet<Long> mockingMemoryEstimates = new
TreeSet<>(Set.of(GBtoBytes(20), GBtoBytes(40)));
+ try (MockedStatic<InterestBasedEnumerator> mockedEnumerator =
+
Mockito.mockStatic(InterestBasedEnumerator.class, Mockito.CALLS_REAL_METHODS)) {
+ mockedEnumerator
+ .when(() ->
InterestBasedEnumerator.getMemoryEstimates(
+ any(Program.class),
+ eq(true),
+
eq(InterestBasedEnumerator.BROADCAST_MEMORY_FACTOR)))
+ .thenReturn(mockingMemoryEstimates);
+ // initiate memoryEstimatesSpark
+ interestBasedEnumerator.preprocessing();
+ }
+
+ // test the general case when the max level of parallelism is
not reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2, 3));
+ actualResult =
interestBasedEnumerator.estimateRangeExecutors(GBtoBytes(16), 4);
+ Assert.assertEquals(expectedResult, actualResult);
+ // test the case when the max level of parallelism (1000) is
reached (0 is never part of the result)
+ expectedResult = new ArrayList<>(List.of(1, 2));
+ actualResult =
interestBasedEnumerator.estimateRangeExecutors(GBtoBytes(16), 500);
+ Assert.assertEquals(expectedResult, actualResult);
+ }
+
+ @Test
+ public void processingTest() {
+ // all implemented enumerators should enumerate the same
solution pool in this basic case - empty program
+ Enumerator gridBasedEnumerator =
getGridBasedEnumeratorPrebuild()
+ .withTimeLimit(Double.MAX_VALUE)
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+ HashMap<String, CloudInstance> instances =
TestingUtils.getSimpleCloudInstanceMap();
+ InstanceSearchSpace space = new InstanceSearchSpace();
+ space.initSpace(instances);
+
+ // run processing for the grid based enumerator
+ gridBasedEnumerator.setDriverSpace(space);
+ gridBasedEnumerator.setExecutorSpace(space);
+ gridBasedEnumerator.processing();
+ ArrayList<SolutionPoint> actualSolutionPoolGB =
gridBasedEnumerator.getSolutionPool();
+ // run processing for the interest based enumerator
+ interestBasedEnumerator.setDriverSpace(space);
+ interestBasedEnumerator.setExecutorSpace(space);
+ interestBasedEnumerator.processing();
+ ArrayList<SolutionPoint> actualSolutionPoolIB =
gridBasedEnumerator.getSolutionPool();
+
+
+ ArrayList<CloudInstance> expectedInstances = new
ArrayList<>(Arrays.asList(
+ instances.get("c5.xlarge"),
+ instances.get("m5.xlarge")
+ ));
+ // expected solution pool with 0 executors (number executors =
0, executors and executorInstance being null)
+ // each solution having one of the available instances as
driver node
+ Assert.assertEquals(expectedInstances.size(),
actualSolutionPoolGB.size());
+ Assert.assertEquals(expectedInstances.size(),
actualSolutionPoolIB.size());
+ for (int i = 0; i < expectedInstances.size(); i++) {
+ SolutionPoint pointGB = actualSolutionPoolGB.get(i);
+ Assert.assertEquals(0, pointGB.numberExecutors);
+ Assert.assertEquals(expectedInstances.get(i),
pointGB.driverInstance);
+ Assert.assertNull(pointGB.executorInstance);
+ SolutionPoint pointIB = actualSolutionPoolGB.get(i);
+ Assert.assertEquals(0, pointIB.numberExecutors);
+ Assert.assertEquals(expectedInstances.get(i),
pointIB.driverInstance);
+ Assert.assertNull(pointIB.executorInstance);
+ }
+ }
+
+ @Test
+ public void postprocessingTest() {
+ // postprocessing equivalent for all types of enumerators
+ Enumerator enumerator =
getGridBasedEnumeratorPrebuild().build();
+ // construct solution pool
+ // first dummy configuration point since not relevant for
postprocessing
+ ConfigurationPoint dummyPoint = new ConfigurationPoint(null);
+ SolutionPoint solution1 = new SolutionPoint(dummyPoint, 1000,
1000);
+ SolutionPoint solution2 = new SolutionPoint(dummyPoint, 900,
1000); // optimal point
+ SolutionPoint solution3 = new SolutionPoint(dummyPoint, 800,
10000);
+ SolutionPoint solution4 = new SolutionPoint(dummyPoint, 1000,
10000);
+ SolutionPoint solution5 = new SolutionPoint(dummyPoint, 900,
10000);
+ ArrayList<SolutionPoint> mockListSolutions = new
ArrayList<>(List.of(solution1, solution2, solution3, solution4, solution5));
+ enumerator.setSolutionPool(mockListSolutions);
+
+ SolutionPoint optimalSolution = enumerator.postprocessing();
+ assertEquals(solution2, optimalSolution);
+ }
+
+ @Test
+ public void GridBasedEnumerationMinPriceTest() {
+ Enumerator gridBasedEnumerator =
getGridBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+
gridBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap());
+
+ gridBasedEnumerator.preprocessing();
+ gridBasedEnumerator.processing();
+ SolutionPoint solution = gridBasedEnumerator.postprocessing();
+
+ // expected m5.xlarge since it is the cheaper
+ Assert.assertEquals("m5.xlarge",
solution.driverInstance.getInstanceName());
+ // expected no executor nodes since tested for a 'zero' program
+ Assert.assertEquals(0, solution.numberExecutors);
+ }
+
+ @Test
+ public void InterestBasedEnumerationMinPriceTest() {
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+
interestBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap());
+
+ interestBasedEnumerator.preprocessing();
+ interestBasedEnumerator.processing();
+ SolutionPoint solution =
interestBasedEnumerator.postprocessing();
+
+ // expected c5.xlarge since is the instance with at least memory
+ Assert.assertEquals("c5.xlarge",
solution.driverInstance.getInstanceName());
+ // expected no executor nodes since tested for a 'zero' program
+ Assert.assertEquals(0, solution.numberExecutors);
+ }
+
+ @Test
+ public void GridBasedEnumerationMinTimeTest() {
+ Enumerator gridBasedEnumerator =
getGridBasedEnumeratorPrebuild()
+
.withOptimizationStrategy(Enumerator.OptimizationStrategy.MinTime)
+ .withBudget(Double.MAX_VALUE)
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+
gridBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap());
+
+ gridBasedEnumerator.preprocessing();
+ gridBasedEnumerator.processing();
+ SolutionPoint solution = gridBasedEnumerator.postprocessing();
+
+ // expected m5.xlarge since it is the cheaper
+ Assert.assertEquals("m5.xlarge",
solution.driverInstance.getInstanceName());
+ // expected no executor nodes since tested for a 'zero' program
+ Assert.assertEquals(0, solution.numberExecutors);
+ }
+
+ @Test
+ public void InterestBasedEnumerationMinTimeTest() {
+ Enumerator interestBasedEnumerator =
getInterestBasedEnumeratorPrebuild()
+
.withOptimizationStrategy(Enumerator.OptimizationStrategy.MinTime)
+ .withBudget(Double.MAX_VALUE)
+ .withNumberExecutorsRange(0, 2)
+ .build();
+
+
interestBasedEnumerator.setInstanceTable(TestingUtils.getSimpleCloudInstanceMap());
+
+ interestBasedEnumerator.preprocessing();
+ interestBasedEnumerator.processing();
+ SolutionPoint solution =
interestBasedEnumerator.postprocessing();
+
+ // expected c5.xlarge since is the instance with at least memory
+ Assert.assertEquals("c5.xlarge",
solution.driverInstance.getInstanceName());
+ // expected no executor nodes since tested for a 'zero' program
+ Assert.assertEquals(0, solution.numberExecutors);
+ }
+
+ // Helpers
+ private static Enumerator.Builder getGridBasedEnumeratorPrebuild() {
+ Program emptyProgram = new Program();
+ return (new Enumerator.Builder())
+ .withRuntimeProgram(emptyProgram)
+
.withEnumerationStrategy(Enumerator.EnumerationStrategy.GridBased)
+
.withOptimizationStrategy(Enumerator.OptimizationStrategy.MinPrice)
+ .withTimeLimit(Double.MAX_VALUE);
+ }
+
+ private static Enumerator.Builder getInterestBasedEnumeratorPrebuild() {
+ Program emptyProgram = new Program();
+ return (new Enumerator.Builder())
+ .withRuntimeProgram(emptyProgram)
+
.withEnumerationStrategy(Enumerator.EnumerationStrategy.InterestBased)
+
.withOptimizationStrategy(Enumerator.OptimizationStrategy.MinPrice)
+ .withTimeLimit(Double.MAX_VALUE);
+ }
+
+ private static void assertInstanceInSearchSpace(
+ String expectedName,
+ InstanceSearchSpace searchSpace,
+ int memory, /* in GB */
+ int cores,
+ int index
+ ) {
+ Assert.assertNotNull(searchSpace.get(GBtoBytes(memory)));
+ try {
+ String actualName =
searchSpace.get(GBtoBytes(memory)).get(cores).get(index).getInstanceName();
+ Assert.assertEquals(expectedName, actualName);
+ } catch (NullPointerException e) {
+ fail(expectedName+" instances not properly passed to
"+searchSpace.getClass().getName());
+ }
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
new file mode 100644
index 0000000000..b2631684c2
--- /dev/null
+++
b/src/test/java/org/apache/sysds/test/component/resource/RecompilationTest.java
@@ -0,0 +1,258 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.resource.ResourceCompiler;
+import org.apache.sysds.runtime.controlprogram.BasicProgramBlock;
+import org.apache.sysds.runtime.controlprogram.Program;
+import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext;
+import
org.apache.sysds.runtime.controlprogram.parfor.stat.InfrastructureAnalyzer;
+import org.apache.sysds.runtime.instructions.Instruction;
+import org.apache.sysds.runtime.instructions.spark.SPInstruction;
+import org.apache.sysds.test.AutomatedTestBase;
+import org.apache.sysds.utils.Explain;
+
+import org.junit.Assert;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+import static
org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext.SparkClusterConfig.RESERVED_SYSTEM_MEMORY_BYTES;
+
+public class RecompilationTest extends AutomatedTestBase {
+ private static final boolean DEBUG_MODE = true;
+ private static final String TEST_DIR = "component/resource/";
+ private static final String TEST_DATA_DIR = "component/resource/data/";
+ private static final String HOME = SCRIPT_DIR + TEST_DIR;
+ private static final String HOME_DATA = SCRIPT_DIR + TEST_DATA_DIR;
+ // Static Configuration values
-------------------------------------------------------------------------------------
+ private static final int driverThreads = 4;
+ private static final int executorThreads = 2;
+
+ @Override
+ public void setUp() {}
+
+ // Tests for setting cluster configurations
------------------------------------------------------------------------
+
+ @Test
+ public void testSetDriverConfigurations() {
+ long expectedMemory = 1024*1024*1024; // 1GB
+ int expectedThreads = 4;
+
+ ResourceCompiler.setDriverConfigurations(expectedMemory,
expectedThreads);
+
+ Assert.assertEquals(expectedMemory,
InfrastructureAnalyzer.getLocalMaxMemory());
+ Assert.assertEquals(expectedThreads,
InfrastructureAnalyzer.getLocalParallelism());
+ }
+
+ @Test
+ public void testSetExecutorConfigurations() {
+ int numberExecutors = 10;
+ long executorMemory = 1024*1024*1024; // 1GB
+ long expectedMemoryBudget = (long)
(numberExecutors*(executorMemory-RESERVED_SYSTEM_MEMORY_BYTES)*0.6);
+ int executorThreads = 4;
+ int expectedParallelism = numberExecutors*executorThreads;
+
+ ResourceCompiler.setExecutorConfigurations(numberExecutors,
executorMemory, executorThreads);
+
+ Assert.assertEquals(numberExecutors,
SparkExecutionContext.getNumExecutors());
+ Assert.assertEquals(expectedMemoryBudget, (long)
SparkExecutionContext.getDataMemoryBudget(false, false));
+ Assert.assertEquals(expectedParallelism,
SparkExecutionContext.getDefaultParallelism(false));
+ }
+
+ // Tests for regular matrix multiplication (X%*%Y)
-----------------------------------------------------------------
+
+ @Test
+ public void test_CP_MM_Enforced() throws IOException {
+ // Single node cluster with 8GB driver memory -> ba+* operator
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG
+ // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB
+ // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB
+ runTestMM("A.csv", "B.csv", 8L*1024*1024*1024, 0, -1, "ba+*");
+ }
+
+ @Test
+ public void test_CP_MM_Preferred() throws IOException {
+ // Distributed cluster with 16GB driver memory (large enough to
host the computation) and any executors
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG
+ // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB
+ // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB
+ runTestMM("A.csv", "B.csv", 16L*1024*1024*1024, 2,
1024*1024*1024, "ba+*");
+ }
+
+ @Test
+ public void test_SP_MAPMM() throws IOException {
+ // Distributed cluster with 4GB driver memory and 4GB executors
-> mapmm operator
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG
+ // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB
+ // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB
+ runTestMM("A.csv", "B.csv", 4L*1024*1024*1024, 2,
4L*1024*1024*1024, "mapmm");
+ }
+
+ @Test
+ public void test_SP_RMM() throws IOException {
+ // Distributed cluster with 1GB driver memory and 500MB
executors -> rmm operator
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG
+ // Y = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB
+ // X %*% Y -> (10^5)x(10^3) = 10^8 ~ 800MB
+ runTestMM("A.csv", "B.csv", 1024*1024*1024, 2, (long)
(0.5*1024*1024*1024), "rmm");
+ }
+
+ @Test
+ public void test_SP_CPMM() throws IOException {
+ // Distributed cluster with 8GB driver memory and 4GB executors
-> cpmm operator
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8BG
+ // Y = C.csv: (10^4)x(10^4) = 10^8 ~ 800MB
+ // X %*% Y -> (10^5)x(10^4) = 10^9 ~ 8GB
+ runTestMM("A.csv", "C.csv", 8L*1024*1024*1024, 2,
4L*1024*1024*1024, "cpmm");
+ }
+
+ // Tests for transposed self matrix multiplication (t(X)%*%X)
------------------------------------------------------
+
+ @Test
+ public void test_CP_TSMM() throws IOException {
+ // Single node cluster with 8GB driver memory -> tsmm operator
in CP
+ // X = B.csv: (10^4)x(10^3) = 10^7 ~ 80MB
+ // t(X) %*% X -> (10^3)x(10^3) = 10^6 ~ 8MB (single block)
+ runTestTSMM("B.csv", 8L*1024*1024*1024, 0, -1, "tsmm", false);
+ }
+
+ @Test
+ public void test_SP_TSMM() throws IOException {
+ // Distributed cluster with 1GB driver memory and 8GB executor
memory -> tsmm operator in Spark
+ // X = D.csv: (10^5)x(10^3) = 10^8 ~ 800MB
+ // t(X) %*% X -> (10^3)x(10^3) = 10^6 ~ 8MB (single block)
+ runTestTSMM("D.csv", 1024*1024*1024, 2, 8L*1024*1024*1024,
"tsmm", true);
+ }
+
+ @Test
+ public void test_SP_TSMM_as_CPMM() throws IOException {
+ // Distributed cluster with 8GB driver memory and 8GB executor
memory -> cpmm operator in Spark
+ // X = A.csv: (10^5)x(10^4) = 10^9 ~ 8GB
+ // t(X) %*% X -> (10^4)x(10^4) = 10^8 ~ 800MB
+ runTestTSMM("A.csv", 8L*1024*1024*1024, 2, 8L*1024*1024*1024,
"cpmm", true);
+ }
+
+ @Test
+ public void test_MM_RecompilationSequence() throws IOException {
+ Map<String, String> nvargs = new HashMap<>();
+ nvargs.put("$X", HOME_DATA+"A.csv");
+ nvargs.put("$Y", HOME_DATA+"B.csv");
+
+ // pre-compiled program using default values to be used as
source for the recompilation
+ Program precompiledProgram =
generateInitialProgram(HOME+"mm_test.dml", nvargs);
+ // original compilation used for comparison
+ Program expectedProgram;
+
+ ResourceCompiler.setDriverConfigurations(8L*1024*1024*1024,
driverThreads);
+ ResourceCompiler.setSingleNodeExecution();
+ expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml",
nvargs);
+ runTest(precompiledProgram, expectedProgram, 8L*1024*1024*1024,
0, -1, "ba+*", false);
+
+ ResourceCompiler.setDriverConfigurations(16L*1024*1024*1024,
driverThreads);
+ ResourceCompiler.setExecutorConfigurations(2, 1024*1024*1024,
executorThreads);
+ expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml",
nvargs);
+ runTest(precompiledProgram, expectedProgram,
16L*1024*1024*1024, 2, 1024*1024*1024, "ba+*", false);
+
+ ResourceCompiler.setDriverConfigurations(4L*1024*1024*1024,
driverThreads);
+ ResourceCompiler.setExecutorConfigurations(2,
4L*1024*1024*1024, executorThreads);
+ expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml",
nvargs);
+ runTest(precompiledProgram, expectedProgram, 4L*1024*1024*1024,
2, 4L*1024*1024*1024, "mapmm", true);
+
+ ResourceCompiler.setDriverConfigurations(1024*1024*1024,
driverThreads);
+ ResourceCompiler.setExecutorConfigurations(2, (long)
(0.5*1024*1024*1024), executorThreads);
+ expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml",
nvargs);
+ runTest(precompiledProgram, expectedProgram, 1024*1024*1024, 2,
(long) (0.5*1024*1024*1024), "rmm", true);
+
+ ResourceCompiler.setDriverConfigurations(8L*1024*1024*1024,
driverThreads);
+ ResourceCompiler.setSingleNodeExecution();
+ expectedProgram = ResourceCompiler.compile(HOME+"mm_test.dml",
nvargs);
+ runTest(precompiledProgram, expectedProgram, 8L*1024*1024*1024,
0, -1, "ba+*", false);
+ }
+
+ // Helper functions
------------------------------------------------------------------------------------------------
+ private Program generateInitialProgram(String filePath, Map<String,
String> args) throws IOException {
+
ResourceCompiler.setDriverConfigurations(ResourceCompiler.DEFAULT_DRIVER_MEMORY,
ResourceCompiler.DEFAULT_DRIVER_THREADS);
+
ResourceCompiler.setExecutorConfigurations(ResourceCompiler.DEFAULT_NUMBER_EXECUTORS,
ResourceCompiler.DEFAULT_EXECUTOR_MEMORY,
ResourceCompiler.DEFAULT_EXECUTOR_THREADS);
+ return ResourceCompiler.compile(filePath, args);
+ }
+
+ private void runTestMM(String fileX, String fileY, long driverMemory,
int numberExecutors, long executorMemory, String expectedOpcode) throws
IOException {
+ boolean expectedSparkExecType =
!Objects.equals(expectedOpcode,"ba+*");
+ Map<String, String> nvargs = new HashMap<>();
+ nvargs.put("$X", HOME_DATA+fileX);
+ nvargs.put("$Y", HOME_DATA+fileY);
+
+ // pre-compiled program using default values to be used as
source for the recompilation
+ Program precompiledProgram =
generateInitialProgram(HOME+"mm_test.dml", nvargs);
+
+ ResourceCompiler.setDriverConfigurations(driverMemory,
driverThreads);
+ if (numberExecutors > 0) {
+
ResourceCompiler.setExecutorConfigurations(numberExecutors, executorMemory,
executorThreads);
+ } else {
+ ResourceCompiler.setSingleNodeExecution();
+ }
+
+ // original compilation used for comparison
+ Program expectedProgram =
ResourceCompiler.compile(HOME+"mm_test.dml", nvargs);
+ runTest(precompiledProgram, expectedProgram, driverMemory,
numberExecutors, executorMemory, expectedOpcode, expectedSparkExecType);
+ }
+
+ private void runTestTSMM(String fileX, long driverMemory, int
numberExecutors, long executorMemory, String expectedOpcode, boolean
expectedSparkExecType) throws IOException {
+ Map<String, String> nvargs = new HashMap<>();
+ nvargs.put("$X", HOME_DATA+fileX);
+
+ // pre-compiled program using default values to be used as
source for the recompilation
+ Program precompiledProgram =
generateInitialProgram(HOME+"mm_transpose_test.dml", nvargs);
+
+ ResourceCompiler.setDriverConfigurations(driverMemory,
driverThreads);
+ if (numberExecutors > 0) {
+
ResourceCompiler.setExecutorConfigurations(numberExecutors, executorMemory,
executorThreads);
+ } else {
+ ResourceCompiler.setSingleNodeExecution();
+ }
+ // original compilation used for comparison
+ Program expectedProgram =
ResourceCompiler.compile(HOME+"mm_transpose_test.dml", nvargs);
+ runTest(precompiledProgram, expectedProgram, driverMemory,
numberExecutors, executorMemory, expectedOpcode, expectedSparkExecType);
+ }
+
+ private void runTest(Program precompiledProgram, Program
expectedProgram, long driverMemory, int numberExecutors, long executorMemory,
String expectedOpcode, boolean expectedSparkExecType) {
+ String expectedProgramExplained =
Explain.explain(expectedProgram);
+
+ Program recompiledProgram;
+ if (numberExecutors == 0) {
+ recompiledProgram =
ResourceCompiler.doFullRecompilation(precompiledProgram, driverMemory,
driverThreads);
+ } else {
+ recompiledProgram =
ResourceCompiler.doFullRecompilation(precompiledProgram, driverMemory,
driverThreads, numberExecutors, executorMemory, executorThreads);
+ }
+ String actualProgramExplained =
Explain.explain(recompiledProgram);
+
+ if (DEBUG_MODE) System.out.println(actualProgramExplained);
+ Assert.assertEquals(expectedProgramExplained,
actualProgramExplained);
+ Optional<Instruction> mmInstruction = ((BasicProgramBlock)
recompiledProgram.getProgramBlocks().get(0)).getInstructions().stream()
+ .filter(inst ->
(Objects.equals(expectedSparkExecType, inst instanceof SPInstruction) &&
Objects.equals(inst.getOpcode(), expectedOpcode)))
+ .findFirst();
+ Assert.assertTrue(mmInstruction.isPresent());
+ }
+}
diff --git
a/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
new file mode 100644
index 0000000000..035fac6ab1
--- /dev/null
+++ b/src/test/java/org/apache/sysds/test/component/resource/TestingUtils.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.test.component.resource;
+
+import org.apache.sysds.resource.CloudInstance;
+import org.junit.Assert;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.apache.sysds.resource.CloudUtils.GBtoBytes;
+
+public class TestingUtils {
+ public static void assertEqualsCloudInstances(CloudInstance expected,
CloudInstance actual) {
+ Assert.assertEquals(expected.getInstanceName(),
actual.getInstanceName());
+ Assert.assertEquals(expected.getMemory(), actual.getMemory());
+ Assert.assertEquals(expected.getVCPUs(), actual.getVCPUs());
+ Assert.assertEquals(expected.getFLOPS(), actual.getFLOPS());
+ Assert.assertEquals(expected.getMemorySpeed(),
actual.getMemorySpeed(), 0.0);
+ Assert.assertEquals(expected.getDiskSpeed(),
actual.getDiskSpeed(), 0.0);
+ Assert.assertEquals(expected.getNetworkSpeed(),
actual.getNetworkSpeed(), 0.0);
+ Assert.assertEquals(expected.getPrice(), actual.getPrice(),
0.0);
+
+ }
+
+ public static HashMap<String, CloudInstance>
getSimpleCloudInstanceMap() {
+ HashMap<String, CloudInstance> instanceMap = new HashMap<>();
+ // fill the map wsearchStrategyh enough cloud instances to
allow testing all search space dimension searchStrategyerations
+ instanceMap.put("m5.xlarge", new CloudInstance("m5.xlarge",
GBtoBytes(16), 4, 0.5, 0.0, 143.75, 160, 1.5));
+ instanceMap.put("m5.2xlarge", new CloudInstance("m5.2xlarge",
GBtoBytes(32), 8, 1.0, 0.0, 0.0, 0.0, 1.9));
+ instanceMap.put("c5.xlarge", new CloudInstance("c5.xlarge",
GBtoBytes(8), 4, 0.5, 0.0, 0.0, 0.0, 1.7));
+ instanceMap.put("c5.2xlarge", new CloudInstance("c5.2xlarge",
GBtoBytes(16), 8, 1.0, 0.0, 0.0, 0.0, 2.1));
+
+ return instanceMap;
+ }
+
+ public static File generateTmpInstanceInfoTableFile() throws
IOException {
+ File tmpFile = File.createTempFile("systemds_tmp", ".csv");
+
+ List<String> csvLines = Arrays.asList(
+
"API_Name,Memory,vCPUs,gFlops,ramSpeed,diskSpeed,networkSpeed,Price",
+ "m5.xlarge,16.0,4,0.5,0,143.75,160,1.5",
+ "m5.2xlarge,32.0,8,1.0,0,0,0,1.9",
+ "c5.xlarge,8.0,4,0.5,0,0,0,1.7",
+ "c5.2xlarge,16.0,8,1.0,0,0,0,2.1"
+ );
+ Files.write(tmpFile.toPath(), csvLines);
+ return tmpFile;
+ }
+}
diff --git a/src/test/scripts/component/resource/data/A.csv
b/src/test/scripts/component/resource/data/A.csv
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/scripts/component/resource/data/A.csv.mtd
b/src/test/scripts/component/resource/data/A.csv.mtd
new file mode 100644
index 0000000000..9e43aec450
--- /dev/null
+++ b/src/test/scripts/component/resource/data/A.csv.mtd
@@ -0,0 +1,10 @@
+{
+ "data_type": "matrix",
+ "value_type": "double",
+ "rows": 100000,
+ "cols": 10000,
+ "nnz": 1000000000,
+ "format": "csv",
+ "header": false,
+ "sep": ","
+}
\ No newline at end of file
diff --git a/src/test/scripts/component/resource/data/B.csv
b/src/test/scripts/component/resource/data/B.csv
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/scripts/component/resource/data/B.csv.mtd
b/src/test/scripts/component/resource/data/B.csv.mtd
new file mode 100644
index 0000000000..db7bd64569
--- /dev/null
+++ b/src/test/scripts/component/resource/data/B.csv.mtd
@@ -0,0 +1,10 @@
+{
+ "data_type": "matrix",
+ "value_type": "double",
+ "rows": 10000,
+ "cols": 1000,
+ "nnz": 10000000,
+ "format": "csv",
+ "header": false,
+ "sep": ","
+}
\ No newline at end of file
diff --git a/src/test/scripts/component/resource/data/C.csv
b/src/test/scripts/component/resource/data/C.csv
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/scripts/component/resource/data/C.csv.mtd
b/src/test/scripts/component/resource/data/C.csv.mtd
new file mode 100644
index 0000000000..6b218bdeca
--- /dev/null
+++ b/src/test/scripts/component/resource/data/C.csv.mtd
@@ -0,0 +1,10 @@
+{
+ "data_type": "matrix",
+ "value_type": "double",
+ "rows": 10000,
+ "cols": 10000,
+ "nnz": 100000000,
+ "format": "csv",
+ "header": false,
+ "sep": ","
+}
\ No newline at end of file
diff --git a/src/test/scripts/component/resource/data/D.csv
b/src/test/scripts/component/resource/data/D.csv
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/src/test/scripts/component/resource/data/D.csv.mtd
b/src/test/scripts/component/resource/data/D.csv.mtd
new file mode 100644
index 0000000000..8552d1cdc0
--- /dev/null
+++ b/src/test/scripts/component/resource/data/D.csv.mtd
@@ -0,0 +1,10 @@
+{
+ "data_type": "matrix",
+ "value_type": "double",
+ "rows": 100000,
+ "cols": 1000,
+ "nnz": 100000000,
+ "format": "csv",
+ "header": false,
+ "sep": ","
+}
\ No newline at end of file
diff --git a/src/test/scripts/component/resource/mm_test.dml
b/src/test/scripts/component/resource/mm_test.dml
new file mode 100644
index 0000000000..88bcfc2f2a
--- /dev/null
+++ b/src/test/scripts/component/resource/mm_test.dml
@@ -0,0 +1,34 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+fileX = $X;
+fileY = $Y;
+# R - virtual result
+fileR = "R.csv";
+fmtR = "csv";
+
+X = read(fileX);
+Y = read(fileY);
+
+R = X%*%Y;
+
+# trigger full calculation
+write(R, fileR, fmtR);
\ No newline at end of file
diff --git a/src/test/scripts/component/resource/mm_transpose_test.dml
b/src/test/scripts/component/resource/mm_transpose_test.dml
new file mode 100644
index 0000000000..5cdf451b6f
--- /dev/null
+++ b/src/test/scripts/component/resource/mm_transpose_test.dml
@@ -0,0 +1,32 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+fileX = $X;
+# R - virtual result
+fileR = "R.csv";
+fmtR = "csv";
+
+X = read(fileX);
+
+R = t(X)%*%X;
+
+# trigger full calculation
+write(R, fileR, fmtR);
\ No newline at end of file