This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new f18d3e6993 [SYSTEMDS-3705] Refactoring and cleanup operator scheduling
algorithms
f18d3e6993 is described below
commit f18d3e699383dcac7be4618c005b23b3c2c393d3
Author: Matthias Boehm <[email protected]>
AuthorDate: Sat Jun 1 20:16:41 2024 +0200
[SYSTEMDS-3705] Refactoring and cleanup operator scheduling algorithms
---
.../apache/sysds/conf/ConfigurationManager.java | 15 +-
src/main/java/org/apache/sysds/conf/DMLConfig.java | 4 +-
.../java/org/apache/sysds/lops/compile/Dag.java | 6 +-
.../lops/compile/linearization/IDagLinearizer.java | 39 ++++
.../linearization/IDagLinearizerFactory.java | 58 ++++++
.../linearization/LinearizerBreadthFirst.java | 36 ++++
...asedLinearize.java => LinearizerCostBased.java} | 12 +-
.../linearization/LinearizerDepthFirst.java | 47 +++++
...inearize.java => LinearizerMaxParallelism.java} | 163 ++-------------
.../linearization/LinearizerMinIntermediates.java | 92 +++++++++
...Linearize.java => LinearizerPipelineAware.java} | 14 +-
.../test/functions/caching/PinVariablesTest.java | 228 ++++++++++-----------
.../functions/linearization/ILinearizeTest.java | 8 +-
13 files changed, 434 insertions(+), 288 deletions(-)
diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 4c8fc9a60e..0d5ee888d8 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -29,7 +29,7 @@ import org.apache.sysds.api.DMLScript;
import org.apache.sysds.conf.CompilerConfig.ConfigType;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.Compression.CompressConfig;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import
org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer;
import org.apache.sysds.runtime.io.IOUtilFunctions;
import org.apache.sysds.runtime.util.CommonThreadPool;
@@ -272,7 +272,7 @@ public class ConfigurationManager{
}
public static boolean isMaxPrallelizeEnabled() {
- return (getLinearizationOrder() ==
ILinearize.DagLinearization.MAX_PARALLELIZE
+ return (getLinearizationOrder() == DagLinearizer.MAX_PARALLELIZE
|| OptimizerUtils.MAX_PARALLELIZE_ORDER);
}
@@ -303,15 +303,14 @@ public class ConfigurationManager{
return OptimizerUtils.AUTO_GPU_CACHE_EVICTION;
}
- public static ILinearize.DagLinearization getLinearizationOrder() {
+ public static DagLinearizer getLinearizationOrder() {
if (OptimizerUtils.COST_BASED_ORDERING)
- return ILinearize.DagLinearization.AUTO;
+ return DagLinearizer.AUTO;
else if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
- return ILinearize.DagLinearization.MAX_PARALLELIZE;
+ return DagLinearizer.MAX_PARALLELIZE;
else
- return ILinearize.DagLinearization
-
.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
-
+ return
DagLinearizer.valueOf(ConfigurationManager.getDMLConfig()
+
.getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
}
///////////////////////////////////////
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index c5ca98d3a8..dd4d3b2457 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -45,7 +45,7 @@ import
org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
import org.apache.sysds.hops.codegen.SpoofCompiler.PlanSelector;
import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.lops.Compression;
-import org.apache.sysds.lops.compile.linearization.ILinearize.DagLinearization;
+import
org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer;
import org.apache.sysds.parser.ParseException;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.io.IOUtilFunctions;
@@ -173,7 +173,7 @@ public class DMLConfig
_defaultVals.put(COMPRESSED_COST_MODEL, "AUTO");
_defaultVals.put(COMPRESSED_TRANSPOSE, "auto");
_defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false");
- _defaultVals.put(DAG_LINEARIZATION,
DagLinearization.DEPTH_FIRST.name());
+ _defaultVals.put(DAG_LINEARIZATION,
DagLinearizer.DEPTH_FIRST.name());
_defaultVals.put(CODEGEN, "false" );
_defaultVals.put(CODEGEN_API,
GeneratorAPI.JAVA.name() );
_defaultVals.put(CODEGEN_COMPILER,
CompilerType.AUTO.name() );
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 05eaa24ebb..b26c539e9a 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -43,7 +43,8 @@ import org.apache.sysds.lops.Lop.Type;
import org.apache.sysds.lops.LopsException;
import org.apache.sysds.lops.OutputParameters;
import org.apache.sysds.lops.UnaryCP;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory;
import org.apache.sysds.parser.DataExpression;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
@@ -175,7 +176,8 @@ public class Dag<N extends Lop>
scratch = config.getTextValue(DMLConfig.SCRATCH_SPACE)
+ "/";
}
- List<Lop> node_v = ILinearize.linearize(nodes);
+ IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer();
+ List<Lop> node_v = dl.linearize(nodes);
prefetchFederated(node_v);
// do greedy grouping of operations
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java
new file mode 100644
index 0000000000..1dc301c7e3
--- /dev/null
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import java.util.List;
+
+import org.apache.sysds.lops.Lop;
+
+/**
+ * An interface for the linearization algorithms that order the DAG nodes into
a
+ * sequence of instructions to execute.
+ */
+public abstract class IDagLinearizer {
+ /**
+ * Linearized a DAG of lops into a sequence of lops that preserves all
+ * data dependencies.
+ *
+ * @param v roots (outputs) of a DAG of lops
+ * @return list of lops (input, inner, and outputs)
+ */
+ public abstract List<Lop> linearize(List<Lop> v);
+}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
new file mode 100644
index 0000000000..dcbd005b88
--- /dev/null
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import org.apache.sysds.conf.ConfigurationManager;
+
+public class IDagLinearizerFactory {
+ public static Log LOG =
LogFactory.getLog(IDagLinearizerFactory.class.getName());
+
+ public enum DagLinearizer {
+ DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE,
AUTO, PIPELINE_DEPTH_FIRST;
+ }
+
+ public static IDagLinearizer createDagLinearizer() {
+ DagLinearizer type =
ConfigurationManager.getLinearizationOrder();
+ return createDagLinearizer(type);
+ }
+
+ public static IDagLinearizer createDagLinearizer(DagLinearizer type) {
+ switch(type) {
+ case AUTO:
+ return new LinearizerCostBased();
+ case BREADTH_FIRST:
+ return new LinearizerBreadthFirst();
+ case DEPTH_FIRST:
+ return new LinearizerDepthFirst();
+ case MAX_PARALLELIZE:
+ return new LinearizerMaxParallelism();
+ case MIN_INTERMEDIATE:
+ return new LinearizerMinIntermediates();
+ case PIPELINE_DEPTH_FIRST:
+ return new LinearizerPipelineAware();
+ default:
+ LOG.warn("Invalid DAG_LINEARIZATION: "+type+",
falling back to DEPTH_FIRST ordering");
+ return new LinearizerDepthFirst();
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
new file mode 100644
index 0000000000..23d4150e78
--- /dev/null
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.lops.Lop;
+
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class LinearizerBreadthFirst extends IDagLinearizer
+{
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
+ return v.stream()
+ .sorted(Comparator.comparing(Lop::getLevel))
+ .collect(Collectors.toList());
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
similarity index 95%
rename from
src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
rename to
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
index 707277630b..a036cf0dd3 100644
---
a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
@@ -31,25 +31,25 @@ import java.util.List;
import java.util.Stack;
import java.util.stream.Collectors;
-public class CostBasedLinearize
+public class LinearizerCostBased extends IDagLinearizer
{
- public static List<Lop> getBestOrder(List<Lop> lops)
- {
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
// Simplify DAG by removing literals and transient inputs and
outputs
List<Lop> removedLeaves = new ArrayList<>();
List<Lop> removedRoots = new ArrayList<>();
HashMap<Long, ArrayList<Lop>> removedInputs = new HashMap<>();
HashMap<Long, ArrayList<Lop>> removedOutputs = new HashMap<>();
- simplifyDag(lops, removedLeaves, removedRoots, removedInputs,
removedOutputs);
+ simplifyDag(v, removedLeaves, removedRoots, removedInputs,
removedOutputs);
// TODO: Partition the DAG if connected by a single node.
Optimize separately
// Collect the leaf nodes of the simplified DAG
- List<Lop> leafNodes = lops.stream().filter(l ->
l.getInputs().isEmpty()).collect(Collectors.toList());
+ List<Lop> leafNodes = v.stream().filter(l ->
l.getInputs().isEmpty()).collect(Collectors.toList());
// For each leaf, find all possible orders starting from the
given leaf
List<Order> finalOrders = new ArrayList<>();
for (Lop leaf : leafNodes)
- generateOrders(leaf, leafNodes, finalOrders,
lops.size());
+ generateOrders(leaf, leafNodes, finalOrders, v.size());
// TODO: Handle distributed and GPU operators (0 compute cost,
memory overhead on collect)
// TODO: Asynchronous operators (max of compute costs, total
operation memory overhead)
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
new file mode 100644
index 0000000000..c0a791eea5
--- /dev/null
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.lops.Lop;
+
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class LinearizerDepthFirst extends IDagLinearizer
+{
+ // previously called doTopologicalSortTwoLevelOrder
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
+ // partition nodes into leaf/inner nodes and dag root nodes,
+ // + sort leaf/inner nodes by ID to force depth-first scheduling
+ // + append root nodes in order of their original definition
+ // (which also preserves the original order of prints)
+ List<Lop> nodes = Stream
+ .concat(v.stream().filter(l ->
!l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())),
+ v.stream().filter(l ->
l.getOutputs().isEmpty()))
+ .collect(Collectors.toList());
+
+ // NOTE: in contrast to hadoop execution modes, we avoid
computing the transitive
+ // closure here to ensure linear time complexity because its
unnecessary for CP and Spark
+ return nodes;
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
similarity index 57%
rename from
src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
rename to
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
index 288017bf1b..92844ec5df 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
@@ -19,26 +19,9 @@
package org.apache.sysds.lops.compile.linearization;
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
-import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.lops.CSVReBlock;
import org.apache.sysds.lops.CentralMoment;
import org.apache.sysds.lops.Checkpoint;
@@ -57,141 +40,28 @@ import org.apache.sysds.lops.SpoofFused;
import org.apache.sysds.lops.UAggOuterChain;
import org.apache.sysds.lops.UnaryCP;
-/**
- * A interface for the linearization algorithms that order the DAG nodes into
a sequence of instructions to execute.
- *
- * https://en.wikipedia.org/wiki/Linearizability#Linearization_points
- */
-public class ILinearize {
- public static Log LOG = LogFactory.getLog(ILinearize.class.getName());
-
- public enum DagLinearization {
- DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE,
AUTO, PIPELINE_DEPTH_FIRST;
- }
-
- public static List<Lop> linearize(List<Lop> v) {
- try {
- DagLinearization linearization =
ConfigurationManager.getLinearizationOrder();
-
- switch(linearization) {
- case MAX_PARALLELIZE:
- return doMaxParallelizeSort(v);
- case AUTO:
- return
CostBasedLinearize.getBestOrder(v);
- case MIN_INTERMEDIATE:
- return doMinIntermediateSort(v);
- case BREADTH_FIRST:
- return doBreadthFirstSort(v);
- case PIPELINE_DEPTH_FIRST:
- return
PipelineAwareLinearize.pipelineDepthFirst(v);
- case DEPTH_FIRST:
- default:
- return depthFirst(v);
- }
- }
- catch(Exception e) {
- LOG.warn("Invalid DAG_LINEARIZATION
"+ConfigurationManager.getLinearizationOrder()+", fallback to DEPTH_FIRST
ordering");
- return depthFirst(v);
- }
- }
-
- /**
- * Sort lops depth-first
- *
- * previously called doTopologicalSortTwoLevelOrder
- *
- * @param v List of lops to sort
- * @return Sorted list of lops
- */
- protected static List<Lop> depthFirst(List<Lop> v) {
- // partition nodes into leaf/inner nodes and dag root nodes,
- // + sort leaf/inner nodes by ID to force depth-first scheduling
- // + append root nodes in order of their original definition
- // (which also preserves the original order of prints)
- List<Lop> nodes = Stream
- .concat(v.stream().filter(l ->
!l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())),
- v.stream().filter(l ->
l.getOutputs().isEmpty()))
- .collect(Collectors.toList());
-
- // NOTE: in contrast to hadoop execution modes, we avoid
computing the transitive
- // closure here to ensure linear time complexity because its
unnecessary for CP and Spark
- return nodes;
- }
-
- private static List<Lop> doBreadthFirstSort(List<Lop> v) {
- List<Lop> nodes =
v.stream().sorted(Comparator.comparing(Lop::getLevel)).collect(Collectors.toList());
-
- return nodes;
- }
-
- /**
- * Sort lops to execute them in an order that minimizes the memory
requirements of intermediates
- *
- * @param v List of lops to sort
- * @return Sorted list of lops
- */
- private static List<Lop> doMinIntermediateSort(List<Lop> v) {
- List<Lop> nodes = new ArrayList<>(v.size());
- // Get the lowest level in the tree to move upwards from
- List<Lop> lowestLevel = v.stream().filter(l ->
l.getOutputs().isEmpty()).collect(Collectors.toList());
-
- // Traverse the tree bottom up, choose nodes with higher memory
requirements, then reverse the list
- List<Lop> remaining = new LinkedList<>(v);
- sortRecursive(nodes, lowestLevel, remaining);
-
- // In some cases (function calls) some output lops are not in
the list of nodes to be sorted.
- // With the next layer up having output lops, they are not
added to the initial list of lops and are
- // subsequently never reached by the recursive sort.
- // We work around this issue by checking for remaining lops
after the initial sort.
- while(!remaining.isEmpty()) {
- // Start with the lowest level lops, this time by level
instead of no outputs
- int maxLevel =
remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1);
- List<Lop> lowestNodes = remaining.stream().filter(l ->
l.getLevel() == maxLevel).collect(Collectors.toList());
- sortRecursive(nodes, lowestNodes, remaining);
- }
-
- // All lops were added bottom up, from highest to lowest memory
consumption, now reverse this
- Collections.reverse(nodes);
-
- return nodes;
- }
-
- private static void sortRecursive(List<Lop> result, List<Lop> input,
List<Lop> remaining) {
- // Sort input lops by memory estimate
- // Lowest level nodes (those with no outputs) receive a memory
estimate of 0 to preserve order
- // This affects prints, writes, ...
- List<Map.Entry<Lop, Long>> memEst =
input.stream().distinct().map(l -> new AbstractMap.SimpleEntry<>(l,
- l.getOutputs().isEmpty() ? 0 :
OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(),
- l.getOutputParameters().getNumCols(),
l.getOutputParameters().getNnz())))
- .sorted(Comparator.comparing(e -> ((Map.Entry<Lop,
Long>) e).getValue())).collect(Collectors.toList());
-
- // Start with the highest memory estimate because the entire
list is reversed later
- Collections.reverse(memEst);
- for(Map.Entry<Lop, Long> e : memEst) {
- // Skip if the node is already in the result list
- // Skip if one of the lop's outputs is not in the
result list yet (will be added once the output lop is
- // traversed), but only if any of the output lops is
bound to be added to the result at a later stage
- if(result.contains(e.getKey()) ||
(!result.containsAll(e.getKey().getOutputs()) &&
- remaining.stream().anyMatch(l ->
e.getKey().getOutputs().contains(l))))
- continue;
- result.add(e.getKey());
- remaining.remove(e.getKey());
- // Add input lops recursively
- sortRecursive(result, e.getKey().getInputs(),
remaining);
- }
- }
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+public class LinearizerMaxParallelism extends IDagLinearizer
+{
// Place the Spark operation chains first (more expensive to less
expensive),
// followed by asynchronously triggering operators and CP chains.
- private static List<Lop> doMaxParallelizeSort(List<Lop> v)
- {
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
List<Lop> v2 = v;
- boolean hasSpark =
v.stream().anyMatch(ILinearize::isDistributedOp);
- boolean hasGPU = v.stream().anyMatch(ILinearize::isGPUOp);
+ boolean hasSpark =
v.stream().anyMatch(LinearizerMaxParallelism::isDistributedOp);
+ boolean hasGPU =
v.stream().anyMatch(LinearizerMaxParallelism::isGPUOp);
// Fallback to default depth-first if all operators are CP
if (!hasSpark && !hasGPU)
- return depthFirst(v);
+ return new LinearizerDepthFirst().linearize(v);
if (hasSpark) {
// Step 1: Collect the Spark roots and #Spark
instructions in each subDAG
@@ -236,6 +106,7 @@ public class ILinearize {
}
return v2;
}
+
// Place the operators in a depth-first manner, but order
// the DAGs based on number of Spark operators
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
new file mode 100644
index 0000000000..4daf435d24
--- /dev/null
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.Lop;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+
+/**
+ * Sort lops to execute them in an order that
+ * minimizes the memory requirements of intermediates
+ */
+public class LinearizerMinIntermediates extends IDagLinearizer
+{
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
+ List<Lop> nodes = new ArrayList<>(v.size());
+ // Get the lowest level in the tree to move upwards from
+ List<Lop> lowestLevel = v.stream().filter(l ->
l.getOutputs().isEmpty()).collect(Collectors.toList());
+
+ // Traverse the tree bottom up, choose nodes with higher memory
requirements, then reverse the list
+ List<Lop> remaining = new LinkedList<>(v);
+ sortRecursive(nodes, lowestLevel, remaining);
+
+ // In some cases (function calls) some output lops are not in
the list of nodes to be sorted.
+ // With the next layer up having output lops, they are not
added to the initial list of lops and are
+ // subsequently never reached by the recursive sort.
+ // We work around this issue by checking for remaining lops
after the initial sort.
+ while(!remaining.isEmpty()) {
+ // Start with the lowest level lops, this time by level
instead of no outputs
+ int maxLevel =
remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1);
+ List<Lop> lowestNodes = remaining.stream().filter(l ->
l.getLevel() == maxLevel).collect(Collectors.toList());
+ sortRecursive(nodes, lowestNodes, remaining);
+ }
+
+ // All lops were added bottom up, from highest to lowest memory
consumption, now reverse this
+ Collections.reverse(nodes);
+
+ return nodes;
+ }
+
+ private static void sortRecursive(List<Lop> result, List<Lop> input,
List<Lop> remaining) {
+ // Sort input lops by memory estimate
+ // Lowest level nodes (those with no outputs) receive a memory
estimate of 0 to preserve order
+ // This affects prints, writes, ...
+ List<Entry<Lop, Long>> memEst = input.stream().distinct().map(l
-> new AbstractMap.SimpleEntry<>(l,
+ l.getOutputs().isEmpty() ? 0 :
OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(),
+ l.getOutputParameters().getNumCols(),
l.getOutputParameters().getNnz())))
+ .sorted(Comparator.comparing(e -> ((Map.Entry<Lop,
Long>) e).getValue())).collect(Collectors.toList());
+
+ // Start with the highest memory estimate because the entire
list is reversed later
+ Collections.reverse(memEst);
+ for(Map.Entry<Lop, Long> e : memEst) {
+ // Skip if the node is already in the result list
+ // Skip if one of the lop's outputs is not in the
result list yet (will be added once the output lop is
+ // traversed), but only if any of the output lops is
bound to be added to the result at a later stage
+ if(result.contains(e.getKey()) ||
(!result.containsAll(e.getKey().getOutputs()) &&
+ remaining.stream().anyMatch(l ->
e.getKey().getOutputs().contains(l))))
+ continue;
+ result.add(e.getKey());
+ remaining.remove(e.getKey());
+ // Add input lops recursively
+ sortRecursive(result, e.getKey().getInputs(),
remaining);
+ }
+ }
+}
diff --git
a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
similarity index 96%
rename from
src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
rename to
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
index dcc476a688..cebf14c6dc 100644
---
a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
+++
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
@@ -28,8 +28,8 @@ import java.util.stream.Collectors;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.lops.OperatorOrderingUtils;
-public class PipelineAwareLinearize {
-
+public class LinearizerPipelineAware extends IDagLinearizer
+{
// Minimum number of nodes in DAG for applying algorithm
private final static int IGNORE_LIMIT = 0;
@@ -45,12 +45,12 @@ public class PipelineAwareLinearize {
* @param v List of lops to sort
* @return Sorted list of lops with set _pipelineID on the Lop Object
*/
- public static List<Lop> pipelineDepthFirst(List<Lop> v) {
-
+ @Override
+ public List<Lop> linearize(List<Lop> v) {
// If size of DAG is smaller than IGNORE_LIMIT, give all nodes
the same pipeline id
if(v.size() <= IGNORE_LIMIT) {
v.forEach(l -> l.setPipelineID(1));
- return ILinearize.depthFirst(v);
+ return new LinearizerDepthFirst().linearize(v);
}
// Find all root nodes (starting points for the depth-first
traversal)
@@ -74,11 +74,11 @@ public class PipelineAwareLinearize {
//DEVPrintDAG.asGraphviz("Step1", v);
// Step 2: Merge pipelines with only one node to another
(connected) pipeline
- PipelineAwareLinearize.mergeSingleNodePipelines(pipelineMap);
+ LinearizerPipelineAware.mergeSingleNodePipelines(pipelineMap);
//DEVPrintDAG.asGraphviz("Step2", v);
// Step 3: Merge small pipelines into bigger ones
- PipelineAwareLinearize.mergeSmallPipelines(pipelineMap);
+ LinearizerPipelineAware.mergeSmallPipelines(pipelineMap);
//DEVPrintDAG.asGraphviz("Step3", v);
// Reset the visited status of all nodes
diff --git
a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
index 346102398c..835e0b45a2 100644
---
a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
@@ -31,122 +31,122 @@ import org.apache.sysds.test.TestConfiguration;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
-import java.util.LinkedList;
import java.util.Queue;
+import java.util.LinkedList;
import java.util.List;
public class PinVariablesTest extends AutomatedTestBase {
- private final static String TEST_NAME = "PinVariables";
- private final static String TEST_DIR = "functions/caching/";
- private final static String TEST_CLASS_DIR = TEST_DIR +
PinVariablesTest.class.getSimpleName() + "/";
-
- @Override
- public void setUp() {
- addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR,
TEST_NAME));
- }
-
- @Test
- public void testPinNoLists() {
- createMockDataAndCall(true, false, false);
- }
-
- @Test
- public void testPinShallowLists() {
- createMockDataAndCall(true, true, false);
- }
-
- @Test
- public void testPinNestedLists() {
- createMockDataAndCall(true, true, true);
- }
-
- private void createMockDataAndCall(boolean matrices, boolean list, boolean
nestedList) {
- LocalVariableMap vars = new LocalVariableMap();
- List<String> varList = new LinkedList<>();
- Queue<Boolean> varStates = new LinkedList<>();
-
- if (matrices) {
- MatrixObject mat1 = new MatrixObject(Types.ValueType.FP64,
"SomeFile1");
- mat1.enableCleanup(true);
- MatrixObject mat2 = new MatrixObject(Types.ValueType.FP64,
"SomeFile2");
- mat2.enableCleanup(true);
- MatrixObject mat3 = new MatrixObject(Types.ValueType.FP64,
"SomeFile3");
- mat3.enableCleanup(false);
- vars.put("mat1", mat1);
- vars.put("mat2", mat2);
- vars.put("mat3", mat3);
-
- varList.add("mat2");
- varList.add("mat3");
-
- varStates.add(true);
- varStates.add(false);
- }
- if (list) {
- MatrixObject mat4 = new MatrixObject(Types.ValueType.FP64,
"SomeFile4");
- mat4.enableCleanup(true);
- MatrixObject mat5 = new MatrixObject(Types.ValueType.FP64,
"SomeFile5");
- mat5.enableCleanup(false);
- List<Data> l1_data = new LinkedList<>();
- l1_data.add(mat4);
- l1_data.add(mat5);
-
- if (nestedList) {
- MatrixObject mat6 = new MatrixObject(Types.ValueType.FP64,
"SomeFile6");
- mat4.enableCleanup(true);
- List<Data> l2_data = new LinkedList<>();
- l2_data.add(mat6);
- ListObject l2 = new ListObject(l2_data);
- l1_data.add(l2);
- }
-
- ListObject l1 = new ListObject(l1_data);
- vars.put("l1", l1);
-
- varList.add("l1");
-
- // cleanup flag of inner matrix (m4)
- varStates.add(true);
- varStates.add(false);
- if (nestedList)
- varStates.add(true);
- }
-
- ExecutionContext ec = new ExecutionContext(vars);
-
- commonPinVariablesTest(ec, varList, varStates);
- }
-
- private void commonPinVariablesTest(ExecutionContext ec, List<String>
varList, Queue<Boolean> varStatesExp) {
- Queue<Boolean> varStates = ec.pinVariables(varList);
-
- // check returned cleanupEnabled flags
- Assert.assertEquals(varStatesExp, varStates);
-
- // assert updated cleanupEnabled flag to false
- for (String varName : varList) {
- Data dat = ec.getVariable(varName);
-
- if (dat instanceof CacheableData<?>)
- Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
- else if (dat instanceof ListObject) {
- assertListFlagsDisabled((ListObject)dat);
- }
- }
-
- ec.unpinVariables(varList, varStates);
-
- // check returned flags after unpinVariables()
- Queue<Boolean> varStates2 = ec.pinVariables(varList);
- Assert.assertEquals(varStatesExp, varStates2);
- }
-
- private void assertListFlagsDisabled(ListObject l) {
- for (Data dat : l.getData()) {
- if (dat instanceof CacheableData<?>)
- Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
- else if (dat instanceof ListObject)
- assertListFlagsDisabled((ListObject)dat);
- }
- }
+ private final static String TEST_NAME = "PinVariables";
+ private final static String TEST_DIR = "functions/caching/";
+ private final static String TEST_CLASS_DIR = TEST_DIR +
PinVariablesTest.class.getSimpleName() + "/";
+
+ @Override
+ public void setUp() {
+ addTestConfiguration(TEST_NAME, new
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+ }
+
+ @Test
+ public void testPinNoLists() {
+ createMockDataAndCall(true, false, false);
+ }
+
+ @Test
+ public void testPinShallowLists() {
+ createMockDataAndCall(true, true, false);
+ }
+
+ @Test
+ public void testPinNestedLists() {
+ createMockDataAndCall(true, true, true);
+ }
+
+ private void createMockDataAndCall(boolean matrices, boolean list,
boolean nestedList) {
+ LocalVariableMap vars = new LocalVariableMap();
+ List<String> varList = new LinkedList<>();
+ Queue<Boolean> varStates = new LinkedList<>();
+
+ if (matrices) {
+ MatrixObject mat1 = new
MatrixObject(Types.ValueType.FP64, "SomeFile1");
+ mat1.enableCleanup(true);
+ MatrixObject mat2 = new
MatrixObject(Types.ValueType.FP64, "SomeFile2");
+ mat2.enableCleanup(true);
+ MatrixObject mat3 = new
MatrixObject(Types.ValueType.FP64, "SomeFile3");
+ mat3.enableCleanup(false);
+ vars.put("mat1", mat1);
+ vars.put("mat2", mat2);
+ vars.put("mat3", mat3);
+
+ varList.add("mat2");
+ varList.add("mat3");
+
+ varStates.add(true);
+ varStates.add(false);
+ }
+ if (list) {
+ MatrixObject mat4 = new
MatrixObject(Types.ValueType.FP64, "SomeFile4");
+ mat4.enableCleanup(true);
+ MatrixObject mat5 = new
MatrixObject(Types.ValueType.FP64, "SomeFile5");
+ mat5.enableCleanup(false);
+ List<Data> l1_data = new LinkedList<>();
+ l1_data.add(mat4);
+ l1_data.add(mat5);
+
+ if (nestedList) {
+ MatrixObject mat6 = new
MatrixObject(Types.ValueType.FP64, "SomeFile6");
+ mat4.enableCleanup(true);
+ List<Data> l2_data = new LinkedList<>();
+ l2_data.add(mat6);
+ ListObject l2 = new ListObject(l2_data);
+ l1_data.add(l2);
+ }
+
+ ListObject l1 = new ListObject(l1_data);
+ vars.put("l1", l1);
+
+ varList.add("l1");
+
+ // cleanup flag of inner matrix (m4)
+ varStates.add(true);
+ varStates.add(false);
+ if (nestedList)
+ varStates.add(true);
+ }
+
+ ExecutionContext ec = new ExecutionContext(vars);
+
+ commonPinVariablesTest(ec, varList, varStates);
+ }
+
+ private void commonPinVariablesTest(ExecutionContext ec, List<String>
varList, Queue<Boolean> varStatesExp) {
+ Queue<Boolean> varStates = ec.pinVariables(varList);
+
+ // check returned cleanupEnabled flags
+ Assert.assertEquals(varStatesExp, varStates);
+
+ // assert updated cleanupEnabled flag to false
+ for (String varName : varList) {
+ Data dat = ec.getVariable(varName);
+
+ if (dat instanceof CacheableData<?>)
+
Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject) {
+ assertListFlagsDisabled((ListObject)dat);
+ }
+ }
+
+ ec.unpinVariables(varList, varStates);
+
+ // check returned flags after unpinVariables()
+ Queue<Boolean> varStates2 = ec.pinVariables(varList);
+ Assert.assertEquals(varStatesExp, varStates2);
+ }
+
+ private void assertListFlagsDisabled(ListObject l) {
+ for (Data dat : l.getData()) {
+ if (dat instanceof CacheableData<?>)
+
Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+ else if (dat instanceof ListObject)
+ assertListFlagsDisabled((ListObject)dat);
+ }
+ }
}
diff --git
a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
index e6680c29c2..bf1fa0fcfe 100644
---
a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
@@ -36,7 +36,8 @@ import org.junit.Test;
import org.apache.sysds.common.Types.DataType;
import org.apache.sysds.common.Types.OpOp2;
import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
@@ -148,8 +149,9 @@ public class ILinearizeTest extends AutomatedTestBase {
lops.forEach(l -> {l.getInputs().remove(d1);
l.getInputs().remove(d2);});
// RUN LINEARIZATION
- ILinearize.linearize(lops);
-
+ IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer();
+ dl.linearize(lops); //TODO results
+
// Set up expected pipelines
Map<Integer, List<Lop>> pipelineMap = new HashMap<>();
pipelineMap.put(4, Arrays.asList(n1, n2, n3, n4, n5, o1, o2));