Repository: systemml
Updated Branches:
  refs/heads/master b56612f02 -> 3bad7e7a3


[MINOR] Profile memory use in JMLC execution

This PR adds utilities to profile memory use during execution in JMLC. 
Specifically, the following changes were made:

1. Added options setStatistics() and gatherMemStats() to api.jmlc.Connection 
which control whether or not statistics should be gathered, and if so, whether 
memory use should be profiled. Also added an appropriate method to 
api.jmlc.PreparedScript to display the resulting statistics. The following 
points are only applicable when running in JMLC mode, and memory statistics 
have been enabled. Both these options are false by default.
2. Modified utils.Statistics to track the memory used by distinct CacheBlock 
objects. At the conclusion of the script, the maximum memory use is reported. 
Memory use is computed by calling the object's getInMemorySize() method. This 
will generally be a slight over-estimate of the actual memory used by the 
object.
3. If FINEGRAINED_STATISTICS are enabled, Statistics will also track the memory 
use by each named variable in a DML script and report this in a table as in 
heavy hitter instructions. The goal of this is to detect unexpected large 
intermediate matrices (e.g. resulting from an outer product X %*% t(X)).
4. If FINEGRAINED_STATISTICS are enabled, Statistics will attempt to measure 
more accurate memory use by checking to see if an object has been garbage 
collected. This is done by maintaining a soft reference to the object and 
periodically checking to see if it has become null. This is enabled only when 
using fine-grained statistics since it introduces potentially non-trivial 
overheads by scanning a list of live objects. Note that simply using rmvar to 
remove a live variable results in a substantial underestimate of memory used by 
the program and so this method is not used. When finegrained statistics are not 
enabled, the resulting statistics will be an overestimate.

Potential impacts to performance: when finegrained statistics are enabled there 
will be some performance degradation from maintaining the set of live variables.

Potential Improvements: Related to the above, it would be nice to find a way of 
accurately tracking when an object is actually released without resorting to 
checking whether a soft reference has become null. It might also be nice to 
include a line number indicating where a "heavy hitting object" was created to 
make debugging easier.

Closes #794.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/3bad7e7a
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/3bad7e7a
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/3bad7e7a

Branch: refs/heads/master
Commit: 3bad7e7a36fa31d58ca8240251953db5921ab45e
Parents: b56612f
Author: Anthony Thomas <ahtho...@eng.ucsd.edu>
Authored: Fri Jul 6 11:10:17 2018 -0700
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Fri Jul 6 11:23:36 2018 -0700

----------------------------------------------------------------------
 docs/jmlc.md                                    |  26 ++-
 .../java/org/apache/sysml/api/DMLScript.java    |   1 +
 .../org/apache/sysml/api/jmlc/Connection.java   |  18 ++
 .../apache/sysml/api/jmlc/PreparedScript.java   |  11 +-
 .../controlprogram/LocalVariableMap.java        |  23 ++-
 .../runtime/controlprogram/ProgramBlock.java    |   2 +
 .../controlprogram/caching/CacheableData.java   |  17 +-
 .../java/org/apache/sysml/utils/Statistics.java | 187 ++++++++++++++++---
 8 files changed, 250 insertions(+), 35 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/docs/jmlc.md
----------------------------------------------------------------------
diff --git a/docs/jmlc.md b/docs/jmlc.md
index 2183700..a703d01 100644
--- a/docs/jmlc.md
+++ b/docs/jmlc.md
@@ -49,6 +49,18 @@ of SystemML's distributed modes, such as Spark batch mode or 
Hadoop batch mode,
 distributed computing capabilities. JMLC offers embeddability at the cost of 
performance, so its use is
 dependent on the nature of the business use case being addressed.
 
+## Statistics
+
+JMLC can be configured to gather runtime statistics, as in the MLContext API, 
by calling Connection's `setStatistics()`
+method with a value of `true`. JMLC can also be configured to gather 
statistics on the memory used by matrices and
+frames in the DML script. To enable collection of memory statistics, call 
Connection's `gatherMemStats()` method
+with a value of `true`. When finegrained statistics are enabled in 
`SystemML.conf`, JMLC will also report the variables
+in the DML script which used the most memory. By default, the memory use 
reported will be an overestimte of the actual
+memory required to run the program. When finegrained statistics are enabled, 
JMLC will gather more accurate statistics
+by keeping track of garbage collection events and reducing the memory estimate 
accordingly. The most accurate way to
+determine the memory required by a script is to run the script in a single 
thread and enable finegrained statistics.
+
+An example showing how to enable statistics in JMLC is presented in the 
section below.
 
 ---
 
@@ -114,11 +126,19 @@ the resulting `"predicted_y"` matrix. We repeat this 
process. When done, we clos
  
         // obtain connection to SystemML
         Connection conn = new Connection();
+
+        // turn on gathering of runtime statistics and memory use
+        conn.setStatistics(true);
+        conn.gatherMemStats(true);
  
         // read in and precompile DML script, registering inputs and outputs
         String dml = conn.readScript("scoring-example.dml");
         PreparedScript script = conn.prepareScript(dml, new String[] { "W", 
"X" }, new String[] { "predicted_y" }, false);
- 
+
+        // obtain the runtime plan generated by SystemML
+        String plan = script.explain();
+        System.out.println(plan);
+
         double[][] mtx = matrix(4, 3, new double[] { 1, 2, 3, 4, 5, 6, 7, 8, 9 
});
         double[][] result = null;
  
@@ -127,6 +147,10 @@ the resulting `"predicted_y"` matrix. We repeat this 
process. When done, we clos
         script.setMatrix("X", randomMatrix(3, 3, -1, 1, 0.7));
         result = script.executeScript().getMatrix("predicted_y");
         displayMatrix(result);
+
+        // print the resulting runtime statistics
+        String stats = script.statistics();
+        System.out.println(stats);
  
         script.setMatrix("W", mtx);
         script.setMatrix("X", randomMatrix(3, 3, -1, 1, 0.7));

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/api/DMLScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/DMLScript.java 
b/src/main/java/org/apache/sysml/api/DMLScript.java
index 215d082..ee4be28 100644
--- a/src/main/java/org/apache/sysml/api/DMLScript.java
+++ b/src/main/java/org/apache/sysml/api/DMLScript.java
@@ -167,6 +167,7 @@ public class DMLScript
        public static RUNTIME_PLATFORM  rtplatform          = 
DMLOptions.defaultOptions.execMode;    // the execution mode
        public static boolean           STATISTICS          = 
DMLOptions.defaultOptions.stats;       // whether to print statistics
        public static boolean           FINEGRAINED_STATISTICS  = false;        
                                             // whether to print fine-grained 
statistics
+       public static boolean                   JMLC_MEMORY_STATISTICS = false; 
                                                         // whether to gather 
memory use stats in JMLC
        public static int               STATISTICS_COUNT    = 
DMLOptions.defaultOptions.statsCount;  // statistics maximum heavy hitter count
        public static int               STATISTICS_MAX_WRAP_LEN = 30;           
                     // statistics maximum wrap length
        public static boolean           ENABLE_DEBUG_MODE   = 
DMLOptions.defaultOptions.debug;       // debug mode

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/api/jmlc/Connection.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/jmlc/Connection.java 
b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
index 7d631f8..550b1c6 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/Connection.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/Connection.java
@@ -178,6 +178,24 @@ public class Connection implements Closeable
                
                setLocalConfigs();
        }
+
+       /**
+        * Sets a boolean flag indicating if runtime statistics should be 
gathered
+        * Same behavior as in "MLContext.setStatistics()"
+        *
+        * @param stats boolean value with true indicating statistics should be 
gathered
+        */
+       public void setStatistics(boolean stats) { DMLScript.STATISTICS = 
stats; }
+
+       /**
+        * Sets a boolean flag indicating if memory profiling statistics should 
be
+        * gathered. The option is false by default.
+        * @param stats boolean value with true indicating memory statistics 
should be gathered
+        */
+       public void gatherMemStats(boolean stats) {
+               DMLScript.STATISTICS = stats || DMLScript.STATISTICS;
+               DMLScript.JMLC_MEMORY_STATISTICS = stats;
+       }
        
        /**
         * Prepares (precompiles) a script and registers input and output 
variables.

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java 
b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
index c0d0be2..dec7eb9 100644
--- a/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
+++ b/src/main/java/org/apache/sysml/api/jmlc/PreparedScript.java
@@ -30,6 +30,7 @@ import org.apache.commons.logging.Log;
 import org.apache.commons.logging.LogFactory;
 import org.apache.sysml.api.ConfigurableAPI;
 import org.apache.sysml.api.DMLException;
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.CompilerConfig;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.conf.DMLConfig;
@@ -60,6 +61,7 @@ import org.apache.sysml.runtime.matrix.data.MatrixBlock;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.util.DataConverter;
 import org.apache.sysml.utils.Explain;
+import org.apache.sysml.utils.Statistics;
 
 /**
  * Representation of a prepared (precompiled) DML/PyDML script.
@@ -446,7 +448,7 @@ public class PreparedScript implements ConfigurableAPI
                
                //clear thread-local configurations
                ConfigurationManager.clearLocalConfigs();
-               
+
                return rvars;
        }
        
@@ -458,6 +460,13 @@ public class PreparedScript implements ConfigurableAPI
        public String explain() {
                return Explain.explain(_prog);
        }
+
+       /**
+        * Return a string containing runtime statistics. Note: these are not 
thread local
+        * and will reflect execution in all threads
+        * @return string containing statistics
+        */
+       public String statistics() { return Statistics.display(); }
        
        /**
         * Enables function recompilation, selectively for the given functions. 

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
index a926816..6ab1efe 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/LocalVariableMap.java
@@ -26,10 +26,12 @@ import java.util.Map.Entry;
 import java.util.Set;
 import java.util.StringTokenizer;
 
+import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.runtime.controlprogram.caching.CacheableData;
 import org.apache.sysml.runtime.controlprogram.parfor.ProgramConverter;
 import org.apache.sysml.runtime.controlprogram.parfor.util.IDSequence;
 import org.apache.sysml.runtime.instructions.cp.Data;
+import org.apache.sysml.utils.Statistics;
 
 /**
  * Replaces <code>HashMap&lang;String, Data&rang;</code> as the table of
@@ -123,21 +125,26 @@ public class LocalVariableMap implements Cloneable
        }
 
        public double getPinnedDataSize() {
-               //note: this method returns the total size of distinct pinned
-               //data objects that are not subject to automatic eviction 
-               //(in JMLC all matrices and frames are pinned)
-               
+               // note: this method returns the total size of distinct pinned
+               // data objects that are not subject to automatic eviction
+               // (in JMLC all matrices and frames are pinned)
+
                //compute map of distinct cachable data
                Map<Integer, Data> dict = new HashMap<>();
+               double total = 0.0;
                for( Entry<String,Data> e : localMap.entrySet() ) {
                        int hash = System.identityHashCode(e.getValue());
-                       if( !dict.containsKey(hash) && e.getValue() instanceof 
CacheableData )
+                       if( !dict.containsKey(hash) && e.getValue() instanceof 
CacheableData ) {
                                dict.put(hash, e.getValue());
+                               double size = ((CacheableData) 
e.getValue()).getDataSize();
+                               if ((DMLScript.JMLC_MEMORY_STATISTICS) && 
(DMLScript.FINEGRAINED_STATISTICS))
+                                       
Statistics.maintainCPHeavyHittersMem(e.getKey(), size);
+                               total += size;
+                       }
                }
-               
+
                //compute total in-memory size
-               return dict.values().stream().mapToDouble(
-                       d -> ((CacheableData<?>)d).getDataSize()).sum();
+               return total;
        }
        
        public long countPinnedData() {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java 
b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
index b9a5133..b7476ae 100644
--- a/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
+++ b/src/main/java/org/apache/sysml/runtime/controlprogram/ProgramBlock.java
@@ -259,6 +259,8 @@ public class ProgramBlock implements ParseInfo
                                Statistics.maintainCPHeavyHitters(
                                        tmp.getExtendedOpcode(), 
System.nanoTime()-t0);
                        }
+                       if ((DMLScript.JMLC_MEMORY_STATISTICS) && 
(DMLScript.FINEGRAINED_STATISTICS))
+                               ec.getVariables().getPinnedDataSize();
 
                        // optional trace information (instruction and runtime)
                        if( LOG.isTraceEnabled() ) {

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
----------------------------------------------------------------------
diff --git 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
index 54d8e14..5b1c26b 100644
--- 
a/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
+++ 
b/src/main/java/org/apache/sysml/runtime/controlprogram/caching/CacheableData.java
@@ -55,6 +55,7 @@ import org.apache.sysml.runtime.matrix.data.InputInfo;
 import org.apache.sysml.runtime.matrix.data.OutputInfo;
 import org.apache.sysml.runtime.util.LocalFileUtils;
 import org.apache.sysml.runtime.util.MapReduceTool;
+import org.apache.sysml.utils.Statistics;
 
 
 /**
@@ -503,6 +504,9 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                
                setDirty(true);
                _isAcquireFromEmpty = false;
+
+               if (DMLScript.JMLC_MEMORY_STATISTICS)
+                       Statistics.addCPMemObject(newData);
                
                //set references to new data
                if (newData == null)
@@ -569,6 +573,11 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                                }
                                _requiresLocalWrite = false;
                        }
+
+                       if ((DMLScript.JMLC_MEMORY_STATISTICS) && (this._data 
!= null)) {
+                               int hash = System.identityHashCode(this._data);
+                               Statistics.removeCPMemObject(hash);
+                       }
                        
                        //create cache
                        createCache();
@@ -597,8 +606,12 @@ public abstract class CacheableData<T extends CacheBlock> 
extends Data
                // clear existing WB / FS representation (but prevent 
unnecessary probes)
                if( !(isEmpty(true)||(_data!=null && isBelowCachingThreshold()) 
                          ||(_data!=null && !isCachingActive()) )) //additional 
condition for JMLC
-                       freeEvictedBlob();      
-               
+                       freeEvictedBlob();
+
+               if ((DMLScript.JMLC_MEMORY_STATISTICS) && (this._data != null)) 
{
+                       int hash = System.identityHashCode(this._data);
+                       Statistics.removeCPMemObject(hash);
+               }
                // clear the in-memory data
                _data = null;
                clearCache();

http://git-wip-us.apache.org/repos/asf/systemml/blob/3bad7e7a/src/main/java/org/apache/sysml/utils/Statistics.java
----------------------------------------------------------------------
diff --git a/src/main/java/org/apache/sysml/utils/Statistics.java 
b/src/main/java/org/apache/sysml/utils/Statistics.java
index 0618b38..107afe0 100644
--- a/src/main/java/org/apache/sysml/utils/Statistics.java
+++ b/src/main/java/org/apache/sysml/utils/Statistics.java
@@ -22,6 +22,7 @@ package org.apache.sysml.utils;
 import java.lang.management.CompilationMXBean;
 import java.lang.management.GarbageCollectorMXBean;
 import java.lang.management.ManagementFactory;
+import java.lang.ref.SoftReference;
 import java.text.DecimalFormat;
 import java.util.Arrays;
 import java.util.Comparator;
@@ -29,11 +30,13 @@ import java.util.List;
 import java.util.Map.Entry;
 import java.util.Set;
 import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.DoubleAdder;
 import java.util.concurrent.atomic.LongAdder;
 
 import org.apache.sysml.api.DMLScript;
 import org.apache.sysml.conf.ConfigurationManager;
 import org.apache.sysml.hops.OptimizerUtils;
+import org.apache.sysml.runtime.controlprogram.caching.CacheBlock;
 import org.apache.sysml.runtime.controlprogram.caching.CacheStatistics;
 import org.apache.sysml.runtime.controlprogram.context.SparkExecutionContext;
 import org.apache.sysml.runtime.instructions.Instruction;
@@ -69,6 +72,21 @@ public class Statistics
        private static final LongAdder numExecutedSPInst = new LongAdder();
        private static final LongAdder numCompiledSPInst = new LongAdder();
 
+       // number and size of pinned objects in scope
+       private static final DoubleAdder sizeofPinnedObjects = new 
DoubleAdder();
+       private static long maxNumPinnedObjects = 0;
+       private static double maxSizeofPinnedObjects = 0;
+
+       // Maps to keep track of CP memory objects for JMLC (e.g. in memory 
matrices and frames)
+       private static final ConcurrentHashMap<String,Double> _cpMemObjs = new 
ConcurrentHashMap<>();
+       private static final ConcurrentHashMap<Integer,Double> _currCPMemObjs = 
new ConcurrentHashMap<>();
+
+       // this hash map maintains soft references to the cache blocks in 
memory. It is periodically scanned to check for
+       // objects which have been garbage collected. This enables more 
accurate memory statistics. Relying on rmvar
+       // instructions to determine when an object has been de-allocated 
results in a substantial underestimate to memory
+       // use by the program since garbage collection will not occur 
immediately.
+       private static final 
ConcurrentHashMap<Integer,SoftReference<CacheBlock>> _liveObjects = new 
ConcurrentHashMap<>();
+
        //JVM stats (low frequency updates)
        private static long jitCompileTime = 0; //in milli sec
        private static long jvmGCTime = 0; //in milli sec
@@ -553,7 +571,7 @@ public class Statistics
        public static void accPSBatchIndexingTime(long t) {
                psBatchIndexTime.add(t);
        }
-       
+
        public static String getCPHeavyHitterCode( Instruction inst )
        {
                String opcode = null;
@@ -583,6 +601,71 @@ public class Statistics
                return opcode;
        }
 
+       public static void addCPMemObject(CacheBlock data) {
+               int hash = System.identityHashCode(data);
+               double sizeof = data.getInMemorySize();
+
+               double sizePrev = _currCPMemObjs.getOrDefault(hash, 0.0);
+               _currCPMemObjs.put(hash, sizeof);
+               sizeofPinnedObjects.add(sizeof - sizePrev);
+               if (DMLScript.FINEGRAINED_STATISTICS)
+                       _liveObjects.putIfAbsent(hash, new 
SoftReference<>(data));
+               maintainMemMaxStats();
+               checkForDeadBlocks();
+       }
+
+       /**
+        * If finegrained statistics are enabled searches through a map of soft 
references to find objects
+        * which have been garbage collected. This results in more accurate 
statistics on memory use but
+        * introduces overhead so is only enabled with finegrained stats and 
when running in JMLC
+        */
+       public static void checkForDeadBlocks() {
+               if (!DMLScript.FINEGRAINED_STATISTICS)
+                       return;
+               for (Entry<Integer,SoftReference<CacheBlock>> e : 
_liveObjects.entrySet()) {
+                       if (e.getValue().get() == null) {
+                               removeCPMemObject(e.getKey());
+                               _liveObjects.remove(e.getKey());
+                       }
+               }
+       }
+
+       /**
+        * Helper method to keep track of the maximum number of pinned
+        * objects and total size yet seen
+        */
+       private static void maintainMemMaxStats() {
+               if (maxSizeofPinnedObjects < sizeofPinnedObjects.doubleValue())
+                       maxSizeofPinnedObjects = 
sizeofPinnedObjects.doubleValue();
+               if (maxNumPinnedObjects < _currCPMemObjs.size())
+                       maxNumPinnedObjects = _currCPMemObjs.size();
+       }
+
+       /**
+        * Helper method to remove a memory object which has become unpinned
+        * @param hash hash of data object
+        */
+       public static void removeCPMemObject( int hash ) {
+               if (_currCPMemObjs.containsKey(hash)) {
+                       double sizeof = _currCPMemObjs.remove(hash);
+                       sizeofPinnedObjects.add(-1.0 * sizeof);
+               }
+       }
+
+       /**
+        * Helper method which keeps track of the heaviest weight objects (by 
total memory used)
+        * throughout execution of the program. Only reported if JMLC memory 
statistics are enabled and
+        * finegrained statistics are enabled. We only keep track of the 
-largest- instance of data associated with a
+        * particular string identifier so no need to worry about multiple 
bindings to the same name
+        * @param name String denoting the variables name
+        * @param sizeof objects size (estimated bytes)
+        */
+       public static void maintainCPHeavyHittersMem( String name, double 
sizeof ) {
+               double prevSize = _cpMemObjs.getOrDefault(name, 0.0);
+               if (prevSize < sizeof)
+                       _cpMemObjs.put(name, sizeof);
+       }
+
        /**
         * "Maintains" or adds time to per instruction/op timers, also 
increments associated count
         * @param instName name of the instruction/op
@@ -708,6 +791,56 @@ public class Statistics
                return sb.toString();
        }
 
+       public static String getCPHeavyHittersMem(int num) {
+               int n = _cpMemObjs.size();
+               if ((n <= 0) || (num <= 0))
+                       return "-";
+
+               Entry<String,Double>[] entries = 
_cpMemObjs.entrySet().toArray(new Entry[_cpMemObjs.size()]);
+               Arrays.sort(entries, new Comparator<Entry<String, Double>>() {
+                       @Override
+                       public int compare(Entry<String, Double> a, 
Entry<String, Double> b) {
+                               return b.getValue().compareTo(a.getValue());
+                       }
+               });
+
+               int numHittersToDisplay = Math.min(num, n);
+               int numPadLen = String.format("%d", 
numHittersToDisplay).length();
+               int maxNameLength = 0;
+               for (String name : _cpMemObjs.keySet())
+                       maxNameLength = Math.max(name.length(), maxNameLength);
+
+               maxNameLength = Math.max(maxNameLength, "Object".length());
+               StringBuilder res = new StringBuilder();
+               res.append(String.format("  %-" + numPadLen + "s" + "  %-" + 
maxNameLength + "s" + "  %s\n",
+                               "#", "Object", "Memory"));
+
+               // lots of futzing around to format strings...
+               for (int ix = 1; ix <= numHittersToDisplay; ix++) {
+                       String objName = entries[ix-1].getKey();
+                       String objSize = 
byteCountToDisplaySize(entries[ix-1].getValue());
+                       String numStr = String.format("  %-" + numPadLen + "s", 
ix);
+                       String objNameStr = String.format("  %-" + 
maxNameLength + "s ", objName);
+                       res.append(numStr + objNameStr + String.format("  %s", 
objSize) + "\n");
+               }
+
+               return res.toString();
+       }
+
+       /**
+        * Helper method to create a nice representation of byte counts - this 
was copied from
+        * GPUMemoryManager and should eventually be refactored probably...
+        */
+       private static String byteCountToDisplaySize(double numBytes) {
+               if (numBytes < 1024) {
+                       return numBytes + " bytes";
+               }
+               else {
+                       int exp = (int) (Math.log(numBytes) / 
6.931471805599453);
+                       return String.format("%.3f %sB", ((double)numBytes) / 
Math.pow(1024, exp), "KMGTP".charAt(exp-1));
+               }
+       }
+
        /**
         * Returns the total time of asynchronous JIT compilation in 
milliseconds.
         * 
@@ -786,6 +919,10 @@ public class Statistics
                return parforMergeTime;
        }
 
+       public static long getNumPinnedObjects() { return maxNumPinnedObjects; }
+
+       public static double getSizeofPinnedObjects() { return 
maxSizeofPinnedObjects; }
+
        /**
         * Returns statistics of the DML program that was recently completed as 
a string
         * @return statistics as a string
@@ -813,7 +950,7 @@ public class Statistics
        public static String display(int maxHeavyHitters)
        {
                StringBuilder sb = new StringBuilder();
-               
+
                sb.append("SystemML Statistics:\n");
                if( DMLScript.STATISTICS ) {
                        sb.append("Total elapsed time:\t\t" + 
String.format("%.3f", (getCompileTime()+getRunTime())*1e-9) + " sec.\n"); // 
nanoSec --> sec
@@ -828,47 +965,49 @@ public class Statistics
                else {
                        if( DMLScript.STATISTICS ) //moved into stats on Shiv's 
request
                                sb.append("Number of compiled MR Jobs:\t" + 
getNoOfCompiledMRJobs() + ".\n");
-                       sb.append("Number of executed MR Jobs:\t" + 
getNoOfExecutedMRJobs() + ".\n");   
+                       sb.append("Number of executed MR Jobs:\t" + 
getNoOfExecutedMRJobs() + ".\n");
                }
 
                if( DMLScript.USE_ACCELERATOR && DMLScript.STATISTICS)
                        sb.append(GPUStatistics.getStringForCudaTimers());
-               
+
                //show extended caching/compilation statistics
-               if( DMLScript.STATISTICS ) 
+               if( DMLScript.STATISTICS )
                {
                        if(NativeHelper.CURRENT_NATIVE_BLAS_STATE == 
NativeHelper.NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE) {
-                               String blas = NativeHelper.getCurrentBLAS(); 
-                               sb.append("Native " + blas + " calls (dense 
mult/conv/bwdF/bwdD):\t" + numNativeLibMatrixMultCalls.longValue()  + "/" + 
+                               String blas = NativeHelper.getCurrentBLAS();
+                               sb.append("Native " + blas + " calls (dense 
mult/conv/bwdF/bwdD):\t" + numNativeLibMatrixMultCalls.longValue()  + "/" +
                                                
numNativeConv2dCalls.longValue() + "/" + 
numNativeConv2dBwdFilterCalls.longValue()
                                                + "/" + 
numNativeConv2dBwdDataCalls.longValue() + ".\n");
-                               sb.append("Native " + blas + " calls (sparse 
conv/bwdF/bwdD):\t" +  
+                               sb.append("Native " + blas + " calls (sparse 
conv/bwdF/bwdD):\t" +
                                                
numNativeSparseConv2dCalls.longValue() + "/" + 
numNativeSparseConv2dBwdFilterCalls.longValue()
                                                + "/" + 
numNativeSparseConv2dBwdDataCalls.longValue() + ".\n");
                                sb.append("Native " + blas + " times (dense 
mult/conv/bwdF/bwdD):\t" + String.format("%.3f", nativeLibMatrixMultTime*1e-9) 
+ "/" +
-                                               String.format("%.3f", 
nativeConv2dTime*1e-9) + "/" + String.format("%.3f", 
nativeConv2dBwdFilterTime*1e-9) + "/" + 
+                                               String.format("%.3f", 
nativeConv2dTime*1e-9) + "/" + String.format("%.3f", 
nativeConv2dBwdFilterTime*1e-9) + "/" +
                                                String.format("%.3f", 
nativeConv2dBwdDataTime*1e-9) + ".\n");
                        }
                        if(recomputeNNZTime != 0 || examSparsityTime != 0 || 
allocateDoubleArrTime != 0) {
                                sb.append("MatrixBlock times 
(recomputeNNZ/examSparsity/allocateDoubleArr):\t" + String.format("%.3f", 
recomputeNNZTime*1e-9) + "/" +
-                                       String.format("%.3f", 
examSparsityTime*1e-9) + "/" + String.format("%.3f", 
allocateDoubleArrTime*1e-9)  + ".\n");
+                                               String.format("%.3f", 
examSparsityTime*1e-9) + "/" + String.format("%.3f", 
allocateDoubleArrTime*1e-9)  + ".\n");
                        }
-                       
+
                        sb.append("Cache hits (Mem, WB, FS, HDFS):\t" + 
CacheStatistics.displayHits() + ".\n");
                        sb.append("Cache writes (WB, FS, HDFS):\t" + 
CacheStatistics.displayWrites() + ".\n");
                        sb.append("Cache times (ACQr/m, RLS, EXP):\t" + 
CacheStatistics.displayTime() + " sec.\n");
+                       if (DMLScript.JMLC_MEMORY_STATISTICS)
+                               sb.append("Max size of objects in CP memory:\t" 
+ byteCountToDisplaySize(getSizeofPinnedObjects()) + " ("  + 
getNumPinnedObjects() + " total objects)" + "\n");
                        sb.append("HOP DAGs recompiled (PRED, SB):\t" + 
getHopRecompiledPredDAGs() + "/" + getHopRecompiledSBDAGs() + ".\n");
                        sb.append("HOP DAGs recompile time:\t" + 
String.format("%.3f", ((double)getHopRecompileTime())/1000000000) + " sec.\n");
                        if( getFunRecompiles()>0 ) {
                                sb.append("Functions recompiled:\t\t" + 
getFunRecompiles() + ".\n");
-                               sb.append("Functions recompile time:\t" + 
String.format("%.3f", ((double)getFunRecompileTime())/1000000000) + " sec.\n"); 
      
+                               sb.append("Functions recompile time:\t" + 
String.format("%.3f", ((double)getFunRecompileTime())/1000000000) + " sec.\n");
                        }
                        if( ConfigurationManager.isCodegenEnabled() ) {
                                sb.append("Codegen compile (DAG,CP,JC):\t" + 
getCodegenDAGCompile() + "/"
                                                + getCodegenCPlanCompile() + 
"/" + getCodegenClassCompile() + ".\n");
                                sb.append("Codegen enum (ALLt/p,EVALt/p):\t" + 
getCodegenEnumAll() + "/" +
                                                getCodegenEnumAllP() + "/" + 
getCodegenEnumEval() + "/" + getCodegenEnumEvalP() + ".\n");
-                               sb.append("Codegen compile times (DAG,JC):\t" + 
String.format("%.3f", (double)getCodegenCompileTime()/1000000000) + "/" + 
+                               sb.append("Codegen compile times (DAG,JC):\t" + 
String.format("%.3f", (double)getCodegenCompileTime()/1000000000) + "/" +
                                                String.format("%.3f", 
(double)getCodegenClassCompileTime()/1000000000)  + " sec.\n");
                                sb.append("Codegen enum plan cache hits:\t" + 
getCodegenPlanCacheHits() + "/" + getCodegenPlanCacheTotal() + ".\n");
                                sb.append("Codegen op plan cache hits:\t" + 
getCodegenOpCacheHits() + "/" + getCodegenOpCacheTotal() + ".\n");
@@ -878,28 +1017,28 @@ public class Statistics
                                sb.append("Spark ctx create time "+lazy+":\t"+
                                                String.format("%.3f", 
((double)sparkCtxCreateTime)*1e-9)  + " sec.\n" ); // nanoSec --> sec
                                sb.append("Spark trans counts (par,bc,col):" +
-                                               String.format("%d/%d/%d.\n", 
sparkParallelizeCount.longValue(), 
+                                               String.format("%d/%d/%d.\n", 
sparkParallelizeCount.longValue(),
                                                                
sparkBroadcastCount.longValue(), sparkCollectCount.longValue()));
                                sb.append("Spark trans times (par,bc,col):\t" +
-                                               String.format("%.3f/%.3f/%.3f 
secs.\n", 
-                                                                
((double)sparkParallelize.longValue())*1e-9,
-                                                                
((double)sparkBroadcast.longValue())*1e-9,
-                                                                
((double)sparkCollect.longValue())*1e-9));
+                                               String.format("%.3f/%.3f/%.3f 
secs.\n",
+                                                               
((double)sparkParallelize.longValue())*1e-9,
+                                                               
((double)sparkBroadcast.longValue())*1e-9,
+                                                               
((double)sparkCollect.longValue())*1e-9));
                        }
                        if (psNumWorkers.longValue() > 0) {
                                sb.append(String.format("Paramserv total num 
workers:\t%d.\n", psNumWorkers.longValue()));
                                sb.append(String.format("Paramserv setup 
time:\t\t%.3f secs.\n", psSetupTime.doubleValue() / 1000));
                                sb.append(String.format("Paramserv grad compute 
time:\t%.3f secs.\n", psGradientComputeTime.doubleValue() / 1000));
                                sb.append(String.format("Paramserv model update 
time:\t%.3f/%.3f secs.\n",
-                                       psLocalModelUpdateTime.doubleValue() / 
1000, psAggregationTime.doubleValue() / 1000));
+                                               
psLocalModelUpdateTime.doubleValue() / 1000, psAggregationTime.doubleValue() / 
1000));
                                sb.append(String.format("Paramserv model 
broadcast time:\t%.3f secs.\n", psModelBroadcastTime.doubleValue() / 1000));
                                sb.append(String.format("Paramserv batch slice 
time:\t%.3f secs.\n", psBatchIndexTime.doubleValue() / 1000));
                        }
                        if( parforOptCount>0 ){
                                sb.append("ParFor loops optimized:\t\t" + 
getParforOptCount() + ".\n");
-                               sb.append("ParFor optimize time:\t\t" + 
String.format("%.3f", ((double)getParforOptTime())/1000) + " sec.\n");  
-                               sb.append("ParFor initialize time:\t\t" + 
String.format("%.3f", ((double)getParforInitTime())/1000) + " sec.\n");       
-                               sb.append("ParFor result merge time:\t" + 
String.format("%.3f", ((double)getParforMergeTime())/1000) + " sec.\n");      
+                               sb.append("ParFor optimize time:\t\t" + 
String.format("%.3f", ((double)getParforOptTime())/1000) + " sec.\n");
+                               sb.append("ParFor initialize time:\t\t" + 
String.format("%.3f", ((double)getParforInitTime())/1000) + " sec.\n");
+                               sb.append("ParFor result merge time:\t" + 
String.format("%.3f", ((double)getParforMergeTime())/1000) + " sec.\n");
                                sb.append("ParFor total update in-place:\t" + 
lTotalUIPVar + "/" + lTotalLixUIP + "/" + lTotalLix + "\n");
                        }
 
@@ -908,8 +1047,10 @@ public class Statistics
                        sb.append("Total JVM GC time:\t\t" + 
((double)getJVMgcTime())/1000 + " sec.\n");
                        LibMatrixDNN.appendStatistics(sb);
                        sb.append("Heavy hitter instructions:\n" + 
getHeavyHitters(maxHeavyHitters));
+                       if ((DMLScript.JMLC_MEMORY_STATISTICS) && 
(DMLScript.FINEGRAINED_STATISTICS))
+                               sb.append("Heavy hitter objects:\n" + 
getCPHeavyHittersMem(maxHeavyHitters));
                }
-               
+
                return sb.toString();
        }
 }

Reply via email to