This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new f18d3e6993 [SYSTEMDS-3705] Refactoring and cleanup operator scheduling 
algorithms
f18d3e6993 is described below

commit f18d3e699383dcac7be4618c005b23b3c2c393d3
Author: Matthias Boehm <mboe...@gmail.com>
AuthorDate: Sat Jun 1 20:16:41 2024 +0200

    [SYSTEMDS-3705] Refactoring and cleanup operator scheduling algorithms
---
 .../apache/sysds/conf/ConfigurationManager.java    |  15 +-
 src/main/java/org/apache/sysds/conf/DMLConfig.java |   4 +-
 .../java/org/apache/sysds/lops/compile/Dag.java    |   6 +-
 .../lops/compile/linearization/IDagLinearizer.java |  39 ++++
 .../linearization/IDagLinearizerFactory.java       |  58 ++++++
 .../linearization/LinearizerBreadthFirst.java      |  36 ++++
 ...asedLinearize.java => LinearizerCostBased.java} |  12 +-
 .../linearization/LinearizerDepthFirst.java        |  47 +++++
 ...inearize.java => LinearizerMaxParallelism.java} | 163 ++-------------
 .../linearization/LinearizerMinIntermediates.java  |  92 +++++++++
 ...Linearize.java => LinearizerPipelineAware.java} |  14 +-
 .../test/functions/caching/PinVariablesTest.java   | 228 ++++++++++-----------
 .../functions/linearization/ILinearizeTest.java    |   8 +-
 13 files changed, 434 insertions(+), 288 deletions(-)

diff --git a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java 
b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
index 4c8fc9a60e..0d5ee888d8 100644
--- a/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
+++ b/src/main/java/org/apache/sysds/conf/ConfigurationManager.java
@@ -29,7 +29,7 @@ import org.apache.sysds.api.DMLScript;
 import org.apache.sysds.conf.CompilerConfig.ConfigType;
 import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.Compression.CompressConfig;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import 
org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
 import org.apache.sysds.runtime.util.CommonThreadPool;
 
@@ -272,7 +272,7 @@ public class ConfigurationManager{
        }
 
        public static boolean isMaxPrallelizeEnabled() {
-               return (getLinearizationOrder() == 
ILinearize.DagLinearization.MAX_PARALLELIZE
+               return (getLinearizationOrder() == DagLinearizer.MAX_PARALLELIZE
                        || OptimizerUtils.MAX_PARALLELIZE_ORDER);
        }
 
@@ -303,15 +303,14 @@ public class ConfigurationManager{
                return OptimizerUtils.AUTO_GPU_CACHE_EVICTION;
        }
 
-       public static ILinearize.DagLinearization getLinearizationOrder() {
+       public static DagLinearizer getLinearizationOrder() {
                if (OptimizerUtils.COST_BASED_ORDERING)
-                       return ILinearize.DagLinearization.AUTO;
+                       return DagLinearizer.AUTO;
                else if (OptimizerUtils.MAX_PARALLELIZE_ORDER)
-                       return ILinearize.DagLinearization.MAX_PARALLELIZE;
+                       return DagLinearizer.MAX_PARALLELIZE;
                else
-                       return ILinearize.DagLinearization
-                       
.valueOf(ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
-
+                       return 
DagLinearizer.valueOf(ConfigurationManager.getDMLConfig()
+                               
.getTextValue(DMLConfig.DAG_LINEARIZATION).toUpperCase());
        }
 
        ///////////////////////////////////////
diff --git a/src/main/java/org/apache/sysds/conf/DMLConfig.java 
b/src/main/java/org/apache/sysds/conf/DMLConfig.java
index c5ca98d3a8..dd4d3b2457 100644
--- a/src/main/java/org/apache/sysds/conf/DMLConfig.java
+++ b/src/main/java/org/apache/sysds/conf/DMLConfig.java
@@ -45,7 +45,7 @@ import 
org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI;
 import org.apache.sysds.hops.codegen.SpoofCompiler.PlanSelector;
 import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
 import org.apache.sysds.lops.Compression;
-import org.apache.sysds.lops.compile.linearization.ILinearize.DagLinearization;
+import 
org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory.DagLinearizer;
 import org.apache.sysds.parser.ParseException;
 import org.apache.sysds.runtime.DMLRuntimeException;
 import org.apache.sysds.runtime.io.IOUtilFunctions;
@@ -173,7 +173,7 @@ public class DMLConfig
                _defaultVals.put(COMPRESSED_COST_MODEL,  "AUTO");
                _defaultVals.put(COMPRESSED_TRANSPOSE,   "auto");
                _defaultVals.put(COMPRESSED_TRANSFORMENCODE, "false");
-               _defaultVals.put(DAG_LINEARIZATION,      
DagLinearization.DEPTH_FIRST.name());
+               _defaultVals.put(DAG_LINEARIZATION,      
DagLinearizer.DEPTH_FIRST.name());
                _defaultVals.put(CODEGEN,                "false" );
                _defaultVals.put(CODEGEN_API,            
GeneratorAPI.JAVA.name() );
                _defaultVals.put(CODEGEN_COMPILER,       
CompilerType.AUTO.name() );
diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java 
b/src/main/java/org/apache/sysds/lops/compile/Dag.java
index 05eaa24ebb..b26c539e9a 100644
--- a/src/main/java/org/apache/sysds/lops/compile/Dag.java
+++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java
@@ -43,7 +43,8 @@ import org.apache.sysds.lops.Lop.Type;
 import org.apache.sysds.lops.LopsException;
 import org.apache.sysds.lops.OutputParameters;
 import org.apache.sysds.lops.UnaryCP;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory;
 import org.apache.sysds.parser.DataExpression;
 import org.apache.sysds.parser.StatementBlock;
 import org.apache.sysds.runtime.DMLRuntimeException;
@@ -175,7 +176,8 @@ public class Dag<N extends Lop>
                        scratch = config.getTextValue(DMLConfig.SCRATCH_SPACE) 
+ "/";
                }
 
-               List<Lop> node_v = ILinearize.linearize(nodes);
+               IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer();
+               List<Lop> node_v = dl.linearize(nodes);
                prefetchFederated(node_v);
 
                // do greedy grouping of operations
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java
new file mode 100644
index 0000000000..1dc301c7e3
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizer.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import java.util.List;
+
+import org.apache.sysds.lops.Lop;
+
+/**
+ * An interface for the linearization algorithms that order the DAG nodes into 
a 
+ * sequence of instructions to execute.
+ */
+public abstract class IDagLinearizer {
+       /**
+        * Linearized a DAG of lops into a sequence of lops that preserves all
+        * data dependencies.
+        * 
+        * @param v roots (outputs) of a DAG of lops
+        * @return list of lops (input, inner, and outputs)
+        */
+       public abstract List<Lop> linearize(List<Lop> v);
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
new file mode 100644
index 0000000000..dcbd005b88
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/IDagLinearizerFactory.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.commons.logging.Log;
+import org.apache.commons.logging.LogFactory;
+
+import org.apache.sysds.conf.ConfigurationManager;
+
+public class IDagLinearizerFactory {
+       public static Log LOG = 
LogFactory.getLog(IDagLinearizerFactory.class.getName());
+       
+       public enum DagLinearizer {
+               DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, 
AUTO, PIPELINE_DEPTH_FIRST;
+       }
+
+       public static IDagLinearizer createDagLinearizer() {
+               DagLinearizer type = 
ConfigurationManager.getLinearizationOrder();
+               return createDagLinearizer(type);
+       }
+       
+       public static IDagLinearizer createDagLinearizer(DagLinearizer type) {
+               switch(type) {
+                       case AUTO:
+                               return new LinearizerCostBased();
+                       case BREADTH_FIRST:
+                               return new LinearizerBreadthFirst();
+                       case DEPTH_FIRST:
+                               return new LinearizerDepthFirst();
+                       case MAX_PARALLELIZE:
+                               return new LinearizerMaxParallelism();
+                       case MIN_INTERMEDIATE:
+                               return new LinearizerMinIntermediates();
+                       case PIPELINE_DEPTH_FIRST:
+                               return new LinearizerPipelineAware();
+                       default:
+                               LOG.warn("Invalid DAG_LINEARIZATION: "+type+", 
falling back to DEPTH_FIRST ordering");                  
+                               return new LinearizerDepthFirst();
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
new file mode 100644
index 0000000000..23d4150e78
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerBreadthFirst.java
@@ -0,0 +1,36 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.lops.Lop;
+
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+public class LinearizerBreadthFirst extends IDagLinearizer
+{
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
+               return v.stream()
+                       .sorted(Comparator.comparing(Lop::getLevel))
+                       .collect(Collectors.toList());
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
similarity index 95%
rename from 
src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
rename to 
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
index 707277630b..a036cf0dd3 100644
--- 
a/src/main/java/org/apache/sysds/lops/compile/linearization/CostBasedLinearize.java
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerCostBased.java
@@ -31,25 +31,25 @@ import java.util.List;
 import java.util.Stack;
 import java.util.stream.Collectors;
 
-public class CostBasedLinearize
+public class LinearizerCostBased extends IDagLinearizer
 {
-       public static List<Lop> getBestOrder(List<Lop> lops)
-       {
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
                // Simplify DAG by removing literals and transient inputs and 
outputs
                List<Lop> removedLeaves = new ArrayList<>();
                List<Lop> removedRoots = new ArrayList<>();
                HashMap<Long, ArrayList<Lop>> removedInputs = new HashMap<>();
                HashMap<Long, ArrayList<Lop>> removedOutputs = new HashMap<>();
-               simplifyDag(lops, removedLeaves, removedRoots, removedInputs, 
removedOutputs);
+               simplifyDag(v, removedLeaves, removedRoots, removedInputs, 
removedOutputs);
                // TODO: Partition the DAG if connected by a single node. 
Optimize separately
 
                // Collect the leaf nodes of the simplified DAG
-               List<Lop> leafNodes = lops.stream().filter(l -> 
l.getInputs().isEmpty()).collect(Collectors.toList());
+               List<Lop> leafNodes = v.stream().filter(l -> 
l.getInputs().isEmpty()).collect(Collectors.toList());
 
                // For each leaf, find all possible orders starting from the 
given leaf
                List<Order> finalOrders = new ArrayList<>();
                for (Lop leaf : leafNodes)
-                       generateOrders(leaf, leafNodes, finalOrders, 
lops.size());
+                       generateOrders(leaf, leafNodes, finalOrders, v.size());
 
                // TODO: Handle distributed and GPU operators (0 compute cost, 
memory overhead on collect)
                // TODO: Asynchronous operators (max of compute costs, total 
operation memory overhead)
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
new file mode 100644
index 0000000000..c0a791eea5
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerDepthFirst.java
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.lops.Lop;
+
+import java.util.Comparator;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
+public class LinearizerDepthFirst extends IDagLinearizer
+{
+       // previously called doTopologicalSortTwoLevelOrder
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
+               // partition nodes into leaf/inner nodes and dag root nodes,
+               // + sort leaf/inner nodes by ID to force depth-first scheduling
+               // + append root nodes in order of their original definition
+               // (which also preserves the original order of prints)
+               List<Lop> nodes = Stream
+                       .concat(v.stream().filter(l -> 
!l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())),
+                               v.stream().filter(l -> 
l.getOutputs().isEmpty()))
+                       .collect(Collectors.toList());
+
+               // NOTE: in contrast to hadoop execution modes, we avoid 
computing the transitive
+               // closure here to ensure linear time complexity because its 
unnecessary for CP and Spark
+               return nodes;
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
similarity index 57%
rename from 
src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
rename to 
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
index 288017bf1b..92844ec5df 100644
--- a/src/main/java/org/apache/sysds/lops/compile/linearization/ILinearize.java
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMaxParallelism.java
@@ -19,26 +19,9 @@
 
 package org.apache.sysds.lops.compile.linearization;
 
-import java.util.AbstractMap;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.LinkedList;
-import java.util.List;
-import java.util.Map;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-import org.apache.commons.logging.Log;
-import org.apache.commons.logging.LogFactory;
-import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.common.Types.DataType;
-import org.apache.sysds.conf.ConfigurationManager;
+import org.apache.sysds.common.Types.ExecType;
 import org.apache.sysds.hops.AggBinaryOp.SparkAggType;
-import org.apache.sysds.hops.OptimizerUtils;
 import org.apache.sysds.lops.CSVReBlock;
 import org.apache.sysds.lops.CentralMoment;
 import org.apache.sysds.lops.Checkpoint;
@@ -57,141 +40,28 @@ import org.apache.sysds.lops.SpoofFused;
 import org.apache.sysds.lops.UAggOuterChain;
 import org.apache.sysds.lops.UnaryCP;
 
-/**
- * A interface for the linearization algorithms that order the DAG nodes into 
a sequence of instructions to execute.
- *
- * https://en.wikipedia.org/wiki/Linearizability#Linearization_points
- */
-public class ILinearize {
-       public static Log LOG = LogFactory.getLog(ILinearize.class.getName());
-
-       public enum DagLinearization {
-               DEPTH_FIRST, BREADTH_FIRST, MIN_INTERMEDIATE, MAX_PARALLELIZE, 
AUTO, PIPELINE_DEPTH_FIRST;
-       }
-
-       public static List<Lop> linearize(List<Lop> v) {
-               try {
-                       DagLinearization linearization = 
ConfigurationManager.getLinearizationOrder();
-
-                       switch(linearization) {
-                               case MAX_PARALLELIZE:
-                                       return doMaxParallelizeSort(v);
-                               case AUTO:
-                                       return 
CostBasedLinearize.getBestOrder(v);
-                               case MIN_INTERMEDIATE:
-                                       return doMinIntermediateSort(v);
-                               case BREADTH_FIRST:
-                                       return doBreadthFirstSort(v);
-                               case PIPELINE_DEPTH_FIRST:
-                                       return 
PipelineAwareLinearize.pipelineDepthFirst(v);
-                               case DEPTH_FIRST:
-                               default:
-                                       return depthFirst(v);
-                       }
-               }
-               catch(Exception e) {
-                       LOG.warn("Invalid DAG_LINEARIZATION 
"+ConfigurationManager.getLinearizationOrder()+", fallback to DEPTH_FIRST 
ordering");
-                       return depthFirst(v);
-               }
-       }
-
-       /**
-        * Sort lops depth-first
-        * 
-        * previously called doTopologicalSortTwoLevelOrder
-        * 
-        * @param v List of lops to sort
-        * @return Sorted list of lops
-        */
-       protected static List<Lop> depthFirst(List<Lop> v) {
-               // partition nodes into leaf/inner nodes and dag root nodes,
-               // + sort leaf/inner nodes by ID to force depth-first scheduling
-               // + append root nodes in order of their original definition
-               // (which also preserves the original order of prints)
-               List<Lop> nodes = Stream
-                       .concat(v.stream().filter(l -> 
!l.getOutputs().isEmpty()).sorted(Comparator.comparing(l -> l.getID())),
-                               v.stream().filter(l -> 
l.getOutputs().isEmpty()))
-                       .collect(Collectors.toList());
-
-               // NOTE: in contrast to hadoop execution modes, we avoid 
computing the transitive
-               // closure here to ensure linear time complexity because its 
unnecessary for CP and Spark
-               return nodes;
-       }
-
-       private static List<Lop> doBreadthFirstSort(List<Lop> v) {
-               List<Lop> nodes = 
v.stream().sorted(Comparator.comparing(Lop::getLevel)).collect(Collectors.toList());
-
-               return nodes;
-       }
-
-       /**
-        * Sort lops to execute them in an order that minimizes the memory 
requirements of intermediates
-        * 
-        * @param v List of lops to sort
-        * @return Sorted list of lops
-        */
-       private static List<Lop> doMinIntermediateSort(List<Lop> v) {
-               List<Lop> nodes = new ArrayList<>(v.size());
-               // Get the lowest level in the tree to move upwards from
-               List<Lop> lowestLevel = v.stream().filter(l -> 
l.getOutputs().isEmpty()).collect(Collectors.toList());
-
-               // Traverse the tree bottom up, choose nodes with higher memory 
requirements, then reverse the list
-               List<Lop> remaining = new LinkedList<>(v);
-               sortRecursive(nodes, lowestLevel, remaining);
-
-               // In some cases (function calls) some output lops are not in 
the list of nodes to be sorted.
-               // With the next layer up having output lops, they are not 
added to the initial list of lops and are
-               // subsequently never reached by the recursive sort.
-               // We work around this issue by checking for remaining lops 
after the initial sort.
-               while(!remaining.isEmpty()) {
-                       // Start with the lowest level lops, this time by level 
instead of no outputs
-                       int maxLevel = 
remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1);
-                       List<Lop> lowestNodes = remaining.stream().filter(l -> 
l.getLevel() == maxLevel).collect(Collectors.toList());
-                       sortRecursive(nodes, lowestNodes, remaining);
-               }
-
-               // All lops were added bottom up, from highest to lowest memory 
consumption, now reverse this
-               Collections.reverse(nodes);
-
-               return nodes;
-       }
-
-       private static void sortRecursive(List<Lop> result, List<Lop> input, 
List<Lop> remaining) {
-               // Sort input lops by memory estimate
-               // Lowest level nodes (those with no outputs) receive a memory 
estimate of 0 to preserve order
-               // This affects prints, writes, ...
-               List<Map.Entry<Lop, Long>> memEst = 
input.stream().distinct().map(l -> new AbstractMap.SimpleEntry<>(l,
-                       l.getOutputs().isEmpty() ? 0 : 
OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(),
-                               l.getOutputParameters().getNumCols(), 
l.getOutputParameters().getNnz())))
-                       .sorted(Comparator.comparing(e -> ((Map.Entry<Lop, 
Long>) e).getValue())).collect(Collectors.toList());
-
-               // Start with the highest memory estimate because the entire 
list is reversed later
-               Collections.reverse(memEst);
-               for(Map.Entry<Lop, Long> e : memEst) {
-                       // Skip if the node is already in the result list
-                       // Skip if one of the lop's outputs is not in the 
result list yet (will be added once the output lop is
-                       // traversed), but only if any of the output lops is 
bound to be added to the result at a later stage
-                       if(result.contains(e.getKey()) || 
(!result.containsAll(e.getKey().getOutputs()) &&
-                               remaining.stream().anyMatch(l -> 
e.getKey().getOutputs().contains(l))))
-                               continue;
-                       result.add(e.getKey());
-                       remaining.remove(e.getKey());
-                       // Add input lops recursively
-                       sortRecursive(result, e.getKey().getInputs(), 
remaining);
-               }
-       }
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
 
+public class LinearizerMaxParallelism extends IDagLinearizer
+{
        // Place the Spark operation chains first (more expensive to less 
expensive),
        // followed by asynchronously triggering operators and CP chains.
-       private static List<Lop> doMaxParallelizeSort(List<Lop> v)
-       {
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
                List<Lop> v2 = v;
-               boolean hasSpark = 
v.stream().anyMatch(ILinearize::isDistributedOp);
-               boolean hasGPU = v.stream().anyMatch(ILinearize::isGPUOp);
+               boolean hasSpark = 
v.stream().anyMatch(LinearizerMaxParallelism::isDistributedOp);
+               boolean hasGPU = 
v.stream().anyMatch(LinearizerMaxParallelism::isGPUOp);
 
                // Fallback to default depth-first if all operators are CP
                if (!hasSpark && !hasGPU)
-                       return depthFirst(v);
+                       return new LinearizerDepthFirst().linearize(v);
 
                if (hasSpark) {
                        // Step 1: Collect the Spark roots and #Spark 
instructions in each subDAG
@@ -236,6 +106,7 @@ public class ILinearize {
                }
                return v2;
        }
+       
 
        // Place the operators in a depth-first manner, but order
        // the DAGs based on number of Spark operators
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
new file mode 100644
index 0000000000..4daf435d24
--- /dev/null
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerMinIntermediates.java
@@ -0,0 +1,92 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.lops.compile.linearization;
+
+import org.apache.sysds.hops.OptimizerUtils;
+import org.apache.sysds.lops.Lop;
+
+import java.util.AbstractMap;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.stream.Collectors;
+
+/**
+ * Sort lops to execute them in an order that
+ * minimizes the memory requirements of intermediates
+ */
+public class LinearizerMinIntermediates extends IDagLinearizer
+{
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
+               List<Lop> nodes = new ArrayList<>(v.size());
+               // Get the lowest level in the tree to move upwards from
+               List<Lop> lowestLevel = v.stream().filter(l -> 
l.getOutputs().isEmpty()).collect(Collectors.toList());
+
+               // Traverse the tree bottom up, choose nodes with higher memory 
requirements, then reverse the list
+               List<Lop> remaining = new LinkedList<>(v);
+               sortRecursive(nodes, lowestLevel, remaining);
+
+               // In some cases (function calls) some output lops are not in 
the list of nodes to be sorted.
+               // With the next layer up having output lops, they are not 
added to the initial list of lops and are
+               // subsequently never reached by the recursive sort.
+               // We work around this issue by checking for remaining lops 
after the initial sort.
+               while(!remaining.isEmpty()) {
+                       // Start with the lowest level lops, this time by level 
instead of no outputs
+                       int maxLevel = 
remaining.stream().mapToInt(Lop::getLevel).max().orElse(-1);
+                       List<Lop> lowestNodes = remaining.stream().filter(l -> 
l.getLevel() == maxLevel).collect(Collectors.toList());
+                       sortRecursive(nodes, lowestNodes, remaining);
+               }
+
+               // All lops were added bottom up, from highest to lowest memory 
consumption, now reverse this
+               Collections.reverse(nodes);
+
+               return nodes;
+       }
+       
+       private static void sortRecursive(List<Lop> result, List<Lop> input, 
List<Lop> remaining) {
+               // Sort input lops by memory estimate
+               // Lowest level nodes (those with no outputs) receive a memory 
estimate of 0 to preserve order
+               // This affects prints, writes, ...
+               List<Entry<Lop, Long>> memEst = input.stream().distinct().map(l 
-> new AbstractMap.SimpleEntry<>(l,
+                       l.getOutputs().isEmpty() ? 0 : 
OptimizerUtils.estimateSizeExactSparsity(l.getOutputParameters().getNumRows(),
+                               l.getOutputParameters().getNumCols(), 
l.getOutputParameters().getNnz())))
+                       .sorted(Comparator.comparing(e -> ((Map.Entry<Lop, 
Long>) e).getValue())).collect(Collectors.toList());
+
+               // Start with the highest memory estimate because the entire 
list is reversed later
+               Collections.reverse(memEst);
+               for(Map.Entry<Lop, Long> e : memEst) {
+                       // Skip if the node is already in the result list
+                       // Skip if one of the lop's outputs is not in the 
result list yet (will be added once the output lop is
+                       // traversed), but only if any of the output lops is 
bound to be added to the result at a later stage
+                       if(result.contains(e.getKey()) || 
(!result.containsAll(e.getKey().getOutputs()) &&
+                               remaining.stream().anyMatch(l -> 
e.getKey().getOutputs().contains(l))))
+                               continue;
+                       result.add(e.getKey());
+                       remaining.remove(e.getKey());
+                       // Add input lops recursively
+                       sortRecursive(result, e.getKey().getInputs(), 
remaining);
+               }
+       }
+}
diff --git 
a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
similarity index 96%
rename from 
src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
rename to 
src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
index dcc476a688..cebf14c6dc 100644
--- 
a/src/main/java/org/apache/sysds/lops/compile/linearization/PipelineAwareLinearize.java
+++ 
b/src/main/java/org/apache/sysds/lops/compile/linearization/LinearizerPipelineAware.java
@@ -28,8 +28,8 @@ import java.util.stream.Collectors;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.lops.OperatorOrderingUtils;
 
-public class PipelineAwareLinearize {
-
+public class LinearizerPipelineAware extends IDagLinearizer
+{
        // Minimum number of nodes in DAG for applying algorithm
        private final static int IGNORE_LIMIT = 0;
 
@@ -45,12 +45,12 @@ public class PipelineAwareLinearize {
        * @param v List of lops to sort
        * @return Sorted list of lops with set _pipelineID on the Lop Object
        */
-       public static List<Lop> pipelineDepthFirst(List<Lop> v) {
-
+       @Override
+       public List<Lop> linearize(List<Lop> v) {
                // If size of DAG is smaller than IGNORE_LIMIT, give all nodes 
the same pipeline id
                if(v.size() <= IGNORE_LIMIT) {
                        v.forEach(l -> l.setPipelineID(1));
-                       return ILinearize.depthFirst(v);
+                       return new LinearizerDepthFirst().linearize(v);
                }
 
                // Find all root nodes (starting points for the depth-first 
traversal)
@@ -74,11 +74,11 @@ public class PipelineAwareLinearize {
                //DEVPrintDAG.asGraphviz("Step1", v);
 
                // Step 2: Merge pipelines with only one node to another 
(connected) pipeline
-               PipelineAwareLinearize.mergeSingleNodePipelines(pipelineMap);
+               LinearizerPipelineAware.mergeSingleNodePipelines(pipelineMap);
                //DEVPrintDAG.asGraphviz("Step2", v);
 
                // Step 3: Merge small pipelines into bigger ones
-               PipelineAwareLinearize.mergeSmallPipelines(pipelineMap);
+               LinearizerPipelineAware.mergeSmallPipelines(pipelineMap);
                //DEVPrintDAG.asGraphviz("Step3", v);
 
                // Reset the visited status of all nodes
diff --git 
a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java 
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
index 346102398c..835e0b45a2 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/caching/PinVariablesTest.java
@@ -31,122 +31,122 @@ import org.apache.sysds.test.TestConfiguration;
 import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
 import org.apache.sysds.runtime.instructions.cp.Data;
 
-import java.util.LinkedList;
 import java.util.Queue;
+import java.util.LinkedList;
 import java.util.List;
 
 public class PinVariablesTest extends AutomatedTestBase {
-    private final static String TEST_NAME = "PinVariables";
-    private final static String TEST_DIR = "functions/caching/";
-    private final static String TEST_CLASS_DIR = TEST_DIR + 
PinVariablesTest.class.getSimpleName() + "/";
-
-    @Override
-    public void setUp() {
-        addTestConfiguration(TEST_NAME, new TestConfiguration(TEST_CLASS_DIR, 
TEST_NAME));
-    }
-
-    @Test
-    public void testPinNoLists() {
-        createMockDataAndCall(true, false, false);
-    }
-
-    @Test
-    public void testPinShallowLists() {
-        createMockDataAndCall(true, true, false);
-    }
-
-    @Test
-    public void testPinNestedLists() {
-        createMockDataAndCall(true, true, true);
-    }
-
-    private void createMockDataAndCall(boolean matrices, boolean list, boolean 
nestedList) {
-        LocalVariableMap vars = new LocalVariableMap();
-        List<String> varList = new LinkedList<>();
-        Queue<Boolean> varStates = new LinkedList<>();
-
-        if (matrices) {
-            MatrixObject mat1 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile1");
-            mat1.enableCleanup(true);
-            MatrixObject mat2 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile2");
-            mat2.enableCleanup(true);
-            MatrixObject mat3 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile3");
-            mat3.enableCleanup(false);
-            vars.put("mat1", mat1);
-            vars.put("mat2", mat2);
-            vars.put("mat3", mat3);
-
-            varList.add("mat2");
-            varList.add("mat3");
-
-            varStates.add(true);
-            varStates.add(false);
-        }
-        if (list) {
-            MatrixObject mat4 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile4");
-            mat4.enableCleanup(true);
-            MatrixObject mat5 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile5");
-            mat5.enableCleanup(false);
-            List<Data> l1_data = new LinkedList<>();
-            l1_data.add(mat4);
-            l1_data.add(mat5);
-
-            if (nestedList) {
-                MatrixObject mat6 = new MatrixObject(Types.ValueType.FP64, 
"SomeFile6");
-                mat4.enableCleanup(true);
-                List<Data> l2_data = new LinkedList<>();
-                l2_data.add(mat6);
-                ListObject l2 = new ListObject(l2_data);
-                l1_data.add(l2);
-            }
-
-            ListObject l1 = new ListObject(l1_data);
-            vars.put("l1", l1);
-
-            varList.add("l1");
-
-            // cleanup flag of inner matrix (m4)
-            varStates.add(true);
-            varStates.add(false);
-            if (nestedList)
-                varStates.add(true);
-        }
-
-        ExecutionContext ec = new ExecutionContext(vars);
-
-        commonPinVariablesTest(ec, varList, varStates);
-    }
-
-    private void commonPinVariablesTest(ExecutionContext ec, List<String> 
varList, Queue<Boolean> varStatesExp) {
-        Queue<Boolean> varStates = ec.pinVariables(varList);
-
-        // check returned cleanupEnabled flags
-        Assert.assertEquals(varStatesExp, varStates);
-
-        // assert updated cleanupEnabled flag to false
-        for (String varName : varList) {
-            Data dat = ec.getVariable(varName);
-
-            if (dat instanceof CacheableData<?>)
-                Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
-            else if (dat instanceof ListObject) {
-                assertListFlagsDisabled((ListObject)dat);
-            }
-        }
-
-        ec.unpinVariables(varList, varStates);
-
-        // check returned flags after unpinVariables()
-        Queue<Boolean> varStates2 = ec.pinVariables(varList);
-        Assert.assertEquals(varStatesExp, varStates2);
-    }
-
-    private void assertListFlagsDisabled(ListObject l) {
-        for (Data dat : l.getData()) {
-            if (dat instanceof CacheableData<?>)
-                Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
-            else if (dat instanceof ListObject)
-                assertListFlagsDisabled((ListObject)dat);
-        }
-    }
+       private final static String TEST_NAME = "PinVariables";
+       private final static String TEST_DIR = "functions/caching/";
+       private final static String TEST_CLASS_DIR = TEST_DIR + 
PinVariablesTest.class.getSimpleName() + "/";
+
+       @Override
+       public void setUp() {
+               addTestConfiguration(TEST_NAME, new 
TestConfiguration(TEST_CLASS_DIR, TEST_NAME));
+       }
+
+       @Test
+       public void testPinNoLists() {
+               createMockDataAndCall(true, false, false);
+       }
+
+       @Test
+       public void testPinShallowLists() {
+               createMockDataAndCall(true, true, false);
+       }
+
+       @Test
+       public void testPinNestedLists() {
+               createMockDataAndCall(true, true, true);
+       }
+
+       private void createMockDataAndCall(boolean matrices, boolean list, 
boolean nestedList) {
+               LocalVariableMap vars = new LocalVariableMap();
+               List<String> varList = new LinkedList<>();
+               Queue<Boolean> varStates = new LinkedList<>();
+
+               if (matrices) {
+                       MatrixObject mat1 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile1");
+                       mat1.enableCleanup(true);
+                       MatrixObject mat2 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile2");
+                       mat2.enableCleanup(true);
+                       MatrixObject mat3 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile3");
+                       mat3.enableCleanup(false);
+                       vars.put("mat1", mat1);
+                       vars.put("mat2", mat2);
+                       vars.put("mat3", mat3);
+
+                       varList.add("mat2");
+                       varList.add("mat3");
+
+                       varStates.add(true);
+                       varStates.add(false);
+               }
+               if (list) {
+                       MatrixObject mat4 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile4");
+                       mat4.enableCleanup(true);
+                       MatrixObject mat5 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile5");
+                       mat5.enableCleanup(false);
+                       List<Data> l1_data = new LinkedList<>();
+                       l1_data.add(mat4);
+                       l1_data.add(mat5);
+
+                       if (nestedList) {
+                               MatrixObject mat6 = new 
MatrixObject(Types.ValueType.FP64, "SomeFile6");
+                               mat4.enableCleanup(true);
+                               List<Data> l2_data = new LinkedList<>();
+                               l2_data.add(mat6);
+                               ListObject l2 = new ListObject(l2_data);
+                               l1_data.add(l2);
+                       }
+
+                       ListObject l1 = new ListObject(l1_data);
+                       vars.put("l1", l1);
+
+                       varList.add("l1");
+
+                       // cleanup flag of inner matrix (m4)
+                       varStates.add(true);
+                       varStates.add(false);
+                       if (nestedList)
+                               varStates.add(true);
+               }
+
+               ExecutionContext ec = new ExecutionContext(vars);
+
+               commonPinVariablesTest(ec, varList, varStates);
+       }
+
+       private void commonPinVariablesTest(ExecutionContext ec, List<String> 
varList, Queue<Boolean> varStatesExp) {
+               Queue<Boolean> varStates = ec.pinVariables(varList);
+
+               // check returned cleanupEnabled flags
+               Assert.assertEquals(varStatesExp, varStates);
+
+               // assert updated cleanupEnabled flag to false
+               for (String varName : varList) {
+                       Data dat = ec.getVariable(varName);
+
+                       if (dat instanceof CacheableData<?>)
+                               
Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+                       else if (dat instanceof ListObject) {
+                               assertListFlagsDisabled((ListObject)dat);
+                       }
+               }
+
+               ec.unpinVariables(varList, varStates);
+
+               // check returned flags after unpinVariables()
+               Queue<Boolean> varStates2 = ec.pinVariables(varList);
+               Assert.assertEquals(varStatesExp, varStates2);
+       }
+
+       private void assertListFlagsDisabled(ListObject l) {
+               for (Data dat : l.getData()) {
+                       if (dat instanceof CacheableData<?>)
+                               
Assert.assertFalse(((CacheableData<?>)dat).isCleanupEnabled());
+                       else if (dat instanceof ListObject)
+                               assertListFlagsDisabled((ListObject)dat);
+               }
+       }
 }
diff --git 
a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
 
b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
index e6680c29c2..bf1fa0fcfe 100644
--- 
a/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
+++ 
b/src/test/java/org/apache/sysds/test/functions/linearization/ILinearizeTest.java
@@ -36,7 +36,8 @@ import org.junit.Test;
 import org.apache.sysds.common.Types.DataType;
 import org.apache.sysds.common.Types.OpOp2;
 import org.apache.sysds.common.Types.ValueType;
-import org.apache.sysds.lops.compile.linearization.ILinearize;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizer;
+import org.apache.sysds.lops.compile.linearization.IDagLinearizerFactory;
 import org.apache.sysds.lops.Lop;
 import org.apache.sysds.conf.ConfigurationManager;
 import org.apache.sysds.conf.DMLConfig;
@@ -148,8 +149,9 @@ public class ILinearizeTest extends AutomatedTestBase {
                lops.forEach(l -> {l.getInputs().remove(d1); 
l.getInputs().remove(d2);});
 
                // RUN LINEARIZATION
-               ILinearize.linearize(lops);
-
+               IDagLinearizer dl = IDagLinearizerFactory.createDagLinearizer();
+               dl.linearize(lops); //TODO results
+               
                // Set up expected pipelines
                Map<Integer, List<Lop>> pipelineMap = new HashMap<>();
                pipelineMap.put(4, Arrays.asList(n1, n2, n3, n4, n5, o1, o2));


Reply via email to