This is an automated email from the ASF dual-hosted git repository.

mboehm7 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/systemds.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f96718939 [SYSTEMDS-3714] N-gram statistics of operation sequences
9f96718939 is described below

commit 9f96718939bd43fedcf942ac71f6fd70bdcf48d6
Author: Jaybit0 <jannik.lindema...@gmail.com>
AuthorDate: Wed Jul 24 18:31:43 2024 +0200

    [SYSTEMDS-3714] N-gram statistics of operation sequences
    
    Closes #2045.
---
 src/main/java/org/apache/sysds/api/DMLOptions.java |  27 +++
 src/main/java/org/apache/sysds/api/DMLScript.java  |   9 +
 .../sysds/runtime/controlprogram/ProgramBlock.java |   6 +-
 .../java/org/apache/sysds/utils/Statistics.java    | 231 ++++++++++++++++++-
 .../org/apache/sysds/utils/stats/NGramBuilder.java | 248 +++++++++++++++++++++
 .../test/applications/ApplyTransformTest.java      |   9 +
 .../apache/sysds/test/applications/L2SVMTest.java  |   3 +
 7 files changed, 529 insertions(+), 4 deletions(-)

diff --git a/src/main/java/org/apache/sysds/api/DMLOptions.java 
b/src/main/java/org/apache/sysds/api/DMLOptions.java
index 70af5ba9e8..acacc39572 100644
--- a/src/main/java/org/apache/sysds/api/DMLOptions.java
+++ b/src/main/java/org/apache/sysds/api/DMLOptions.java
@@ -53,7 +53,10 @@ public class DMLOptions {
        public String               configFile    = null;             // Path 
to config file if default config and default config is to be overridden
        public boolean              clean         = false;            // 
Whether to clean up all SystemDS working directories (FS, DFS)
        public boolean              stats         = false;            // 
Whether to record and print the statistics
+       public boolean              statsNGrams  = false;            // Whether 
to record and print the statistics n-grams
        public int                  statsCount    = 10;               // 
Default statistics count
+       public int[]                statsNGramSizes = { 3 };          // 
Default n-gram tuple sizes
+       public int                  statsTopKNGrams = 10;             // How 
many of the most heavy hitting n-grams are displayed
        public boolean              fedStats      = false;            // 
Whether to record and print the federated statistics
        public int                  fedStatsCount = 10;               // 
Default federated statistics count
        public boolean              memStats      = false;            // max 
memory statistics
@@ -212,6 +215,26 @@ public class DMLOptions {
                                }
                        }
                }
+
+               dmlOptions.statsNGrams = line.hasOption("ngrams");
+               if (dmlOptions.statsNGrams){
+                       String[] nGramArgs = line.getOptionValues("ngrams");
+                       if (nGramArgs.length == 2) {
+                               try {
+                                       String[] nGramSizeSplit = 
nGramArgs[0].split(",");
+                                       dmlOptions.statsNGramSizes = new 
int[nGramSizeSplit.length];
+
+                                       for (int i = 0; i < 
nGramSizeSplit.length; i++) {
+                                               dmlOptions.statsNGramSizes[i] = 
Integer.parseInt(nGramSizeSplit[i]);
+                                       }
+
+                                       dmlOptions.statsTopKNGrams = 
Integer.parseInt(nGramArgs[1]);
+                               } catch (NumberFormatException e) {
+                                       throw new 
org.apache.commons.cli.ParseException("Invalid argument specified for -ngrams 
option, must be a valid integer");
+                               }
+                       }
+               }
+
                dmlOptions.fedStats = line.hasOption("fedStats");
                if (dmlOptions.fedStats) {
                        String fedStatsCount = line.getOptionValue("fedStats");
@@ -335,6 +358,9 @@ public class DMLOptions {
                Option statsOpt = OptionBuilder.withArgName("count")
                        .withDescription("monitors and reports summary 
execution statistics; heavy hitter <count> is 10 unless overridden; default 
off")
                        .hasOptionalArg().create("stats");
+               Option ngramsOpt = OptionBuilder//.withArgName("ngrams")
+                       .withDescription("monitors and reports the most 
occurring n-grams; -ngrams <comma separated n's> <topK>")
+                       .hasOptionalArgs(2).create("ngrams");
                Option fedStatsOpt = OptionBuilder.withArgName("count")
                        .withDescription("monitors and reports summary 
execution statistics of federated workers; heavy hitter <count> is 10 unless 
overridden; default off")
                        .hasOptionalArg().create("fedStats");
@@ -396,6 +422,7 @@ public class DMLOptions {
                options.addOption(configOpt);
                options.addOption(cleanOpt);
                options.addOption(statsOpt);
+               options.addOption(ngramsOpt);
                options.addOption(fedStatsOpt);
                options.addOption(memOpt);
                options.addOption(explainOpt);
diff --git a/src/main/java/org/apache/sysds/api/DMLScript.java 
b/src/main/java/org/apache/sysds/api/DMLScript.java
index 3443f68740..cd86426a42 100644
--- a/src/main/java/org/apache/sysds/api/DMLScript.java
+++ b/src/main/java/org/apache/sysds/api/DMLScript.java
@@ -94,10 +94,16 @@ public class DMLScript
        private static ExecMode   EXEC_MODE                  = 
DMLOptions.defaultOptions.execMode;
        // Enable/disable to print statistics
        public static boolean     STATISTICS                 = 
DMLOptions.defaultOptions.stats;
+       // Enable/disable to print statistics n-grams
+       public static boolean     STATISTICS_NGRAMS          = 
DMLOptions.defaultOptions.statsNGrams;
        // Enable/disable to gather memory use stats in JMLC
        public static boolean     JMLC_MEM_STATISTICS        = false;
        // Set maximum heavy hitter count
        public static int         STATISTICS_COUNT           = 
DMLOptions.defaultOptions.statsCount;
+       // The sizes of recorded n-gram tuples
+       public static int[]         STATISTICS_NGRAM_SIZES   = 
DMLOptions.defaultOptions.statsNGramSizes;
+       // Set top k displayed n-grams limit
+       public static int         STATISTICS_TOP_K_NGRAMS    = 
DMLOptions.defaultOptions.statsTopKNGrams;
        // Set statistics maximum wrap length
        public static int         STATISTICS_MAX_WRAP_LEN    = 30;
        // Enable/disable to print federated statistics
@@ -250,6 +256,9 @@ public class DMLScript
                {
                        STATISTICS            = dmlOptions.stats;
                        STATISTICS_COUNT      = dmlOptions.statsCount;
+                       STATISTICS_NGRAMS     = dmlOptions.statsNGrams;
+                       STATISTICS_NGRAM_SIZES = dmlOptions.statsNGramSizes;
+                       STATISTICS_TOP_K_NGRAMS = dmlOptions.statsTopKNGrams;
                        FED_STATISTICS        = dmlOptions.fedStats;
                        FED_STATISTICS_COUNT  = dmlOptions.fedStatsCount;
                        JMLC_MEM_STATISTICS   = dmlOptions.memStats;
diff --git 
a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
index 34b954148b..4e75d5456f 100644
--- a/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysds/runtime/controlprogram/ProgramBlock.java
@@ -241,7 +241,7 @@ public abstract class ProgramBlock implements ParseInfo {
        private void executeSingleInstruction(Instruction currInst, 
ExecutionContext ec) {
                try {
                        // start time measurement for statistics
-                       long t0 = (DMLScript.STATISTICS || 
LOG.isTraceEnabled()) ? System.nanoTime() : 0;
+                       long t0 = (DMLScript.STATISTICS || 
DMLScript.STATISTICS_NGRAMS || LOG.isTraceEnabled()) ? System.nanoTime() : 0;
 
                        // pre-process instruction (inst patching, listeners, 
lineage)
                        Instruction tmp = currInst.preprocessInstruction(ec);
@@ -263,6 +263,10 @@ public abstract class ProgramBlock implements ParseInfo {
                                if(DMLScript.STATISTICS) {
                                        
Statistics.maintainCPHeavyHitters(tmp.getExtendedOpcode(), System.nanoTime() - 
t0);
                                }
+
+                               if (DMLScript.STATISTICS_NGRAMS) {
+                                       
Statistics.maintainNGrams(tmp.getExtendedOpcode(), System.nanoTime() - t0);
+                               }
                        }
 
                        // optional trace information (instruction and runtime)
diff --git a/src/main/java/org/apache/sysds/utils/Statistics.java 
b/src/main/java/org/apache/sysds/utils/Statistics.java
index 5cab7dbd30..a7c764cf78 100644
--- a/src/main/java/org/apache/sysds/utils/Statistics.java
+++ b/src/main/java/org/apache/sysds/utils/Statistics.java
@@ -34,10 +34,11 @@ import 
org.apache.sysds.runtime.instructions.spark.SPInstruction;
 import org.apache.sysds.runtime.lineage.LineageCacheConfig.ReuseCacheType;
 import org.apache.sysds.runtime.lineage.LineageCacheStatistics;
 import org.apache.sysds.utils.stats.CodegenStatistics;
-import org.apache.sysds.utils.stats.RecompileStatistics;
+import org.apache.sysds.utils.stats.NGramBuilder;
 import org.apache.sysds.utils.stats.NativeStatistics;
-import org.apache.sysds.utils.stats.ParamServStatistics;
 import org.apache.sysds.utils.stats.ParForStatistics;
+import org.apache.sysds.utils.stats.ParamServStatistics;
+import org.apache.sysds.utils.stats.RecompileStatistics;
 import org.apache.sysds.utils.stats.SparkStatistics;
 import org.apache.sysds.utils.stats.TransformStatistics;
 
@@ -45,10 +46,13 @@ import java.lang.management.CompilationMXBean;
 import java.lang.management.GarbageCollectorMXBean;
 import java.lang.management.ManagementFactory;
 import java.text.DecimalFormat;
+import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Comparator;
 import java.util.HashMap;
 import java.util.List;
+import java.util.Locale;
+import java.util.Map;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
@@ -64,6 +68,46 @@ public class Statistics
                private final LongAdder time = new LongAdder();
                private final LongAdder count = new LongAdder();
        }
+
+       public static class NGramStats {
+
+               public final long n;
+               public final long cumTimeNanos;
+               public final double m2;
+
+               public static <T> Comparator<NGramBuilder.NGramEntry<T, 
NGramStats>> getComparator() {
+                       return Comparator.comparingLong(entry -> 
entry.getCumStats().cumTimeNanos);
+               }
+
+               public static NGramStats merge(NGramStats stats1, NGramStats 
stats2) {
+                       // Using the algorithm from: 
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Parallel_algorithm
+                       long newN = stats1.n + stats2.n;
+                       long cumTimeNanos = stats1.cumTimeNanos + 
stats2.cumTimeNanos;
+
+                       // Ensure the calculation uses floating-point arithmetic
+                       double mean1 = (double) stats1.cumTimeNanos / 
1000000000d / stats1.n;
+                       double mean2 = (double) stats2.cumTimeNanos / 
1000000000d / stats2.n;
+                       double delta = mean2 - mean1;
+
+                       double newM2 = stats1.m2 + stats2.m2 + delta * delta * 
stats1.n * stats2.n / (double)newN;
+
+                       return new NGramStats(newN, cumTimeNanos, newM2);
+               }
+
+               public NGramStats(final long n, final long cumTimeNanos, final 
double m2) {
+                       this.n = n;
+                       this.cumTimeNanos = cumTimeNanos;
+                       this.m2 = m2;
+               }
+
+               public double getTimeVariance() {
+                       return m2 / Math.max(n-1, 1);
+               }
+
+               public String toString() {
+                       return String.format(Locale.US, "%.5f", (cumTimeNanos / 
1000000000d));
+               }
+       }
        
        private static long compileStartTime = 0;
        private static long compileEndTime = 0;
@@ -71,7 +115,8 @@ public class Statistics
        private static long execEndTime = 0;
        
        //heavy hitter counts and times 
-       private static final ConcurrentHashMap<String,InstStats>_instStats = 
new ConcurrentHashMap<>();
+       private static final ConcurrentHashMap<String,InstStats> _instStats = 
new ConcurrentHashMap<>();
+       private static final ConcurrentHashMap<String, NGramBuilder<String, 
NGramStats>[]> _instStatsNGram = new ConcurrentHashMap<>();
 
        // number of compiled/executed SP instructions
        private static final LongAdder numExecutedSPInst = new LongAdder();
@@ -252,6 +297,8 @@ public class Statistics
                DMLCompressionStatistics.reset();
 
                FederatedStatistics.reset();
+
+               _instStatsNGram.clear();
        }
 
        public static void resetJITCompileTime(){
@@ -353,6 +400,177 @@ public class Statistics
                tmp.time.add(timeNanos);
                tmp.count.increment();
        }
+
+       public static void maintainNGrams(String instName, long timeNanos) {
+               NGramBuilder<String, NGramStats>[] tmp = 
_instStatsNGram.computeIfAbsent(Thread.currentThread().getName(), k -> {
+                       NGramBuilder<String, NGramStats>[] threadEntry = new 
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+                       for (int i = 0; i < threadEntry.length; i++) {
+                               threadEntry[i] = new NGramBuilder<String, 
NGramStats>(String.class, NGramStats.class, 
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+                       }
+                       return threadEntry;
+               });
+
+               for (int i = 0; i < tmp.length; i++)
+                       tmp[i].append(instName, new NGramStats(1, timeNanos, 
0));
+       }
+
+       public static NGramBuilder<String, NGramStats>[] mergeNGrams() {
+               NGramBuilder<String, NGramStats>[] builders = new 
NGramBuilder[DMLScript.STATISTICS_NGRAM_SIZES.length];
+
+               for (int i = 0; i < builders.length; i++) {
+                       builders[i] = new NGramBuilder<String, 
NGramStats>(String.class, NGramStats.class, 
DMLScript.STATISTICS_NGRAM_SIZES[i], s -> s, NGramStats::merge);
+               }
+
+               for (int i = 0; i < DMLScript.STATISTICS_NGRAM_SIZES.length; 
i++) {
+                       for (Map.Entry<String, NGramBuilder<String, 
NGramStats>[]> entry : _instStatsNGram.entrySet()) {
+                               NGramBuilder<String, NGramStats> mbuilder = 
entry.getValue()[i];
+                               builders[i].merge(mbuilder);
+                       }
+               }
+
+               return builders;
+       }
+
+       public static String getNGramStdDevs(NGramStats[] stats, int offset, 
int prec, boolean displayZero) {
+               StringBuilder sb = new StringBuilder();
+               sb.append("(");
+               boolean containsData = false;
+               int actualIndex;
+               for (int i = 0; i < stats.length; i++) {
+                       if (i != 0)
+                               sb.append(", ");
+                       actualIndex = (offset + i) % stats.length;
+                       double var = 1000000000d * stats[actualIndex].n * 
Math.sqrt(stats[actualIndex].getTimeVariance()) / 
stats[actualIndex].cumTimeNanos;
+                       if (displayZero || var >= Math.pow(10, -prec)) {
+                               sb.append(String.format(Locale.US, "%." + prec 
+ "f", var));
+                               containsData = true;
+                       }
+               }
+               sb.append(")");
+               return containsData ? sb.toString() : "-";
+       }
+
+       public static String getNGramAvgTimes(NGramStats[] stats, int offset, 
int prec) {
+               StringBuilder sb = new StringBuilder();
+               sb.append("(");
+               int actualIndex;
+               for (int i = 0; i < stats.length; i++) {
+                       if (i != 0)
+                               sb.append(", ");
+                       actualIndex = (offset + i) % stats.length;
+                       double var = (stats[actualIndex].cumTimeNanos / 
1000000000d) / stats[actualIndex].n;
+                       sb.append(String.format(Locale.US, "%." + prec + "f", 
var));
+               }
+               sb.append(")");
+               return sb.toString();
+       }
+
+       public static String nGramToCSV(final NGramBuilder<String, NGramStats> 
mbuilder) {
+               ArrayList<String> colList = new ArrayList<>();
+               colList.add("N-Gram");
+               colList.add("Time[s]");
+
+               for (int j = 0; j < mbuilder.getSize(); j++)
+                       colList.add("Col" + (j + 1));
+               for (int j = 0; j < mbuilder.getSize(); j++)
+                       colList.add("Col" + (j + 1) + "::Mean(Time[s])");
+               for (int j = 0; j < mbuilder.getSize(); j++)
+                       colList.add("Col" + (j + 1) + "::StdDev(Time[s])/Col" + 
(j + 1) + "::Mean(Time[s])");
+
+               colList.add("Count");
+
+               return NGramBuilder.toCSV(colList.toArray(new 
String[colList.size()]), mbuilder.getTopK(100000, 
Statistics.NGramStats.getComparator(), true), e -> {
+                       StringBuilder builder = new StringBuilder();
+                       builder.append(e.getIdentifier().replace("(", 
"").replace(")", "").replace(", ", ","));
+                       builder.append(",");
+                       
builder.append(Statistics.getNGramAvgTimes(e.getStats(), e.getOffset(), 
9).replace("-", "").replace("(", "").replace(")", ""));
+                       builder.append(",");
+                       String stdDevs = 
Statistics.getNGramStdDevs(e.getStats(), e.getOffset(), 9, true).replace("-", 
"").replace("(", "").replace(")", "");
+                       if (stdDevs.isEmpty()) {
+                               for (int j = 0; j < mbuilder.getSize()-1; j++)
+                                       builder.append(",");
+                       } else {
+                               builder.append(stdDevs);
+                       }
+                       return builder.toString();
+               });
+       }
+
+       public static String getCommonNGrams(NGramBuilder<String, NGramStats> 
builder, int num) {
+               if (num <= 0 || _instStatsNGram.size() <= 0)
+                       return "-";
+
+               //NGramBuilder<String, Long> builder = mergeNGrams();
+               @SuppressWarnings("unchecked")
+               NGramBuilder.NGramEntry<String, NGramStats>[] topNGrams = 
builder.getTopK(num, NGramStats.getComparator(), 
true).toArray(NGramBuilder.NGramEntry[]::new);
+
+               final String numCol = "#";
+               final String instCol = "N-Gram";
+               final String timeSCol = "Time(s)";
+               final String timeSVar = "StdDev(t)/Mean(t)";
+               final String countCol = "Count";
+               StringBuilder sb = new StringBuilder();
+               int len = topNGrams.length;
+               int numHittersToDisplay = Math.min(num, len);
+               int maxNumLen = String.valueOf(numHittersToDisplay).length();
+               int maxInstLen = instCol.length();
+               int maxTimeSLen = timeSCol.length();
+               int maxTimeSVarLen = timeSVar.length();
+               int maxCountLen = countCol.length();
+               DecimalFormat sFormat = new DecimalFormat("#,##0.000");
+
+               for (int i = 0; i < numHittersToDisplay; i++) {
+                       long timeNs = topNGrams[i].getCumStats().cumTimeNanos;
+                       String instruction = topNGrams[i].getIdentifier();
+                       double timeS = timeNs / 1000000000d;
+
+
+                       maxInstLen = Math.max(maxInstLen, instruction.length() 
+ 1);
+
+                       String timeSString = sFormat.format(timeS);
+                       String timeSVarString = 
getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, false);
+                       maxTimeSLen = Math.max(maxTimeSLen, 
timeSString.length());
+                       maxTimeSVarLen = Math.max(maxTimeSVarLen, 
timeSVarString.length());
+
+                       maxCountLen = Math.max(maxCountLen, 
String.valueOf(topNGrams[i].getOccurrences()).length());
+               }
+
+               maxInstLen = Math.min(maxInstLen, 
DMLScript.STATISTICS_MAX_WRAP_LEN);
+               sb.append(String.format( " %" + maxNumLen + "s  %-" + 
maxInstLen + "s  %"
+                               + maxTimeSLen + "s  %" + maxTimeSVarLen + "s  
%" + maxCountLen + "s", numCol, instCol, timeSCol, timeSVar, countCol));
+               sb.append("\n");
+               for (int i = 0; i < numHittersToDisplay; i++) {
+                       String instruction = topNGrams[i].getIdentifier();
+                       String [] wrappedInstruction = wrap(instruction, 
maxInstLen);
+
+                       //long timeNs = tmp[len - 1 - 
i].getValue().time.longValue();
+                       double timeS = topNGrams[i].getCumStats().cumTimeNanos 
/ 1000000000d;
+                       double timeVar = 
topNGrams[i].getCumStats().getTimeVariance();
+                       String timeSString = sFormat.format(timeS);
+                       String timeVarString = 
getNGramStdDevs(topNGrams[i].getStats(), topNGrams[i].getOffset(), 3, 
false);//sFormat.format(timeVar);
+
+                       long count = topNGrams[i].getOccurrences();
+                       int numLines = wrappedInstruction.length;
+
+                       for(int wrapIter = 0; wrapIter < numLines; wrapIter++) {
+                               String instStr = (wrapIter < 
wrappedInstruction.length) ? wrappedInstruction[wrapIter] : "";
+                               if(wrapIter == 0) {
+                                       // Display instruction count
+                                       sb.append(String.format(
+                                                       " %" + maxNumLen + "d  
%-" + maxInstLen + "s  %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s  %" + 
maxCountLen + "d",
+                                                       (i + 1), instStr, 
timeSString, timeVarString, count));
+                               }
+                               else {
+                                       sb.append(String.format(
+                                                       " %" + maxNumLen + "s  
%-" + maxInstLen + "s  %" + maxTimeSLen + "s %" + maxTimeSVarLen + "s  %" + 
maxCountLen + "s",
+                                                       "", instStr, "", "", 
""));
+                               }
+                               sb.append("\n");
+                       }
+               }
+
+               return sb.toString();
+       }
        
        public static void maintainCPFuncCallStats(String instName) {
                InstStats tmp = _instStats.get(instName);
@@ -679,6 +897,13 @@ public class Statistics
                        sb.append("Heavy hitter instructions:\n" + 
getHeavyHitters(maxHeavyHitters));
                }
 
+               if (DMLScript.STATISTICS_NGRAMS) {
+                       NGramBuilder<String, NGramStats>[] mergedNGrams = 
mergeNGrams();
+                       for (int i = 0; i < 
DMLScript.STATISTICS_NGRAM_SIZES.length; i++) {
+                               sb.append("Most common " + 
DMLScript.STATISTICS_NGRAM_SIZES[i] + "-grams (sorted by absolute time):\n" + 
getCommonNGrams(mergedNGrams[i], DMLScript.STATISTICS_TOP_K_NGRAMS));
+                       }
+               }
+
                if(DMLScript.FED_STATISTICS) {
                        sb.append("\n");
                        
sb.append(FederatedStatistics.displayStatistics(DMLScript.FED_STATISTICS_COUNT));
diff --git a/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java 
b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
new file mode 100644
index 0000000000..7554fdcd67
--- /dev/null
+++ b/src/main/java/org/apache/sysds/utils/stats/NGramBuilder.java
@@ -0,0 +1,248 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.sysds.utils.stats;
+
+import java.lang.reflect.Array;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.function.BiFunction;
+import java.util.function.Function;
+import java.util.stream.Collectors;
+
+public class NGramBuilder<T, U> {
+
+       public static <T, U> String toCSV(String[] columnNames, 
List<NGramEntry<T, U>> entries, Function<NGramEntry<T, U>, String> statsMapper) 
{
+               StringBuilder builder = new StringBuilder(String.join(",", 
columnNames));
+               builder.append("\n");
+
+               for (NGramEntry<T, U> entry : entries) {
+                       builder.append(entry.getIdentifier().replace(",", ";"));
+                       builder.append(",");
+                       builder.append(entry.getCumStats());
+                       builder.append(",");
+
+                       if (statsMapper != null) {
+                               builder.append(statsMapper.apply(entry));
+                               builder.append(",");
+                       }
+
+                       builder.append(entry.getOccurrences());
+                       builder.append("\n");
+               }
+
+               return builder.toString();
+       }
+
+       public static class NGramEntry<T, U> {
+               private final String identifier;
+               private final T[] entry;
+               private U[] stats;
+               private U cumStats;
+               private long occurrences;
+               private int offset;
+
+               public NGramEntry(String identifier, T[] entry, U[] stats, U 
cumStats, int offset) {
+                       this.identifier = identifier;
+                       this.entry = entry;
+                       this.stats = stats;
+                       this.occurrences = 1;
+                       this.offset = offset;
+                       this.cumStats = cumStats;
+               }
+
+               public String getIdentifier() {
+                       return identifier;
+               }
+
+               public long getOccurrences() {
+                       return occurrences;
+               }
+
+               public U getStat(int index) {
+                       if (index < 0 || index >= entry.length)
+                               throw new ArrayIndexOutOfBoundsException("Index 
" + index + " is out of bounds");
+
+                       index = (index + offset) % entry.length;
+                       return stats[index];
+               }
+
+               public U getCumStats() {
+                       return cumStats;
+               }
+
+               public U[] getStats() {
+                       return stats;
+               }
+
+               public int getOffset() {
+                       return offset;
+               }
+
+               void setCumStats(U cumStats) {
+                       this.cumStats = cumStats;
+               }
+
+               public T get(int index) {
+                       if (index < 0 || index >= entry.length)
+                               throw new ArrayIndexOutOfBoundsException("Index 
" + index + " is out of bounds");
+
+                       index = (index + offset) % entry.length;
+                       return entry[index];
+               }
+
+               private NGramEntry<T, U> increment() {
+                       occurrences++;
+                       return this;
+               }
+
+               private NGramEntry<T, U> add(NGramEntry<T, U> entry) {
+                       return add(entry.occurrences);
+               }
+
+               private NGramEntry<T, U> add(long n) {
+                       occurrences += n;
+                       return this;
+               }
+       }
+
+       private final T[] currentNGram;
+       private final U[] currentStats;
+       private int currentIndex = 0;
+       private int currentSize = 0;
+       private final Function<T, String> idGenerator;
+       private final BiFunction<U, U, U> statsMerger;
+       private final ConcurrentHashMap<String, NGramEntry<T, U>> nGrams;
+
+       @SuppressWarnings("unchecked")
+       public NGramBuilder(Class<T> clazz, Class<U> clazz2, int size, 
Function<T, String> idGenerator, BiFunction<U, U, U> statsMerger) {
+               currentNGram = (T[]) Array.newInstance(clazz, size);
+               currentStats = (U[]) Array.newInstance(clazz2, size);
+               this.idGenerator = idGenerator;
+               this.nGrams = new ConcurrentHashMap<>();
+               this.statsMerger = statsMerger;
+       }
+
+       public int getSize() {
+               return currentNGram.length;
+       }
+
+       public synchronized void merge(NGramBuilder<T, U> builder) {
+               builder.nGrams.forEach((k, v) -> nGrams.merge(k, v, (v1, v2) ->
+               {
+                       v1.add(v2.occurrences);
+                       v1.setCumStats(statsMerger.apply(v1.getCumStats(), 
v2.getCumStats()));
+                       int index1 = v1.offset;
+                       int index2 = v2.offset;
+                       U[] stats1 = v1.getStats();
+                       U[] stats2 = v2.getStats();
+
+                       for (int i = 0; i < stats1.length; i++) {
+                               stats1[index1] = 
statsMerger.apply(stats1[index1], stats2[index2]);
+                               index1 = (index1 + 1) % stats1.length;
+                               index2 = (index2 + 1) % stats2.length;
+                       }
+
+                       return v1;
+               }));
+       }
+
+       public synchronized void append(T element, U stat) {
+               currentNGram[currentIndex] = element;
+               currentStats[currentIndex] = stat;
+               currentIndex = (currentIndex + 1) % currentNGram.length;
+
+               if (currentSize < currentNGram.length)
+                       currentSize++;
+
+               if (currentSize == currentNGram.length) {
+                       StringBuilder builder = new 
StringBuilder(currentNGram.length);
+                       builder.append("(");
+
+                       for (int i = 0; i < currentNGram.length; i++) {
+                               int actualIndex = (i + currentIndex) % 
currentSize;
+                               
builder.append(idGenerator.apply(currentNGram[actualIndex]));
+
+                               if (i != currentNGram.length - 1)
+                                       builder.append(", ");
+                       }
+
+                       builder.append(")");
+
+                       registerElement(builder.toString(), stat);
+               }
+       }
+
+       public synchronized List<NGramEntry<T, U>> getTopK(int k) {
+               return nGrams.entrySet().stream()
+                               
.sorted(Comparator.comparingLong((Map.Entry<String, NGramEntry<T, U>> v) -> 
v.getValue().occurrences).reversed())
+                               .map(Map.Entry::getValue)
+                               .limit(k)
+                               .collect(Collectors.toList());
+       }
+
+       public synchronized List<NGramEntry<T, U>> getTopK(int k, 
Comparator<NGramEntry<T, U>> comparator, boolean reversed) {
+               return nGrams.entrySet().stream()
+                               .sorted((e1, e2) -> reversed ? 
comparator.compare(e2.getValue(), e1.getValue()) : 
comparator.compare(e1.getValue(), e2.getValue()))
+                               .map(Map.Entry::getValue)
+                               .limit(k)
+                               .collect(Collectors.toList());
+       }
+
+       private synchronized void registerElement(String id, U stat) {
+               nGrams.compute(id, (key, entry) ->  {
+                       if (entry == null) {
+                               U cumStat = currentStats[0];
+
+                               for (int i = 1; i < currentStats.length; i++) {
+                                       cumStat = 
statsMerger.apply(currentStats[i], cumStat);
+                               }
+
+                               entry = new NGramEntry<T, U>(id, 
Arrays.copyOf(currentNGram, currentNGram.length), Arrays.copyOf(currentStats, 
currentStats.length), cumStat, currentIndex);
+                       } else {
+                               entry.increment();
+                               U[] stats = entry.getStats();
+                               U cumStat = null;
+
+                               int mCurrentIndex = currentIndex;
+                               int mIndexEntry = entry.offset;
+
+                               for (int i = 0; i < stats.length; i++) {
+                                       stats[mIndexEntry] = 
statsMerger.apply(stats[mIndexEntry], currentStats[mCurrentIndex]);
+                                       if (i == 0) {
+                                               cumStat = stats[mIndexEntry];
+                                       } else {
+                                               cumStat = 
statsMerger.apply(stats[mIndexEntry], cumStat);
+                                       }
+
+                                       mCurrentIndex = (mCurrentIndex + 1) % 
stats.length;
+                                       mIndexEntry = (mIndexEntry + 1) % 
stats.length;
+                               }
+
+                               entry.setCumStats(cumStat);
+                       }
+
+                       return entry;
+               });
+       }
+
+}
diff --git 
a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java 
b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
index 4f8af7f149..bdd0a9b415 100644
--- a/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/ApplyTransformTest.java
@@ -19,6 +19,8 @@
 
 package org.apache.sysds.test.applications;
 
+import java.io.FileWriter;
+import java.io.IOException;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -27,6 +29,9 @@ import java.util.Iterator;
 import java.util.List;
 import java.util.Map;
 
+import org.apache.sysds.api.DMLScript;
+import org.apache.sysds.utils.Statistics;
+import org.apache.sysds.utils.stats.NGramBuilder;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import org.junit.runners.Parameterized;
@@ -81,6 +86,10 @@ public class ApplyTransformTest extends AutomatedTestBase{
                getAndLoadTestConfiguration(TEST_NAME);
                
                List<String> proArgs = new ArrayList<>();
+               proArgs.add("-stats");
+               proArgs.add("-ngrams");
+               proArgs.add("1,2,3,4,5,6,7,8,9,10");
+               proArgs.add("10");
                proArgs.add("-nvargs");
                proArgs.add("X=" + sourceDirectory + X);
                proArgs.add("missing_value_maps=" + 
(missing_value_maps.equals(" ") ? " " : sourceDirectory + missing_value_maps));
diff --git a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java 
b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
index b77bc49a0d..dbb98160e2 100644
--- a/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
+++ b/src/test/java/org/apache/sysds/test/applications/L2SVMTest.java
@@ -83,6 +83,9 @@ public class L2SVMTest extends AutomatedTestBase
 
                List<String> proArgs = new ArrayList<>();
                proArgs.add("-stats");
+               proArgs.add("-ngrams");
+               proArgs.add("3,2");
+               proArgs.add("10");
                proArgs.add("-nvargs");
                proArgs.add("X=" + input("X"));
                proArgs.add("Y=" + input("Y"));

Reply via email to