This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 8185411c8b [MINOR] Cleanup code quality hops compilation / federated
planners
8185411c8b is described below
commit 8185411c8b24bd9d8311da1c092b84d5a33a4748
Author: Matthias Boehm <[email protected]>
AuthorDate: Sun Mar 24 20:05:49 2024 +0100
[MINOR] Cleanup code quality hops compilation / federated planners
---
src/main/java/org/apache/sysds/hops/DnnOp.java | 13 +++---
src/main/java/org/apache/sysds/hops/Hop.java | 8 ++--
src/main/java/org/apache/sysds/hops/MemoTable.java | 3 +-
.../sysds/hops/codegen/template/TemplateUtils.java | 4 +-
.../hops/ipa/IPAPassRewriteFederatedPlan.java | 2 +-
.../apache/sysds/hops/rewrite/HopDagValidator.java | 3 +-
.../RewriteAlgebraicSimplificationDynamic.java | 19 +++-----
.../RewriteCommonSubexpressionElimination.java | 3 +-
.../sysds/hops/rewrite/RewriteConstantFolding.java | 23 +++++-----
.../RewriteElementwiseMultChainOptimization.java | 9 ++--
.../sysds/hops/rewrite/RewriteGPUSpecificOps.java | 7 +--
.../hops/rewrite/RewriteIndexingVectorization.java | 9 ++--
.../RewriteMatrixMultChainOptimization.java | 17 +++----
.../RewriteMatrixMultChainOptimizationSparse.java | 6 +--
...ewriteMatrixMultChainOptimizationTranspose.java | 3 +-
.../RewriteRemoveDanglingParentReferences.java | 3 +-
.../rewrite/RewriteRemovePersistentReadWrite.java | 3 +-
.../hops/rewrite/RewriteRemoveReadAfterWrite.java | 4 +-
.../rewrite/RewriteRemoveUnnecessaryCasts.java | 10 ++--
.../compress/colgroup/indexes/IColIndex.java | 2 +-
.../compress/workload/WorkloadAnalyzer.java | 4 +-
.../controlprogram/FunctionProgramBlock.java | 53 ----------------------
.../controlprogram/paramserv/ParamservUtils.java | 4 +-
.../controlprogram/parfor/opt/CostEstimator.java | 8 ++--
.../runtime/controlprogram/parfor/opt/OptNode.java | 9 ++--
.../parfor/opt/OptTreeConverter.java | 6 +--
.../parfor/opt/OptimizerRuleBased.java | 8 ++--
.../parfor/opt/ProgramRecompiler.java | 7 ++-
.../sysds/runtime/frame/data/columns/DDCArray.java | 4 +-
29 files changed, 99 insertions(+), 155 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/DnnOp.java
b/src/main/java/org/apache/sysds/hops/DnnOp.java
index 3502696583..1600b5298a 100644
--- a/src/main/java/org/apache/sysds/hops/DnnOp.java
+++ b/src/main/java/org/apache/sysds/hops/DnnOp.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops;
import java.util.ArrayList;
+import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -122,9 +123,7 @@ public class DnnOp extends MultiThreadedHop {
ExecType et = optFindExecType();
- ArrayList<Hop> inputs = getInput();
- switch( op )
- {
+ switch( op ) {
case MAX_POOL:
case MAX_POOL_BACKWARD:
case AVG_POOL:
@@ -135,7 +134,7 @@ public class DnnOp extends MultiThreadedHop {
case BIASADD:
case BIASMULT: {
if(et == ExecType.CP || et == ExecType.GPU) {
- setLops(constructDnnLops(et, inputs));
+ setLops(constructDnnLops(et,
getInput()));
break;
}
throw new HopsException("Unimplemented DnnOp
for execution type: " + et.name());
@@ -144,7 +143,7 @@ public class DnnOp extends MultiThreadedHop {
case CHANNEL_SUMS:
case UPDATE_NESTEROV_X: {
if(et == ExecType.GPU) {
- setLops(constructDnnLops(et, inputs));
+ setLops(constructDnnLops(et,
getInput()));
break;
}
throw new HopsException("Unimplemented DnnOp
for execution type: " + et.name());
@@ -254,7 +253,7 @@ public class DnnOp extends MultiThreadedHop {
return null;
}
- public Lop constructDnnLops(ExecType et, ArrayList<Hop> inputs) {
+ public Lop constructDnnLops(ExecType et, List<Hop> inputs) {
if(inputs.size() != getNumExpectedInputs())
throw new HopsException("Incorrect number of inputs for
" + op.name());
@@ -262,7 +261,7 @@ public class DnnOp extends MultiThreadedHop {
//
---------------------------------------------------------------
// Deal with fused operators and contruct
lhsInputLop/optionalRhsInputLop
Lop lhsInputLop = null; Lop optionalRhsInputLop = null;
- ArrayList<Hop> inputsOfPotentiallyFusedOp = inputs;
+ List<Hop> inputsOfPotentiallyFusedOp = inputs;
OpOpDnn lopOp = op;
// RELU_MAX_POOLING and RELU_MAX_POOLING_BACKWARD is extremely
useful for CP backend
diff --git a/src/main/java/org/apache/sysds/hops/Hop.java
b/src/main/java/org/apache/sysds/hops/Hop.java
index 34cb568b5d..127fe7e145 100644
--- a/src/main/java/org/apache/sysds/hops/Hop.java
+++ b/src/main/java/org/apache/sysds/hops/Hop.java
@@ -86,9 +86,9 @@ public abstract class Hop implements ParseInfo {
protected UpdateType _updateType = UpdateType.COPY;
/** The output Hops that are connected to this Hop */
- protected ArrayList<Hop> _parent = new ArrayList<>();
+ protected List<Hop> _parent = new ArrayList<>();
/** The input Hops that are connected to this Hop */
- protected ArrayList<Hop> _input = new ArrayList<>();
+ protected List<Hop> _input = new ArrayList<>();
/** Currently used exec type */
protected ExecType _etype = null;
@@ -929,11 +929,11 @@ public abstract class Hop implements ParseInfo {
return false;
}
- public ArrayList<Hop> getParent() {
+ public List<Hop> getParent() {
return _parent;
}
- public ArrayList<Hop> getInput() {
+ public List<Hop> getInput() {
return _input;
}
diff --git a/src/main/java/org/apache/sysds/hops/MemoTable.java
b/src/main/java/org/apache/sysds/hops/MemoTable.java
index dc6437410b..e67001292e 100644
--- a/src/main/java/org/apache/sysds/hops/MemoTable.java
+++ b/src/main/java/org/apache/sysds/hops/MemoTable.java
@@ -27,6 +27,7 @@ import org.apache.sysds.runtime.meta.MatrixCharacteristics;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
/**
* Memoization Table (hop id, worst-case matrix characteristics).
@@ -97,7 +98,7 @@ public class MemoTable
_memo.put(hopID, dc);
}
- public DataCharacteristics[] getAllInputStats(ArrayList<Hop> inputs )
+ public DataCharacteristics[] getAllInputStats(List<Hop> inputs )
{
if( inputs == null )
return null;
diff --git
a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
index 5a4a7ac62a..16d88c5a98 100644
--- a/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
+++ b/src/main/java/org/apache/sysds/hops/codegen/template/TemplateUtils.java
@@ -19,10 +19,10 @@
package org.apache.sysds.hops.codegen.template;
-import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
+import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -674,7 +674,7 @@ public class TemplateUtils
return ret;
}
- private static boolean checkContainment(ArrayList<Hop> inputs, Hop
probe, boolean inclTranspose) {
+ private static boolean checkContainment(List<Hop> inputs, Hop probe,
boolean inclTranspose) {
if( !inclTranspose )
return inputs.contains(probe);
for( Hop hop : inputs )
diff --git
a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
index 09630df954..af23b19e35 100644
--- a/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
+++ b/src/main/java/org/apache/sysds/hops/ipa/IPAPassRewriteFederatedPlan.java
@@ -67,7 +67,7 @@ public class IPAPassRewriteFederatedPlan extends IPAPass {
private void generatePlan(DMLProgram prog, FunctionCallGraph fgraph,
FunctionCallSizeInfo fcallSizes, String splanner){
FederatedPlanner planner =
FederatedPlanner.isCompiled(splanner) ?
FederatedPlanner.valueOf(splanner.toUpperCase()) :
- FederatedPlanner.COMPILE_COST_BASED;
+ FederatedPlanner.COMPILE_FED_HEURISTIC;
// run planner rewrite with forced federated exec types
planner.getPlanner().rewriteProgram(prog, fgraph, fcallSizes);
diff --git a/src/main/java/org/apache/sysds/hops/rewrite/HopDagValidator.java
b/src/main/java/org/apache/sysds/hops/rewrite/HopDagValidator.java
index 29e82c664d..a8d77cf234 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/HopDagValidator.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/HopDagValidator.java
@@ -23,6 +23,7 @@ import static org.apache.sysds.hops.HopsException.check;
import java.util.ArrayList;
import java.util.HashSet;
+import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
@@ -106,7 +107,7 @@ public class HopDagValidator {
"not properly linked to its parent
pid=%d %s",
parent.getHopID(),
parent.getClass().getName());
- final ArrayList<Hop> input = hop.getInput();
+ final List<Hop> input = hop.getInput();
final DataType dt = hop.getDataType();
final ValueType vt = hop.getValueType();
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
index 3e1c498f01..fea525703d 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteAlgebraicSimplificationDynamic.java
@@ -562,7 +562,6 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
return hi;
}
- @SuppressWarnings("unchecked")
private static Hop fuseDatagenAndReorgOperation(Hop parent, Hop hi, int
pos)
{
if( HopRewriteUtils.isTransposeOperation(hi)
@@ -575,7 +574,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
{
//relink all parents and dataop (remove
transpose)
HopRewriteUtils.removeAllChildReferences(hi);
- ArrayList<Hop> parents = (ArrayList<Hop>)
hi.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
for( int i=0; i<parents.size(); i++ ) {
Hop lparent = parents.get(i);
int ppos =
HopRewriteUtils.getChildReferencePos(lparent, hi);
@@ -600,7 +599,6 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
return hi;
}
- @SuppressWarnings("unchecked")
private static Hop simplifyColwiseAggregate( Hop parent, Hop hi, int
pos ) {
if( hi instanceof AggUnaryOp )
{
@@ -636,7 +634,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
else if( input.getDim2() == 1 )
{
//get old parents (before
creating cast over aggregate)
- ArrayList<Hop> parents =
(ArrayList<Hop>) hi.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
//simplify col-aggregate to
full aggregate
uhi.setDirection(Direction.RowCol);
@@ -662,7 +660,6 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
return hi;
}
- @SuppressWarnings("unchecked")
private static Hop simplifyRowwiseAggregate( Hop parent, Hop hi, int
pos ) {
if( hi instanceof AggUnaryOp )
{
@@ -701,7 +698,7 @@ public class RewriteAlgebraicSimplificationDynamic extends
HopRewriteRule
else if( input.getDim1() == 1 )
{
//get old parents (before
creating cast over aggregate)
- ArrayList<Hop> parents =
(ArrayList<Hop>) hi.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
//simplify row-aggregate to
full aggregate
uhi.setDirection(Direction.RowCol);
@@ -1270,7 +1267,6 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
return hi;
}
- @SuppressWarnings("unchecked")
private static Hop pushdownBinaryOperationOnDiag(Hop parent, Hop hi,
int pos)
{
//diag(X)*7 --> diag(X*7) in order to (1) reduce required
memory for b(*) and
@@ -1304,8 +1300,8 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
{
//remove all parent links to binary op (since
we want to reorder
//we cannot just look at the current parent)
- ArrayList<Hop> parents = (ArrayList<Hop>)
hi.getParent().clone();
- ArrayList<Integer> parentspos = new
ArrayList<>();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
+ List<Integer> parentspos = new ArrayList<>();
for(Hop lparent : parents) {
int lpos =
HopRewriteUtils.getChildReferencePos(lparent, hi);
HopRewriteUtils.removeChildReferenceByPos(lparent, hi, lpos);
@@ -2525,7 +2521,6 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
* @param pos position
* @return high-level operator
*/
- @SuppressWarnings("unchecked")
private static Hop reorderMinusMatrixMult(Hop parent, Hop hi, int pos)
{
if( HopRewriteUtils.isMatrixMultiply(hi) ) //X%*%Y
@@ -2545,7 +2540,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
HopRewriteUtils.removeChildReference(hi,
hileft);
//get old parents (before creating minus over
matrix mult)
- ArrayList<Hop> parents = (ArrayList<Hop>)
hi.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
//create new operators
BinaryOp minus =
HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
@@ -2579,7 +2574,7 @@ public class RewriteAlgebraicSimplificationDynamic
extends HopRewriteRule
HopRewriteUtils.removeChildReference(hi,
hiright);
//get old parents (before creating minus over
matrix mult)
- ArrayList<Hop> parents = (ArrayList<Hop>)
hi.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hi.getParent());
//create new operators
BinaryOp minus =
HopRewriteUtils.createBinary(new LiteralOp(0), hi, OpOp2.MINUS);
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCommonSubexpressionElimination.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCommonSubexpressionElimination.java
index d620b7d10a..399e23d48c 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteCommonSubexpressionElimination.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteCommonSubexpressionElimination.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import org.apache.sysds.hops.DataOp;
import org.apache.sysds.hops.Hop;
@@ -186,7 +187,7 @@ public class RewriteCommonSubexpressionElimination extends
HopRewriteRule
hop.getParent().remove(j);
//replace h2 w/ h1 in h2-parent
inputs
- ArrayList<Hop> parent =
h2.getParent();
+ List<Hop> parent =
h2.getParent();
for( Hop p : parent )
for( int k=0;
k<p.getInput().size(); k++ )
if(
p.getInput().get(k)==h2 ) {
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
index 6980e5b661..9655bed2d1 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteConstantFolding.java
@@ -191,9 +191,8 @@ public class RewriteConstantFolding extends HopRewriteRule
return _tmpEC;
}
- private static boolean isApplicableBinaryOp( Hop hop )
- {
- ArrayList<Hop> in = hop.getInput();
+ private static boolean isApplicableBinaryOp( Hop hop ) {
+ List<Hop> in = hop.getInput();
return ( hop instanceof BinaryOp
&& in.get(0) instanceof LiteralOp
&& in.get(1) instanceof LiteralOp
@@ -205,7 +204,7 @@ public class RewriteConstantFolding extends HopRewriteRule
}
private static boolean isApplicableUnaryOp( Hop hop ) {
- ArrayList<Hop> in = hop.getInput();
+ List<Hop> in = hop.getInput();
return ( hop instanceof UnaryOp
&& in.get(0) instanceof LiteralOp
&& ((UnaryOp)hop).getOp() != OpOp1.EXISTS
@@ -226,16 +225,16 @@ public class RewriteConstantFolding extends HopRewriteRule
}
private static boolean isApplicableFalseConjunctivePredicate( Hop hop )
{
- ArrayList<Hop> in = hop.getInput();
- return ( HopRewriteUtils.isBinary(hop, OpOp2.AND) &&
hop.getDataType().isScalar()
- && ( (in.get(0) instanceof LiteralOp &&
!((LiteralOp)in.get(0)).getBooleanValue())
- ||(in.get(1) instanceof LiteralOp &&
!((LiteralOp)in.get(1)).getBooleanValue())) );
+ List<Hop> in = hop.getInput();
+ return (HopRewriteUtils.isBinary(hop, OpOp2.AND) &&
hop.getDataType().isScalar()
+ && ( (in.get(0) instanceof LiteralOp &&
!((LiteralOp)in.get(0)).getBooleanValue())
+ ||(in.get(1) instanceof LiteralOp &&
!((LiteralOp)in.get(1)).getBooleanValue())) );
}
private static boolean isApplicableTrueDisjunctivePredicate( Hop hop ) {
- ArrayList<Hop> in = hop.getInput();
- return ( HopRewriteUtils.isBinary(hop, OpOp2.OR) &&
hop.getDataType().isScalar()
- && ( (in.get(0) instanceof LiteralOp &&
((LiteralOp)in.get(0)).getBooleanValue())
- ||(in.get(1) instanceof LiteralOp &&
((LiteralOp)in.get(1)).getBooleanValue())) );
+ List<Hop> in = hop.getInput();
+ return (HopRewriteUtils.isBinary(hop, OpOp2.OR) &&
hop.getDataType().isScalar()
+ && ( (in.get(0) instanceof LiteralOp &&
((LiteralOp)in.get(0)).getBooleanValue())
+ ||(in.get(1) instanceof LiteralOp &&
((LiteralOp)in.get(1)).getBooleanValue())) );
}
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
index 005005b9d7..cd56c56c18 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteElementwiseMultChainOptimization.java
@@ -24,6 +24,7 @@ import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.SortedSet;
@@ -125,7 +126,7 @@ public class RewriteElementwiseMultChainOptimization
extends HopRewriteRule {
}
private static void recurseInputs(final Hop parent) {
- final ArrayList<Hop> inputs = parent.getInput();
+ final List<Hop> inputs = parent.getInput();
for (int i = 0; i < inputs.size(); i++) {
final Hop input = inputs.get(i);
final Hop newInput = rule_RewriteEMult(input);
@@ -308,14 +309,14 @@ public class RewriteElementwiseMultChainOptimization
extends HopRewriteRule {
* @return Whether this interior emult or any child emult has a foreign
parent.
*/
private static boolean checkForeignParent(final Set<BinaryOp> emults,
final BinaryOp child) {
- final ArrayList<Hop> parents = child.getParent();
+ final List<Hop> parents = child.getParent();
if (parents.size() > 1)
for (final Hop parent : parents)
if (!(parent instanceof BinaryOp) ||
!emults.contains(parent))
return false;
// child does not have foreign parents
- final ArrayList<Hop> inputs = child.getInput();
+ final List<Hop> inputs = child.getInput();
final Hop left = inputs.get(0), right = inputs.get(1);
return (!isBinaryMult(left) || checkForeignParent(emults,
(BinaryOp)left)) &&
(!isBinaryMult(right) ||
checkForeignParent(emults, (BinaryOp)right));
@@ -334,7 +335,7 @@ public class RewriteElementwiseMultChainOptimization
extends HopRewriteRule {
// TODO proper handling of DAGs (avoid collecting the same leaf
multiple times)
// TODO exclude hops with unknown dimensions and move rewrites
to dynamic rewrites
- final ArrayList<Hop> inputs = root.getInput();
+ final List<Hop> inputs = root.getInput();
final Hop left = inputs.get(0), right = inputs.get(1);
if (isBinaryMult(left))
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
index 2e82aedb9e..fc7286cdd8 100644
--- a/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
+++ b/src/main/java/org/apache/sysds/hops/rewrite/RewriteGPUSpecificOps.java
@@ -22,6 +22,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.List;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.hops.AggUnaryOp;
@@ -309,7 +310,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
&&
OptimizerUtils.rEvalSimpleDoubleExpression(getFirstInput(h), new HashMap<>())
== expectedValue;
}
- private static boolean isAnyBinaryAdd(ArrayList<Hop> hops) {
+ private static boolean isAnyBinaryAdd(List<Hop> hops) {
if(hops != null) {
for(Hop h : hops) {
if(h instanceof BinaryOp &&
((BinaryOp)h).getOp() == OpOp2.PLUS)
@@ -466,7 +467,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
* @param mu value of mu
* @return an array [ema_mean_upd, ema_mean] if any of the expression
matched, else null
*/
- private static Hop [] getUpdatedMovingAverageExpressions(ArrayList<Hop>
rhsTimesOps, double mu) {
+ private static Hop[] getUpdatedMovingAverageExpressions(List<Hop>
rhsTimesOps, double mu) {
if(rhsTimesOps == null || rhsTimesOps.size() == 0)
return null;
@@ -490,7 +491,7 @@ public class RewriteGPUSpecificOps extends HopRewriteRule {
* @param rhsTimesOps hop representing BinaryOp of expression
(1-mu)*mean
* @return value of mu if the expression matched else null
*/
- private static Double
getMuFromUpdatedMovingAverageExpressions(ArrayList<Hop> rhsTimesOps) {
+ private static Double
getMuFromUpdatedMovingAverageExpressions(List<Hop> rhsTimesOps) {
if(rhsTimesOps == null || rhsTimesOps.size() == 0)
return null;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
index 3c10f02732..9c04959ed5 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteIndexingVectorization.java
@@ -254,7 +254,6 @@ public class RewriteIndexingVectorization extends
HopRewriteRule
}
}
- @SuppressWarnings("unchecked")
private static Hop vectorizeLeftIndexing( Hop hop )
{
Hop ret = hop;
@@ -315,8 +314,8 @@ public class RewriteIndexingVectorization extends
HopRewriteRule
//new row left indexing operator (for
all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the
parent list before creating newLix on top of ihop0)
- ArrayList<Hop> ihop0parents =
(ArrayList<Hop>) ihop0.getParent().clone();
- ArrayList<Integer> ihop0parentsPos =
new ArrayList<>();
+ List<Hop> ihop0parents = new
ArrayList<>(ihop0.getParent());
+ List<Integer> ihop0parentsPos = new
ArrayList<>();
for( Hop parent : ihop0parents ) {
int posp =
HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
@@ -394,8 +393,8 @@ public class RewriteIndexingVectorization extends
HopRewriteRule
//new row left indexing operator (for
all parents, only intermediates are guaranteed to have 1 parent)
//(note: it's important to clone the
parent list before creating newLix on top of ihop0)
- ArrayList<Hop> ihop0parents =
(ArrayList<Hop>) ihop0.getParent().clone();
- ArrayList<Integer> ihop0parentsPos =
new ArrayList<>();
+ List<Hop> ihop0parents = new
ArrayList<>(ihop0.getParent());
+ List<Integer> ihop0parentsPos = new
ArrayList<>();
for( Hop parent : ihop0parents ) {
int posp =
HopRewriteUtils.getChildReferencePos(parent, ihop0);
HopRewriteUtils.removeChildReferenceByPos(parent, ihop0, posp); //input data
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
index 9b942c0478..fdd2f8343f 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimization.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.hops.AggBinaryOp;
@@ -108,9 +109,9 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
+ ", " + hop.getHopID() + ", " + hop.getName()
+ ")");
}
- ArrayList<Hop> mmChain = new ArrayList<>();
- ArrayList<Hop> mmOperators = new ArrayList<>();
- ArrayList<Hop> tempList;
+ List<Hop> mmChain = new ArrayList<>();
+ List<Hop> mmOperators = new ArrayList<>();
+ List<Hop> tempList;
// Step 1: Identify the chain (mmChain) & clear all links among
the Hops
// that are involved in mmChain.
@@ -181,7 +182,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
optimizeMMChain(hop, mmChain, mmOperators, state);
}
- protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain,
ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
+ protected void optimizeMMChain(Hop hop, List<Hop> mmChain, List<Hop>
mmOperators, ProgramRewriteStatus state) {
// Step 2: construct dims array
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
@@ -264,8 +265,8 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
* @param split optimal order
* @param level log level
*/
- protected final void mmChainRelinkHops(Hop h, int i, int j,
ArrayList<Hop> mmChain,
- ArrayList<Hop> mmOperators, MutableInt opIndex, int[][] split,
int level)
+ protected final void mmChainRelinkHops(Hop h, int i, int j, List<Hop>
mmChain,
+ List<Hop> mmOperators, MutableInt opIndex, int[][] split, int
level)
{
//NOTE: the opIndex is a MutableInt in order to get the correct
positions
//in ragged chains like ((((a, b), c), (D, E), f), e) that
might be given
@@ -319,7 +320,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
}
}
- protected static void clearLinksWithinChain( Hop hop, ArrayList<Hop>
operators )
+ protected static void clearLinksWithinChain( Hop hop, List<Hop>
operators )
{
for( int i=0; i < operators.size(); i++ ) {
Hop op = operators.get(i);
@@ -346,7 +347,7 @@ public class RewriteMatrixMultChainOptimization extends
HopRewriteRule
* @param dimsArray dimension array
* @return true if all dimensions known
*/
- protected static boolean getDimsArray( Hop hop, ArrayList<Hop> chain,
double[] dimsArray )
+ protected static boolean getDimsArray( Hop hop, List<Hop> chain,
double[] dimsArray )
{
boolean dimsKnown = true;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
index 69301fbda7..48b457f759 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationSparse.java
@@ -19,8 +19,8 @@
package org.apache.sysds.hops.rewrite;
-import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types.OpOpData;
@@ -48,7 +48,7 @@ import org.apache.sysds.runtime.matrix.data.MatrixBlock;
public class RewriteMatrixMultChainOptimizationSparse extends
RewriteMatrixMultChainOptimization
{
@Override
- protected void optimizeMMChain(Hop hop, ArrayList<Hop> mmChain,
ArrayList<Hop> mmOperators, ProgramRewriteStatus state) {
+ protected void optimizeMMChain(Hop hop, List<Hop> mmChain, List<Hop>
mmOperators, ProgramRewriteStatus state) {
// Step 2: construct dims array and input matrices
double[] dimsArray = new double[mmChain.size() + 1];
boolean dimsKnown = getDimsArray( hop, mmChain, dimsArray );
@@ -127,7 +127,7 @@ public class RewriteMatrixMultChainOptimizationSparse
extends RewriteMatrixMultC
return split;
}
- private static boolean getInputMatrices(Hop hop, ArrayList<Hop> chain,
MMNode[] sketchArray, ProgramRewriteStatus state) {
+ private static boolean getInputMatrices(Hop hop, List<Hop> chain,
MMNode[] sketchArray, ProgramRewriteStatus state) {
boolean inputsAvail = true;
LocalVariableMap vars = state.getVariables();
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
index 56702542a8..b327480609 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteMatrixMultChainOptimizationTranspose.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.Arrays;
+import java.util.List;
import org.apache.commons.lang3.mutable.MutableInt;
import org.apache.sysds.common.Types;
@@ -150,7 +151,7 @@ public class RewriteMatrixMultChainOptimizationTranspose
extends HopRewriteRule
mmChainIndex++;
}
else {
- ArrayList<Hop> tempList =
mmChain.get(mmChainIndex).getInput();
+ List<Hop> tempList =
mmChain.get(mmChainIndex).getInput();
if( tempList.size() != 2 ) {
throw new
HopsException(hop.printErrorLocation() + "Hops::rule_OptimizeMMChain():
AggBinary must have exactly two inputs.");
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
index e786e51c34..573afe856c 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveDanglingParentReferences.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.common.Types.OpOp1;
import org.apache.sysds.common.Types.OpOpData;
@@ -86,7 +87,7 @@ public class RewriteRemoveDanglingParentReferences extends
HopRewriteRule
}
//process node itself and children recursively
- ArrayList<Hop> inputs = hop.getInput();
+ List<Hop> inputs = hop.getInput();
if( !pin && hop.getParent().isEmpty() && !isValidRootNode(hop)
) {
HopRewriteUtils.cleanupUnreferenced(hop);
count++;
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java
index e0d9033add..8f63587578 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemovePersistentReadWrite.java
@@ -34,6 +34,7 @@ import org.apache.sysds.runtime.meta.MetaDataFormat;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.List;
/**
* This rewrite is a custom rewrite for JMLC in order to replace all
persistent reads
@@ -95,7 +96,7 @@ public class RewriteRemovePersistentReadWrite extends
HopRewriteRule
return;
//recursively process childs
- ArrayList<Hop> inputs = hop.getInput();
+ List<Hop> inputs = hop.getInput();
for( int i=0; i<inputs.size(); i++ )
rule_RemovePersistentDataOp( inputs.get(i) );
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveReadAfterWrite.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveReadAfterWrite.java
index 5793632b2d..032b395f24 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveReadAfterWrite.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveReadAfterWrite.java
@@ -21,6 +21,7 @@ package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
import java.util.HashMap;
+import java.util.List;
import java.util.Map.Entry;
import org.apache.sysds.common.Types.OpOpData;
@@ -41,7 +42,6 @@ public class RewriteRemoveReadAfterWrite extends
HopRewriteRule
{
@Override
- @SuppressWarnings("unchecked")
public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots,
ProgramRewriteStatus state)
{
if( roots == null )
@@ -66,7 +66,7 @@ public class RewriteRemoveReadAfterWrite extends
HopRewriteRule
{
//rewire read consumers to write input
Hop input =
writes.get(rfname).getInput().get(0);
- ArrayList<Hop> parents = (ArrayList<Hop>)
rhop.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(rhop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, rhop, input);
}
diff --git
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
index 341323ef8f..c1f0648a8d 100644
---
a/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
+++
b/src/main/java/org/apache/sysds/hops/rewrite/RewriteRemoveUnnecessaryCasts.java
@@ -20,6 +20,7 @@
package org.apache.sysds.hops.rewrite;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.hops.Hop;
import org.apache.sysds.hops.UnaryOp;
@@ -60,7 +61,6 @@ public class RewriteRemoveUnnecessaryCasts extends
HopRewriteRule
return root;
}
- @SuppressWarnings("unchecked")
private void rule_RemoveUnnecessaryCasts( Hop hop )
{
//check mark processed
@@ -68,7 +68,7 @@ public class RewriteRemoveUnnecessaryCasts extends
HopRewriteRule
return;
//recursively process childs
- ArrayList<Hop> inputs = hop.getInput();
+ List<Hop> inputs = hop.getInput();
for( int i=0; i<inputs.size(); i++ )
rule_RemoveUnnecessaryCasts( inputs.get(i) );
@@ -82,11 +82,11 @@ public class RewriteRemoveUnnecessaryCasts extends
HopRewriteRule
//if input/output types match, no need to cast
if( vtIn == vtOut && vtIn != ValueType.UNKNOWN )
{
- ArrayList<Hop> parents = hop.getParent();
+ List<Hop> parents = hop.getParent();
for( int i=0; i<parents.size(); i++ ) //for all
parents
{
Hop p = parents.get(i);
- ArrayList<Hop> pin = p.getInput();
+ List<Hop> pin = p.getInput();
for( int j=0; j<pin.size(); j++ ) //for
all parent childs
{
Hop pinj = pin.get(j);
@@ -112,7 +112,7 @@ public class RewriteRemoveUnnecessaryCasts extends
HopRewriteRule
|| (uop1.getOp()==OpOp1.CAST_AS_SCALAR &&
uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) {
Hop input = uop2.getInput().get(0);
//rewire parents
- ArrayList<Hop> parents = (ArrayList<Hop>)
hop.getParent().clone();
+ List<Hop> parents = new
ArrayList<>(hop.getParent());
for( Hop p : parents )
HopRewriteUtils.replaceChildReference(p, hop, input);
}
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
index 8da8ad518f..ed9581b72a 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/colgroup/indexes/IColIndex.java
@@ -282,7 +282,7 @@ public interface IColIndex {
bi.next();
}
- return new Pair<int[],int[]>(ar, br);
+ return new Pair<>(ar, br);
}
/** A Class for slice results containing indexes for the slicing of
dictionaries, and the resulting column index */
diff --git
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
index a4c15b2b53..bec6ac18aa 100644
---
a/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
+++
b/src/main/java/org/apache/sysds/runtime/compress/workload/WorkloadAnalyzer.java
@@ -397,7 +397,7 @@ public class WorkloadAnalyzer {
}
else if(hop instanceof BinaryOp) {
if(HopRewriteUtils.isBinary(hop, OpOp2.CBIND)) {
- ArrayList<Hop> in = hop.getInput();
+ List<Hop> in = hop.getInput();
o = new OpNormal(hop, true);
if(isOverlapping(in.get(0)) ||
isOverlapping(in.get(1))) {
overlapping.add(hop.getHopID());
@@ -412,7 +412,7 @@ public class WorkloadAnalyzer {
return;
}
else {
- ArrayList<Hop> in = hop.getInput();
+ List<Hop> in = hop.getInput();
final boolean ol0 =
isOverlapping(in.get(0));
final boolean ol1 =
isOverlapping(in.get(1));
final boolean ol = ol0 || ol1;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
index 8bae43d6ef..00c975719a 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/FunctionProgramBlock.java
@@ -29,18 +29,11 @@ import org.apache.sysds.common.Types.ExecMode;
import org.apache.sysds.common.Types.FunctionBlock;
import org.apache.sysds.common.Types.ValueType;
import org.apache.sysds.conf.ConfigurationManager;
-import org.apache.sysds.conf.DMLConfig;
-import org.apache.sysds.hops.OptimizerUtils;
-import org.apache.sysds.hops.fedplanner.AFederatedPlanner;
-import org.apache.sysds.hops.fedplanner.FTypes;
-import org.apache.sysds.hops.fedplanner.FTypes.FederatedPlanner;
import org.apache.sysds.hops.recompile.Recompiler;
import org.apache.sysds.hops.recompile.Recompiler.ResetType;
import org.apache.sysds.parser.DataIdentifier;
-import org.apache.sysds.parser.FunctionStatementBlock;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.DMLScriptException;
-import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.util.ProgramConverter;
@@ -57,7 +50,6 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
private boolean _recompileOnce = false;
private boolean _nondeterministic = false;
- private boolean _isFedPlan = false;
public FunctionProgramBlock( Program prog, List<DataIdentifier>
inputParams, List<DataIdentifier> outputParams) {
super(prog);
@@ -129,14 +121,7 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
boolean codegen =
ConfigurationManager.isCodegenEnabled();
boolean singlenode =
DMLScript.getGlobalExecMode() == ExecMode.SINGLE_NODE;
ResetType reset = (codegen || singlenode) ?
ResetType.RESET_KNOWN_DIMS : ResetType.RESET;
-
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false,
reset);
- if (shouldRunFedPlanner(ec)) {
-
recompileFederatedPlan((LocalVariableMap) ec.getVariables().clone());
- // recreate instructions/LOPs for new
updated HOPs
-
Recompiler.recompileProgramBlockHierarchy(_childBlocks, tmp, _tid, false,
reset);
- }
-
if( DMLScript.STATISTICS ){
long t1 = System.nanoTime();
@@ -165,44 +150,6 @@ public class FunctionProgramBlock extends ProgramBlock
implements FunctionBlock
// check return values
checkOutputParameters(ec.getVariables());
}
-
- private boolean shouldRunFedPlanner(ExecutionContext ec) {
- String planner =
ConfigurationManager.getDMLConfig().getTextValue(DMLConfig.FEDERATED_PLANNER);
- if (!OptimizerUtils.FEDERATED_COMPILATION &&
!FTypes.FederatedPlanner.isCompiled(planner))
- return false;
- for (String varName : ec.getVariables().keySet()) {
- Data variable = ec.getVariable(varName);
- if (variable instanceof CacheableData<?> &&
((CacheableData<?>) variable).isFederated()) {
- _isFedPlan = true;
- return true;
- }
- }
- if (_isFedPlan) {
- _isFedPlan = false;
- // current function uses HOPs with FED execution type.
Remove the forced FED execution type by running
- // planner again
- return true;
- }
- else {
- return false;
- }
- }
-
- /**
- * Recompile the HOPs of the function, keeping federation in mind.
- * @param variableMap The variable map for the function arguments
- */
- private void recompileFederatedPlan(LocalVariableMap variableMap) {
- String splanner = ConfigurationManager.getDMLConfig()
- .getTextValue(DMLConfig.FEDERATED_PLANNER);
- AFederatedPlanner planner = FederatedPlanner
- .valueOf(splanner.toUpperCase()).getPlanner();
- if (planner == null)
- // unreachable, if planner does not support compilation
cost based would be chosen
- throw new DMLRuntimeException(
- "Recompilation chose to apply federation
planner, but configured planner does not support compilation.");
- planner.rewriteFunctionDynamic((FunctionStatementBlock) _sb,
variableMap);
- }
protected void checkOutputParameters( LocalVariableMap vars )
{
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
index 2a6877d89e..29ea19b713 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/paramserv/ParamservUtils.java
@@ -20,7 +20,6 @@
package org.apache.sysds.runtime.controlprogram.paramserv;
import java.io.IOException;
-import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Map.Entry;
@@ -344,8 +343,7 @@ public class ParamservUtils {
mhop.setMaxNumThreads(k);
recompiled = true;
}
- ArrayList<Hop> inputs = hop.getInput();
- for (Hop h : inputs) {
+ for (Hop h : hop.getInput()) {
recompiled |= rAssignParallelismAndRecompile(h, k,
recompiled);
}
hop.setVisited();
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator.java
index da407959b6..29be3d8c94 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/CostEstimator.java
@@ -19,8 +19,8 @@
package org.apache.sysds.runtime.controlprogram.parfor.opt;
-import java.util.ArrayList;
import java.util.Collection;
+import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
@@ -207,16 +207,16 @@ public abstract class CostEstimator
return -1;
}
- protected double getMaxEstimate( TestMeasure measure,
ArrayList<OptNode> nodes, ExecType et ) {
+ protected double getMaxEstimate( TestMeasure measure, List<OptNode>
nodes, ExecType et ) {
return nodes.stream().mapToDouble(n -> getEstimate(measure, n,
et))
.max().orElse(Double.NEGATIVE_INFINITY);
}
- protected double getSumEstimate( TestMeasure measure,
ArrayList<OptNode> nodes, ExecType et ) {
+ protected double getSumEstimate( TestMeasure measure, List<OptNode>
nodes, ExecType et ) {
return nodes.stream().mapToDouble(n -> getEstimate(measure, n,
et)).sum();
}
- protected double getWeightedEstimate( TestMeasure measure,
ArrayList<OptNode> nodes, ExecType et ) {
+ protected double getWeightedEstimate( TestMeasure measure,
List<OptNode> nodes, ExecType et ) {
return nodes.stream().mapToDouble(n -> getEstimate(measure, n,
et)).sum() / nodes.size(); //weighting
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptNode.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptNode.java
index f195721f31..a67e9a8673 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptNode.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptNode.java
@@ -23,6 +23,7 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
+import java.util.List;
import java.util.Set;
import org.apache.commons.lang3.ArrayUtils;
@@ -82,7 +83,7 @@ public class OptNode
}
//child nodes
- private ArrayList<OptNode> _childs = null;
+ private List<OptNode> _childs = null;
//node configuration
private long _id = -1;
@@ -178,17 +179,17 @@ public class OptNode
_childs.add( child );
}
- public void addChilds( ArrayList<OptNode> childs ) {
+ public void addChilds( List<OptNode> childs ) {
if( _childs==null )
_childs = new ArrayList<>();
_childs.addAll( childs );
}
- public void setChilds(ArrayList<OptNode> childs) {
+ public void setChilds(List<OptNode> childs) {
_childs = childs;
}
- public ArrayList<OptNode> getChilds() {
+ public List<OptNode> getChilds() {
return _childs;
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
index 38e57429c7..5cc354d30e 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptTreeConverter.java
@@ -442,9 +442,9 @@ public class OptTreeConverter
return ret;
}
- public static ArrayList<OptNode> rCreateAbstractOptNodes(Hop hop,
LocalVariableMap vars, OptTreePlanMappingAbstract hlMap, Set<String> memo) {
- ArrayList<OptNode> ret = new ArrayList<>();
- ArrayList<Hop> in = hop.getInput();
+ public static List<OptNode> rCreateAbstractOptNodes(Hop hop,
LocalVariableMap vars, OptTreePlanMappingAbstract hlMap, Set<String> memo) {
+ List<OptNode> ret = new ArrayList<>();
+ List<Hop> in = hop.getInput();
if( hop.isVisited() )
return ret;
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
index fb97c6a15c..1d0b07ef35 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/OptimizerRuleBased.java
@@ -1234,11 +1234,10 @@ public class OptimizerRuleBased extends Optimizer {
protected void rAssignRemainingParallelism(OptNode n, int parforK, int
opsK)
{
- ArrayList<OptNode> childs = n.getChilds();
- if( childs != null )
+ if( n.getChilds() != null )
{
boolean recompileSB = false;
- for( OptNode c : childs )
+ for( OptNode c : n.getChilds() )
{
//NOTE: we cannot shortcut with
c.setSerialParFor() on par=1 because
//this would miss to recompile multi-threaded
hop operations
@@ -1548,8 +1547,7 @@ public class OptimizerRuleBased extends Optimizer {
{
//check that all parents are transpose-safe
operations
//(even a transient write would not be safe due
to indirection into other DAGs)
- ArrayList<Hop> parent = h.getParent();
- for( Hop p : parent )
+ for( Hop p : h.getParent() )
ret &= p.isTransposeSafe();
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/ProgramRecompiler.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/ProgramRecompiler.java
index 191d58cfa3..8405cda3ee 100644
---
a/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/ProgramRecompiler.java
+++
b/src/main/java/org/apache/sysds/runtime/controlprogram/parfor/opt/ProgramRecompiler.java
@@ -20,6 +20,7 @@
package org.apache.sysds.runtime.controlprogram.parfor.opt;
import java.util.ArrayList;
+import java.util.List;
import org.apache.sysds.common.Types.ExecType;
import org.apache.sysds.conf.ConfigurationManager;
@@ -373,8 +374,7 @@ public class ProgramRecompiler
if( hop.isVisited() )
return ret;
- ArrayList<Hop> in = hop.getInput();
-
+ List<Hop> in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
@@ -409,8 +409,7 @@ public class ProgramRecompiler
if( hop.isVisited() )
return ret;
- ArrayList<Hop> in = hop.getInput();
-
+ List<Hop> in = hop.getInput();
if( hop instanceof IndexingOp )
{
String inMatrix = hop.getInput().get(0).getName();
diff --git
a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
index 3b7200c7be..0b5a31f13a 100644
--- a/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
+++ b/src/main/java/org/apache/sysds/runtime/frame/data/columns/DDCArray.java
@@ -64,11 +64,11 @@ public class DDCArray<T> extends ACompressedArray<T> {
}
public <J> DDCArray<J> setDict(Array<J> dict) {
- return new DDCArray<J>(dict, map);
+ return new DDCArray<>(dict, map);
}
public DDCArray<T> nullDict() {
- return new DDCArray<T>(null, map);
+ return new DDCArray<>(null, map);
}
private static int getTryThreshold(ValueType t, int allRows, long
inMemSize) {