This is an automated email from the ASF dual-hosted git repository.
mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push:
new 335239e929 [SYSTEMDS-3714] Extended N-gram statistics based on lineage
335239e929 is described below
commit 335239e929c27c65aa59af6f22375826d61a662f
Author: Jaybit0 <[email protected]>
AuthorDate: Tue Sep 3 12:35:13 2024 +0200
[SYSTEMDS-3714] Extended N-gram statistics based on lineage
Closes #2062.
---
src/main/java/org/apache/sysds/api/DMLOptions.java | 11 +-
src/main/java/org/apache/sysds/api/DMLScript.java | 2 +
.../sysds/runtime/controlprogram/ProgramBlock.java | 8 +-
.../sysds/runtime/instructions/Instruction.java | 3 +
.../sysds/runtime/lineage/LineageItemUtils.java | 39 ++++
.../apache/sysds/runtime/lineage/LineageMap.java | 4 +
.../sysds/runtime/matrix/data/MatrixBlock.java | 2 -
.../java/org/apache/sysds/utils/Statistics.java | 225 ++++++++++++++++++++-
.../org/apache/sysds/utils/stats/NGramBuilder.java | 32 +++
.../sysds/performance/matrix/MatrixAggregate.java | 5 +-
.../apache/sysds/test/applications/L2SVMTest.java | 19 +-
11 files changed, 330 insertions(+), 20 deletions(-)
diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index acacc39572..5bd5e019d0 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -57,6 +57,7 @@ public class DMLOptions {
public int statsCount = 10; //
Default statistics count
public int[] statsNGramSizes = { 3 }; //
Default n-gram tuple sizes
public int statsTopKNGrams = 10; // How
many of the most heavy hitting n-grams are displayed
+ public boolean statsNGramsUseLineage = true; // If
N-Grams use lineage for data-dependent tracking
public boolean fedStats = false; //
Whether to record and print the federated statistics
public int fedStatsCount = 10; //
Default federated statistics count
public boolean memStats = false; // max
memory statistics
@@ -219,7 +220,7 @@ public class DMLOptions {
dmlOptions.statsNGrams = line.hasOption("ngrams");
if (dmlOptions.statsNGrams){
String[] nGramArgs = line.getOptionValues("ngrams");
- if (nGramArgs.length == 2) {
+ if (nGramArgs.length >= 2) {
try {
String[] nGramSizeSplit =
nGramArgs[0].split(",");
dmlOptions.statsNGramSizes = new
int[nGramSizeSplit.length];
@@ -229,10 +230,18 @@ public class DMLOptions {
}
dmlOptions.statsTopKNGrams =
Integer.parseInt(nGramArgs[1]);
+
+ if (nGramArgs.length == 3) {
+
dmlOptions.statsNGramsUseLineage = Boolean.parseBoolean(nGramArgs[2]);
+ }
} catch (NumberFormatException e) {
throw new
org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams
option, must be a valid integer");
}
}
+
+ if (dmlOptions.statsNGramsUseLineage) {
+ dmlOptions.lineage = true;
+ }
}
dmlOptions.fedStats = line.hasOption("fedStats");
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 2137915f22..81ce1f04b0 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -104,6 +104,8 @@ public class DMLScript
public static int[] STATISTICS_NGRAM_SIZES =
DMLOptions.defaultOptions.statsNGramSizes;
// Set top k displayed n-grams limit
public static int STATISTICS_TOP_K_NGRAMS =
DMLOptions.defaultOptions.statsTopKNGrams;
+ // Set if N-Grams use lineage for data-dependent tracking
+ public static boolean STATISTICS_NGRAMS_USE_LINEAGE =
DMLOptions.defaultOptions.statsNGramsUseLineage;
// Set statistics maximum wrap length
public static int STATISTICS_MAX_WRAP_LEN = 30;
// Enable/disable to print federated statistics
diff --git
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 4e75d5456f..0739334680 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -241,7 +241,8 @@ public abstract class ProgramBlock implements ParseInfo {
private void executeSingleInstruction(Instruction currInst,
ExecutionContext ec) {
try {
// start time measurement for statistics
- long t0 = (DMLScript.STATISTICS ||
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;
+ long t0 = (DMLScript.STATISTICS ||
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled())
+ ? System.nanoTime() : 0;
// pre-process instruction (inst patching, listeners,
lineage)
Instruction tmp = currInst.preprocessInstruction(ec);
@@ -264,9 +265,8 @@ public abstract class ProgramBlock implements ParseInfo {
Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() -
t0);
}
- if (DMLScript.STATISTICS_NGRAMS) {
-
Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0);
- }
+ if (DMLScript.STATISTICS_NGRAMS)
+
Statistics.maintainNGramsFromLineage(tmp, ec, t0);
}
// optional trace information (instruction and runtime)
diff --git
a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
index 969dfaf5c2..50238aadd8 100644
--- a/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
+++ b/src/main/java/org/apache/sysds/runtime/instructions/Instruction.java
@@ -26,6 +26,7 @@ import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.DataIdentifier;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.matrix.operators.Operator;
+import org.apache.sysds.utils.Statistics;
public abstract class Instruction
{
@@ -214,6 +215,8 @@ public abstract class Instruction
* @return instruction
*/
public Instruction preprocessInstruction(ExecutionContext ec) {
+ if (DMLScript.STATISTICS_NGRAMS &&
DMLScript.STATISTICS_NGRAMS_USE_LINEAGE)
+ Statistics.prepareNGramInst(null); // Reset the current
LineageItem for this thread
// Lineage tracing
if (DMLScript.LINEAGE)
ec.traceLineage(this);
diff --git
a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
index 58dab47534..5766437fe1 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageItemUtils.java
@@ -63,6 +63,7 @@ import
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import
org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction.DiagMatrix;
import org.apache.sysds.runtime.instructions.fed.ReorgFEDInstruction.Rdiag;
import org.apache.sysds.runtime.util.HDFSTool;
+import org.apache.sysds.utils.Statistics;
import java.io.IOException;
import java.util.ArrayList;
@@ -124,6 +125,44 @@ public class LineageItemUtils {
public static boolean isFunctionDebugging () {
return FUNCTION_DEBUGGING;
}
+
+ public static String explainLineageType(LineageItem li,
Statistics.LineageNGramExtension ext) {
+ if (li.getType() == LineageItemType.Literal) {
+ String[] splt = li.getData().split("·");
+ if (splt.length >= 3)
+ return splt[1] + "·" + splt[2];
+ return "·";
+ }
+ return ext != null ? ext.getDataType() + "·" +
ext.getValueType() : "··";
+ }
+
+ public static String explainLineageWithTypes(LineageItem li,
Statistics.LineageNGramExtension ext) {
+ if (li.getType() == LineageItemType.Literal) {
+ String[] splt = li.getData().split("·");
+ if (splt.length >= 3)
+ return "L·" + splt[1] + "·" + splt[2];
+ return "L··";
+ }
+ return li.getOpcode() + "·" + (ext != null ? ext.getDataType()
+ "·" + ext.getValueType() : "·");
+ }
+
+ public static String explainLineageAsInstruction(LineageItem li,
Statistics.LineageNGramExtension ext) {
+ StringBuilder sb = new
StringBuilder(explainLineageWithTypes(li, ext));
+ sb.append("(");
+ if (li.getInputs() != null) {
+ int ctr = 0;
+ for (LineageItem liIn : li.getInputs()) {
+ if (ctr++ != 0)
+ sb.append(" ° ");
+ if (liIn.getType() == LineageItemType.Literal)
+ sb.append("L_" +
explainLineageType(liIn, Statistics.getExtendedLineage(li)));
+ else
+ sb.append(explainLineageType(liIn,
Statistics.getExtendedLineage(li)));
+ }
+ }
+ sb.append(")");
+ return sb.toString();
+ }
public static String explainSingleLineageItem(LineageItem li) {
StringBuilder sb = new StringBuilder();
diff --git a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
index 41875bdfdf..2b3c981d9e 100644
--- a/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
+++ b/src/main/java/org/apache/sysds/runtime/lineage/LineageMap.java
@@ -32,6 +32,7 @@ import
org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.instructions.spark.WriteSPInstruction;
import org.apache.sysds.runtime.lineage.LineageItem.LineageItemType;
import org.apache.sysds.utils.Explain;
+import org.apache.sysds.utils.Statistics;
import java.util.HashMap;
import java.util.Map;
@@ -146,6 +147,9 @@ public class LineageMap {
}
private void trace(Instruction inst, ExecutionContext ec, Pair<String,
LineageItem> li) {
+ if (li != null && li.getValue() != null &&
DMLScript.STATISTICS_NGRAMS && DMLScript.STATISTICS_NGRAMS_USE_LINEAGE)
+ Statistics.prepareNGramInst(li);
+
if (inst instanceof VariableCPInstruction) {
VariableCPInstruction vcp_inst =
((VariableCPInstruction) inst);
diff --git
a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
index 79eb73ba0a..f76502ef7c 100644
--- a/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/matrix/data/MatrixBlock.java
@@ -73,7 +73,6 @@ import
org.apache.sysds.runtime.functionobjects.Builtin.BuiltinCode;
import org.apache.sysds.runtime.functionobjects.CM;
import org.apache.sysds.runtime.functionobjects.CTable;
import org.apache.sysds.runtime.functionobjects.DiagIndex;
-import org.apache.sysds.runtime.functionobjects.Divide;
import org.apache.sysds.runtime.functionobjects.FunctionObject;
import org.apache.sysds.runtime.functionobjects.IfElse;
import org.apache.sysds.runtime.functionobjects.KahanFunction;
@@ -96,7 +95,6 @@ import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.spark.data.IndexedMatrixValue;
import org.apache.sysds.runtime.io.IOUtilFunctions;
-import org.apache.sysds.runtime.matrix.data.LibMatrixBincell.BinaryAccessType;
import org.apache.sysds.runtime.matrix.operators.AggregateBinaryOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateOperator;
import org.apache.sysds.runtime.matrix.operators.AggregateTernaryOperator;
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 3ad613c842..c0f087d0b0 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -26,13 +26,19 @@ import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.hops.fedplanner.FederatedCompilationTimer;
import org.apache.sysds.runtime.controlprogram.caching.CacheStatistics;
+import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
+import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedStatistics;
import org.apache.sysds.runtime.instructions.Instruction;
import org.apache.sysds.runtime.instructions.InstructionUtils;
+import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.FunctionCallCPInstruction;
import org.apache.sysds.runtime.instructions.spark.SPInstruction;
import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
+import org.apache.sysds.runtime.meta.DataCharacteristics;
import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
+import org.apache.sysds.runtime.lineage.LineageItem;
+import org.apache.sysds.runtime.lineage.LineageItemUtils;
import org.apache.sysds.utils.stats.CodegenStatistics;
import org.apache.sysds.utils.stats.NGramBuilder;
import org.apache.sysds.utils.stats.NativeStatistics;
@@ -46,6 +52,7 @@ import java.lang.management.CompilationMXBean;
import java.lang.management.GarbageCollectorMXBean;
import java.lang.management.ManagementFactory;
import java.text.DecimalFormat;
+import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
@@ -54,10 +61,12 @@ import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Map.Entry;
+import java.util.Optional;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.DoubleAdder;
import java.util.concurrent.atomic.LongAdder;
+import java.util.function.Consumer;
/**
* This class captures all statistics.
@@ -74,6 +83,7 @@ public class Statistics
public final long n;
public final long cumTimeNanos;
public final double m2;
+ public final HashMap<String, Double> meta;
public static <T> Comparator<NGramBuilder.NGramEntry<T,
NGramStats>> getComparator() {
return Comparator.comparingLong(entry ->
entry.getCumStats().cumTimeNanos);
@@ -91,13 +101,25 @@ public class Statistics
double newM2 = stats1.m2 + stats2.m2 + delta * delta *
stats1.n * stats2.n / (double)newN;
- return new NGramStats(newN, cumTimeNanos, newM2);
+ HashMap<String, Double> cpy = null;
+
+ if (stats1.meta != null) {
+ cpy = new HashMap<>(stats1.meta);
+ final HashMap<String, Double> mCpy = cpy;
+ if (stats2.meta != null)
+ stats2.meta.forEach((key, value) ->
mCpy.merge(key, value, Double::sum));
+ } else if (stats2.meta != null) {
+ cpy = new HashMap<>(stats2.meta);
+ }
+
+ return new NGramStats(newN, cumTimeNanos, newM2, cpy);
}
- public NGramStats(final long n, final long cumTimeNanos, final
double m2) {
+ public NGramStats(final long n, final long cumTimeNanos, final
double m2, HashMap<String, Double> meta) {
this.n = n;
this.cumTimeNanos = cumTimeNanos;
this.m2 = m2;
+ this.meta = meta;
}
public double getTimeVariance() {
@@ -107,6 +129,54 @@ public class Statistics
public String toString() {
return String.format(Locale.US, "%.5f", (cumTimeNanos /
1000000000d));
}
+
+ public HashMap<String, Double> getMeta() {
+ return meta;
+ }
+ }
+
+ public static class LineageNGramExtension {
+ private String _datatype;
+ private String _valuetype;
+ private long _execNanos;
+
+ private HashMap<String, Double> _meta;
+
+ public void setDataType(String dataType) {
+ _datatype = dataType;
+ }
+
+ public String getDataType() {
+ return _datatype == null ? "" : _datatype;
+ }
+
+ public void setValueType(String valueType) {
+ _valuetype = valueType;
+ }
+
+ public String getValueType() {
+ return _valuetype == null ? "" : _valuetype;
+ }
+
+ public void setExecNanos(long nanos) {
+ _execNanos = nanos;
+ }
+
+ public long getExecNanos() {
+ return _execNanos;
+ }
+
+ public void setMeta(String key, Double value) {
+ if (_meta == null)
+ _meta = new HashMap<>();
+ _meta.put(key, value);
+ }
+
+ public Object getMeta(String key) {
+ if (_meta == null)
+ return null;
+ return _meta.get(key);
+ }
}
private static long compileStartTime = 0;
@@ -117,6 +187,8 @@ public class Statistics
//heavy hitter counts and times
private static final ConcurrentHashMap<String,InstStats> _instStats =
new ConcurrentHashMap<>();
private static final ConcurrentHashMap<String, NGramBuilder<String,
NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap<Long, Entry<String,
LineageItem>> _instStatsLineageTracker = new ConcurrentHashMap<>();
+ private static final ConcurrentHashMap<LineageItem,
LineageNGramExtension> _lineageExtensions = new ConcurrentHashMap<>();
// number of compiled/executed SP instructions
private static final LongAdder numExecutedSPInst = new LongAdder();
@@ -299,6 +371,8 @@ public class Statistics
FederatedStatistics.reset();
_instStatsNGram.clear();
+ _instStatsLineageTracker.clear();
+ _instStats.clear();
}
public static void resetJITCompileTime(){
@@ -401,6 +475,120 @@ public class Statistics
tmp.count.increment();
}
+ public static void prepareNGramInst(Entry<String, LineageItem> li) {
+ if (li == null)
+
_instStatsLineageTracker.remove(Thread.currentThread().getId());
+ else
+
_instStatsLineageTracker.put(Thread.currentThread().getId(), li);
+ }
+
+ public static Optional<Entry<String, LineageItem>>
getCurrentLineageItem() {
+ Entry<String, LineageItem> item =
_instStatsLineageTracker.get(Thread.currentThread().getId());
+ return item == null ? Optional.empty() : Optional.of(item);
+ }
+
+ public static synchronized void clearNGramRecording() {
+ NGramBuilder<String, NGramStats>[] bl =
_instStatsNGram.get(Thread.currentThread().getName());
+ for (NGramBuilder<String, NGramStats> b : bl)
+ b.clearCurrentRecording();
+ }
+
+ public static synchronized void extendLineageItem(LineageItem li,
LineageNGramExtension ext) {
+ _lineageExtensions.put(li, ext);
+ }
+
+ public static synchronized LineageNGramExtension
getExtendedLineage(LineageItem li) {
+ return _lineageExtensions.get(li);
+ }
+
+ public static synchronized void maintainNGramsFromLineage(Instruction
tmp, ExecutionContext ec, long t0) {
+ final long nanoTime = System.nanoTime() - t0;
+ if (DMLScript.STATISTICS_NGRAMS_USE_LINEAGE) {
+ Statistics.getCurrentLineageItem().ifPresent(li -> {
+ Data data = ec.getVariable(li.getKey());
+ Statistics.LineageNGramExtension ext = new
Statistics.LineageNGramExtension();
+ if (data != null) {
+
ext.setDataType(data.getDataType().toString());
+
ext.setValueType(data.getValueType().toString());
+ if (data instanceof CacheableData) {
+ DataCharacteristics dc =
((CacheableData<?>)data).getDataCharacteristics();
+ ext.setMeta("NDims",
(double)dc.getNumDims());
+ ext.setMeta("NumRows",
(double)dc.getRows());
+ ext.setMeta("NumCols",
(double)dc.getCols());
+ ext.setMeta("NonZeros",
(double)dc.getNonZeros());
+ }
+ }
+ ext.setExecNanos(nanoTime);
+ Statistics.extendLineageItem(li.getValue(),
ext);
+
Statistics.maintainNGramsFromLineage(li.getValue());
+ });
+ } else
+ Statistics.maintainNGrams(tmp.getExtendedOpcode(),
nanoTime);
+ }
+
+ @SuppressWarnings("unchecked")
+ public static synchronized void maintainNGramsFromLineage(LineageItem
li) {
+ NGramBuilder<String, NGramStats>[] tmp =
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
+ NGramBuilder<String, NGramStats>[] threadEntry = new
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+ for (int i = 0; i < threadEntry.length; i++) {
+ threadEntry[i] = new NGramBuilder<String,
NGramStats>(String.class, NGramStats.class,
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+ }
+ return threadEntry;
+ });
+ addLineagePaths(li, new ArrayList<>(), new ArrayList<>(), tmp);
+ }
+
+ /**
+ * Adds the corresponding sequences of instructions to the n-grams.
+ * <p></p>
+ * Example: 2-grams from (a*b + a/c) will add [(*,+), (/,+)]
+ * @param li
+ * @param currentPath
+ * @param indexes
+ * @param builders
+ */
+ private static void addLineagePaths(LineageItem li,
ArrayList<Entry<LineageItem, LineageNGramExtension>> currentPath,
ArrayList<Integer> indexes, NGramBuilder<String, NGramStats>[] builders) {
+ if (li.getType() == LineageItem.LineageItemType.Literal)
+ return; // Skip literals as they are no real instruction
+
+ currentPath.add(new AbstractMap.SimpleEntry<>(li,
getExtendedLineage(li)));
+
+ int maxSize = 0;
+ NGramBuilder<String, NGramStats> matchingBuilder = null;
+
+ for (NGramBuilder<String, NGramStats> builder : builders) {
+ if (builder.getSize() == currentPath.size())
+ matchingBuilder = builder;
+ if (builder.getSize() > maxSize)
+ maxSize = builder.getSize();
+ }
+
+ if (matchingBuilder != null) {
+ // If we have an n-gram builder with n =
currentPath.size(), then we want to insert the entry
+ // As we cannot incrementally add the instructions (we
have a DAG rather than a sequence of instructions)
+ // we need to clear the current n-grams
+ clearNGramRecording();
+ // We then record a new n-gram with all the
LineageItems of the current lineage path
+ Entry<LineageItem, LineageNGramExtension> currentEntry
= currentPath.get(currentPath.size()-1);
+
matchingBuilder.append(LineageItemUtils.explainLineageAsInstruction(currentEntry.getKey(),
currentEntry.getValue()) + (indexes.size() > 0 ? ("[" +
indexes.get(currentPath.size()-2) + "]") : ""), new NGramStats(1,
currentEntry.getValue() != null ? currentEntry.getValue().getExecNanos() : 0,
0, currentEntry.getValue() != null ? currentEntry.getValue()._meta : null));
+ for (int i = currentPath.size()-2; i >= 0; i--) {
+ currentEntry = currentPath.get(i);
+
matchingBuilder.append(LineageItemUtils.explainLineageAsInstruction(currentEntry.getKey(),
currentEntry.getValue()) + (i > 0 ? ("[" + indexes.get(i-1) + "]") : ""), new
NGramStats(1, currentEntry.getValue() != null ?
currentEntry.getValue().getExecNanos() : 0, 0, currentEntry.getValue() != null
? currentEntry.getValue()._meta : null));
+ }
+ }
+
+ if (currentPath.size() < maxSize && li.getInputs() != null) {
+ int idx = 0;
+ for (LineageItem input : li.getInputs()) {
+ indexes.add(idx++);
+ addLineagePaths(input, currentPath, indexes,
builders);
+ indexes.remove(indexes.size()-1);
+ }
+ }
+
+ currentPath.remove(currentPath.size()-1);
+ }
+
@SuppressWarnings("unchecked")
public static void maintainNGrams(String instName, long timeNanos) {
NGramBuilder<String, NGramStats>[] tmp =
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
@@ -412,7 +600,7 @@ public class Statistics
});
for (int i = 0; i < tmp.length; i++)
- tmp[i].append(instName, new NGramStats(1, timeNanos,
0));
+ tmp[i].append(instName, new NGramStats(1, timeNanos, 0,
null));
}
@SuppressWarnings("unchecked")
@@ -467,7 +655,7 @@ public class Statistics
return sb.toString();
}
- public static String nGramToCSV(final NGramBuilder<String, NGramStats>
mbuilder) {
+ public static void toCSVStream(final NGramBuilder<String, NGramStats>
mbuilder, final Consumer<String> lineConsumer) {
ArrayList<String> colList = new ArrayList<>();
colList.add("N-Gram");
colList.add("Time[s]");
@@ -478,10 +666,12 @@ public class Statistics
colList.add("Col" + (j + 1) + "::Mean(Time[s])");
for (int j = 0; j < mbuilder.getSize(); j++)
colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" +
(j + 1) + "::Mean(Time[s])");
+ for (int j = 0; j < mbuilder.getSize(); j++)
+ colList.add("Col" + (j + 1) + "_Meta");
colList.add("Count");
- return NGramBuilder.toCSV(colList.toArray(new
String[colList.size()]), mbuilder.getTopK(100000,
Statistics.NGramStats.getComparator(), true), e -> {
+ NGramBuilder.toCSVStream(colList.toArray(new
String[colList.size()]), mbuilder.getTopK(100000,
Statistics.NGramStats.getComparator(), true), e -> {
StringBuilder builder = new StringBuilder();
builder.append(e.getIdentifier().replace("(",
"").replace(")", "").replace(", ", ","));
builder.append(",");
@@ -494,8 +684,31 @@ public class Statistics
} else {
builder.append(stdDevs);
}
+ //builder.append(",");
+ boolean first = true;
+ NGramStats[] stats = e.getStats();
+ for (int i = 0; i < stats.length; i++) {
+ builder.append(",");
+ NGramStats stat = stats[i];
+ if (stat.getMeta() != null) {
+ for (Entry<String, Double> metaData :
stat.getMeta().entrySet()) {
+ if (first)
+ first = false;
+ else
+ builder.append("&");
+ if (metaData.getValue() != null)
+
builder.append(metaData.getKey()).append(":").append(metaData.getValue());
+ }
+ }
+ }
return builder.toString();
- });
+ }, lineConsumer);
+ }
+
+ public static String nGramToCSV(final NGramBuilder<String, NGramStats>
mbuilder) {
+ final StringBuilder b = new StringBuilder();
+ toCSVStream(mbuilder, b::append);
+ return b.toString();
}
public static String getCommonNGrams(NGramBuilder<String, NGramStats>
builder, int num) {
diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
index e0212e5c73..85d8012789 100644
--- a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
+++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
@@ -19,6 +19,8 @@
package org.apache.sysds.utils.stats;
+import org.apache.commons.lang3.function.TriFunction;
+
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Comparator;
@@ -26,6 +28,7 @@ import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
+import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
@@ -53,6 +56,30 @@ public class NGramBuilder<T, U> {
return builder.toString();
}
+ public static <T, U> void toCSVStream(String[] columnNames,
List<NGramEntry<T, U>> entries, Function<NGramEntry<T, U>, String> statsMapper,
Consumer<String> lineConsumer) {
+ StringBuilder builder = new StringBuilder(String.join(",",
columnNames));
+ builder.append("\n");
+ lineConsumer.accept(builder.toString());
+ builder.setLength(0);
+
+ for (NGramEntry<T, U> entry : entries) {
+ builder.append(entry.getIdentifier().replace(",", ";"));
+ builder.append(",");
+ builder.append(entry.getCumStats());
+ builder.append(",");
+
+ if (statsMapper != null) {
+ builder.append(statsMapper.apply(entry));
+ builder.append(",");
+ }
+
+ builder.append(entry.getOccurrences());
+ builder.append("\n");
+ lineConsumer.accept(builder.toString());
+ builder.setLength(0);
+ }
+ }
+
public static class NGramEntry<T, U> {
private final String identifier;
private final T[] entry;
@@ -209,6 +236,11 @@ public class NGramBuilder<T, U> {
.collect(Collectors.toList());
}
+ public synchronized void clearCurrentRecording() {
+ currentIndex = 0;
+ currentSize = 0;
+ }
+
private synchronized void registerElement(String id, U stat) {
nGrams.compute(id, (key, entry) -> {
if (entry == null) {
diff --git
a/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
b/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
index f6a466efaa..8e60ee97cb 100644
--- a/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
+++ b/src/test/java/org/apache/sysds/performance/matrix/MatrixAggregate.java
@@ -21,7 +21,6 @@ package org.apache.sysds.performance.matrix;
import org.apache.sysds.performance.compression.APerfTest;
import org.apache.sysds.performance.generators.ConstMatrix;
-import org.apache.sysds.performance.generators.GenPair;
import org.apache.sysds.performance.generators.IGenerate;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.test.TestUtils;
@@ -39,8 +38,8 @@ public class MatrixAggregate extends APerfTest<Object,
MatrixBlock> {
public void run() throws Exception {
MatrixBlock mb = gen.take();
- String info = String.format("rows: %5d cols: %5d sp: %5.3f par:
%2d", mb.getNumRows(), mb.getNumColumns(),
- mb.getSparsity(), k);
+ String info = String.format("rows: %5d cols: %5d sp: %5.3f par:
%2d",
+ mb.getNumRows(), mb.getNumColumns(), mb.getSparsity(),
k);
warmup(() -> sum(), 100);
execute(() -> sum(), info + " sum");
}
diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
index dbb98160e2..534b058425 100644
--- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
@@ -67,7 +67,16 @@ public class L2SVMTest extends AutomatedTestBase
}
@Test
- public void testL2SVM()
+ public void testL2SVM1() {
+ testL2SVM(true);
+ }
+
+ @Test
+ public void testL2SVM2() {
+ testL2SVM(false);
+ }
+
+ private void testL2SVM(boolean ngrams)
{
System.out.println("------------ BEGIN " + TEST_NAME
+ " TEST WITH {" + numRecords + ", " + numFeatures
@@ -83,9 +92,11 @@ public class L2SVMTest extends AutomatedTestBase
List<String> proArgs = new ArrayList<>();
proArgs.add("-stats");
- proArgs.add("-ngrams");
- proArgs.add("3,2");
- proArgs.add("10");
+ if (ngrams) {
+ proArgs.add("-ngrams");
+ proArgs.add("3,2");
+ proArgs.add("10");
+ }
proArgs.add("-nvargs");
proArgs.add("X=" + input("X"));
proArgs.add("Y=" + input("Y"));