This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 6fd96bbda2 [MINOR] Cleanup code quality inter-procedural analysis /
recompiler
6fd96bbda2 is described below
commit 6fd96bbda21e7a556fd5f17b232454ac1c41068e
Author: Matthias Boehm <[email protected]>
AuthorDate: Fri Mar 22 19:53:37 2024 +0100
[MINOR] Cleanup code quality inter-procedural analysis / recompiler
---
src/main/java/org/apache/sysds/hops/Hop.java | 16 +++++----
.../java/org/apache/sysds/hops/OptimizerUtils.java | 18 +++++-----
src/main/java/org/apache/sysds/hops/ReorgOp.java | 4 +--
.../apache/sysds/hops/ipa/FunctionCallGraph.java | 33 +++++++++---------
.../sysds/hops/ipa/FunctionCallSizeInfo.java | 4 +--
.../sysds/hops/ipa/IPAPassFlagNonDeterminism.java | 13 ++++----
.../hops/ipa/IPAPassForwardFunctionCalls.java | 13 +++++---
.../sysds/hops/ipa/IPAPassInlineFunctions.java | 18 +++++-----
.../hops/ipa/IPAPassPropagateReplaceLiterals.java | 5 ++-
.../hops/ipa/IPAPassRemoveConstantBinaryOps.java | 11 +++---
.../ipa/IPAPassRemoveUnnecessaryCheckpoints.java | 17 +++++-----
.../sysds/hops/ipa/InterProceduralAnalysis.java | 18 +++++-----
.../sysds/hops/recompile/LiteralReplacement.java | 5 +--
.../sysds/hops/recompile/RecompileStatus.java | 5 +--
.../apache/sysds/hops/recompile/Recompiler.java | 39 ++++++++++++----------
.../apache/sysds/hops/rewrite/HopRewriteUtils.java | 15 ++++-----
16 files changed, 122 insertions(+), 112 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 265ba672e9..7b2ff60253 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -23,6 +23,8 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -954,7 +956,7 @@ public abstract class Hop implements ParseInfo {
h._parent.add(this);
}
- public void addAllInputs( ArrayList<Hop> list ) {
+ public void addAllInputs( List<Hop> list ) {
for( Hop h : list )
addInput(h);
}
@@ -1130,13 +1132,13 @@ public abstract class Hop implements ParseInfo {
return _dataType.isScalar() || _dc.colsKnown();
}
- public static void resetVisitStatus( ArrayList<Hop> hops ) {
+ public static void resetVisitStatus( List<Hop> hops ) {
if( hops != null )
for( Hop hopRoot : hops )
hopRoot.resetVisitStatus();
}
- public static void resetVisitStatus( ArrayList<Hop> hops, boolean force
) {
+ public static void resetVisitStatus( List<Hop> hops, boolean force ) {
if( !force )
resetVisitStatus(hops);
else {
@@ -1413,7 +1415,7 @@ public abstract class Hop implements ParseInfo {
setDim1(computeSizeInformation(input, vars));
}
- public void refreshRowsParameterInformation( Hop input,
LocalVariableMap vars, HashMap<Long,Long> memo ) {
+ public void refreshRowsParameterInformation( Hop input,
LocalVariableMap vars, Map<Long,Long> memo ) {
setDim1(computeSizeInformation(input, vars, memo));
}
@@ -1421,7 +1423,7 @@ public abstract class Hop implements ParseInfo {
setDim2(computeSizeInformation(input, vars));
}
- public void refreshColsParameterInformation( Hop input,
LocalVariableMap vars, HashMap<Long,Long> memo ) {
+ public void refreshColsParameterInformation( Hop input,
LocalVariableMap vars, Map<Long,Long> memo ) {
setDim2(computeSizeInformation(input, vars, memo));
}
@@ -1429,7 +1431,7 @@ public abstract class Hop implements ParseInfo {
return computeSizeInformation(input, vars, new
HashMap<Long,Long>());
}
- public long computeSizeInformation( Hop input, LocalVariableMap vars,
HashMap<Long,Long> memo )
+ public long computeSizeInformation( Hop input, LocalVariableMap vars,
Map<Long,Long> memo )
{
long ret = -1;
try {
@@ -1460,7 +1462,7 @@ public abstract class Hop implements ParseInfo {
return computeBoundsInformation(input, vars, new HashMap<Long,
Double>());
}
- public static double computeBoundsInformation( Hop input,
LocalVariableMap vars, HashMap<Long, Double> memo ) {
+ public static double computeBoundsInformation( Hop input,
LocalVariableMap vars, Map<Long, Double> memo ) {
double ret = Double.MAX_VALUE;
try {
ret = OptimizerUtils.rEvalSimpleDoubleExpression(input,
memo, vars);
diff --git a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
index 8953cba378..0a37570ee8 100644
--- a/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
+++ b/src/main/java/org/apache/sysds/hops/OptimizerUtils.java
@@ -1457,7 +1457,7 @@ public class OptimizerUtils
* @param valMemo ?
* @return size expression
*/
- public static long rEvalSimpleLongExpression( Hop root, HashMap<Long,
Long> valMemo )
+ public static long rEvalSimpleLongExpression( Hop root, Map<Long, Long>
valMemo )
{
long ret = Long.MAX_VALUE;
@@ -1470,7 +1470,7 @@ public class OptimizerUtils
return ret;
}
- public static long rEvalSimpleLongExpression( Hop root, HashMap<Long,
Long> valMemo, LocalVariableMap vars )
+ public static long rEvalSimpleLongExpression( Hop root, Map<Long, Long>
valMemo, LocalVariableMap vars )
{
long ret = Long.MAX_VALUE;
@@ -1483,7 +1483,7 @@ public class OptimizerUtils
return ret;
}
- public static double rEvalSimpleDoubleExpression( Hop root,
HashMap<Long, Double> valMemo )
+ public static double rEvalSimpleDoubleExpression( Hop root, Map<Long,
Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
@@ -1510,7 +1510,7 @@ public class OptimizerUtils
return ret;
}
- public static double rEvalSimpleDoubleExpression( Hop root,
HashMap<Long, Double> valMemo, LocalVariableMap vars )
+ public static double rEvalSimpleDoubleExpression( Hop root, Map<Long,
Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
@@ -1538,7 +1538,7 @@ public class OptimizerUtils
return ret;
}
- protected static double rEvalSimpleUnaryDoubleExpression( Hop root,
HashMap<Long, Double> valMemo )
+ protected static double rEvalSimpleUnaryDoubleExpression( Hop root,
Map<Long, Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
@@ -1576,7 +1576,7 @@ public class OptimizerUtils
return ret;
}
- protected static double rEvalSimpleUnaryDoubleExpression( Hop root,
HashMap<Long, Double> valMemo, LocalVariableMap vars )
+ protected static double rEvalSimpleUnaryDoubleExpression( Hop root,
Map<Long, Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
@@ -1614,7 +1614,7 @@ public class OptimizerUtils
return ret;
}
- protected static double rEvalSimpleBinaryDoubleExpression( Hop root,
HashMap<Long, Double> valMemo )
+ protected static double rEvalSimpleBinaryDoubleExpression( Hop root,
Map<Long, Double> valMemo )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
@@ -1649,7 +1649,7 @@ public class OptimizerUtils
return ret;
}
- protected static double rEvalSimpleTernaryDoubleExpression( Hop root,
HashMap<Long, Double> valMemo ) {
+ protected static double rEvalSimpleTernaryDoubleExpression( Hop root,
Map<Long, Double> valMemo ) {
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
return valMemo.get(root.getHopID());
@@ -1666,7 +1666,7 @@ public class OptimizerUtils
return ret;
}
- protected static double rEvalSimpleBinaryDoubleExpression( Hop root,
HashMap<Long, Double> valMemo, LocalVariableMap vars )
+ protected static double rEvalSimpleBinaryDoubleExpression( Hop root,
Map<Long, Double> valMemo, LocalVariableMap vars )
{
//memoization (prevent redundant computation of common subexpr)
if( valMemo.containsKey(root.getHopID()) )
diff --git a/src/main/java/org/apache/sysds/hops/ReorgOp.java
b/src/main/java/org/apache/sysds/hops/ReorgOp.java
index 057bdac782..0dce06964c 100644
--- a/src/main/java/org/apache/sysds/hops/ReorgOp.java
+++ b/src/main/java/org/apache/sysds/hops/ReorgOp.java
@@ -30,7 +30,7 @@ import org.apache.sysds.lops.Transform;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.meta.MatrixCharacteristics;
-import java.util.ArrayList;
+import java.util.List;
/**
* Reorg (cell) operation: aij
@@ -66,7 +66,7 @@ public class ReorgOp extends MultiThreadedHop
refreshSizeInformation();
}
- public ReorgOp(String l, DataType dt, ValueType vt, ReOrgOp o,
ArrayList<Hop> inp)
+ public ReorgOp(String l, DataType dt, ValueType vt, ReOrgOp o,
List<Hop> inp)
{
super(l, dt, vt);
_op = o;
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
index 690eb19f69..feeafe83e1 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallGraph.java
@@ -24,6 +24,7 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.Stack;
@@ -58,18 +59,18 @@ public class FunctionCallGraph
//unrolled function call graph, in call direction
//(mapping from function keys to called function keys)
- private final HashMap<String, HashSet<String>> _fGraph;
+ private final Map<String, Set<String>> _fGraph;
//program-wide function call operators per target function
//(mapping from function keys to set of its function calls)
- private final HashMap<String, ArrayList<FunctionOp>> _fCalls;
- private final HashMap<String, ArrayList<StatementBlock>> _fCallsSB;
+ private final Map<String, List<FunctionOp>> _fCalls;
+ private final Map<String, List<StatementBlock>> _fCallsSB;
//subset of direct or indirect recursive functions
- private final HashSet<String> _fRecursive;
+ private final Set<String> _fRecursive;
//subset of side-effect-free functions
- private final HashSet<String> _fSideEffectFree;
+ private final Set<String> _fSideEffectFree;
// a boolean value to indicate if exists the second order function
(e.g. eval, paramserv)
// and the UDFs that are marked secondorder="true"
@@ -168,7 +169,7 @@ public class FunctionCallGraph
_fCallsSB.remove(fkey);
_fRecursive.remove(fkey);
_fGraph.remove(fkey);
- for( Entry<String, HashSet<String>> e : _fGraph.entrySet() )
+ for( Entry<String, Set<String>> e : _fGraph.entrySet() )
e.getValue().removeIf(s -> s.equals(fkey));
}
@@ -195,8 +196,8 @@ public class FunctionCallGraph
* @param fkey new function key of called function
*/
public void replaceFunctionCalls(String fkeyOld, String fkey) {
- ArrayList<FunctionOp> fopTmp = _fCalls.get(fkeyOld);
- ArrayList<StatementBlock> sbTmp =_fCallsSB.get(fkeyOld);
+ List<FunctionOp> fopTmp = _fCalls.get(fkeyOld);
+ List<StatementBlock> sbTmp =_fCallsSB.get(fkeyOld);
_fCalls.remove(fkeyOld);
_fCallsSB.remove(fkeyOld);
_fCalls.put(fkey, fopTmp);
@@ -205,7 +206,7 @@ public class FunctionCallGraph
_fRecursive.remove(fkeyOld);
_fSideEffectFree.remove(fkeyOld);
_fGraph.remove(fkeyOld);
- for( HashSet<String> hs : _fGraph.values() )
+ for( Set<String> hs : _fGraph.values() )
hs.remove(fkeyOld);
}
@@ -350,7 +351,7 @@ public class FunctionCallGraph
try {
//construct the main function call graph
Stack<String> fstack = new Stack<>();
- HashSet<String> lfset = new HashSet<>();
+ Set<String> lfset = new HashSet<>();
_fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
for( StatementBlock sblk : prog.getStatementBlocks() )
ret |=
rConstructFunctionCallGraph(MAIN_FUNCTION_KEY, sblk, fstack, lfset);
@@ -373,7 +374,7 @@ public class FunctionCallGraph
try {
Stack<String> fstack = new Stack<>();
- HashSet<String> lfset = new HashSet<>();
+ Set<String> lfset = new HashSet<>();
_fGraph.put(MAIN_FUNCTION_KEY, new HashSet<String>());
return rConstructFunctionCallGraph(MAIN_FUNCTION_KEY,
sb, fstack, lfset);
}
@@ -382,7 +383,7 @@ public class FunctionCallGraph
}
}
- private boolean rConstructFunctionCallGraph(String fkey, StatementBlock
sb, Stack<String> fstack, HashSet<String> lfset) {
+ private boolean rConstructFunctionCallGraph(String fkey, StatementBlock
sb, Stack<String> fstack, Set<String> lfset) {
boolean ret = false;
if (sb instanceof WhileStatementBlock) {
WhileStatement ws = (WhileStatement)sb.getStatement(0);
@@ -408,7 +409,7 @@ public class FunctionCallGraph
}
else {
// For generic StatementBlock
- ArrayList<Hop> hopsDAG = sb.getHops();
+ List<Hop> hopsDAG = sb.getHops();
if( hopsDAG == null || hopsDAG.isEmpty() )
return false; //nothing to do
@@ -428,7 +429,7 @@ public class FunctionCallGraph
return ret;
}
- private boolean rConstructFunctionCallGraph(Hop hop, String fkey,
StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
+ private boolean rConstructFunctionCallGraph(Hop hop, String fkey,
StatementBlock sb, Stack<String> fstack, Set<String> lfset) {
boolean ret = false;
if( hop.isVisited() )
return ret;
@@ -452,7 +453,7 @@ public class FunctionCallGraph
return ret;
}
- private boolean addFunctionOpToGraph(FunctionOp fop, String fkey,
StatementBlock sb, Stack<String> fstack, HashSet<String> lfset) {
+ private boolean addFunctionOpToGraph(FunctionOp fop, String fkey,
StatementBlock sb, Stack<String> fstack, Set<String> lfset) {
try{
boolean ret = false;
String lfkey = fop.getFunctionKey();
@@ -523,7 +524,7 @@ public class FunctionCallGraph
}
else {
// For generic StatementBlock
- ArrayList<Hop> hopsDAG = sb.getHops();
+ List<Hop> hopsDAG = sb.getHops();
if( hopsDAG == null || hopsDAG.isEmpty() )
return false; //nothing to do
//function ops can only occur as root nodes of the dag
diff --git a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
index 551ce987ab..11a9ded1c1 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/FunctionCallSizeInfo.java
@@ -252,7 +252,7 @@ public class FunctionCallSizeInfo
if( flist == null || flist.isEmpty() ) //robustness
removed functions
continue;
FunctionOp first = flist.get(0);
- HashSet<Integer> tmp = new HashSet<>();
+ Set<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ ) {
//if nnz known it is safe to propagate those
nnz because for multiple calls
//we checked of equivalence and hence all calls
have the same nnz
@@ -271,7 +271,7 @@ public class FunctionCallSizeInfo
continue;
FunctionOp first = flist.get(0);
//initialize w/ all literals of first call
- HashSet<Integer> tmp = new HashSet<>();
+ Set<Integer> tmp = new HashSet<>();
for( int j=0; j<first.getInput().size(); j++ )
if( first.getInput().get(j) instanceof
LiteralOp )
tmp.add(j);
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
index 9dbe81447e..7a20393525 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassFlagNonDeterminism.java
@@ -19,9 +19,10 @@
package org.apache.sysds.hops.ipa;
-import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
+import java.util.List;
+import java.util.Set;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.FunctionOp;
@@ -56,7 +57,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
try {
// Find the individual functions and statementblocks
with non-determinism.
- HashSet<String> ndfncs = new HashSet<>();
+ Set<String> ndfncs = new HashSet<>();
for (String fkey : fgraph.getReachableFunctions()) {
FunctionStatementBlock fsblock =
prog.getFunctionStatementBlock(fkey);
FunctionStatement fnstmt =
(FunctionStatement)fsblock.getStatement(0);
@@ -88,7 +89,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
return false;
}
- private boolean rIsNonDeterministicFnc (String fname,
ArrayList<StatementBlock> sbs)
+ private boolean rIsNonDeterministicFnc (String fname,
List<StatementBlock> sbs)
{
boolean isND = false;
for (StatementBlock sb : sbs)
@@ -124,7 +125,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
return isND;
}
- private void rMarkNondeterministicSBs (ArrayList<StatementBlock> sbs,
HashSet<String> ndfncs)
+ private void rMarkNondeterministicSBs (List<StatementBlock> sbs,
Set<String> ndfncs)
{
for (StatementBlock sb : sbs)
{
@@ -156,7 +157,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
}
}
- private boolean rMarkNondeterministicHop(Hop hop, HashSet<String>
ndfncs) {
+ private boolean rMarkNondeterministicHop(Hop hop, Set<String> ndfncs) {
if (hop.isVisited())
return false;
@@ -182,7 +183,7 @@ public class IPAPassFlagNonDeterminism extends IPAPass {
return isND;
}
- private void propagate2Callers (FunctionCallGraph fgraph,
HashSet<String> ndfncs, HashSet<String> fstack, String fkey) {
+ private void propagate2Callers (FunctionCallGraph fgraph, Set<String>
ndfncs, Set<String> fstack, String fkey) {
Collection<String> cfkeys = fgraph.getCalledFunctions(fkey);
if (cfkeys != null) {
for (String cfkey : cfkeys) {
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
index e35f2f50d6..b0f22e0578 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassForwardFunctionCalls.java
@@ -23,6 +23,9 @@ import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
import java.util.stream.IntStream;
import org.apache.sysds.common.Types.OpOpData;
@@ -90,7 +93,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
return false;
}
- private static boolean singleFunctionOp(ArrayList<Hop> hops) {
+ private static boolean singleFunctionOp(List<Hop> hops) {
if( hops==null || hops.isEmpty() || hops.size()!=1 )
return false;
return hops.get(0) instanceof FunctionOp;
@@ -114,7 +117,7 @@ public class IPAPassForwardFunctionCalls extends IPAPass
private static boolean isFirstSubsetOfSecond(String[] first, String[]
second) {
//build phase: second
- HashSet<String> probe = new HashSet<>();
+ Set<String> probe = new HashSet<>();
for( String s : second )
probe.add(s);
//probe phase: first
@@ -123,13 +126,13 @@ public class IPAPassForwardFunctionCalls extends IPAPass
private static void reconcileFunctionInputsInPlace(FunctionOp call1,
FunctionOp call2) {
//prepare all input of call2 for probing
- HashMap<String,Hop> probe = new HashMap<>();
+ Map<String,Hop> probe = new HashMap<>();
for( int i=0; i<call2.getInput().size(); i++ )
probe.put(call2.getInputVariableNames()[i],
call2.getInput().get(i));
//construct new named inputs for call1 (in right order)
- ArrayList<String> varNames = new ArrayList<>();
- ArrayList<Hop> inputs = new ArrayList<>();
+ List<String> varNames = new ArrayList<>();
+ List<Hop> inputs = new ArrayList<>();
for( int i=0; i<call1.getInput().size(); i++ )
if( probe.containsKey(call1.getInputVariableNames()[i])
) {
varNames.add(call1.getInputVariableNames()[i]);
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
index dee56174c2..c8359c1a63 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassInlineFunctions.java
@@ -19,10 +19,10 @@
package org.apache.sysds.hops.ipa;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import org.apache.sysds.common.Types.OpOpData;
@@ -76,7 +76,7 @@ public class IPAPassInlineFunctions extends IPAPass
LOG.debug("IPA: Inline function
'"+fkey+"'");
//replace all relevant function calls
- ArrayList<Hop> hops =
fstmt.getBody().get(0).getHops();
+ List<Hop> hops =
fstmt.getBody().get(0).getHops();
List<FunctionOp> fcalls =
fgraph.getFunctionCalls(fkey);
List<StatementBlock> fcallsSB =
fgraph.getFunctionCallsSB(fkey);
boolean removedAll = true;
@@ -98,10 +98,10 @@ public class IPAPassInlineFunctions extends IPAPass
}
//step 1: deep copy hop dag
- ArrayList<Hop> hops2 =
Recompiler.deepCopyHopsDag(hops);
+ List<Hop> hops2 =
Recompiler.deepCopyHopsDag(hops);
//step 2: replace inputs
- HashMap<String,Hop> inMap = new
HashMap<>();
+ Map<String,Hop> inMap = new HashMap<>();
for(int j=0; j<op.getInput().size();
j++) {
String argName =
op.getInputVariableNames()[j];
DataIdentifier di =
fstmt.getInputParam(argName);
@@ -113,7 +113,7 @@ public class IPAPassInlineFunctions extends IPAPass
replaceTransientReads(hops2, inMap);
//step 3: replace outputs
- HashMap<String,String> outMap = new
HashMap<>();
+ Map<String,String> outMap = new
HashMap<>();
String[] opOutputs =
op.getOutputVariableNames();
for(int j=0; j<opOutputs.length; j++)
outMap.put(fstmt.getOutputParams().get(j).getName(), opOutputs[j]);
@@ -148,7 +148,7 @@ public class IPAPassInlineFunctions extends IPAPass
return ret;
}
- private static boolean containsFunctionOp(ArrayList<Hop> hops) {
+ private static boolean containsFunctionOp(List<Hop> hops) {
if( hops==null || hops.isEmpty() )
return false;
Hop.resetVisitStatus(hops);
@@ -157,7 +157,7 @@ public class IPAPassInlineFunctions extends IPAPass
return ret;
}
- private static int countOperators(ArrayList<Hop> hops) {
+ private static int countOperators(List<Hop> hops) {
if( hops==null || hops.isEmpty() )
return 0;
Hop.resetVisitStatus(hops);
@@ -179,14 +179,14 @@ public class IPAPassInlineFunctions extends IPAPass
return count;
}
- private static void replaceTransientReads(ArrayList<Hop> hops,
HashMap<String, Hop> inMap) {
+ private static void replaceTransientReads(List<Hop> hops, Map<String,
Hop> inMap) {
Hop.resetVisitStatus(hops);
for( Hop hop : hops )
rReplaceTransientReads(hop, inMap);
Hop.resetVisitStatus(hops);
}
- private static void rReplaceTransientReads(Hop current, HashMap<String,
Hop> inMap) {
+ private static void rReplaceTransientReads(Hop current, Map<String,
Hop> inMap) {
if( current.isVisited() )
return;
for( int i=0; i<current.getInput().size(); i++ ) {
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
index 0f33a455c1..d56e5467ce 100644
---
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
+++
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassPropagateReplaceLiterals.java
@@ -19,7 +19,6 @@
package org.apache.sysds.hops.ipa;
-import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.common.Types.OpOpData;
@@ -72,7 +71,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
if( fcallSizes.hasSafeLiterals(fkey) ) {
FunctionStatementBlock fsb =
prog.getFunctionStatementBlock(fkey);
FunctionStatement fstmt =
(FunctionStatement)fsb.getStatement(0);
- ArrayList<DataIdentifier> finputs =
fstmt.getInputParams();
+ List<DataIdentifier> finputs =
fstmt.getInputParams();
//populate call vars with amenable literals
LocalVariableMap callVars = new
LocalVariableMap();
@@ -154,7 +153,7 @@ public class IPAPassPropagateReplaceLiterals extends IPAPass
}
}
- private static void replaceLiterals(ArrayList<Hop> roots,
LocalVariableMap constants) {
+ private static void replaceLiterals(List<Hop> roots, LocalVariableMap
constants) {
if( roots == null )
return;
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
index ffb7d6849f..4033c2089e 100644
---
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
+++
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveConstantBinaryOps.java
@@ -19,8 +19,9 @@
package org.apache.sysds.hops.ipa;
-import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
import org.apache.sysds.hops.BinaryOp;
import org.apache.sysds.hops.DataGenOp;
@@ -59,7 +60,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
public boolean rewriteProgram( DMLProgram prog, FunctionCallGraph
fgraph, FunctionCallSizeInfo fcallSizes ) {
//approach: scan over top-level program (guaranteed to be
unconditional),
//collect ones=matrix(1,...); remove b(*)ones if not outer
operation
- HashMap<String, Hop> mOnes = new HashMap<>();
+ Map<String, Hop> mOnes = new HashMap<>();
for( StatementBlock sb : prog.getStatementBlocks() ) {
//pruning updated variables
@@ -81,7 +82,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
return false;
}
- private static void collectMatrixOfOnes(ArrayList<Hop> roots,
HashMap<String,Hop> mOnes)
+ private static void collectMatrixOfOnes(List<Hop> roots,
Map<String,Hop> mOnes)
{
if( roots == null )
return;
@@ -96,7 +97,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
}
}
- private static void rRemoveConstantBinaryOp(StatementBlock sb,
HashMap<String,Hop> mOnes) {
+ private static void rRemoveConstantBinaryOp(StatementBlock sb,
Map<String,Hop> mOnes) {
if( sb instanceof IfStatementBlock )
{
IfStatementBlock isb = (IfStatementBlock) sb;
@@ -131,7 +132,7 @@ public class IPAPassRemoveConstantBinaryOps extends IPAPass
}
}
- private static void rRemoveConstantBinaryOp(Hop hop,
HashMap<String,Hop> mOnes)
+ private static void rRemoveConstantBinaryOp(Hop hop, Map<String,Hop>
mOnes)
{
if( hop.isVisited() )
return;
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
index c78aac627c..e491aeec75 100644
---
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
+++
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRemoveUnnecessaryCheckpoints.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Set;
import org.apache.sysds.common.Types.OpOp1;
@@ -72,7 +73,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends
IPAPass
//collect checkpoints; determine if used before update; remove
first checkpoint
//on second checkpoint if update in between and not used before
update
- HashMap<String, Hop> chkpointCand = new HashMap<>();
+ Map<String, Hop> chkpointCand = new HashMap<>();
for( StatementBlock sb : dmlp.getStatementBlocks() )
{
@@ -123,7 +124,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends
IPAPass
//collect checkpoints and remove unnecessary checkpoints
if( HopRewriteUtils.isLastLevelStatementBlock(sb) ) {
- ArrayList<Hop> tmp =
collectCheckpoints(sb.getHops());
+ List<Hop> tmp =
collectCheckpoints(sb.getHops());
for( Hop chkpoint : tmp ) {
if(
chkpointCand.containsKey(chkpoint.getName()) ) {
chkpointCand.get(chkpoint.getName()).setRequiresCheckpoint(false);
@@ -140,7 +141,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends
IPAPass
//after update if not used before update (best effort move
which often avoids
//the second checkpoint on loops even though used in between)
- HashMap<String, Hop> chkpointCand = new HashMap<>();
+ Map<String, Hop> chkpointCand = new HashMap<>();
for( StatementBlock sb : dmlp.getStatementBlocks() )
{
@@ -197,7 +198,7 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends
IPAPass
//collect checkpoints
if( HopRewriteUtils.isLastLevelStatementBlock(sb) ) {
- ArrayList<Hop> tmp =
collectCheckpoints(sb.getHops());
+ List<Hop> tmp =
collectCheckpoints(sb.getHops());
for( Hop chkpoint : tmp )
chkpointCand.put(chkpoint.getName(),
chkpoint);
}
@@ -219,19 +220,17 @@ public class IPAPassRemoveUnnecessaryCheckpoints extends
IPAPass
}
}
- private static ArrayList<Hop> collectCheckpoints(ArrayList<Hop> roots)
- {
- ArrayList<Hop> ret = new ArrayList<>();
+ private static List<Hop> collectCheckpoints(List<Hop> roots) {
+ List<Hop> ret = new ArrayList<>();
if( roots != null ) {
Hop.resetVisitStatus(roots);
for( Hop root : roots )
rCollectCheckpoints(root, ret);
}
-
return ret;
}
- private static void rCollectCheckpoints(Hop hop, ArrayList<Hop>
checkpoints)
+ private static void rCollectCheckpoints(Hop hop, List<Hop> checkpoints)
{
if( hop.isVisited() )
return;
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
index d0ea21a8aa..de350baea0 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/InterProceduralAnalysis.java
@@ -102,7 +102,7 @@ public class InterProceduralAnalysis {
private FunctionCallGraph _fgraph;
//set IPA passes to apply in order
- private final ArrayList<IPAPass> _passes;
+ private final List<IPAPass> _passes;
/**
* Creates a handle for performing inter-procedural analysis
@@ -266,7 +266,7 @@ public class InterProceduralAnalysis {
//check size-preserving characteristic
if( ret ) {
FunctionCallSizeInfo fcallSizes = new
FunctionCallSizeInfo(_fgraph, false);
- HashSet<String> fnStack = new HashSet<>();
+ Set<String> fnStack = new HashSet<>();
LocalVariableMap callVars = new LocalVariableMap();
//populate input (recognizable numbers, later reset)
@@ -366,7 +366,7 @@ public class InterProceduralAnalysis {
//remove updated constant scalars
Recompiler.removeUpdatedScalars(callVars, sb);
//old stats in, new stats out if updated
- ArrayList<Hop> roots = sb.getHops();
+ List<Hop> roots = sb.getHops();
DMLProgram prog = sb.getDMLProg();
//replace scalar reads with literals
if( replaceScalars ) {
@@ -396,7 +396,7 @@ public class InterProceduralAnalysis {
* @param roots List of HOPs.
* @param vars Map of variables eligible for propagation.
*/
- private static void propagateScalarsAcrossDAG(ArrayList<Hop> roots,
LocalVariableMap vars) {
+ private static void propagateScalarsAcrossDAG(List<Hop> roots,
LocalVariableMap vars) {
for (Hop hop : roots) {
try {
Recompiler.rReplaceLiterals(hop, vars, true);
@@ -428,7 +428,7 @@ public class InterProceduralAnalysis {
* @param roots List of HOP DAG root nodes.
* @param vars Map of variables eligible for propagation.
*/
- private static void propagateStatisticsAcrossDAG( ArrayList<Hop> roots,
LocalVariableMap vars ) {
+ private static void propagateStatisticsAcrossDAG( List<Hop> roots,
LocalVariableMap vars ) {
if( roots == null )
return;
@@ -460,7 +460,7 @@ public class InterProceduralAnalysis {
* @param fcallSizes function call summary
* @param fnStack Function stack to determine current scope.
*/
- private void propagateStatisticsIntoFunctions(DMLProgram prog,
ArrayList<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo
fcallSizes, Set<String> fnStack, boolean replaceScalars) {
+ private void propagateStatisticsIntoFunctions(DMLProgram prog,
List<Hop> roots, LocalVariableMap callVars, FunctionCallSizeInfo fcallSizes,
Set<String> fnStack, boolean replaceScalars) {
for( Hop root : roots )
propagateStatisticsIntoFunctions(prog, root, callVars,
fcallSizes, fnStack, replaceScalars);
}
@@ -530,7 +530,7 @@ public class InterProceduralAnalysis {
//note: due to arbitrary binding sequences of named function
arguments,
//we cannot use the sequence as defined in the function
signature
String[] funArgNames = fop.getInputVariableNames();
- ArrayList<Hop> inputOps = fop.getInput();
+ List<Hop> inputOps = fop.getInput();
String fkey = fop.getFunctionKey();
//iterate over all parameters (with robustness for missing
parameters)
@@ -589,7 +589,7 @@ public class InterProceduralAnalysis {
* calling program's variable map.
*/
private static void extractFunctionCallReturnStatistics(
FunctionStatement fstmt, FunctionOp fop, LocalVariableMap tmpVars,
LocalVariableMap callVars, boolean overwrite ) {
- ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams();
+ List<DataIdentifier> foutputOps = fstmt.getOutputParams();
String[] outputVars = fop.getOutputVariableNames();
String fkey = fop.getFunctionKey();
@@ -647,7 +647,7 @@ public class InterProceduralAnalysis {
}
private static void
extractFunctionCallUnknownReturnStatistics(FunctionStatement fstmt, FunctionOp
fop, LocalVariableMap callVars) {
- ArrayList<DataIdentifier> foutputOps = fstmt.getOutputParams();
+ List<DataIdentifier> foutputOps = fstmt.getOutputParams();
String[] outputVars = fop.getOutputVariableNames();
String fkey = fop.getFunctionKey();
try {
diff --git
a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
index 4740472ba6..fdd112ac07 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/LiteralReplacement.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.recompile;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.AggUnaryOp;
@@ -93,7 +94,7 @@ public class LiteralReplacement
//because hop c marked as visited, and
(2) repeated evaluation of uagg ops
if( c.getParent().size() > 1 ) {
//multiple parents
- ArrayList<Hop> parents = new
ArrayList<>(c.getParent());
+ List<Hop> parents = new
ArrayList<>(c.getParent());
for( Hop p : parents ) {
int pos =
HopRewriteUtils.getChildReferencePos(p, c);
HopRewriteUtils.removeChildReferenceByPos(p, c, pos);
@@ -369,7 +370,7 @@ public class LiteralReplacement
&& HopRewriteUtils.isData(in,
OpOpData.TRANSIENTREAD) ) {
ListObject list =
(ListObject)ec.getVariables().get(in.getName());
if( list.getLength() <= 128 ) {
- ArrayList<Hop> tmp = new ArrayList<>();
+ List<Hop> tmp = new ArrayList<>();
for( int i=0; i < list.getLength(); i++
) {
String varname =
Dag.getNextUniqueVarname(DataType.MATRIX);
MatrixObject mo =
(MatrixObject) list.slice(i);
diff --git a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
index edb03d24d9..943f0043a9 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/RecompileStatus.java
@@ -24,6 +24,7 @@ import org.apache.sysds.runtime.controlprogram.ProgramBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import java.util.HashMap;
+import java.util.Map;
public class RecompileStatus
{
@@ -37,7 +38,7 @@ public class RecompileStatus
private boolean _requiresRecompile = false;
//collection of extracted statistics for control flow reconciliation
- private final HashMap<String, DataCharacteristics> _lastTWrites;
+ private final Map<String, DataCharacteristics> _lastTWrites;
public RecompileStatus() {
this(0, true, ResetType.NO_RESET, false);
@@ -55,7 +56,7 @@ public class RecompileStatus
_initialCodegen = initialCodegen;
}
- public HashMap<String, DataCharacteristics> getTWriteStats() {
+ public Map<String, DataCharacteristics> getTWriteStats() {
return _lastTWrites;
}
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index dfa6988122..70e21c34bc 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -25,9 +25,10 @@ import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
-import java.util.HashSet;
import java.util.List;
+import java.util.Map;
import java.util.Map.Entry;
+import java.util.Set;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -142,9 +143,9 @@ public class Recompiler {
};
// additional reused objects to avoid repeated, incremental
reallocation on deepCopyDags
- private static ThreadLocal<HashMap<Long,Hop>> _memoHop = new
ThreadLocal<>() {
- @Override protected HashMap<Long,Hop> initialValue() { return
new HashMap<>(); }
- @Override public HashMap<Long,Hop> get() { var tmp =
super.get(); tmp.clear(); return tmp; }
+ private static ThreadLocal<Map<Long,Hop>> _memoHop = new
ThreadLocal<>() {
+ @Override protected Map<Long,Hop> initialValue() { return new
HashMap<>(); }
+ @Override public Map<Long,Hop> get() { var tmp = super.get();
tmp.clear(); return tmp; }
};
public enum ResetType {
@@ -488,7 +489,7 @@ public class Recompiler {
* @param fnStack function stack
* @param et execution type
*/
- public static void recompileProgramBlockHierarchy2Forced(
ArrayList<ProgramBlock> pbs, long tid, HashSet<String> fnStack, ExecType et ) {
+ public static void recompileProgramBlockHierarchy2Forced(
ArrayList<ProgramBlock> pbs, long tid, Set<String> fnStack, ExecType et ) {
synchronized( pbs ) {
for( ProgramBlock pb : pbs )
rRecompileProgramBlock2Forced(pb, tid, fnStack,
et);
@@ -575,7 +576,7 @@ public class Recompiler {
try {
//note: need memo table over all independent DAGs in
order to
//account for shared transient reads (otherwise more
instructions generated)
- HashMap<Long, Hop> memo = _memoHop.get(); //orig ID,
new clone
+ Map<Long, Hop> memo = _memoHop.get(); //orig ID, new
clone
for( Hop hopRoot : hops )
ret.add(rDeepCopyHopsDag(hopRoot, memo));
}
@@ -596,7 +597,7 @@ public class Recompiler {
Hop ret = null;
try {
- HashMap<Long, Hop> memo = _memoHop.get(); //orig ID,
new clone
+ Map<Long, Hop> memo = _memoHop.get(); //orig ID, new
clone
ret = rDeepCopyHopsDag(hops, memo);
}
catch(Exception ex) {
@@ -606,7 +607,7 @@ public class Recompiler {
return ret;
}
- private static Hop rDeepCopyHopsDag( Hop hop, HashMap<Long,Hop> memo )
+ private static Hop rDeepCopyHopsDag( Hop hop, Map<Long,Hop> memo )
throws CloneNotSupportedException
{
Hop ret = memo.get(hop.getHopID());
@@ -628,7 +629,7 @@ public class Recompiler {
}
- public static void updateFunctionNames(ArrayList<Hop> hops, long pid)
+ public static void updateFunctionNames(List<Hop> hops, long pid)
{
Hop.resetVisitStatus(hops);
for( Hop hopRoot : hops )
@@ -1041,7 +1042,7 @@ public class Recompiler {
status.trackRecompile(fsb.requiresPredicateRecompilation());
}
- public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long
tid, HashSet<String> fnStack, ExecType et ) {
+ public static void rRecompileProgramBlock2Forced( ProgramBlock pb, long
tid, Set<String> fnStack, ExecType et ) {
if (pb instanceof WhileProgramBlock)
{
WhileProgramBlock pbTmp = (WhileProgramBlock)pb;
@@ -1130,7 +1131,9 @@ public class Recompiler {
}
}
- private static void rRecompileProgramBlock2Forced(String fnamespace,
String fname, Program prog, long tid, HashSet<String> fnStack, ExecType et) {
+ private static void rRecompileProgramBlock2Forced(String fnamespace,
String fname,
+ Program prog, long tid, Set<String> fnStack, ExecType et)
+ {
String fKey = DMLProgram.constructFunctionKey(fnamespace,
fname);
if( !fnStack.contains(fKey) ) { //memoization for multiple
calls, recursion
fnStack.add(fKey);
@@ -1159,11 +1162,11 @@ public class Recompiler {
}
}
- public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars) {
+ public static void extractDAGOutputStatistics(List<Hop> hops,
LocalVariableMap vars) {
extractDAGOutputStatistics(hops, vars, true);
}
- public static void extractDAGOutputStatistics(ArrayList<Hop> hops,
LocalVariableMap vars, boolean overwrite) {
+ public static void extractDAGOutputStatistics(List<Hop> hops,
LocalVariableMap vars, boolean overwrite) {
for( Hop hop : hops ) //for all hop roots
extractDAGOutputStatistics(hop, vars, overwrite);
}
@@ -1384,7 +1387,7 @@ public class Recompiler {
else if ( hop instanceof DataGenOp )
{
DataGenOp d = (DataGenOp) hop;
- HashMap<String,Integer> params = d.getParamIndexMap();
+ Map<String,Integer> params = d.getParamIndexMap();
if ( d.getOp() == OpOpDG.RAND ||
d.getOp()==OpOpDG.SINIT
|| d.getOp() == OpOpDG.SAMPLE || d.getOp() ==
OpOpDG.FRAMEINIT )
{
@@ -1394,7 +1397,7 @@ public class Recompiler {
int ix1 =
params.get(DataExpression.RAND_ROWS);
int ix2 =
params.get(DataExpression.RAND_COLS);
//update rows/cols by evaluating simple
expression of literals, nrow, ncol, scalars, binaryops
- HashMap<Long, Long> memo = new
HashMap<>();
+ Map<Long, Long> memo = new HashMap<>();
d.refreshRowsParameterInformation(d.getInput().get(ix1), vars, memo);
d.refreshColsParameterInformation(d.getInput().get(ix2), vars, memo);
if( !(initUnknown & d.dimsKnown()) )
@@ -1407,7 +1410,7 @@ public class Recompiler {
int ix1 = params.get(Statement.SEQ_FROM);
int ix2 = params.get(Statement.SEQ_TO);
int ix3 = params.get(Statement.SEQ_INCR);
- HashMap<Long, Double> memo = new HashMap<>();
+ Map<Long, Double> memo = new HashMap<>();
double from =
Hop.computeBoundsInformation(d.getInput().get(ix1), vars, memo);
double to =
Hop.computeBoundsInformation(d.getInput().get(ix2), vars, memo);
double incr =
Hop.computeBoundsInformation(d.getInput().get(ix3), vars, memo);
@@ -1437,7 +1440,7 @@ public class Recompiler {
if (hop.getDataType() != DataType.TENSOR) {
hop.refreshSizeInformation(); //update incl
reset
if (!hop.dimsKnown()) {
- HashMap<Long, Long> memo = new
HashMap<>();
+ Map<Long, Long> memo = new HashMap<>();
hop.refreshRowsParameterInformation(hop.getInput().get(1), vars, memo);
hop.refreshColsParameterInformation(hop.getInput().get(2), vars, memo);
}
@@ -1455,7 +1458,7 @@ public class Recompiler {
hop.setDim2(1);
}
else {
- HashMap<Long, Double> memo = new
HashMap<>();
+ Map<Long, Double> memo = new
HashMap<>();
double rl =
Hop.computeBoundsInformation(hop.getInput().get(1), vars, memo);
double ru =
Hop.computeBoundsInformation(hop.getInput().get(2), vars, memo);
double cl =
Hop.computeBoundsInformation(hop.getInput().get(3), vars, memo);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
index eab6a861f8..61bd0921ce 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopRewriteUtils.java
@@ -81,7 +81,6 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.util.UtilFunctions;
-import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
@@ -264,7 +263,7 @@ public class HopRewriteUtils {
* @return hnew
*/
public static Hop rewireAllParentChildReferences( Hop hold, Hop hnew ) {
- ArrayList<Hop> parents = hold.getParent();
+ List<Hop> parents = hold.getParent();
while (!parents.isEmpty())
HopRewriteUtils.replaceChildReference(parents.get(0),
hold, hnew);
return hnew;
@@ -487,7 +486,7 @@ public class HopRewriteUtils {
return datagen;
}
- public static Hop createDataGenOpByVal( ArrayList<LiteralOp> values,
long rows, long cols )
+ public static Hop createDataGenOpByVal( List<LiteralOp> values, long
rows, long cols )
{
StringBuilder sb = new StringBuilder();
for(LiteralOp lit : values) {
@@ -606,7 +605,7 @@ public class HopRewriteUtils {
return reorg;
}
- public static ReorgOp createReorg(ArrayList<Hop> inputs, ReOrgOp rop) {
+ public static ReorgOp createReorg(List<Hop> inputs, ReOrgOp rop) {
Hop main = inputs.get(0);
ReorgOp reorg = new ReorgOp(main.getName(), main.getDataType(),
main.getValueType(), rop, inputs);
reorg.setBlocksize(main.getBlocksize());
@@ -1016,7 +1015,7 @@ public class HopRewriteUtils {
return isTransposeOperation(hop) && hop.getParent().size() <=
maxParents;
}
- public static boolean containsTransposeOperation(ArrayList<Hop> hops) {
+ public static boolean containsTransposeOperation(List<Hop> hops) {
boolean ret = false;
for( Hop hop : hops )
ret |= isTransposeOperation(hop);
@@ -1398,7 +1397,7 @@ public class HopRewriteUtils {
public static boolean hasOnlyWriteParents( Hop hop, boolean
inclTransient, boolean inclPersistent ) {
boolean ret = true;
- ArrayList<Hop> parents = hop.getParent();
+ List<Hop> parents = hop.getParent();
for( Hop p : parents ) {
if( inclTransient && inclPersistent )
ret &= ( p instanceof DataOp &&
(((DataOp)p).getOp()==OpOpData.TRANSIENTWRITE
@@ -1425,7 +1424,7 @@ public class HopRewriteUtils {
&& ((DataOp)hop).getFileFormat()!=FileFormat.BINARY);
}
- public static boolean containsOp(ArrayList<Hop> candidates, Class<?
extends Hop> clazz) {
+ public static boolean containsOp(List<Hop> candidates, Class<? extends
Hop> clazz) {
if( candidates != null )
for( Hop cand : candidates )
if( cand.getClass().equals(clazz) )
@@ -1652,7 +1651,7 @@ public class HopRewriteUtils {
&& hop.getInput().stream().anyMatch(h ->
h.getDataType().isList());
}
- public static boolean containsSecondOrderBuiltin(ArrayList<Hop> roots) {
+ public static boolean containsSecondOrderBuiltin(List<Hop> roots) {
Hop.resetVisitStatus(roots);
return roots.stream().anyMatch(r ->
containsSecondOrderBuiltin(r));
}