This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
commit 48cd678d1f80d7933fcdfed66d19b9b9614a6c3f Author: Matthias Boehm <[email protected]> AuthorDate: Tue Apr 4 20:28:11 2023 +0200 [MINOR] Additional reduction of unnecessary recompilation overhead --- src/main/java/org/apache/sysds/hops/TernaryOp.java | 9 +++++--- .../java/org/apache/sysds/lops/CentralMoment.java | 2 +- .../java/org/apache/sysds/lops/CoVariance.java | 3 ++- .../java/org/apache/sysds/lops/Compression.java | 3 ++- src/main/java/org/apache/sysds/lops/Ctable.java | 7 ++++--- .../apache/sysds/lops/CumulativeOffsetBinary.java | 24 ++++++++-------------- src/main/java/org/apache/sysds/lops/Data.java | 3 ++- .../java/org/apache/sysds/lops/DeCompression.java | 14 +++++-------- src/main/java/org/apache/sysds/lops/Federated.java | 22 ++++++++------------ .../java/org/apache/sysds/lops/FunctionCallCP.java | 3 ++- .../org/apache/sysds/lops/GroupedAggregate.java | 3 ++- src/main/java/org/apache/sysds/lops/LeftIndex.java | 3 ++- src/main/java/org/apache/sysds/lops/Local.java | 15 +++++--------- src/main/java/org/apache/sysds/lops/MMCJ.java | 4 ++-- src/main/java/org/apache/sysds/lops/MMTSJ.java | 3 ++- .../java/org/apache/sysds/lops/MapMultChain.java | 3 ++- src/main/java/org/apache/sysds/lops/Nary.java | 11 +++------- src/main/java/org/apache/sysds/lops/PMMJ.java | 3 ++- .../apache/sysds/lops/ParameterizedBuiltin.java | 7 ++++--- .../java/org/apache/sysds/lops/PickByCount.java | 2 +- .../java/org/apache/sysds/lops/RightIndex.java | 4 ++-- .../java/org/apache/sysds/lops/SpoofFused.java | 3 ++- src/main/java/org/apache/sysds/lops/Sql.java | 3 ++- .../org/apache/sysds/lops/TernaryAggregate.java | 6 +++--- src/main/java/org/apache/sysds/lops/Transform.java | 3 ++- src/main/java/org/apache/sysds/lops/Unary.java | 3 ++- .../apache/sysds/lops/WeightedCrossEntropy.java | 3 ++- .../java/org/apache/sysds/lops/WeightedDivMM.java | 3 ++- .../org/apache/sysds/lops/WeightedSigmoid.java | 3 ++- .../org/apache/sysds/lops/WeightedSquaredLoss.java | 3 ++- .../org/apache/sysds/lops/WeightedUnaryMM.java | 3 ++- .../java/org/apache/sysds/lops/compile/Dag.java | 9 ++++---- .../runtime/instructions/InstructionUtils.java | 6 ++++++ .../instructions/cp/VariableCPInstruction.java | 2 +- 34 files changed, 100 insertions(+), 98 deletions(-) diff --git a/src/main/java/org/apache/sysds/hops/TernaryOp.java b/src/main/java/org/apache/sysds/hops/TernaryOp.java index bca5938051..58c1f564ee 100644 --- a/src/main/java/org/apache/sysds/hops/TernaryOp.java +++ b/src/main/java/org/apache/sysds/hops/TernaryOp.java @@ -338,13 +338,16 @@ public class TernaryOp extends MultiThreadedHop k= OptimizerUtils.getConstrainedNumThreads( _maxNumThreads ); Ternary plusmult = new Ternary(_op, - getInput().get(0).constructLops(), - getInput().get(1).constructLops(), - getInput().get(2).constructLops(), + getInput(0).constructLops(), + getInput(1).constructLops(), + getInput(2).constructLops(), getDataType(),getValueType(), et, k ); setOutputDimensions(plusmult); setLineNumbers(plusmult); setLops(plusmult); + + if( _op==OpOp3.IFELSE && getInput(0).getDataType().isScalar() ) + setRequiresRecompile(); //good chance of removing ops } @Override diff --git a/src/main/java/org/apache/sysds/lops/CentralMoment.java b/src/main/java/org/apache/sysds/lops/CentralMoment.java index bb825fe5cf..13558afdcf 100644 --- a/src/main/java/org/apache/sysds/lops/CentralMoment.java +++ b/src/main/java/org/apache/sysds/lops/CentralMoment.java @@ -80,7 +80,7 @@ public class CentralMoment extends Lop */ @Override public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); if( input3 == null ) { sb.append(InstructionUtils.concatOperands( getExecType().toString(), "cm", diff --git a/src/main/java/org/apache/sysds/lops/CoVariance.java b/src/main/java/org/apache/sysds/lops/CoVariance.java index a68844fa35..4548114681 100644 --- a/src/main/java/org/apache/sysds/lops/CoVariance.java +++ b/src/main/java/org/apache/sysds/lops/CoVariance.java @@ -22,6 +22,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** * Lop to compute covariance between two 1D matrices @@ -81,7 +82,7 @@ public class CoVariance extends Lop */ @Override public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); sb.append( "cov" ); diff --git a/src/main/java/org/apache/sysds/lops/Compression.java b/src/main/java/org/apache/sysds/lops/Compression.java index 3dd6a2bad2..6871dc8771 100644 --- a/src/main/java/org/apache/sysds/lops/Compression.java +++ b/src/main/java/org/apache/sysds/lops/Compression.java @@ -22,6 +22,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class Compression extends Lop { public static final String OPCODE = "compress"; @@ -55,7 +56,7 @@ public class Compression extends Lop { @Override public String getInstructions(String input1, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); sb.append(Lop.OPERAND_DELIMITOR); sb.append(OPCODE); diff --git a/src/main/java/org/apache/sysds/lops/Ctable.java b/src/main/java/org/apache/sysds/lops/Ctable.java index 012c17f295..3384119ed2 100644 --- a/src/main/java/org/apache/sysds/lops/Ctable.java +++ b/src/main/java/org/apache/sysds/lops/Ctable.java @@ -22,6 +22,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** @@ -120,7 +121,7 @@ public class Ctable extends Lop @Override public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); if( operation != Ctable.OperationTypes.CTABLE_EXPAND_SCALAR_WEIGHT ) @@ -144,7 +145,7 @@ public class Ctable extends Lop getInputs().get(2).prepInputOperand(input3)); sb.append( OPERAND_DELIMITOR ); - if ( this.getInputs().size() > 3 ) { + if ( getInputs().size() > 3 ) { sb.append(getInputs().get(3).getOutputParameters().getLabel()); sb.append(LITERAL_PREFIX); sb.append((getInputs().get(3).getType() == Type.Data && ((Data)getInputs().get(3)).isLiteral()) ); @@ -166,7 +167,7 @@ public class Ctable extends Lop sb.append(true); sb.append( OPERAND_DELIMITOR ); } - sb.append( this.prepOutputOperand(output)); + sb.append( prepOutputOperand(output)); sb.append( OPERAND_DELIMITOR ); sb.append( _ignoreZeros ); diff --git a/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java b/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java index 80654431ef..45e32f57af 100644 --- a/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java +++ b/src/main/java/org/apache/sysds/lops/CumulativeOffsetBinary.java @@ -23,6 +23,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.AggOp; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class CumulativeOffsetBinary extends Lop { @@ -89,24 +90,17 @@ public class CumulativeOffsetBinary extends Lop @Override public String getInstructions(String input1, String input2, String output) { - StringBuilder sb = new StringBuilder(); - sb.append( getExecType() ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getOpcode() ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(input1) ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(1).prepInputOperand(input2) ); - sb.append( OPERAND_DELIMITOR ); - sb.append( this.prepOutputOperand(output) ); + String inst = InstructionUtils.concatOperands( + getExecType().name(), getOpcode(), + getInputs().get(0).prepInputOperand(input1), + getInputs().get(1).prepInputOperand(input2), + prepOutputOperand(output) ); if( getExecType() == ExecType.SPARK ) { - sb.append( OPERAND_DELIMITOR ); - sb.append( _initValue ); - sb.append( OPERAND_DELIMITOR ); - sb.append( _broadcast ); + inst = InstructionUtils.concatOperands(inst, + String.valueOf(_initValue), String.valueOf(_broadcast) ); } - return sb.toString(); + return inst; } } diff --git a/src/main/java/org/apache/sysds/lops/Data.java b/src/main/java/org/apache/sysds/lops/Data.java index 0c03965b3f..1489c69842 100644 --- a/src/main/java/org/apache/sysds/lops/Data.java +++ b/src/main/java/org/apache/sysds/lops/Data.java @@ -27,6 +27,7 @@ import org.apache.sysds.common.Types.OpOpData; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** * Lop to represent data objects. Data objects represent matrices, vectors, @@ -404,7 +405,7 @@ public class Data extends Lop OutputParameters oparams = getOutputParameters(); - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( "CP" ); sb.append( OPERAND_DELIMITOR ); sb.append( "createvar" ); diff --git a/src/main/java/org/apache/sysds/lops/DeCompression.java b/src/main/java/org/apache/sysds/lops/DeCompression.java index 780767ed09..a54bf6ee00 100644 --- a/src/main/java/org/apache/sysds/lops/DeCompression.java +++ b/src/main/java/org/apache/sysds/lops/DeCompression.java @@ -22,6 +22,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class DeCompression extends Lop { @@ -50,14 +51,9 @@ public class DeCompression extends Lop @Override public String getInstructions(String input1, String output) { - StringBuilder sb = new StringBuilder(); - sb.append( getExecType() ); - sb.append( Lop.OPERAND_DELIMITOR ); - sb.append( OPCODE ); - sb.append( OPERAND_DELIMITOR ); - sb.append( getInputs().get(0).prepInputOperand(input1)); - sb.append( OPERAND_DELIMITOR ); - sb.append( prepOutputOperand(output)); - return sb.toString(); + return InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(output)); } } diff --git a/src/main/java/org/apache/sysds/lops/Federated.java b/src/main/java/org/apache/sysds/lops/Federated.java index 2ed1de2fdb..46e1f0c6e0 100644 --- a/src/main/java/org/apache/sysds/lops/Federated.java +++ b/src/main/java/org/apache/sysds/lops/Federated.java @@ -22,6 +22,8 @@ package org.apache.sysds.lops; import java.util.HashMap; +import org.apache.sysds.runtime.instructions.InstructionUtils; + import static org.apache.sysds.common.Types.DataType; import static org.apache.sysds.common.Types.ValueType; import static org.apache.sysds.parser.DataExpression.FED_ADDRESSES; @@ -70,20 +72,12 @@ public class Federated extends Lop { @Override public String getInstructions(String type, String addresses, String ranges, String object, String output) { - StringBuilder sb = new StringBuilder("FED"); - sb.append(OPERAND_DELIMITOR); - sb.append("fedinit"); - sb.append(OPERAND_DELIMITOR); - sb.append(_type.prepScalarInputOperand(type)); - sb.append(OPERAND_DELIMITOR); - sb.append(_addresses.prepScalarInputOperand(addresses)); - sb.append(OPERAND_DELIMITOR); - sb.append(_ranges.prepScalarInputOperand(ranges)); - sb.append(OPERAND_DELIMITOR); - sb.append(_localObject.prepScalarInputOperand(object)); - sb.append(OPERAND_DELIMITOR); - sb.append(prepOutputOperand(output)); - return sb.toString(); + return InstructionUtils.concatOperands( + "FED", "fedinit", _type.prepScalarInputOperand(type), + _addresses.prepScalarInputOperand(addresses), + _ranges.prepScalarInputOperand(ranges), + _localObject.prepScalarInputOperand(object), + prepOutputOperand(output)); } @Override diff --git a/src/main/java/org/apache/sysds/lops/FunctionCallCP.java b/src/main/java/org/apache/sysds/lops/FunctionCallCP.java index f6c24f8e82..5f106c3ae6 100644 --- a/src/main/java/org/apache/sysds/lops/FunctionCallCP.java +++ b/src/main/java/org/apache/sysds/lops/FunctionCallCP.java @@ -26,6 +26,7 @@ import org.apache.sysds.hops.FunctionOp; import org.apache.sysds.hops.Hop; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.parser.DMLProgram; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.common.Builtins; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; @@ -127,7 +128,7 @@ public class FunctionCallCP extends Lop return getInstructionsMultipleReturnBuiltins(inputs, outputs); } - StringBuilder inst = new StringBuilder(); + StringBuilder inst = InstructionUtils.getStringBuilder(); inst.append(getExecType()); inst.append(Lop.OPERAND_DELIMITOR); diff --git a/src/main/java/org/apache/sysds/lops/GroupedAggregate.java b/src/main/java/org/apache/sysds/lops/GroupedAggregate.java index 40850751d1..f7108e7925 100644 --- a/src/main/java/org/apache/sysds/lops/GroupedAggregate.java +++ b/src/main/java/org/apache/sysds/lops/GroupedAggregate.java @@ -26,6 +26,7 @@ import java.util.Map.Entry; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.parser.Statement; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; @@ -102,7 +103,7 @@ public class GroupedAggregate extends Lop */ @Override public String getInstructions(String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/LeftIndex.java b/src/main/java/org/apache/sysds/lops/LeftIndex.java index 3c6bf0a640..0e5ad11fd4 100644 --- a/src/main/java/org/apache/sysds/lops/LeftIndex.java +++ b/src/main/java/org/apache/sysds/lops/LeftIndex.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class LeftIndex extends Lop @@ -105,7 +106,7 @@ public class LeftIndex extends Lop @Override public String getInstructions(String lhsInput, String rhsInput, String rowl, String rowu, String coll, String colu, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/Local.java b/src/main/java/org/apache/sysds/lops/Local.java index 10ebc79edd..6bad6136aa 100644 --- a/src/main/java/org/apache/sysds/lops/Local.java +++ b/src/main/java/org/apache/sysds/lops/Local.java @@ -22,6 +22,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class Local extends Lop { public static final String OPCODE = "local"; @@ -40,15 +41,9 @@ public class Local extends Lop { @Override public String getInstructions(String input1, String output) { - StringBuilder sb = new StringBuilder(); - sb.append(getExecType()); - sb.append(Lop.OPERAND_DELIMITOR); - sb.append(OPCODE); - sb.append(OPERAND_DELIMITOR); - sb.append(getInputs().get(0).prepInputOperand(input1)); - sb.append(OPERAND_DELIMITOR); - sb.append(prepOutputOperand(output)); - - return sb.toString(); + return InstructionUtils.concatOperands( + getExecType().name(), OPCODE, + getInputs().get(0).prepInputOperand(input1), + prepOutputOperand(output)); } } diff --git a/src/main/java/org/apache/sysds/lops/MMCJ.java b/src/main/java/org/apache/sysds/lops/MMCJ.java index 7659a6631f..cbd7df97a5 100644 --- a/src/main/java/org/apache/sysds/lops/MMCJ.java +++ b/src/main/java/org/apache/sysds/lops/MMCJ.java @@ -20,7 +20,7 @@ package org.apache.sysds.lops; import org.apache.sysds.hops.AggBinaryOp.SparkAggType; - +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; @@ -86,7 +86,7 @@ public class MMCJ extends Lop @Override public String getInstructions(String input1, String input2, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/MMTSJ.java b/src/main/java/org/apache/sysds/lops/MMTSJ.java index cbde9b4d5c..e91391833e 100644 --- a/src/main/java/org/apache/sysds/lops/MMTSJ.java +++ b/src/main/java/org/apache/sysds/lops/MMTSJ.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** @@ -80,7 +81,7 @@ public class MMTSJ extends Lop @Override public String getInstructions(String input_index1, String output_index) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); sb.append( _multiPass ? "tsmm2" : "tsmm" ); diff --git a/src/main/java/org/apache/sysds/lops/MapMultChain.java b/src/main/java/org/apache/sysds/lops/MapMultChain.java index 9f68f3ed86..88ca82d8c6 100644 --- a/src/main/java/org/apache/sysds/lops/MapMultChain.java +++ b/src/main/java/org/apache/sysds/lops/MapMultChain.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class MapMultChain extends Lop @@ -115,7 +116,7 @@ public class MapMultChain extends Lop public String getInstructions(String input1, String input2, String input3, String output) { //Spark instruction XtwXv - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); sb.append(Lop.OPERAND_DELIMITOR); diff --git a/src/main/java/org/apache/sysds/lops/Nary.java b/src/main/java/org/apache/sysds/lops/Nary.java index 6fb9759096..68e7191ad9 100644 --- a/src/main/java/org/apache/sysds/lops/Nary.java +++ b/src/main/java/org/apache/sysds/lops/Nary.java @@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOpN; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** * Lop to perform an operation on a variable number of operands. @@ -88,21 +89,15 @@ public class Nary extends Lop { */ @Override public String getInstructions(String[] inputs, String output) { - String opString = getOpcode(); - - StringBuilder sb = new StringBuilder(); - + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); sb.append(Lop.OPERAND_DELIMITOR); - - sb.append(opString); + sb.append(getOpcode()); sb.append(OPERAND_DELIMITOR); - for( int i=0; i<inputs.length; i++ ) { sb.append(getInputs().get(i).prepInputOperand(inputs[i])); sb.append(OPERAND_DELIMITOR); } - sb.append(prepOutputOperand(output)); return sb.toString(); diff --git a/src/main/java/org/apache/sysds/lops/PMMJ.java b/src/main/java/org/apache/sysds/lops/PMMJ.java index e0c8dcf539..c25da78298 100644 --- a/src/main/java/org/apache/sysds/lops/PMMJ.java +++ b/src/main/java/org/apache/sysds/lops/PMMJ.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class PMMJ extends Lop { @@ -80,7 +81,7 @@ public class PMMJ extends Lop @Override public String getInstructions(String input_index1, String input_index2, String input_index3, String output_index) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); sb.append(Lop.OPERAND_DELIMITOR); diff --git a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java index eb8174dbca..fdf23a7b62 100644 --- a/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java +++ b/src/main/java/org/apache/sysds/lops/ParameterizedBuiltin.java @@ -28,6 +28,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ParamBuiltinOp; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** @@ -86,7 +87,7 @@ public class ParameterizedBuiltin extends Lop @Override public String getInstructions(String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); @@ -224,7 +225,7 @@ public class ParameterizedBuiltin extends Lop @Override public String toString() { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(_operation.toString()); if( !getInputs().isEmpty() ) @@ -243,7 +244,7 @@ public class ParameterizedBuiltin extends Lop } private static String compileGenericParamMap(HashMap<String, Lop> params) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); for ( Entry<String, Lop> e : params.entrySet() ) { sb.append(e.getKey()); sb.append(NAME_VALUE_SEPARATOR); diff --git a/src/main/java/org/apache/sysds/lops/PickByCount.java b/src/main/java/org/apache/sysds/lops/PickByCount.java index 2319a5c94d..948b0fcd7e 100644 --- a/src/main/java/org/apache/sysds/lops/PickByCount.java +++ b/src/main/java/org/apache/sysds/lops/PickByCount.java @@ -83,7 +83,7 @@ public class PickByCount extends Lop */ @Override public String getInstructions(String input1, String input2, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/RightIndex.java b/src/main/java/org/apache/sysds/lops/RightIndex.java index 7858638ebb..e24e8b905a 100644 --- a/src/main/java/org/apache/sysds/lops/RightIndex.java +++ b/src/main/java/org/apache/sysds/lops/RightIndex.java @@ -20,7 +20,7 @@ package org.apache.sysds.lops; import org.apache.sysds.hops.AggBinaryOp.SparkAggType; - +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; @@ -90,7 +90,7 @@ public class RightIndex extends Lop @Override public String getInstructions(String input, String rowl, String rowu, String coll, String colu, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); sb.append( getOpcode() ); diff --git a/src/main/java/org/apache/sysds/lops/SpoofFused.java b/src/main/java/org/apache/sysds/lops/SpoofFused.java index e393f5fb0c..300ba94a24 100644 --- a/src/main/java/org/apache/sysds/lops/SpoofFused.java +++ b/src/main/java/org/apache/sysds/lops/SpoofFused.java @@ -23,6 +23,7 @@ import java.util.ArrayList; import org.apache.sysds.hops.codegen.SpoofCompiler.GeneratorAPI; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; @@ -102,7 +103,7 @@ public class SpoofFused extends Lop @Override public String getInstructions(String[] inputs, String[] outputs) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); sb.append( "spoof" ); diff --git a/src/main/java/org/apache/sysds/lops/Sql.java b/src/main/java/org/apache/sysds/lops/Sql.java index 9d3d4642a4..c687346cbf 100644 --- a/src/main/java/org/apache/sysds/lops/Sql.java +++ b/src/main/java/org/apache/sysds/lops/Sql.java @@ -20,6 +20,7 @@ package org.apache.sysds.lops; import org.apache.sysds.parser.DataExpression; +import org.apache.sysds.runtime.instructions.InstructionUtils; import java.util.HashMap; @@ -50,7 +51,7 @@ public class Sql extends Lop { @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); // TODO spark sb.append("CP"); sb.append(OPERAND_DELIMITOR); diff --git a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java index 65773c01aa..b9ac1a721e 100644 --- a/src/main/java/org/apache/sysds/lops/TernaryAggregate.java +++ b/src/main/java/org/apache/sysds/lops/TernaryAggregate.java @@ -26,6 +26,7 @@ import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.Direction; import org.apache.sysds.common.Types.OpOp2; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class TernaryAggregate extends Lop { @@ -66,10 +67,9 @@ public class TernaryAggregate extends Lop } @Override - public String getInstructions(String input1, String input2, String input3, String output) - + public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); sb.append( getOpCode() ); diff --git a/src/main/java/org/apache/sysds/lops/Transform.java b/src/main/java/org/apache/sysds/lops/Transform.java index aab8e223a1..0fcdc09fbf 100644 --- a/src/main/java/org/apache/sysds/lops/Transform.java +++ b/src/main/java/org/apache/sysds/lops/Transform.java @@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ReOrgOp; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /* @@ -150,7 +151,7 @@ public class Transform extends Lop } private String getInstructions(String input1, int numInputs, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( OPERAND_DELIMITOR ); sb.append( getOpcode() ); diff --git a/src/main/java/org/apache/sysds/lops/Unary.java b/src/main/java/org/apache/sysds/lops/Unary.java index d95235798a..5e83c1de4d 100644 --- a/src/main/java/org/apache/sysds/lops/Unary.java +++ b/src/main/java/org/apache/sysds/lops/Unary.java @@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; /** @@ -181,7 +182,7 @@ public class Unary extends Lop @Override public String getInstructions(String input1, String input2, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append( getExecType() ); sb.append( Lop.OPERAND_DELIMITOR ); diff --git a/src/main/java/org/apache/sysds/lops/WeightedCrossEntropy.java b/src/main/java/org/apache/sysds/lops/WeightedCrossEntropy.java index 616ce4cdd2..394ce2e84f 100644 --- a/src/main/java/org/apache/sysds/lops/WeightedCrossEntropy.java +++ b/src/main/java/org/apache/sysds/lops/WeightedCrossEntropy.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class WeightedCrossEntropy extends Lop { @@ -66,7 +67,7 @@ public class WeightedCrossEntropy extends Lop @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); diff --git a/src/main/java/org/apache/sysds/lops/WeightedDivMM.java b/src/main/java/org/apache/sysds/lops/WeightedDivMM.java index 1f119700f1..5bd3da374c 100644 --- a/src/main/java/org/apache/sysds/lops/WeightedDivMM.java +++ b/src/main/java/org/apache/sysds/lops/WeightedDivMM.java @@ -23,6 +23,7 @@ package org.apache.sysds.lops; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; import org.apache.sysds.common.Types.ExecType; +import org.apache.sysds.runtime.instructions.InstructionUtils; import org.apache.sysds.runtime.meta.MatrixCharacteristics; public class WeightedDivMM extends Lop @@ -105,7 +106,7 @@ public class WeightedDivMM extends Lop @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); final ExecType et = getExecType(); diff --git a/src/main/java/org/apache/sysds/lops/WeightedSigmoid.java b/src/main/java/org/apache/sysds/lops/WeightedSigmoid.java index c0a048a6d7..3a71bb17ca 100644 --- a/src/main/java/org/apache/sysds/lops/WeightedSigmoid.java +++ b/src/main/java/org/apache/sysds/lops/WeightedSigmoid.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class WeightedSigmoid extends Lop { @@ -63,7 +64,7 @@ public class WeightedSigmoid extends Lop @Override public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); diff --git a/src/main/java/org/apache/sysds/lops/WeightedSquaredLoss.java b/src/main/java/org/apache/sysds/lops/WeightedSquaredLoss.java index c0298274b2..7d0d20f485 100644 --- a/src/main/java/org/apache/sysds/lops/WeightedSquaredLoss.java +++ b/src/main/java/org/apache/sysds/lops/WeightedSquaredLoss.java @@ -24,6 +24,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class WeightedSquaredLoss extends Lop { @@ -68,7 +69,7 @@ public class WeightedSquaredLoss extends Lop @Override public String getInstructions(String input1, String input2, String input3, String input4, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); diff --git a/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java b/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java index a312d612d2..9705036528 100644 --- a/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java +++ b/src/main/java/org/apache/sysds/lops/WeightedUnaryMM.java @@ -25,6 +25,7 @@ import org.apache.sysds.common.Types.ExecType; import org.apache.sysds.common.Types.DataType; import org.apache.sysds.common.Types.OpOp1; import org.apache.sysds.common.Types.ValueType; +import org.apache.sysds.runtime.instructions.InstructionUtils; public class WeightedUnaryMM extends Lop { @@ -62,7 +63,7 @@ public class WeightedUnaryMM extends Lop @Override public String getInstructions(String input1, String input2, String input3, String output) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append(getExecType()); diff --git a/src/main/java/org/apache/sysds/lops/compile/Dag.java b/src/main/java/org/apache/sysds/lops/compile/Dag.java index 786e281d64..260d8484a4 100644 --- a/src/main/java/org/apache/sysds/lops/compile/Dag.java +++ b/src/main/java/org/apache/sysds/lops/compile/Dag.java @@ -262,12 +262,12 @@ public class Dag<N extends Lop> LOG.trace("In delete updated variables"); // CANDIDATE list of variables which could have been updated in this statement block - HashMap<String, Lop> labelNodeMapping = new HashMap<>(); + HashMap<String, Lop> labelNodeMapping = new HashMap<>(nodeV.size()); // ACTUAL list of variables whose value is updated, AND the old value of the variable // is no longer accessible/used. - HashSet<String> updatedLabels = new HashSet<>(); - HashMap<String, Lop> updatedLabelsLineNum = new HashMap<>(); + HashSet<String> updatedLabels = new HashSet<>(nodeV.size()); + HashMap<String, Lop> updatedLabelsLineNum = new HashMap<>(nodeV.size()); // first capture all transient read variables for ( Lop node : nodeV ) { @@ -307,9 +307,8 @@ public class Dag<N extends Lop> } // generate RM instructions - Instruction rm_inst = null; for ( String label : updatedLabels ) { - rm_inst = VariableCPInstruction.prepareRemoveInstruction(label); + Instruction rm_inst = VariableCPInstruction.prepareRemoveInstruction(label); rm_inst.setLocation(updatedLabelsLineNum.get(label)); if( LOG.isTraceEnabled() ) LOG.trace(rm_inst.toString()); diff --git a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java index cd30034019..6d1325cc0c 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/InstructionUtils.java @@ -121,6 +121,12 @@ public class InstructionUtils { } }; + public static StringBuilder getStringBuilder() { + StringBuilder sb = _strBuilders.get(); + sb.setLength(0); //reuse allocated space + return sb; + } + public static int checkNumFields( String str, int expected ) { //note: split required for empty tokens int numParts = str.split(Instruction.OPERAND_DELIM).length; diff --git a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java index 4083009c8c..9ad18b385e 100644 --- a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java +++ b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java @@ -1255,7 +1255,7 @@ public class VariableCPInstruction extends CPInstruction implements LineageTrace } public static Instruction prepareRemoveInstruction(String... varNames) { - StringBuilder sb = new StringBuilder(); + StringBuilder sb = InstructionUtils.getStringBuilder(); sb.append("CP"); sb.append(Lop.OPERAND_DELIMITOR); sb.append("rmvar");
