This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new f18d3e6993 [SYSTEMDS-3705] Refactoring and cleanup operator scheduling algorithms f18d3e6993 is described below commit f18d3e699383dcac7be4618c005b23b3c2c393d3 Author: Matthias Boehm <mboe...@gmail.com> AuthorDate: Sat Jun 1 20:16:41 2024 +0200 [SYSTEMDS-3705] Refactoring and cleanup operator scheduling algorithms --- .../apache/sysds/conf/ConfigurationManager.java | 15 +- src/main/java/org/apache/sysds/conf/DMLConfig.java | 4 +- .../java/org/apache/sysds/lops/compile/Dag.java | 6 +- .../lops/compile/linearization/IDagLinearizer.java | 39 ++++ .../linearization/IDagLinearizerFactory.java | 58 ++++++ .../linearization/LinearizerBreadthFirst.java | 36 ++++ ...asedLinearize.java => LinearizerCostBased.java} | 12 +- .../linearization/LinearizerDepthFirst.java | 47 +++++ ...inearize.java => LinearizerMaxParallelism.java} | 163 ++------------- .../linearization/LinearizerMinIntermediates.java | 92 +++++++++ ...Linearize.java => LinearizerPipelineAware.java} | 14 +- .../test/functions/caching/PinVariablesTest.java | 228 ++++++++++----------- .../functions/linearization/ILinearizeTest.java | 8 +- 13 files changed, 434 insertions(+), 288 deletions(-) diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java index 4c8fc9a60e..0d5ee888d8 100644 --- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java +++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java @@ -29,7 +29,7 @@ import org.apache.sysds.api.DMLScript; import org.apache.sysds.conf.CompilerConfig.ConfigType; import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.Compression.CompressConfig; -import org.apache.sysds.lops.compile.linearization.ILinearize; +import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer; import org.apache.sysds.runtime.io.IOUtilFunctions; import org.apache.sysds.runtime.util.CommonThreadPool; @@ -272,7 +272,7 @@ public class ConfigurationManager{ } public static boolean isMaxPrallelizeEnabled() { - return (getLinearizationOrder() == ILinearize.DagLinearization.MAX_PARALLELIZE + return (getLinearizationOrder() == DagLinearizer.MAX_PARALLELIZE || OptimizerUtils.MAX_PARALLELIZE_ORDER); } @@ -303,15 +303,14 @@ public class ConfigurationManager{ return OptimizerUtils.AUTO_GPU_CACHE_EVICTION; } - public static ILinearize.DagLinearization getLinearizationOrder() { + public static DagLinearizer getLinearizationOrder() { if (OptimizerUtils.COST_BASED_ORDERING) - return ILinearize.DagLinearization.AUTO; + return DagLinearizer.AUTO; else if (OptimizerUtils.MAX_PARALLELIZE_ORDER) - return ILinearize.DagLinearization.MAX_PARALLELIZE; + return DagLinearizer.MAX_PARALLELIZE; else - return ILinearize.DagLinearization - .valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase()); - + return DagLinearizer.valueOf(ConfigurationManager.getDMLConfig() + .getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase()); } /////////////////////////////////////// diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java b/src/main/java/org/apache/sysds/conf/DMLConfig.java index c5ca98d3a8..dd4d3b2457 100644 --- a/src/main/java/org/apache/sysds/conf/DMLConfig.java +++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java @@ -45,7 +45,7 @@ import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI; import org.apache.sysds.hops.codegen.SpoofCompiler.PlanSelector; import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner; import org.apache.sysds.lops.Compression; -import org.apache.sysds.lops.compile.linearization.ILinearize.DagLinearization; +import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer; import org.apache.sysds.parser.ParseException; import org.apache.sysds.runtime.DMLRuntimeException; import org.apache.sysds.runtime.io.IOUtilFunctions; @@ -173,7 +173,7 @@ public class DMLConfig _defaultVals.put(COMPRESSED_COST_MODEL, "AUTO"); _defaultVals.put(COMPRESSED_TRANSPOSE, "auto"); _defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false"); - _defaultVals.put(DAG_LINEARIZATION, DagLinearization.DEPTH_FIRST.name()); + _defaultVals.put(DAG_LINEARIZATION, DagLinearizer.DEPTH_FIRST.name()); _defaultVals.put(CODEGEN, "false" ); _defaultVals.put(CODEGEN_API, GeneratorAPI.JAVA.name() ); _defaultVals.put(CODEGEN_COMPILER, CompilerType.AUTO.name() ); diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index 05eaa24ebb..b26c539e9a 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -43,7 +43,8 @@ import org.apache.sysds.lops.Lop.Type; import org.apache.sysds.lops.LopsException; import org.apache.sysds.lops.OutputParameters; import org.apache.sysds.lops.UnaryCP; -import org.apache.sysds.lops.compile.linearization.ILinearize; +import org.apache.sysds.lops.compile.linearization.IDagLinearizer; +import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory; import org.apache.sysds.parser.DataExpression; import org.apache.sysds.parser.StatementBlock; import org.apache.sysds.runtime.DMLRuntimeException; @@ -175,7 +176,8 @@ public class Dag<N extends Lop> scratch = config.getTextValue(DMLConfig.SCRATCH_SPACE) + "/"; } - List<Lop> node_v = ILinearize.linearize(nodes); + IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer(); + List<Lop> node_v = dl.linearize(nodes); prefetchFederated(node_v); // do greedy grouping of operations diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java new file mode 100644 index 0000000000..1dc301c7e3 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java @@ -0,0 +1,39 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops.compile.linearization; + +import java.util.List; + +import org.apache.sysds.lops.Lop; + +/** + * An interface for the linearization algorithms that order the DAG nodes into a + * sequence of instructions to execute. + */ +public abstract class IDagLinearizer { + /** + * Linearized a DAG of lops into a sequence of lops that preserves all + * data dependencies. + * + * @param v roots (outputs) of a DAG of lops + * @return list of lops (input, inner, and outputs) + */ + public abstract List<Lop> linearize(List<Lop> v); +} diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java new file mode 100644 index 0000000000..dcbd005b88 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops.compile.linearization; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.apache.sysds.conf.ConfigurationManager; + +public class IDagLinearizerFactory { + public static Log LOG = LogFactory.getLog(IDagLinearizerFactory.class.getName()); + + public enum DagLinearizer { + DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, AUTO, PIPELINE_DEPTH_FIRST; + } + + public static IDagLinearizer createDagLinearizer() { + DagLinearizer type = ConfigurationManager.getLinearizationOrder(); + return createDagLinearizer(type); + } + + public static IDagLinearizer createDagLinearizer(DagLinearizer type) { + switch(type) { + case AUTO: + return new LinearizerCostBased(); + case BREADTH_FIRST: + return new LinearizerBreadthFirst(); + case DEPTH_FIRST: + return new LinearizerDepthFirst(); + case MAX_PARALLELIZE: + return new LinearizerMaxParallelism(); + case MIN_INTERMEDIATE: + return new LinearizerMinIntermediates(); + case PIPELINE_DEPTH_FIRST: + return new LinearizerPipelineAware(); + default: + LOG.warn("Invalid DAG_LINEARIZATION: "+type+", falling back to DEPTH_FIRST ordering"); + return new LinearizerDepthFirst(); + } + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java new file mode 100644 index 0000000000..23d4150e78 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java @@ -0,0 +1,36 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops.compile.linearization; + +import org.apache.sysds.lops.Lop; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; + +public class LinearizerBreadthFirst extends IDagLinearizer +{ + @Override + public List<Lop> linearize(List<Lop> v) { + return v.stream() + .sorted(Comparator.comparing(Lop::getLevel)) + .collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java similarity index 95% rename from src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java rename to src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java index 707277630b..a036cf0dd3 100644 --- a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java @@ -31,25 +31,25 @@ import java.util.List; import java.util.Stack; import java.util.stream.Collectors; -public class CostBasedLinearize +public class LinearizerCostBased extends IDagLinearizer { - public static List<Lop> getBestOrder(List<Lop> lops) - { + @Override + public List<Lop> linearize(List<Lop> v) { // Simplify DAG by removing literals and transient inputs and outputs List<Lop> removedLeaves = new ArrayList<>(); List<Lop> removedRoots = new ArrayList<>(); HashMap<Long, ArrayList<Lop>> removedInputs = new HashMap<>(); HashMap<Long, ArrayList<Lop>> removedOutputs = new HashMap<>(); - simplifyDag(lops, removedLeaves, removedRoots, removedInputs, removedOutputs); + simplifyDag(v, removedLeaves, removedRoots, removedInputs, removedOutputs); // TODO: Partition the DAG if connected by a single node. Optimize separately // Collect the leaf nodes of the simplified DAG - List<Lop> leafNodes = lops.stream().filter(l -> l.getInputs().isEmpty()).collect(Collectors.toList()); + List<Lop> leafNodes = v.stream().filter(l -> l.getInputs().isEmpty()).collect(Collectors.toList()); // For each leaf, find all possible orders starting from the given leaf List<Order> finalOrders = new ArrayList<>(); for (Lop leaf : leafNodes) - generateOrders(leaf, leafNodes, finalOrders, lops.size()); + generateOrders(leaf, leafNodes, finalOrders, v.size()); // TODO: Handle distributed and GPU operators (0 compute cost, memory overhead on collect) // TODO: Asynchronous operators (max of compute costs, total operation memory overhead) diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java new file mode 100644 index 0000000000..c0a791eea5 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops.compile.linearization; + +import org.apache.sysds.lops.Lop; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +public class LinearizerDepthFirst extends IDagLinearizer +{ + // previously called doTopologicalSortTwoLevelOrder + @Override + public List<Lop> linearize(List<Lop> v) { + // partition nodes into leaf/inner nodes and dag root nodes, + // + sort leaf/inner nodes by ID to force depth-first scheduling + // + append root nodes in order of their original definition + // (which also preserves the original order of prints) + List<Lop> nodes = Stream + .concat(v.stream().filter(l -> !l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())), + v.stream().filter(l -> l.getOutputs().isEmpty())) + .collect(Collectors.toList()); + + // NOTE: in contrast to hadoop execution modes, we avoid computing the transitive + // closure here to ensure linear time complexity because its unnecessary for CP and Spark + return nodes; + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java similarity index 57% rename from src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java rename to src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java index 288017bf1b..92844ec5df 100644 --- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java @@ -19,26 +19,9 @@ package org.apache.sysds.lops.compile.linearization; -import java.util.AbstractMap; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashMap; -import java.util.HashSet; -import java.util.LinkedList; -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -import org.apache.commons.logging.Log; -import org.apache.commons.logging.LogFactory; -import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; -import org.apache.sysds.conf.ConfigurationManager; +import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.hops.AggBinaryOp.SparkAggType; -import org.apache.sysds.hops.OptimizerUtils; import org.apache.sysds.lops.CSVReBlock; import org.apache.sysds.lops.CentralMoment; import org.apache.sysds.lops.Checkpoint; @@ -57,141 +40,28 @@ import org.apache.sysds.lops.SpoofFused; import org.apache.sysds.lops.UAggOuterChain; import org.apache.sysds.lops.UnaryCP; -/** - * A interface for the linearization algorithms that order the DAG nodes into a sequence of instructions to execute. - * - * https://en.wikipedia.org/wiki/Linearizability#Linearization_points - */ -public class ILinearize { - public static Log LOG = LogFactory.getLog(ILinearize.class.getName()); - - public enum DagLinearization { - DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, AUTO, PIPELINE_DEPTH_FIRST; - } - - public static List<Lop> linearize(List<Lop> v) { - try { - DagLinearization linearization = ConfigurationManager.getLinearizationOrder(); - - switch(linearization) { - case MAX_PARALLELIZE: - return doMaxParallelizeSort(v); - case AUTO: - return CostBasedLinearize.getBestOrder(v); - case MIN_INTERMEDIATE: - return doMinIntermediateSort(v); - case BREADTH_FIRST: - return doBreadthFirstSort(v); - case PIPELINE_DEPTH_FIRST: - return PipelineAwareLinearize.pipelineDepthFirst(v); - case DEPTH_FIRST: - default: - return depthFirst(v); - } - } - catch(Exception e) { - LOG.warn("Invalid DAG_LINEARIZATION "+ConfigurationManager.getLinearizationOrder()+", fallback to DEPTH_FIRST ordering"); - return depthFirst(v); - } - } - - /** - * Sort lops depth-first - * - * previously called doTopologicalSortTwoLevelOrder - * - * @param v List of lops to sort - * @return Sorted list of lops - */ - protected static List<Lop> depthFirst(List<Lop> v) { - // partition nodes into leaf/inner nodes and dag root nodes, - // + sort leaf/inner nodes by ID to force depth-first scheduling - // + append root nodes in order of their original definition - // (which also preserves the original order of prints) - List<Lop> nodes = Stream - .concat(v.stream().filter(l -> !l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())), - v.stream().filter(l -> l.getOutputs().isEmpty())) - .collect(Collectors.toList()); - - // NOTE: in contrast to hadoop execution modes, we avoid computing the transitive - // closure here to ensure linear time complexity because its unnecessary for CP and Spark - return nodes; - } - - private static List<Lop> doBreadthFirstSort(List<Lop> v) { - List<Lop> nodes = v.stream().sorted(Comparator.comparing(Lop::getLevel)).collect(Collectors.toList()); - - return nodes; - } - - /** - * Sort lops to execute them in an order that minimizes the memory requirements of intermediates - * - * @param v List of lops to sort - * @return Sorted list of lops - */ - private static List<Lop> doMinIntermediateSort(List<Lop> v) { - List<Lop> nodes = new ArrayList<>(v.size()); - // Get the lowest level in the tree to move upwards from - List<Lop> lowestLevel = v.stream().filter(l -> l.getOutputs().isEmpty()).collect(Collectors.toList()); - - // Traverse the tree bottom up, choose nodes with higher memory requirements, then reverse the list - List<Lop> remaining = new LinkedList<>(v); - sortRecursive(nodes, lowestLevel, remaining); - - // In some cases (function calls) some output lops are not in the list of nodes to be sorted. - // With the next layer up having output lops, they are not added to the initial list of lops and are - // subsequently never reached by the recursive sort. - // We work around this issue by checking for remaining lops after the initial sort. - while(!remaining.isEmpty()) { - // Start with the lowest level lops, this time by level instead of no outputs - int maxLevel = remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1); - List<Lop> lowestNodes = remaining.stream().filter(l -> l.getLevel() == maxLevel).collect(Collectors.toList()); - sortRecursive(nodes, lowestNodes, remaining); - } - - // All lops were added bottom up, from highest to lowest memory consumption, now reverse this - Collections.reverse(nodes); - - return nodes; - } - - private static void sortRecursive(List<Lop> result, List<Lop> input, List<Lop> remaining) { - // Sort input lops by memory estimate - // Lowest level nodes (those with no outputs) receive a memory estimate of 0 to preserve order - // This affects prints, writes, ... - List<Map.Entry<Lop, Long>> memEst = input.stream().distinct().map(l -> new AbstractMap.SimpleEntry<>(l, - l.getOutputs().isEmpty() ? 0 : OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(), - l.getOutputParameters().getNumCols(), l.getOutputParameters().getNnz()))) - .sorted(Comparator.comparing(e -> ((Map.Entry<Lop, Long>) e).getValue())).collect(Collectors.toList()); - - // Start with the highest memory estimate because the entire list is reversed later - Collections.reverse(memEst); - for(Map.Entry<Lop, Long> e : memEst) { - // Skip if the node is already in the result list - // Skip if one of the lop's outputs is not in the result list yet (will be added once the output lop is - // traversed), but only if any of the output lops is bound to be added to the result at a later stage - if(result.contains(e.getKey()) || (!result.containsAll(e.getKey().getOutputs()) && - remaining.stream().anyMatch(l -> e.getKey().getOutputs().contains(l)))) - continue; - result.add(e.getKey()); - remaining.remove(e.getKey()); - // Add input lops recursively - sortRecursive(result, e.getKey().getInputs(), remaining); - } - } +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Comparator; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; +public class LinearizerMaxParallelism extends IDagLinearizer +{ // Place the Spark operation chains first (more expensive to less expensive), // followed by asynchronously triggering operators and CP chains. - private static List<Lop> doMaxParallelizeSort(List<Lop> v) - { + @Override + public List<Lop> linearize(List<Lop> v) { List<Lop> v2 = v; - boolean hasSpark = v.stream().anyMatch(ILinearize::isDistributedOp); - boolean hasGPU = v.stream().anyMatch(ILinearize::isGPUOp); + boolean hasSpark = v.stream().anyMatch(LinearizerMaxParallelism::isDistributedOp); + boolean hasGPU = v.stream().anyMatch(LinearizerMaxParallelism::isGPUOp); // Fallback to default depth-first if all operators are CP if (!hasSpark && !hasGPU) - return depthFirst(v); + return new LinearizerDepthFirst().linearize(v); if (hasSpark) { // Step 1: Collect the Spark roots and #Spark instructions in each subDAG @@ -236,6 +106,7 @@ public class ILinearize { } return v2; } + // Place the operators in a depth-first manner, but order // the DAGs based on number of Spark operators diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java new file mode 100644 index 0000000000..4daf435d24 --- /dev/null +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java @@ -0,0 +1,92 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.lops.compile.linearization; + +import org.apache.sysds.hops.OptimizerUtils; +import org.apache.sysds.lops.Lop; + +import java.util.AbstractMap; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Comparator; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.stream.Collectors; + +/** + * Sort lops to execute them in an order that + * minimizes the memory requirements of intermediates + */ +public class LinearizerMinIntermediates extends IDagLinearizer +{ + @Override + public List<Lop> linearize(List<Lop> v) { + List<Lop> nodes = new ArrayList<>(v.size()); + // Get the lowest level in the tree to move upwards from + List<Lop> lowestLevel = v.stream().filter(l -> l.getOutputs().isEmpty()).collect(Collectors.toList()); + + // Traverse the tree bottom up, choose nodes with higher memory requirements, then reverse the list + List<Lop> remaining = new LinkedList<>(v); + sortRecursive(nodes, lowestLevel, remaining); + + // In some cases (function calls) some output lops are not in the list of nodes to be sorted. + // With the next layer up having output lops, they are not added to the initial list of lops and are + // subsequently never reached by the recursive sort. + // We work around this issue by checking for remaining lops after the initial sort. + while(!remaining.isEmpty()) { + // Start with the lowest level lops, this time by level instead of no outputs + int maxLevel = remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1); + List<Lop> lowestNodes = remaining.stream().filter(l -> l.getLevel() == maxLevel).collect(Collectors.toList()); + sortRecursive(nodes, lowestNodes, remaining); + } + + // All lops were added bottom up, from highest to lowest memory consumption, now reverse this + Collections.reverse(nodes); + + return nodes; + } + + private static void sortRecursive(List<Lop> result, List<Lop> input, List<Lop> remaining) { + // Sort input lops by memory estimate + // Lowest level nodes (those with no outputs) receive a memory estimate of 0 to preserve order + // This affects prints, writes, ... + List<Entry<Lop, Long>> memEst = input.stream().distinct().map(l -> new AbstractMap.SimpleEntry<>(l, + l.getOutputs().isEmpty() ? 0 : OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(), + l.getOutputParameters().getNumCols(), l.getOutputParameters().getNnz()))) + .sorted(Comparator.comparing(e -> ((Map.Entry<Lop, Long>) e).getValue())).collect(Collectors.toList()); + + // Start with the highest memory estimate because the entire list is reversed later + Collections.reverse(memEst); + for(Map.Entry<Lop, Long> e : memEst) { + // Skip if the node is already in the result list + // Skip if one of the lop's outputs is not in the result list yet (will be added once the output lop is + // traversed), but only if any of the output lops is bound to be added to the result at a later stage + if(result.contains(e.getKey()) || (!result.containsAll(e.getKey().getOutputs()) && + remaining.stream().anyMatch(l -> e.getKey().getOutputs().contains(l)))) + continue; + result.add(e.getKey()); + remaining.remove(e.getKey()); + // Add input lops recursively + sortRecursive(result, e.getKey().getInputs(), remaining); + } + } +} diff --git a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java similarity index 96% rename from src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java rename to src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java index dcc476a688..cebf14c6dc 100644 --- a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java +++ b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java @@ -28,8 +28,8 @@ import java.util.stream.Collectors; import org.apache.sysds.lops.Lop; import org.apache.sysds.lops.OperatorOrderingUtils; -public class PipelineAwareLinearize { - +public class LinearizerPipelineAware extends IDagLinearizer +{ // Minimum number of nodes in DAG for applying algorithm private final static int IGNORE_LIMIT = 0; @@ -45,12 +45,12 @@ public class PipelineAwareLinearize { * @param v List of lops to sort * @return Sorted list of lops with set _pipelineID on the Lop Object */ - public static List<Lop> pipelineDepthFirst(List<Lop> v) { - + @Override + public List<Lop> linearize(List<Lop> v) { // If size of DAG is smaller than IGNORE_LIMIT, give all nodes the same pipeline id if(v.size() <= IGNORE_LIMIT) { v.forEach(l -> l.setPipelineID(1)); - return ILinearize.depthFirst(v); + return new LinearizerDepthFirst().linearize(v); } // Find all root nodes (starting points for the depth-first traversal) @@ -74,11 +74,11 @@ public class PipelineAwareLinearize { //DEVPrintDAG.asGraphviz("Step1", v); // Step 2: Merge pipelines with only one node to another (connected) pipeline - PipelineAwareLinearize.mergeSingleNodePipelines(pipelineMap); + LinearizerPipelineAware.mergeSingleNodePipelines(pipelineMap); //DEVPrintDAG.asGraphviz("Step2", v); // Step 3: Merge small pipelines into bigger ones - PipelineAwareLinearize.mergeSmallPipelines(pipelineMap); + LinearizerPipelineAware.mergeSmallPipelines(pipelineMap); //DEVPrintDAG.asGraphviz("Step3", v); // Reset the visited status of all nodes diff --git a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java index 346102398c..835e0b45a2 100644 --- a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java +++ b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java @@ -31,122 +31,122 @@ import org.apache.sysds.test.TestConfiguration; import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; import org.apache.sysds.runtime.instructions.cp.Data; -import java.util.LinkedList; import java.util.Queue; +import java.util.LinkedList; import java.util.List; public class PinVariablesTest extends AutomatedTestBase { - private final static String TEST_NAME = "PinVariables"; - private final static String TEST_DIR = "functions/caching/"; - private final static String TEST_CLASS_DIR = TEST_DIR + PinVariablesTest.class.getSimpleName() + "/"; - - @Override - public void setUp() { - addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); - } - - @Test - public void testPinNoLists() { - createMockDataAndCall(true, false, false); - } - - @Test - public void testPinShallowLists() { - createMockDataAndCall(true, true, false); - } - - @Test - public void testPinNestedLists() { - createMockDataAndCall(true, true, true); - } - - private void createMockDataAndCall(boolean matrices, boolean list, boolean nestedList) { - LocalVariableMap vars = new LocalVariableMap(); - List<String> varList = new LinkedList<>(); - Queue<Boolean> varStates = new LinkedList<>(); - - if (matrices) { - MatrixObject mat1 = new MatrixObject(Types.ValueType.FP64, "SomeFile1"); - mat1.enableCleanup(true); - MatrixObject mat2 = new MatrixObject(Types.ValueType.FP64, "SomeFile2"); - mat2.enableCleanup(true); - MatrixObject mat3 = new MatrixObject(Types.ValueType.FP64, "SomeFile3"); - mat3.enableCleanup(false); - vars.put("mat1", mat1); - vars.put("mat2", mat2); - vars.put("mat3", mat3); - - varList.add("mat2"); - varList.add("mat3"); - - varStates.add(true); - varStates.add(false); - } - if (list) { - MatrixObject mat4 = new MatrixObject(Types.ValueType.FP64, "SomeFile4"); - mat4.enableCleanup(true); - MatrixObject mat5 = new MatrixObject(Types.ValueType.FP64, "SomeFile5"); - mat5.enableCleanup(false); - List<Data> l1_data = new LinkedList<>(); - l1_data.add(mat4); - l1_data.add(mat5); - - if (nestedList) { - MatrixObject mat6 = new MatrixObject(Types.ValueType.FP64, "SomeFile6"); - mat4.enableCleanup(true); - List<Data> l2_data = new LinkedList<>(); - l2_data.add(mat6); - ListObject l2 = new ListObject(l2_data); - l1_data.add(l2); - } - - ListObject l1 = new ListObject(l1_data); - vars.put("l1", l1); - - varList.add("l1"); - - // cleanup flag of inner matrix (m4) - varStates.add(true); - varStates.add(false); - if (nestedList) - varStates.add(true); - } - - ExecutionContext ec = new ExecutionContext(vars); - - commonPinVariablesTest(ec, varList, varStates); - } - - private void commonPinVariablesTest(ExecutionContext ec, List<String> varList, Queue<Boolean> varStatesExp) { - Queue<Boolean> varStates = ec.pinVariables(varList); - - // check returned cleanupEnabled flags - Assert.assertEquals(varStatesExp, varStates); - - // assert updated cleanupEnabled flag to false - for (String varName : varList) { - Data dat = ec.getVariable(varName); - - if (dat instanceof CacheableData<?>) - Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled()); - else if (dat instanceof ListObject) { - assertListFlagsDisabled((ListObject)dat); - } - } - - ec.unpinVariables(varList, varStates); - - // check returned flags after unpinVariables() - Queue<Boolean> varStates2 = ec.pinVariables(varList); - Assert.assertEquals(varStatesExp, varStates2); - } - - private void assertListFlagsDisabled(ListObject l) { - for (Data dat : l.getData()) { - if (dat instanceof CacheableData<?>) - Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled()); - else if (dat instanceof ListObject) - assertListFlagsDisabled((ListObject)dat); - } - } + private final static String TEST_NAME = "PinVariables"; + private final static String TEST_DIR = "functions/caching/"; + private final static String TEST_CLASS_DIR = TEST_DIR + PinVariablesTest.class.getSimpleName() + "/"; + + @Override + public void setUp() { + addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, TEST_NAME)); + } + + @Test + public void testPinNoLists() { + createMockDataAndCall(true, false, false); + } + + @Test + public void testPinShallowLists() { + createMockDataAndCall(true, true, false); + } + + @Test + public void testPinNestedLists() { + createMockDataAndCall(true, true, true); + } + + private void createMockDataAndCall(boolean matrices, boolean list, boolean nestedList) { + LocalVariableMap vars = new LocalVariableMap(); + List<String> varList = new LinkedList<>(); + Queue<Boolean> varStates = new LinkedList<>(); + + if (matrices) { + MatrixObject mat1 = new MatrixObject(Types.ValueType.FP64, "SomeFile1"); + mat1.enableCleanup(true); + MatrixObject mat2 = new MatrixObject(Types.ValueType.FP64, "SomeFile2"); + mat2.enableCleanup(true); + MatrixObject mat3 = new MatrixObject(Types.ValueType.FP64, "SomeFile3"); + mat3.enableCleanup(false); + vars.put("mat1", mat1); + vars.put("mat2", mat2); + vars.put("mat3", mat3); + + varList.add("mat2"); + varList.add("mat3"); + + varStates.add(true); + varStates.add(false); + } + if (list) { + MatrixObject mat4 = new MatrixObject(Types.ValueType.FP64, "SomeFile4"); + mat4.enableCleanup(true); + MatrixObject mat5 = new MatrixObject(Types.ValueType.FP64, "SomeFile5"); + mat5.enableCleanup(false); + List<Data> l1_data = new LinkedList<>(); + l1_data.add(mat4); + l1_data.add(mat5); + + if (nestedList) { + MatrixObject mat6 = new MatrixObject(Types.ValueType.FP64, "SomeFile6"); + mat4.enableCleanup(true); + List<Data> l2_data = new LinkedList<>(); + l2_data.add(mat6); + ListObject l2 = new ListObject(l2_data); + l1_data.add(l2); + } + + ListObject l1 = new ListObject(l1_data); + vars.put("l1", l1); + + varList.add("l1"); + + // cleanup flag of inner matrix (m4) + varStates.add(true); + varStates.add(false); + if (nestedList) + varStates.add(true); + } + + ExecutionContext ec = new ExecutionContext(vars); + + commonPinVariablesTest(ec, varList, varStates); + } + + private void commonPinVariablesTest(ExecutionContext ec, List<String> varList, Queue<Boolean> varStatesExp) { + Queue<Boolean> varStates = ec.pinVariables(varList); + + // check returned cleanupEnabled flags + Assert.assertEquals(varStatesExp, varStates); + + // assert updated cleanupEnabled flag to false + for (String varName : varList) { + Data dat = ec.getVariable(varName); + + if (dat instanceof CacheableData<?>) + Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled()); + else if (dat instanceof ListObject) { + assertListFlagsDisabled((ListObject)dat); + } + } + + ec.unpinVariables(varList, varStates); + + // check returned flags after unpinVariables() + Queue<Boolean> varStates2 = ec.pinVariables(varList); + Assert.assertEquals(varStatesExp, varStates2); + } + + private void assertListFlagsDisabled(ListObject l) { + for (Data dat : l.getData()) { + if (dat instanceof CacheableData<?>) + Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled()); + else if (dat instanceof ListObject) + assertListFlagsDisabled((ListObject)dat); + } + } } diff --git a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java index e6680c29c2..bf1fa0fcfe 100644 --- a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java +++ b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java @@ -36,7 +36,8 @@ import org.junit.Test; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.ValueType; -import org.apache.sysds.lops.compile.linearization.ILinearize; +import org.apache.sysds.lops.compile.linearization.IDagLinearizer; +import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory; import org.apache.sysds.lops.Lop; import org.apache.sysds.conf.ConfigurationManager; import org.apache.sysds.conf.DMLConfig; @@ -148,8 +149,9 @@ public class ILinearizeTest extends AutomatedTestBase { lops.forEach(l -> {l.getInputs().remove(d1); l.getInputs().remove(d2);}); // RUN LINEARIZATION - ILinearize.linearize(lops); - + IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer(); + dl.linearize(lops); //TODO results + // Set up expected pipelines Map<Integer, List<Lop>> pipelineMap = new HashMap<>(); pipelineMap.put(4, Arrays.asList(n1, n2, n3, n4, n5, o1, o2));