This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 6505c7b109 [MINOR] Improved dynamic recompilation (reduced recompile
overhead)
6505c7b109 is described below
commit 6505c7b1091db15058cd76550521602f131daae1
Author: Matthias Boehm <[email protected]>
AuthorDate: Mon Apr 3 22:05:30 2023 +0200
[MINOR] Improved dynamic recompilation (reduced recompile overhead)
This patch makes selected changes to the instruction handling during
recompilation in order to reduce the recompilation overhead for scripts
like decisionTree that are dominated by recompilation for small data.
On the existing Titanic test case, this patch improve end-to-end runtime
as follows but this patch is beneficial to a wide variety of scripts:
OLD:
Total execution time: 19.646 sec.
HOP DAGs recompiled (PRED, SB): 136/71630.
HOP DAGs recompile time: 31.189 sec.
Functions recompiled: 58.
Functions recompile time: 0.238 sec.
NEW:
Total execution time: 17.148 sec.
HOP DAGs recompiled (PRED, SB): 136/71630.
HOP DAGs recompile time: 26.572 sec.
Functions recompiled: 58.
Functions recompile time: 0.231 sec.
---
.../apache/sysds/hops/recompile/Recompiler.java | 42 ++++----
.../instructions/cp/VariableCPInstruction.java | 109 ++++-----------------
.../gpu/context/GPUMemoryEviction.java | 5 +-
.../instructions/gpu/context/GPUMemoryManager.java | 1 -
.../runtime/lineage/LineageGPUCacheEviction.java | 1 +
.../sysds/runtime/util/ProgramConverter.java | 10 ++
.../part1/BuiltinDecisionTreeRealDataTest.java | 3 +-
7 files changed, 56 insertions(+), 115 deletions(-)
diff --git a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
index 392d303ca4..d5c8169d68 100644
--- a/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
+++ b/src/main/java/org/apache/sysds/hops/recompile/Recompiler.java
@@ -27,6 +27,7 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
+import java.util.Map.Entry;
import org.apache.hadoop.fs.FileSystem;
import org.apache.hadoop.fs.Path;
@@ -62,6 +63,7 @@ import org.apache.sysds.lops.compile.Dag;
import org.apache.sysds.lops.rewrite.LopRewriter;
import org.apache.sysds.parser.DMLProgram;
import org.apache.sysds.parser.DataExpression;
+import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.parser.ForStatementBlock;
import org.apache.sysds.parser.IfStatementBlock;
import org.apache.sysds.parser.ParseInfo;
@@ -136,6 +138,12 @@ public class Recompiler {
@Override protected LopRewriter initialValue() {return new
LopRewriter();}
};
+ // additional reused objects to avoid repeated, incremental
reallocation on deepCopyDags
+ private static ThreadLocal<HashMap<Long,Hop>> _memoHop = new
ThreadLocal<HashMap<Long,Hop>>() {
+ @Override protected HashMap<Long,Hop> initialValue() { return
new HashMap<>(); }
+ @Override public HashMap<Long,Hop> get() { var tmp =
super.get(); tmp.clear(); return tmp; }
+ };
+
public enum ResetType {
RESET,
RESET_KNOWN_DIMS,
@@ -166,7 +174,7 @@ public class Recompiler {
// replace thread ids in new instructions
if( ProgramBlock.isThreadID(tid) ) //only in parfor context
- newInst =
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null,
null, false, false);
+ newInst =
ProgramConverter.createShallowCopyInstructionSet(newInst, tid);
// remove writes if called through mlcontext or jmlc
if( ec.getVariables().getRegisteredOutputs() != null )
@@ -198,7 +206,7 @@ public class Recompiler {
// replace thread ids in new instructions
if( ProgramBlock.isThreadID(tid) ) //only in parfor context
- newInst =
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null,
null, false, false);
+ newInst =
ProgramConverter.createShallowCopyInstructionSet(newInst, tid);
// explain recompiled instructions
if( DMLScript.EXPLAIN == ExplainType.RECOMPILE_RUNTIME )
@@ -220,7 +228,7 @@ public class Recompiler {
// replace thread ids in new instructions
if( ProgramBlock.isThreadID(tid) ) //only in parfor context
- newInst =
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null,
null, false, false);
+ newInst =
ProgramConverter.createShallowCopyInstructionSet(newInst, tid);
// explain recompiled hops / instructions
if( DMLScript.EXPLAIN == ExplainType.RECOMPILE_RUNTIME )
@@ -242,7 +250,7 @@ public class Recompiler {
// replace thread ids in new instructions
if( ProgramBlock.isThreadID(tid) ) //only in parfor context
- newInst =
ProgramConverter.createDeepCopyInstructionSet(newInst, tid, -1, null, null,
null, false, false);
+ newInst =
ProgramConverter.createShallowCopyInstructionSet(newInst, tid);
// explain recompiled hops / instructions
if( DMLScript.EXPLAIN == ExplainType.RECOMPILE_RUNTIME )
@@ -390,7 +398,7 @@ public class Recompiler {
rSetMaxParallelism(hops, maxK);
// construct lops
- ArrayList<Lop> lops = new ArrayList<>();
+ ArrayList<Lop> lops = new ArrayList<>(hops.size());
for( Hop hopRoot : hops ){
lops.add(hopRoot.constructLops());
}
@@ -564,7 +572,7 @@ public class Recompiler {
try {
//note: need memo table over all independent DAGs in
order to
//account for shared transient reads (otherwise more
instructions generated)
- HashMap<Long, Hop> memo = new HashMap<>(); //orig ID,
new clone
+ HashMap<Long, Hop> memo = _memoHop.get(); //orig ID,
new clone
for( Hop hopRoot : hops )
ret.add(rDeepCopyHopsDag(hopRoot, memo));
}
@@ -585,7 +593,7 @@ public class Recompiler {
Hop ret = null;
try {
- HashMap<Long, Hop> memo = new HashMap<>(); //orig ID,
new clone
+ HashMap<Long, Hop> memo = _memoHop.get(); //orig ID,
new clone
ret = rDeepCopyHopsDag(hops, memo);
}
catch(Exception ex) {
@@ -955,8 +963,7 @@ public class Recompiler {
private static MatrixObject createOutputMatrix(long dim1, long dim2,
long nnz) {
MatrixObject moOut = new MatrixObject(ValueType.FP64, null);
int blksz = ConfigurationManager.getBlocksize();
- DataCharacteristics mc = new MatrixCharacteristics(
- dim1, dim2, blksz, nnz);
+ DataCharacteristics mc = new MatrixCharacteristics(dim1, dim2,
blksz, nnz);
MetaDataFormat meta = new MetaDataFormat(mc,null);
moOut.setMetaData(meta);
return moOut;
@@ -1140,19 +1147,12 @@ public class Recompiler {
* @param callVars Map of variables eligible for propagation.
* @param sb DML statement block.
*/
- public static void removeUpdatedScalars( LocalVariableMap callVars,
StatementBlock sb )
- {
- if( sb != null )
- {
+ public static void removeUpdatedScalars( LocalVariableMap callVars,
StatementBlock sb ) {
+ if( sb != null ) {
//remove updated scalar variables from constants
- for( String varname :
sb.variablesUpdated().getVariables().keySet() )
- {
- Data dat = callVars.get(varname);
- if( dat != null && dat.getDataType() ==
DataType.SCALAR )
- {
- callVars.remove(varname);
- }
- }
+ for( Entry<String, DataIdentifier> v :
sb.variablesUpdated().getVariables().entrySet() )
+ if( v.getValue().getDataType().isScalar() )
+ callVars.remove(v.getKey());
}
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
index 4c55c7207b..4083009c8c 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/cp/VariableCPInstruction.java
@@ -1267,44 +1267,18 @@ public class VariableCPInstruction extends
CPInstruction implements LineageTrace
}
public static Instruction prepareCopyInstruction(String srcVar, String
destVar) {
- StringBuilder sb = new StringBuilder();
- sb.append("CP");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append("cpvar");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(srcVar);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(destVar);
- return parseInstruction(sb.toString());
+ return parseInstruction(
+ InstructionUtils.concatOperands("CP", "cpvar", srcVar,
destVar));
}
public static Instruction prepMoveInstruction(String srcVar, String
destFileName, String format) {
- StringBuilder sb = new StringBuilder();
- sb.append("CP");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append("mvvar");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(srcVar);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(destFileName);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(format);
- String str = sb.toString();
- return parseInstruction(str);
+ return parseInstruction(
+ InstructionUtils.concatOperands("CP", "mvvar", srcVar,
destFileName, format));
}
public static Instruction prepMoveInstruction(String srcVar, String
destVar) {
- // example: mvvar tempA A
- StringBuilder sb = new StringBuilder();
- sb.append("CP");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append("mvvar");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(srcVar);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(destVar);
- String str = sb.toString();
- return parseInstruction(str);
+ return parseInstruction(
+ InstructionUtils.concatOperands("CP", "mvvar", srcVar,
destVar));
}
private static String getBasicCreatevarString(String varName, String
fileName, boolean fNameOverride, DataType dt, String format) {
@@ -1313,22 +1287,10 @@ public class VariableCPInstruction extends
CPInstruction implements LineageTrace
boolean lfNameOverride = fNameOverride && !ConfigurationManager
.getCompilerConfigFlag(ConfigType.IGNORE_TEMPORARY_FILENAMES);
- StringBuilder sb = new StringBuilder();
- sb.append("CP");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append("createvar");
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(varName);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(fileName); // Constant
CREATEVAR_FILE_NAME_VAR_POS is used to find a position of filename within a
string generated through this function.
- // If
this position of filename within this string changes then constant
CREATEVAR_FILE_NAME_VAR_POS to be updated.
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(lfNameOverride);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(dt.toString());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(format);
- return sb.toString();
+ // Constant CREATEVAR_FILE_NAME_VAR_POS is used to find a
position of filename within a string generated through this function.
+ // If this position of filename within this string changes then
constant CREATEVAR_FILE_NAME_VAR_POS to be updated.
+ return InstructionUtils.concatOperands(
+ "CP", "createvar", varName, fileName,
String.valueOf(lfNameOverride), dt.toString(), format);
}
public static Instruction prepCreatevarInstruction(String varName,
String fileName, boolean fNameOverride, String format) {
@@ -1336,47 +1298,18 @@ public class VariableCPInstruction extends
CPInstruction implements LineageTrace
}
public static Instruction prepCreatevarInstruction(String varName,
String fileName, boolean fNameOverride, DataType dt, String format,
DataCharacteristics mc, UpdateType update) {
- StringBuilder sb = new StringBuilder();
- sb.append(getBasicCreatevarString(varName, fileName,
fNameOverride, dt, format));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getRows());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getCols());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getBlocksize());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getNonZeros());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(update.toString().toLowerCase());
-
- return parseInstruction(sb.toString());
+ return parseInstruction(InstructionUtils.concatOperands(
+ getBasicCreatevarString(varName, fileName,
fNameOverride, dt, format),
+ String.valueOf(mc.getRows()),
String.valueOf(mc.getCols()), String.valueOf(mc.getBlocksize()),
+ String.valueOf(mc.getNonZeros()),
update.toString().toLowerCase()));
}
public static Instruction prepCreatevarInstruction(String varName,
String fileName, boolean fNameOverride, DataType dt, String format,
DataCharacteristics mc, UpdateType update, boolean hasHeader, String delim,
boolean sparse) {
- StringBuilder sb = new StringBuilder();
- sb.append(getBasicCreatevarString(varName, fileName,
fNameOverride, dt, format));
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getRows());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getCols());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getBlocksize());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(mc.getNonZeros());
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(update.toString().toLowerCase());
-
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(hasHeader);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(delim);
- sb.append(Lop.OPERAND_DELIMITOR);
- sb.append(sparse);
-
- String str = sb.toString();
- return parseInstruction(str);
+ return parseInstruction(InstructionUtils.concatOperands(
+ getBasicCreatevarString(varName, fileName,
fNameOverride, dt, format),
+ String.valueOf(mc.getRows()),
String.valueOf(mc.getCols()), String.valueOf(mc.getBlocksize()),
+ String.valueOf(mc.getNonZeros()),
update.toString().toLowerCase(),
+ String.valueOf(hasHeader), delim,
String.valueOf(sparse)));
}
@Override
@@ -1393,10 +1326,10 @@ public class VariableCPInstruction extends
CPInstruction implements LineageTrace
int iPos2 = StringUtils.indexOf(instString,
Lop.OPERAND_DELIMITOR, iPos+1);
StringBuilder sb = new StringBuilder();
- sb.append(instString.substring(0,iPos+1));
// It takes first part before file name.
+ sb.append(instString.substring(0,iPos+1)); // It takes
first part before file name.
// This will replace 'pattern' with 'replace' string
from file name.
sb.append(ProgramConverter.saveReplaceFilenameThreadID(instString.substring(iPos+1,
iPos2+1), pattern, replace));
- sb.append(instString.substring(iPos2+1));
// It takes last part after file name.
+ sb.append(instString.substring(iPos2+1)); // It takes
last part after file name.
instString = sb.toString();
}
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryEviction.java
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryEviction.java
index 3a55fd906b..831390489e 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryEviction.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryEviction.java
@@ -23,11 +23,7 @@ import java.util.ArrayList;
import java.util.List;
import org.apache.sysds.api.DMLScript;
-import org.apache.sysds.runtime.lineage.LineageCacheConfig;
import org.apache.sysds.runtime.lineage.LineageCacheEntry;
-import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
-import org.apache.sysds.runtime.lineage.LineageGPUCacheEviction;
-import org.apache.sysds.utils.GPUStatistics;
public class GPUMemoryEviction implements Runnable
{
@@ -41,6 +37,7 @@ public class GPUMemoryEviction implements Runnable
numEvicts = 0;
}
+ @SuppressWarnings("unused")
@Override
public void run() {
//long currentAvailableMemory = allocator.getAvailableMemory();
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryManager.java
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryManager.java
index 57636726a2..a52c9eb6a9 100644
---
a/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryManager.java
+++
b/src/main/java/org/apache/sysds/runtime/instructions/gpu/context/GPUMemoryManager.java
@@ -22,7 +22,6 @@
import static jcuda.runtime.JCuda.cudaMemGetInfo;
import static jcuda.runtime.JCuda.cudaMemset;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
index 429794a3ce..9135565bd9 100644
---
a/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
+++
b/src/main/java/org/apache/sysds/runtime/lineage/LineageGPUCacheEviction.java
@@ -110,6 +110,7 @@ public class LineageGPUCacheEviction
return _startTimestamp;
}
+ @SuppressWarnings("unused")
private static void adjustD2HTransferSpeed(double sizeByte, double
copyTime) {
double sizeMB = sizeByte / (1024*1024);
double newTSpeed = sizeMB / copyTime; //bandwidth (MB/sec) +
java overhead
diff --git a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
index 8a4476bf63..34a5287b70 100644
--- a/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
+++ b/src/main/java/org/apache/sysds/runtime/util/ProgramConverter.java
@@ -466,6 +466,15 @@ public class ProgramConverter
return tmp;
}
+ public static ArrayList<Instruction>
createShallowCopyInstructionSet(ArrayList<Instruction> insts, long pid) {
+ ArrayList<Instruction> ret = new ArrayList<>();
+ for( Instruction inst : insts ) {
+ //save replacement of thread id references in
instructions
+ ret.add(saveReplaceThreadID( inst,
Lop.CP_ROOT_THREAD_ID, Lop.CP_CHILD_THREAD+pid));
+ }
+ return ret;
+ }
+
public static Instruction cloneInstruction( Instruction oInst, long
pid, boolean plain, boolean cpFunctions )
{
Instruction inst = null;
@@ -1724,6 +1733,7 @@ public class ProgramConverter
//////////
// CUSTOM SAFE LITERAL REPLACEMENT
+
/**
* In-place replacement of thread ids in filenames, functions names etc
*
diff --git
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
index 42985ba362..2af6784c36 100644
---
a/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
+++
b/src/test/java/org/apache/sysds/test/functions/builtin/part1/BuiltinDecisionTreeRealDataTest.java
@@ -52,7 +52,8 @@ public class BuiltinDecisionTreeRealDataTest extends
AutomatedTestBase {
String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
- programArgs = new String[] {"-args", data, tfspec,
output("R")};
+ programArgs = new String[] {"-stats",
+ "-args", data, tfspec, output("R")};
runTest(true, false, null, -1);