This is an automated email from the ASF dual-hosted git repository. mboehm7 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/systemds.git
The following commit(s) were added to refs/heads/main by this push: new 9f96718939 [SYSTEMDS-3714] N-gram statistics of operation sequences 9f96718939 is described below commit 9f96718939bd43fedcf942ac71f6fd70bdcf48d6 Author: Jaybit0 <jannik.lindema...@gmail.com> AuthorDate: Wed Jul 24 18:31:43 2024 +0200 [SYSTEMDS-3714] N-gram statistics of operation sequences Closes #2045. --- src/main/java/org/apache/sysds/api/DMLOptions.java | 27 +++ src/main/java/org/apache/sysds/api/DMLScript.java | 9 + .../sysds/runtime/controlprogram/ProgramBlock.java | 6 +- .../java/org/apache/sysds/utils/Statistics.java | 231 ++++++++++++++++++- .../org/apache/sysds/utils/stats/NGramBuilder.java | 248 +++++++++++++++++++++ .../test/applications/ApplyTransformTest.java | 9 + .../apache/sysds/test/applications/L2SVMTest.java | 3 + 7 files changed, 529 insertions(+), 4 deletions(-) diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java b/src/main/java/org/apache/sysds/api/DMLOptions.java index 70af5ba9e8..acacc39572 100644 --- a/src/main/java/org/apache/sysds/api/DMLOptions.java +++ b/src/main/java/org/apache/sysds/api/DMLOptions.java @@ -53,7 +53,10 @@ public class DMLOptions { public String configFile = null; // Path to config file if default config and default config is to be overridden public boolean clean = false; // Whether to clean up all SystemDS working directories (FS, DFS) public boolean stats = false; // Whether to record and print the statistics + public boolean statsNGrams = false; // Whether to record and print the statistics n-grams public int statsCount = 10; // Default statistics count + public int[] statsNGramSizes = { 3 }; // Default n-gram tuple sizes + public int statsTopKNGrams = 10; // How many of the most heavy hitting n-grams are displayed public boolean fedStats = false; // Whether to record and print the federated statistics public int fedStatsCount = 10; // Default federated statistics count public boolean memStats = false; // max memory statistics @@ -212,6 +215,26 @@ public class DMLOptions { } } } + + dmlOptions.statsNGrams = line.hasOption("ngrams"); + if (dmlOptions.statsNGrams){ + String[] nGramArgs = line.getOptionValues("ngrams"); + if (nGramArgs.length == 2) { + try { + String[] nGramSizeSplit = nGramArgs[0].split(","); + dmlOptions.statsNGramSizes = new int[nGramSizeSplit.length]; + + for (int i = 0; i < nGramSizeSplit.length; i++) { + dmlOptions.statsNGramSizes[i] = Integer.parseInt(nGramSizeSplit[i]); + } + + dmlOptions.statsTopKNGrams = Integer.parseInt(nGramArgs[1]); + } catch (NumberFormatException e) { + throw new org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams option, must be a valid integer"); + } + } + } + dmlOptions.fedStats = line.hasOption("fedStats"); if (dmlOptions.fedStats) { String fedStatsCount = line.getOptionValue("fedStats"); @@ -335,6 +358,9 @@ public class DMLOptions { Option statsOpt = OptionBuilder.withArgName("count") .withDescription("monitors and reports summary execution statistics; heavy hitter <count> is 10 unless overridden; default off") .hasOptionalArg().create("stats"); + Option ngramsOpt = OptionBuilder//.withArgName("ngrams") + .withDescription("monitors and reports the most occurring n-grams; -ngrams <comma separated n's> <topK>") + .hasOptionalArgs(2).create("ngrams"); Option fedStatsOpt = OptionBuilder.withArgName("count") .withDescription("monitors and reports summary execution statistics of federated workers; heavy hitter <count> is 10 unless overridden; default off") .hasOptionalArg().create("fedStats"); @@ -396,6 +422,7 @@ public class DMLOptions { options.addOption(configOpt); options.addOption(cleanOpt); options.addOption(statsOpt); + options.addOption(ngramsOpt); options.addOption(fedStatsOpt); options.addOption(memOpt); options.addOption(explainOpt); diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java b/src/main/java/org/apache/sysds/api/DMLScript.java index 3443f68740..cd86426a42 100644 --- a/src/main/java/org/apache/sysds/api/DMLScript.java +++ b/src/main/java/org/apache/sysds/api/DMLScript.java @@ -94,10 +94,16 @@ public class DMLScript private static ExecMode EXEC_MODE = DMLOptions.defaultOptions.execMode; // Enable/disable to print statistics public static boolean STATISTICS = DMLOptions.defaultOptions.stats; + // Enable/disable to print statistics n-grams + public static boolean STATISTICS_NGRAMS = DMLOptions.defaultOptions.statsNGrams; // Enable/disable to gather memory use stats in JMLC public static boolean JMLC_MEM_STATISTICS = false; // Set maximum heavy hitter count public static int STATISTICS_COUNT = DMLOptions.defaultOptions.statsCount; + // The sizes of recorded n-gram tuples + public static int[] STATISTICS_NGRAM_SIZES = DMLOptions.defaultOptions.statsNGramSizes; + // Set top k displayed n-grams limit + public static int STATISTICS_TOP_K_NGRAMS = DMLOptions.defaultOptions.statsTopKNGrams; // Set statistics maximum wrap length public static int STATISTICS_MAX_WRAP_LEN = 30; // Enable/disable to print federated statistics @@ -250,6 +256,9 @@ public class DMLScript { STATISTICS = dmlOptions.stats; STATISTICS_COUNT = dmlOptions.statsCount; + STATISTICS_NGRAMS = dmlOptions.statsNGrams; + STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes; + STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams; FED_STATISTICS = dmlOptions.fedStats; FED_STATISTICS_COUNT = dmlOptions.fedStatsCount; JMLC_MEM_STATISTICS = dmlOptions.memStats; diff --git a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java index 34b954148b..4e75d5456f 100644 --- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java +++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java @@ -241,7 +241,7 @@ public abstract class ProgramBlock implements ParseInfo { private void executeSingleInstruction(Instruction currInst, ExecutionContext ec) { try { // start time measurement for statistics - long t0 = (DMLScript.STATISTICS || LOG.isTraceEnabled()) ? System.nanoTime() : 0; + long t0 = (DMLScript.STATISTICS || DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0; // pre-process instruction (inst patching, listeners, lineage) Instruction tmp = currInst.preprocessInstruction(ec); @@ -263,6 +263,10 @@ public abstract class ProgramBlock implements ParseInfo { if(DMLScript.STATISTICS) { Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() - t0); } + + if (DMLScript.STATISTICS_NGRAMS) { + Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0); + } } // optional trace information (instruction and runtime) diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java b/src/main/java/org/apache/sysds/utils/Statistics.java index 5cab7dbd30..a7c764cf78 100644 --- a/src/main/java/org/apache/sysds/utils/Statistics.java +++ b/src/main/java/org/apache/sysds/utils/Statistics.java @@ -34,10 +34,11 @@ import org.apache.sysds.runtime.instructions.spark.SPInstruction; import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType; import org.apache.sysds.runtime.lineage.LineageCacheStatistics; import org.apache.sysds.utils.stats.CodegenStatistics; -import org.apache.sysds.utils.stats.RecompileStatistics; +import org.apache.sysds.utils.stats.NGramBuilder; import org.apache.sysds.utils.stats.NativeStatistics; -import org.apache.sysds.utils.stats.ParamServStatistics; import org.apache.sysds.utils.stats.ParForStatistics; +import org.apache.sysds.utils.stats.ParamServStatistics; +import org.apache.sysds.utils.stats.RecompileStatistics; import org.apache.sysds.utils.stats.SparkStatistics; import org.apache.sysds.utils.stats.TransformStatistics; @@ -45,10 +46,13 @@ import java.lang.management.CompilationMXBean; import java.lang.management.GarbageCollectorMXBean; import java.lang.management.ManagementFactory; import java.text.DecimalFormat; +import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Map.Entry; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -64,6 +68,46 @@ public class Statistics private final LongAdder time = new LongAdder(); private final LongAdder count = new LongAdder(); } + + public static class NGramStats { + + public final long n; + public final long cumTimeNanos; + public final double m2; + + public static <T> Comparator<NGramBuilder.NGramEntry<T, NGramStats>> getComparator() { + return Comparator.comparingLong(entry -> entry.getCumStats().cumTimeNanos); + } + + public static NGramStats merge(NGramStats stats1, NGramStats stats2) { + // Using the algorithm from: https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm + long newN = stats1.n + stats2.n; + long cumTimeNanos = stats1.cumTimeNanos + stats2.cumTimeNanos; + + // Ensure the calculation uses floating-point arithmetic + double mean1 = (double) stats1.cumTimeNanos / 1000000000d / stats1.n; + double mean2 = (double) stats2.cumTimeNanos / 1000000000d / stats2.n; + double delta = mean2 - mean1; + + double newM2 = stats1.m2 + stats2.m2 + delta * delta * stats1.n * stats2.n / (double)newN; + + return new NGramStats(newN, cumTimeNanos, newM2); + } + + public NGramStats(final long n, final long cumTimeNanos, final double m2) { + this.n = n; + this.cumTimeNanos = cumTimeNanos; + this.m2 = m2; + } + + public double getTimeVariance() { + return m2 / Math.max(n-1, 1); + } + + public String toString() { + return String.format(Locale.US, "%.5f", (cumTimeNanos / 1000000000d)); + } + } private static long compileStartTime = 0; private static long compileEndTime = 0; @@ -71,7 +115,8 @@ public class Statistics private static long execEndTime = 0; //heavy hitter counts and times - private static final ConcurrentHashMap<String,InstStats>_instStats = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String,InstStats> _instStats = new ConcurrentHashMap<>(); + private static final ConcurrentHashMap<String, NGramBuilder<String, NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>(); // number of compiled/executed SP instructions private static final LongAdder numExecutedSPInst = new LongAdder(); @@ -252,6 +297,8 @@ public class Statistics DMLCompressionStatistics.reset(); FederatedStatistics.reset(); + + _instStatsNGram.clear(); } public static void resetJITCompileTime(){ @@ -353,6 +400,177 @@ public class Statistics tmp.time.add(timeNanos); tmp.count.increment(); } + + public static void maintainNGrams(String instName, long timeNanos) { + NGramBuilder<String, NGramStats>[] tmp = _instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> { + NGramBuilder<String, NGramStats>[] threadEntry = new NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length]; + for (int i = 0; i < threadEntry.length; i++) { + threadEntry[i] = new NGramBuilder<String, NGramStats>(String.class, NGramStats.class, DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge); + } + return threadEntry; + }); + + for (int i = 0; i < tmp.length; i++) + tmp[i].append(instName, new NGramStats(1, timeNanos, 0)); + } + + public static NGramBuilder<String, NGramStats>[] mergeNGrams() { + NGramBuilder<String, NGramStats>[] builders = new NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length]; + + for (int i = 0; i < builders.length; i++) { + builders[i] = new NGramBuilder<String, NGramStats>(String.class, NGramStats.class, DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge); + } + + for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length; i++) { + for (Map.Entry<String, NGramBuilder<String, NGramStats>[]> entry : _instStatsNGram.entrySet()) { + NGramBuilder<String, NGramStats> mbuilder = entry.getValue()[i]; + builders[i].merge(mbuilder); + } + } + + return builders; + } + + public static String getNGramStdDevs(NGramStats[] stats, int offset, int prec, boolean displayZero) { + StringBuilder sb = new StringBuilder(); + sb.append("("); + boolean containsData = false; + int actualIndex; + for (int i = 0; i < stats.length; i++) { + if (i != 0) + sb.append(", "); + actualIndex = (offset + i) % stats.length; + double var = 1000000000d * stats[actualIndex].n * Math.sqrt(stats[actualIndex].getTimeVariance()) / stats[actualIndex].cumTimeNanos; + if (displayZero || var >= Math.pow(10, -prec)) { + sb.append(String.format(Locale.US, "%." + prec + "f", var)); + containsData = true; + } + } + sb.append(")"); + return containsData ? sb.toString() : "-"; + } + + public static String getNGramAvgTimes(NGramStats[] stats, int offset, int prec) { + StringBuilder sb = new StringBuilder(); + sb.append("("); + int actualIndex; + for (int i = 0; i < stats.length; i++) { + if (i != 0) + sb.append(", "); + actualIndex = (offset + i) % stats.length; + double var = (stats[actualIndex].cumTimeNanos / 1000000000d) / stats[actualIndex].n; + sb.append(String.format(Locale.US, "%." + prec + "f", var)); + } + sb.append(")"); + return sb.toString(); + } + + public static String nGramToCSV(final NGramBuilder<String, NGramStats> mbuilder) { + ArrayList<String> colList = new ArrayList<>(); + colList.add("N-Gram"); + colList.add("Time[s]"); + + for (int j = 0; j < mbuilder.getSize(); j++) + colList.add("Col" + (j + 1)); + for (int j = 0; j < mbuilder.getSize(); j++) + colList.add("Col" + (j + 1) + "::Mean(Time[s])"); + for (int j = 0; j < mbuilder.getSize(); j++) + colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" + (j + 1) + "::Mean(Time[s])"); + + colList.add("Count"); + + return NGramBuilder.toCSV(colList.toArray(new String[colList.size()]), mbuilder.getTopK(100000, Statistics.NGramStats.getComparator(), true), e -> { + StringBuilder builder = new StringBuilder(); + builder.append(e.getIdentifier().replace("(", "").replace(")", "").replace(", ", ",")); + builder.append(","); + builder.append(Statistics.getNGramAvgTimes(e.getStats(), e.getOffset(), 9).replace("-", "").replace("(", "").replace(")", "")); + builder.append(","); + String stdDevs = Statistics.getNGramStdDevs(e.getStats(), e.getOffset(), 9, true).replace("-", "").replace("(", "").replace(")", ""); + if (stdDevs.isEmpty()) { + for (int j = 0; j < mbuilder.getSize()-1; j++) + builder.append(","); + } else { + builder.append(stdDevs); + } + return builder.toString(); + }); + } + + public static String getCommonNGrams(NGramBuilder<String, NGramStats> builder, int num) { + if (num <= 0 || _instStatsNGram.size() <= 0) + return "-"; + + //NGramBuilder<String, Long> builder = mergeNGrams(); + @SuppressWarnings("unchecked") + NGramBuilder.NGramEntry<String, NGramStats>[] topNGrams = builder.getTopK(num, NGramStats.getComparator(), true).toArray(NGramBuilder.NGramEntry[]::new); + + final String numCol = "#"; + final String instCol = "N-Gram"; + final String timeSCol = "Time(s)"; + final String timeSVar = "StdDev(t)/Mean(t)"; + final String countCol = "Count"; + StringBuilder sb = new StringBuilder(); + int len = topNGrams.length; + int numHittersToDisplay = Math.min(num, len); + int maxNumLen = String.valueOf(numHittersToDisplay).length(); + int maxInstLen = instCol.length(); + int maxTimeSLen = timeSCol.length(); + int maxTimeSVarLen = timeSVar.length(); + int maxCountLen = countCol.length(); + DecimalFormat sFormat = new DecimalFormat("#,##0.000"); + + for (int i = 0; i < numHittersToDisplay; i++) { + long timeNs = topNGrams[i].getCumStats().cumTimeNanos; + String instruction = topNGrams[i].getIdentifier(); + double timeS = timeNs / 1000000000d; + + + maxInstLen = Math.max(maxInstLen, instruction.length() + 1); + + String timeSString = sFormat.format(timeS); + String timeSVarString = getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false); + maxTimeSLen = Math.max(maxTimeSLen, timeSString.length()); + maxTimeSVarLen = Math.max(maxTimeSVarLen, timeSVarString.length()); + + maxCountLen = Math.max(maxCountLen, String.valueOf(topNGrams[i].getOccurrences()).length()); + } + + maxInstLen = Math.min(maxInstLen, DMLScript.STATISTICS_MAX_WRAP_LEN); + sb.append(String.format( " %" + maxNumLen + "s %-" + maxInstLen + "s %" + + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "s", numCol, instCol, timeSCol, timeSVar, countCol)); + sb.append("\n"); + for (int i = 0; i < numHittersToDisplay; i++) { + String instruction = topNGrams[i].getIdentifier(); + String [] wrappedInstruction = wrap(instruction, maxInstLen); + + //long timeNs = tmp[len - 1 - i].getValue().time.longValue(); + double timeS = topNGrams[i].getCumStats().cumTimeNanos / 1000000000d; + double timeVar = topNGrams[i].getCumStats().getTimeVariance(); + String timeSString = sFormat.format(timeS); + String timeVarString = getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false);//sFormat.format(timeVar); + + long count = topNGrams[i].getOccurrences(); + int numLines = wrappedInstruction.length; + + for(int wrapIter = 0; wrapIter < numLines; wrapIter++) { + String instStr = (wrapIter < wrappedInstruction.length) ? wrappedInstruction[wrapIter] : ""; + if(wrapIter == 0) { + // Display instruction count + sb.append(String.format( + " %" + maxNumLen + "d %-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "d", + (i + 1), instStr, timeSString, timeVarString, count)); + } + else { + sb.append(String.format( + " %" + maxNumLen + "s %-" + maxInstLen + "s %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s %" + maxCountLen + "s", + "", instStr, "", "", "")); + } + sb.append("\n"); + } + } + + return sb.toString(); + } public static void maintainCPFuncCallStats(String instName) { InstStats tmp = _instStats.get(instName); @@ -679,6 +897,13 @@ public class Statistics sb.append("Heavy hitter instructions:\n" + getHeavyHitters(maxHeavyHitters)); } + if (DMLScript.STATISTICS_NGRAMS) { + NGramBuilder<String, NGramStats>[] mergedNGrams = mergeNGrams(); + for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length; i++) { + sb.append("Most common " + DMLScript.STATISTICS_NGRAM_SIZES[i] + "-grams (sorted by absolute time):\n" + getCommonNGrams(mergedNGrams[i], DMLScript.STATISTICS_TOP_K_NGRAMS)); + } + } + if(DMLScript.FED_STATISTICS) { sb.append("\n"); sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT)); diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java new file mode 100644 index 0000000000..7554fdcd67 --- /dev/null +++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java @@ -0,0 +1,248 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.sysds.utils.stats; + +import java.lang.reflect.Array; +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.function.BiFunction; +import java.util.function.Function; +import java.util.stream.Collectors; + +public class NGramBuilder<T, U> { + + public static <T, U> String toCSV(String[] columnNames, List<NGramEntry<T, U>> entries, Function<NGramEntry<T, U>, String> statsMapper) { + StringBuilder builder = new StringBuilder(String.join(",", columnNames)); + builder.append("\n"); + + for (NGramEntry<T, U> entry : entries) { + builder.append(entry.getIdentifier().replace(",", ";")); + builder.append(","); + builder.append(entry.getCumStats()); + builder.append(","); + + if (statsMapper != null) { + builder.append(statsMapper.apply(entry)); + builder.append(","); + } + + builder.append(entry.getOccurrences()); + builder.append("\n"); + } + + return builder.toString(); + } + + public static class NGramEntry<T, U> { + private final String identifier; + private final T[] entry; + private U[] stats; + private U cumStats; + private long occurrences; + private int offset; + + public NGramEntry(String identifier, T[] entry, U[] stats, U cumStats, int offset) { + this.identifier = identifier; + this.entry = entry; + this.stats = stats; + this.occurrences = 1; + this.offset = offset; + this.cumStats = cumStats; + } + + public String getIdentifier() { + return identifier; + } + + public long getOccurrences() { + return occurrences; + } + + public U getStat(int index) { + if (index < 0 || index >= entry.length) + throw new ArrayIndexOutOfBoundsException("Index " + index + " is out of bounds"); + + index = (index + offset) % entry.length; + return stats[index]; + } + + public U getCumStats() { + return cumStats; + } + + public U[] getStats() { + return stats; + } + + public int getOffset() { + return offset; + } + + void setCumStats(U cumStats) { + this.cumStats = cumStats; + } + + public T get(int index) { + if (index < 0 || index >= entry.length) + throw new ArrayIndexOutOfBoundsException("Index " + index + " is out of bounds"); + + index = (index + offset) % entry.length; + return entry[index]; + } + + private NGramEntry<T, U> increment() { + occurrences++; + return this; + } + + private NGramEntry<T, U> add(NGramEntry<T, U> entry) { + return add(entry.occurrences); + } + + private NGramEntry<T, U> add(long n) { + occurrences += n; + return this; + } + } + + private final T[] currentNGram; + private final U[] currentStats; + private int currentIndex = 0; + private int currentSize = 0; + private final Function<T, String> idGenerator; + private final BiFunction<U, U, U> statsMerger; + private final ConcurrentHashMap<String, NGramEntry<T, U>> nGrams; + + @SuppressWarnings("unchecked") + public NGramBuilder(Class<T> clazz, Class<U> clazz2, int size, Function<T, String> idGenerator, BiFunction<U, U, U> statsMerger) { + currentNGram = (T[]) Array.newInstance(clazz, size); + currentStats = (U[]) Array.newInstance(clazz2, size); + this.idGenerator = idGenerator; + this.nGrams = new ConcurrentHashMap<>(); + this.statsMerger = statsMerger; + } + + public int getSize() { + return currentNGram.length; + } + + public synchronized void merge(NGramBuilder<T, U> builder) { + builder.nGrams.forEach((k, v) -> nGrams.merge(k, v, (v1, v2) -> + { + v1.add(v2.occurrences); + v1.setCumStats(statsMerger.apply(v1.getCumStats(), v2.getCumStats())); + int index1 = v1.offset; + int index2 = v2.offset; + U[] stats1 = v1.getStats(); + U[] stats2 = v2.getStats(); + + for (int i = 0; i < stats1.length; i++) { + stats1[index1] = statsMerger.apply(stats1[index1], stats2[index2]); + index1 = (index1 + 1) % stats1.length; + index2 = (index2 + 1) % stats2.length; + } + + return v1; + })); + } + + public synchronized void append(T element, U stat) { + currentNGram[currentIndex] = element; + currentStats[currentIndex] = stat; + currentIndex = (currentIndex + 1) % currentNGram.length; + + if (currentSize < currentNGram.length) + currentSize++; + + if (currentSize == currentNGram.length) { + StringBuilder builder = new StringBuilder(currentNGram.length); + builder.append("("); + + for (int i = 0; i < currentNGram.length; i++) { + int actualIndex = (i + currentIndex) % currentSize; + builder.append(idGenerator.apply(currentNGram[actualIndex])); + + if (i != currentNGram.length - 1) + builder.append(", "); + } + + builder.append(")"); + + registerElement(builder.toString(), stat); + } + } + + public synchronized List<NGramEntry<T, U>> getTopK(int k) { + return nGrams.entrySet().stream() + .sorted(Comparator.comparingLong((Map.Entry<String, NGramEntry<T, U>> v) -> v.getValue().occurrences).reversed()) + .map(Map.Entry::getValue) + .limit(k) + .collect(Collectors.toList()); + } + + public synchronized List<NGramEntry<T, U>> getTopK(int k, Comparator<NGramEntry<T, U>> comparator, boolean reversed) { + return nGrams.entrySet().stream() + .sorted((e1, e2) -> reversed ? comparator.compare(e2.getValue(), e1.getValue()) : comparator.compare(e1.getValue(), e2.getValue())) + .map(Map.Entry::getValue) + .limit(k) + .collect(Collectors.toList()); + } + + private synchronized void registerElement(String id, U stat) { + nGrams.compute(id, (key, entry) -> { + if (entry == null) { + U cumStat = currentStats[0]; + + for (int i = 1; i < currentStats.length; i++) { + cumStat = statsMerger.apply(currentStats[i], cumStat); + } + + entry = new NGramEntry<T, U>(id, Arrays.copyOf(currentNGram, currentNGram.length), Arrays.copyOf(currentStats, currentStats.length), cumStat, currentIndex); + } else { + entry.increment(); + U[] stats = entry.getStats(); + U cumStat = null; + + int mCurrentIndex = currentIndex; + int mIndexEntry = entry.offset; + + for (int i = 0; i < stats.length; i++) { + stats[mIndexEntry] = statsMerger.apply(stats[mIndexEntry], currentStats[mCurrentIndex]); + if (i == 0) { + cumStat = stats[mIndexEntry]; + } else { + cumStat = statsMerger.apply(stats[mIndexEntry], cumStat); + } + + mCurrentIndex = (mCurrentIndex + 1) % stats.length; + mIndexEntry = (mIndexEntry + 1) % stats.length; + } + + entry.setCumStats(cumStat); + } + + return entry; + }); + } + +} diff --git a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java index 4f8af7f149..bdd0a9b415 100644 --- a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java +++ b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java @@ -19,6 +19,8 @@ package org.apache.sysds.test.applications; +import java.io.FileWriter; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; @@ -27,6 +29,9 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import org.apache.sysds.api.DMLScript; +import org.apache.sysds.utils.Statistics; +import org.apache.sysds.utils.stats.NGramBuilder; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -81,6 +86,10 @@ public class ApplyTransformTest extends AutomatedTestBase{ getAndLoadTestConfiguration(TEST_NAME); List<String> proArgs = new ArrayList<>(); + proArgs.add("-stats"); + proArgs.add("-ngrams"); + proArgs.add("1,2,3,4,5,6,7,8,9,10"); + proArgs.add("10"); proArgs.add("-nvargs"); proArgs.add("X=" + sourceDirectory + X); proArgs.add("missing_value_maps=" + (missing_value_maps.equals(" ") ? " " : sourceDirectory + missing_value_maps)); diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java index b77bc49a0d..dbb98160e2 100644 --- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java +++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java @@ -83,6 +83,9 @@ public class L2SVMTest extends AutomatedTestBase List<String> proArgs = new ArrayList<>(); proArgs.add("-stats"); + proArgs.add("-ngrams"); + proArgs.add("3,2"); + proArgs.add("10"); proArgs.add("-nvargs"); proArgs.add("X=" + input("X")); proArgs.add("Y=" + input("Y"));