Repository: systemml
Updated Branches:
  refs/heads/master ddcb9e019 -> 933a17c5f


[SYSTEMML-540] Allow for `allreduce_parallel_batches` for multi-GPU training

- Introduced `allreduce_parallel_batches` for multi-GPU training as per
  Mike's suggestion.
- Moved `train_algo` and `test_algo` from solver specification to Python API
  to conform with Caffe as per Berthold's suggestion.
- Updated the documentation for Caffe2DML.

Closes #543.


Project: http://git-wip-us.apache.org/repos/asf/systemml/repo
Commit: http://git-wip-us.apache.org/repos/asf/systemml/commit/933a17c5
Tree: http://git-wip-us.apache.org/repos/asf/systemml/tree/933a17c5
Diff: http://git-wip-us.apache.org/repos/asf/systemml/diff/933a17c5

Branch: refs/heads/master
Commit: 933a17c5f38ac71d2aa3e855da1e38797fe6603d
Parents: ddcb9e0
Author: Niketan Pansare <npan...@us.ibm.com>
Authored: Fri Jun 16 10:14:44 2017 -0800
Committer: Niketan Pansare <npan...@us.ibm.com>
Committed: Fri Jun 16 11:14:44 2017 -0700

----------------------------------------------------------------------
 docs/beginners-guide-caffe2dml.md               |  23 +-
 src/main/proto/caffe/caffe.proto                |  47 ++--
 src/main/python/systemml/mllearn/estimators.py  |   7 +-
 .../org/apache/sysml/api/dl/Caffe2DML.scala     | 247 ++++++++++++++-----
 4 files changed, 248 insertions(+), 76 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/systemml/blob/933a17c5/docs/beginners-guide-caffe2dml.md
----------------------------------------------------------------------
diff --git a/docs/beginners-guide-caffe2dml.md 
b/docs/beginners-guide-caffe2dml.md
index b44b113..f15e025 100644
--- a/docs/beginners-guide-caffe2dml.md
+++ b/docs/beginners-guide-caffe2dml.md
@@ -94,6 +94,8 @@ lenet.setStatistics(True).setExplain(True)
 
 # If you want to force GPU execution. Please make sure the required dependency 
are available.  
 # lenet.setGPU(True).setForceGPU(True)
+# Example usage of train_algo, test_algo. Assume 2 gpus on driver
+# lenet.set(train_algo="allreduce_parallel_batches", test_algo="minibatch", 
parallel_batches=2)
 
 # (Optional but recommended) Enable native BLAS. 
 lenet.setConfigProperty("native.blas", "auto")
@@ -108,6 +110,16 @@ lenet.predict(X_test)
 
 For more detail on enabling native BLAS, please see the documentation for the 
[native backend](http://apache.github.io/systemml/native-backend).
 
+Common settings for `train_algo` and `test_algo` parameters:
+
+|                                                                          | 
PySpark script                                                                  
                                                         | Changes to 
Network/Solver                                              |
+|--------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------|
+| Single-node CPU execution (similar to Caffe with solver_mode: CPU)       | 
`caffe2dml.set(train_algo="minibatch", test_algo="minibatch")`                  
                                                         | Ensure that 
`batch_size` is set to appropriate value (for example: 64) |
+| Single-node single-GPU execution                                         | 
`caffe2dml.set(train_algo="minibatch", 
test_algo="minibatch").setGPU(True).setForceGPU(True)`                          
                  | Ensure that `batch_size` is set to appropriate value (for 
example: 64) |
+| Single-node multi-GPU execution (similar to Caffe with solver_mode: GPU) | 
`caffe2dml.set(train_algo="allreduce_parallel_batches", test_algo="minibatch", 
parallel_batches=num_gpu).setGPU(True).setForceGPU(True)` | Ensure that 
`batch_size` is set to appropriate value (for example: 64) |
+| Distributed prediction                                                   | 
`caffe2dml.set(test_algo="allreduce")`                                          
                                                         |                      
                                                  |
+| Distributed synchronous training                                         | 
`caffe2dml.set(train_algo="allreduce_parallel_batches", 
parallel_batches=num_cluster_cores)`                                            
 | Ensure that `batch_size` is set to appropriate value (for example: 64) |
+
 ## Frequently asked questions
 
 #### What is the purpose of Caffe2DML API ?
@@ -282,4 +294,13 @@ train_df.write.parquet('kaggle-cats-dogs.parquet')
 #### Can I use Caffe2DML via Scala ?
 
 Though we recommend using Caffe2DML via its Python interfaces, it is possible 
to use it by creating an object of the class
-`org.apache.sysml.api.dl.Caffe2DML`. It is important to note that Caffe2DML's 
scala API is packaged in `systemml-*-extra.jar`.
\ No newline at end of file
+`org.apache.sysml.api.dl.Caffe2DML`. It is important to note that Caffe2DML's 
scala API is packaged in `systemml-*-extra.jar`.
+
+
+#### How can I view the script generated by Caffe2DML ?
+
+To view the generated DML script (and additional debugging information), 
please set the `debug` parameter to True.
+
+```python
+caffe2dmlObject.set(debug=True)
+```

http://git-wip-us.apache.org/repos/asf/systemml/blob/933a17c5/src/main/proto/caffe/caffe.proto
----------------------------------------------------------------------
diff --git a/src/main/proto/caffe/caffe.proto b/src/main/proto/caffe/caffe.proto
index ca30f21..37d78e9 100644
--- a/src/main/proto/caffe/caffe.proto
+++ b/src/main/proto/caffe/caffe.proto
@@ -98,7 +98,7 @@ message NetParameter {
 // NOTE
 // Update the next available ID when you add a new SolverParameter field.
 //
-// SolverParameter next available ID: 43 (last added: test_algo)
+// SolverParameter next available ID: 42 (last added: layer_wise_reduce)
 message SolverParameter {
   
//////////////////////////////////////////////////////////////////////////////
   // Specifying the train and test networks
@@ -113,10 +113,6 @@ message SolverParameter {
   // A test_iter must be specified for each test_net.
   // A test_level and/or a test_stage may also be specified for each test_net.
   
//////////////////////////////////////////////////////////////////////////////
-  
-  // SystemML extension
-  optional string train_algo = 41 [default = "minibatch"];
-  optional string test_algo = 42 [default = "minibatch"];
 
   // Proto filename for the train net, possibly combined with one or more
   // test nets.
@@ -132,8 +128,7 @@ message SolverParameter {
   // The states for the train/test nets. Must be unspecified or
   // specified once per net.
   //
-  // By default, all states will have solver = true;
-  // train_state will have phase = TRAIN,
+  // By default, train_state will have phase = TRAIN,
   // and all test_state's will have phase = TEST.
   // Other defaults are set according to the NetState defaults.
   optional NetState train_state = 26;
@@ -243,6 +238,9 @@ message SolverParameter {
   }
   // DEPRECATED: use type instead of solver_type
   optional SolverType solver_type = 30 [default = SGD];
+
+  // Overlap compute and communication for data parallel training
+  optional bool layer_wise_reduce = 41 [default = true];
 }
 
 // A message that stores the solver snapshots
@@ -422,7 +420,7 @@ message TransformationParameter {
   optional uint32 crop_size = 3 [default = 0];
   // mean_file and mean_value cannot be specified at the same time
   optional string mean_file = 4;
-  // if specified can be repeated once (would substract it from all the 
channels)
+  // if specified can be repeated once (would subtract it from all the 
channels)
   // or can be repeated the same number of times as channels
   // (would subtract them from the corresponding channel)
   repeated float mean_value = 5;
@@ -438,7 +436,7 @@ message LossParameter {
   optional int32 ignore_label = 1;
   // How to normalize the loss for loss layers that aggregate across batches,
   // spatial dimensions, or other dimensions.  Currently only implemented in
-  // SoftmaxWithLoss layer.
+  // SoftmaxWithLoss and SigmoidCrossEntropyLoss layers.
   enum NormalizationMode {
     // Divide by the number of examples in the batch times spatial dimensions.
     // Outputs that receive the ignore label will NOT be ignored in computing
@@ -452,6 +450,8 @@ message LossParameter {
     // Do not normalize the loss.
     NONE = 3;
   }
+  // For historical reasons, the default normalization for
+  // SigmoidCrossEntropyLoss is BATCH_SIZE and *not* VALID.
   optional NormalizationMode normalization = 3 [default = VALID];
   // Deprecated.  Ignored if normalization is specified.  If normalization
   // is not specified, then setting this to false will be equivalent to
@@ -502,11 +502,21 @@ message ConcatParameter {
 }
 
 message BatchNormParameter {
-  // If false, accumulate global mean/variance values via a moving average. If
-  // true, use those accumulated values instead of computing mean/variance
-  // across the batch.
+  // If false, normalization is performed over the current mini-batch
+  // and global statistics are accumulated (but not yet used) by a moving
+  // average.
+  // If true, those accumulated mean and variance values are used for the
+  // normalization.
+  // By default, it is set to false when the network is in the training
+  // phase and true when the network is in the testing phase.
   optional bool use_global_stats = 1;
-  // How much does the moving average decay each iteration?
+  // What fraction of the moving average remains each iteration?
+  // Smaller values make the moving average decay faster, giving more
+  // weight to the recent values.
+  // Each iteration updates the moving average @f$S_{t-1}@f$ with the
+  // current mean @f$ Y_t @f$ by
+  // @f$ S_t = (1-\beta)Y_t + \beta \cdot S_{t-1} @f$, where @f$ \beta @f$
+  // is the moving_average_fraction parameter.
   optional float moving_average_fraction = 2 [default = .999];
   // Small value to add to the variance estimate so that we don't divide by
   // zero.
@@ -657,8 +667,8 @@ message DataParameter {
   optional bool mirror = 6 [default = false];
   // Force the encoded image to have 3 color channels
   optional bool force_encoded_color = 9 [default = false];
-  // Prefetch queue (Number of batches to prefetch to host memory, increase if
-  // data access bandwidth varies).
+  // Prefetch queue (Increase if data feeding bandwidth varies, within the
+  // limit of device memory for GPU training)
   optional uint32 prefetch = 10 [default = 4];
 }
 
@@ -805,6 +815,7 @@ message ImageDataParameter {
 message InfogainLossParameter {
   // Specify the infogain matrix source.
   optional string source = 1;
+  optional int32 axis = 2 [default = 1]; // axis of prob
 }
 
 message InnerProductParameter {
@@ -927,9 +938,7 @@ message PythonParameter {
   // string, dictionary in Python dict format, JSON, etc. You may parse this
   // string in `setup` method and use it in `forward` and `backward`.
   optional string param_str = 3 [default = ''];
-  // Whether this PythonLayer is shared among worker solvers during data 
parallelism.
-  // If true, each worker solver sequentially run forward from this layer.
-  // This value should be set true if you are using it as a data layer.
+  // DEPRECATED
   optional bool share_in_parallel = 4 [default = false];
 }
 
@@ -1398,6 +1407,6 @@ message PReLUParameter {
 
   // Initial value of a_i. Default is a_i=0.25 for all i.
   optional FillerParameter filler = 1;
-  // Whether or not slope paramters are shared across channels.
+  // Whether or not slope parameters are shared across channels.
   optional bool channel_shared = 2 [default = false];
 }
\ No newline at end of file

http://git-wip-us.apache.org/repos/asf/systemml/blob/933a17c5/src/main/python/systemml/mllearn/estimators.py
----------------------------------------------------------------------
diff --git a/src/main/python/systemml/mllearn/estimators.py 
b/src/main/python/systemml/mllearn/estimators.py
index ec225c4..30e66d4 100644
--- a/src/main/python/systemml/mllearn/estimators.py
+++ b/src/main/python/systemml/mllearn/estimators.py
@@ -737,15 +737,20 @@ class Caffe2DML(BaseSystemMLClassifier):
         if ignore_weights is not None:
             self.estimator.setWeightsToIgnore(ignore_weights)
             
-    def set(self, num_classes=None, debug=None):
+    def set(self, debug=None, train_algo=None, test_algo=None, 
parallel_batches=None):
         """
         Set input to Caffe2DML
         
         Parameters
         ----------
         debug: to add debugging DML code such as classification report, print 
DML script, etc (default: False)
+        train_algo: can be minibatch, batch, allreduce_parallel_batches or 
allreduce (default: minibatch)
+        test_algo: can be minibatch, batch, allreduce_parallel_batches or 
allreduce (default: minibatch)
         """
         if debug is not None: self.estimator.setInput("$debug", 
str(debug).upper())
+        if train_algo is not None: self.estimator.setInput("$train_algo", 
str(train_algo).lower())
+        if test_algo is not None: self.estimator.setInput("$test_algo", 
str(test_algo).lower())
+        if parallel_batches is not None: 
self.estimator.setInput("$parallel_batches", str(parallel_batches))
         return self
     
     def visualize(self, layerName=None, varType='weight', aggFn='mean'):

http://git-wip-us.apache.org/repos/asf/systemml/blob/933a17c5/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
----------------------------------------------------------------------
diff --git a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala 
b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
index f9b7ecc..f338fd7 100644
--- a/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
+++ b/src/main/scala/org/apache/sysml/api/dl/Caffe2DML.scala
@@ -182,6 +182,9 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
   // Method called by Python mllearn to visualize variable of certain layer
   def visualizeLayer(layerName:String, varType:String, aggFn:String): Unit = 
visualizeLayer(net, layerName, varType, aggFn)
   
+  def getTrainAlgo():String = if(inputs.containsKey("$train_algo")) 
inputs.get("$train_algo") else "minibatch"
+  def getTestAlgo():String = if(inputs.containsKey("$test_algo")) 
inputs.get("$test_algo") else "minibatch"
+    
   // 
================================================================================================
   // The below method parses the provided network and solver file and 
generates DML script.
        def getTrainingScript(isSingleNode:Boolean):(Script, String, String)  = 
{
@@ -209,47 +212,97 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
          // 
----------------------------------------------------------------------------
          // Main logic
          forBlock("e", "1", "max_epochs") {
-           solverParam.getTrainAlgo.toLowerCase match {
+           getTrainAlgo.toLowerCase match {
              case "minibatch" => 
                forBlock("i", "1", "num_iters_per_epoch") {
                  getTrainingBatch(tabDMLScript)
-                 tabDMLScript.append("iter = start_iter + i\n")
+                 tabDMLScript.append("iter = iter + 1\n")
+                 // -------------------------------------------------------
+                 // Perform forward, backward and update on minibatch
                  forward; backward; update
+                 // -------------------------------------------------------
                  displayLoss(lossLayers(0), shouldValidate)
             performSnapshot
                }
              case "batch" => {
-          tabDMLScript.append("iter = start_iter + i\n")
+          tabDMLScript.append("iter = iter + 1\n")
+          // -------------------------------------------------------
+          // Perform forward, backward and update on entire dataset
           forward; backward; update
+          // -------------------------------------------------------
           displayLoss(lossLayers(0), shouldValidate)
           performSnapshot
              }
+             case "allreduce_parallel_batches" => {
+               // This setting uses the batch size provided by the user
+          if(!inputs.containsKey("$parallel_batches")) {
+            throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
+          }
+          // The user specifies the number of parallel_batches
+          // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
+          assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+          assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
+          assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
+          // Grab groups of mini-batches
+          forBlock("g", "1", "groups") {
+            tabDMLScript.append("iter = iter + 1\n")
+            // Get next group of mini-batches
+            assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " 
+ Caffe2DML.numImages + " + 1")
+            assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + 
", group_beg + group_batch_size - 1)")
+            assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[group_beg:group_end,]")
+            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + 
"[group_beg:group_end,]")
+            initializeGradients("parallel_batches")
+            parForBlock("j", "1", "parallel_batches") {
+              // Get a mini-batch in this group
+              assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + 
") %% nrow(X_group_batch) + 1")
+              assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + 
Caffe2DML.batchSize + " - 1)")
+              assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+              assign(tabDMLScript, "yb", "y_group_batch[beg:end,]")
+              forward; backward
+              flattenGradients
+            }
+            aggregateAggGradients    
+                 update
+                 // -------------------------------------------------------
+                 assign(tabDMLScript, "Xb", "X_group_batch")
+            assign(tabDMLScript, "yb", "y_group_batch")
+            displayLoss(lossLayers(0), shouldValidate)
+            performSnapshot
+          }
+             }
              case "allreduce" => {
+               // This is distributed synchronous gradient descent
                forBlock("i", "1", "num_iters_per_epoch") {
-                 getTrainingBatch(tabDMLScript)
-                 assign(tabDMLScript, "X_group_batch", "Xb")
-                 assign(tabDMLScript, "y_group_batch", "yb")
-                 tabDMLScript.append("iter = start_iter + i\n")
-                 initAggGradients
-                 parForBlock("j", "1", "nrow(y_group_batch)") {
+                 tabDMLScript.append("iter = iter + 1\n")
+                 // -------------------------------------------------------
+            // Perform forward, backward and update on minibatch in parallel
+                 assign(tabDMLScript, "beg", "((i-1) * " + Caffe2DML.batchSize 
+ ") %% " + Caffe2DML.numImages + " + 1")
+                 assign(tabDMLScript, "end", " min(beg +  " + 
Caffe2DML.batchSize + " - 1, " + Caffe2DML.numImages + ")")
+                 assign(tabDMLScript, "X_group_batch", Caffe2DML.X + 
"[beg:end,]")
+            assign(tabDMLScript, "y_group_batch", Caffe2DML.y + "[beg:end,]")
+                 tabDMLScript.append("local_batch_size = 
nrow(y_group_batch)\n")
+                 val localBatchSize = "local_batch_size"
+                 initializeGradients(localBatchSize)
+                 parForBlock("j", "1", localBatchSize) {
                    assign(tabDMLScript, "Xb", "X_group_batch[j,]")
                    assign(tabDMLScript, "yb", "y_group_batch[j,]")
-                   forward; backward("_agg")
-              flattenAndStoreAggGradients_j
+                   forward; backward
+              flattenGradients
                  }
-                 aggregateAggGradients
-            tabDMLScript.append("iter = start_iter + parallel_batches\n")    
+                 aggregateAggGradients    
                  update
+                 // -------------------------------------------------------
+                 assign(tabDMLScript, "Xb", "X_group_batch")
+            assign(tabDMLScript, "yb", "y_group_batch")
             displayLoss(lossLayers(0), shouldValidate)
             performSnapshot
                }
              }
-             case _ => throw new DMLRuntimeException("Unsupported train algo:" 
+ solverParam.getTrainAlgo)
+             case _ => throw new DMLRuntimeException("Unsupported train algo:" 
+ getTrainAlgo)
            }
            // After every epoch, update the learning rate
            tabDMLScript.append("# Learning rate\n")
            lrPolicy.updateLearningRate(tabDMLScript)
-           tabDMLScript.append("start_iter = start_iter + 
num_iters_per_epoch\n")
          }
          // 
----------------------------------------------------------------------------
          
@@ -308,29 +361,38 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
   private def displayLoss(lossLayer:IsLossLayer, shouldValidate:Boolean):Unit 
= {
     if(solverParam.getDisplay > 0) {
       // Append the DML to compute training loss
-      tabDMLScript.append("# Compute training loss & accuracy\n")
-      ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
-        assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", 
"0")
-        lossLayer.computeLoss(dmlScript, numTabs)
-        assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, 
"training_accuracy", "accuracy")
-        tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
-            asDMLString(", training loss:"), "training_loss", asDMLString(", 
training accuracy:"), "training_accuracy" )))
-        appendTrainingVisualizationBody(dmlScript, numTabs)
-        printClassificationReport
+      if(!getTrainAlgo.toLowerCase.startsWith("allreduce")) {
+        // Compute training loss for allreduce
+        tabDMLScript.append("# Compute training loss & accuracy\n")
+        ifBlock("iter  %% " + solverParam.getDisplay + " == 0") {
+          assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", 
"0")
+          lossLayer.computeLoss(dmlScript, numTabs)
+          assign(tabDMLScript, "training_loss", "loss"); assign(tabDMLScript, 
"training_accuracy", "accuracy")
+          tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
+              asDMLString(", training loss:"), "training_loss", asDMLString(", 
training accuracy:"), "training_accuracy" )))
+          appendTrainingVisualizationBody(dmlScript, numTabs)
+          printClassificationReport
+        }
+      }
+      else {
+        Caffe2DML.LOG.info("Training loss is not printed for train_algo=" + 
getTrainAlgo)
       }
       if(shouldValidate) {
+        if(  getTrainAlgo.toLowerCase.startsWith("allreduce") &&
+            getTestAlgo.toLowerCase.startsWith("allreduce")) {
+          Caffe2DML.LOG.warn("The setting: train_algo=" + getTrainAlgo + " and 
test_algo=" + getTestAlgo + " is not recommended. Consider changing 
test_algo=minibatch")
+        }
         // Append the DML to compute validation loss
         val numValidationBatches = if(solverParam.getTestIterCount > 0) 
solverParam.getTestIter(0) else 0
         tabDMLScript.append("# Compute validation loss & accuracy\n")
         ifBlock("iter  %% " + solverParam.getTestInterval + " == 0") {
           assign(tabDMLScript, "loss", "0"); assign(tabDMLScript, "accuracy", 
"0")
-          solverParam.getTestAlgo.toLowerCase match {
+          getTestAlgo.toLowerCase match {
             case "minibatch" => {
               assign(tabDMLScript, "validation_loss", "0")
               assign(tabDMLScript, "validation_accuracy", "0")
               forBlock("iVal", "1", "num_iters_per_epoch") {
                  getValidationBatch(tabDMLScript)
-                 tabDMLScript.append("iter = start_iter + i\n")
                  forward;  lossLayer.computeLoss(dmlScript, numTabs)
                 tabDMLScript.append("validation_loss = validation_loss + 
loss\n")
                 tabDMLScript.append("validation_accuracy = validation_accuracy 
+ accuracy\n")
@@ -344,7 +406,60 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
               assign(tabDMLScript, "validation_loss", "loss"); 
assign(tabDMLScript, "validation_accuracy", "accuracy")
               
             }
-            case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
solverParam.getTestAlgo)
+            case "allreduce_parallel_batches" => {
+              // This setting uses the batch size provided by the user
+              if(!inputs.containsKey("$parallel_batches")) {
+                throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
+              }
+              // The user specifies the number of parallel_batches
+              // This ensures that the user of generated script remembers to 
provide the commandline parameter $parallel_batches
+              assign(tabDMLScript, "parallel_batches_val", 
"$parallel_batches") 
+              assign(tabDMLScript, "group_batch_size_val", 
"parallel_batches_val*" + Caffe2DML.batchSize)
+              assign(tabDMLScript, "groups_val", "as.integer(ceil(" + 
Caffe2DML.numValidationImages + "/group_batch_size_val))")
+              assign(tabDMLScript, "validation_accuracy", "0")
+              assign(tabDMLScript, "validation_loss", "0")
+              // Grab groups of mini-batches
+              forBlock("g_val", "1", "groups_val") {
+                assign(tabDMLScript, "group_beg_val", "((g_val-1) * 
group_batch_size_val) %% " + Caffe2DML.numValidationImages + " + 1")
+                assign(tabDMLScript, "group_end_val", "min(" + 
Caffe2DML.numValidationImages + ", group_beg_val + group_batch_size_val - 1)")
+                assign(tabDMLScript, "X_group_batch_val", Caffe2DML.XVal + 
"[group_beg_val:group_end_val,]")
+                assign(tabDMLScript, "y_group_batch_val", Caffe2DML.yVal + 
"[group_beg_val:group_end_val,]")
+                assign(tabDMLScript, "group_validation_loss", matrix("0", 
"parallel_batches_val", "1"))
+                assign(tabDMLScript, "group_validation_accuracy", matrix("0", 
"parallel_batches_val", "1"))
+                //  Run graph on each mini-batch in this group in parallel 
(ideally on multiple GPUs)
+                parForBlock("iVal", "1", "parallel_batches_val") {
+                  assign(tabDMLScript, "beg_val", "((iVal-1) * " + 
Caffe2DML.batchSize + ") %% nrow(y_group_batch_val) + 1")
+                  assign(tabDMLScript, "end_val", 
"min(nrow(y_group_batch_val), beg_val + " + Caffe2DML.batchSize + " - 1)")
+                  assign(tabDMLScript, "Xb", 
"X_group_batch_val[beg_val:end_val,]")
+                  assign(tabDMLScript, "yb", 
"y_group_batch_val[beg_val:end_val,]")
+                  net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
+                  lossLayer.computeLoss(dmlScript, numTabs)
+                  assign(tabDMLScript, "group_validation_loss[iVal,1]", "loss")
+                  assign(tabDMLScript, "group_validation_accuracy[iVal,1]", 
"accuracy")
+                }
+                assign(tabDMLScript, "validation_loss", "validation_loss + 
sum(group_validation_loss)")
+                assign(tabDMLScript, "validation_accuracy", 
"validation_accuracy + sum(group_validation_accuracy)")
+              }
+              assign(tabDMLScript, "validation_accuracy", 
"validation_accuracy/groups_val")
+            }
+            case "allreduce" => {
+              // This setting doesnot use the batch size for validation and 
allows the parfor optimizer to select plan
+              // by minimizing the memory requirement (i.e. batch size = 1)
+              assign(tabDMLScript, "group_validation_loss", matrix("0", 
Caffe2DML.numValidationImages, "1"))
+              assign(tabDMLScript, "group_validation_accuracy", matrix("0", 
Caffe2DML.numValidationImages, "1"))
+              parForBlock("iVal", "1", Caffe2DML.numValidationImages) {
+                assign(tabDMLScript, "Xb",  Caffe2DML.XVal + "[iVal,]")
+                assign(tabDMLScript, "yb",  Caffe2DML.yVal + "[iVal,]")
+                net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
+                lossLayer.computeLoss(dmlScript, numTabs)
+                assign(tabDMLScript, "group_validation_loss[iVal,1]", "loss")
+                assign(tabDMLScript, "group_validation_accuracy[iVal,1]", 
"accuracy")
+              }
+              assign(tabDMLScript, "validation_loss", 
"sum(group_validation_loss)")
+              assign(tabDMLScript, "validation_accuracy", 
"mean(group_validation_accuracy)")
+            }
+            
+            case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
getTestAlgo)
           }
           tabDMLScript.append(print( dmlConcat( asDMLString("Iter:"), "iter", 
               asDMLString(", validation loss:"), "validation_loss", 
asDMLString(", validation accuracy:"), "validation_accuracy" )))
@@ -368,23 +483,22 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
     tabDMLScript.append("# Perform forward pass\n")
          net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, false))
   }
-  private def backward():Unit = backward("")
-  private def backward(suffix:String):Unit = {
+  private def backward():Unit = {
     tabDMLScript.append("# Perform backward pass\n")
-    net.getLayers.reverse.map(layer => 
net.getCaffeLayer(layer).backward(tabDMLScript, suffix))
+    net.getLayers.reverse.map(layer => 
net.getCaffeLayer(layer).backward(tabDMLScript, ""))
   }
   private def update():Unit = {
     tabDMLScript.append("# Update the parameters\n")
     net.getLayers.map(layer => solver.update(tabDMLScript, 
net.getCaffeLayer(layer)))
   }
-  private def initAggGradients():Unit = {
-    tabDMLScript.append("# Data structure to store gradients computed in 
parallel")
+  private def initializeGradients(parallel_batches:String):Unit = {
+    tabDMLScript.append("# Data structure to store gradients computed in 
parallel\n")
     net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
-      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", 
matrix("0", "parallel_batches", multiply(nrow(l.weight), ncol(l.weight))))
-      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", 
matrix("0", "parallel_batches", multiply(nrow(l.bias), ncol(l.bias)))) 
+      if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.weight), ncol(l.weight))))
+      if(l.shouldUpdateBias) assign(tabDMLScript, l.dBias + "_agg", 
matrix("0", parallel_batches, multiply(nrow(l.bias), ncol(l.bias)))) 
     })
   }
-  private def flattenAndStoreAggGradients_j():Unit = {
+  private def flattenGradients():Unit = {
     tabDMLScript.append("# Flatten and store gradients for this parallel 
execution\n")
     net.getLayers.map(layer => net.getCaffeLayer(layer)).map(l => {
       if(l.shouldUpdateWeight) assign(tabDMLScript, l.dWeight + "_agg[j,]", 
@@ -404,7 +518,7 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
   }
   // Set iteration-related variables such as max_epochs, num_iters_per_epoch, 
lr, etc.
   def setIterationVariables():Unit = {
-    solverParam.getTrainAlgo.toLowerCase match {
+    getTrainAlgo.toLowerCase match {
            case "batch" => 
              assign(tabDMLScript, "max_epochs", 
solverParam.getMaxIter.toString)
            case _ => {
@@ -412,14 +526,13 @@ class Caffe2DML(val sc: SparkContext, val 
solverParam:Caffe.SolverParameter,
              ceilDivide(tabDMLScript, "max_epochs", 
solverParam.getMaxIter.toString, "num_iters_per_epoch")
            }
          }
-         assign(tabDMLScript, "start_iter", "0")
+         assign(tabDMLScript, "iter", "0")
          assign(tabDMLScript, "lr", solverParam.getBaseLr.toString)
   }
   // 
-------------------------------------------------------------------------------------------
 }
 
-class Caffe2DMLModel(val mloutput: MLResults,  
-    val numClasses:String, val sc: SparkContext, val solver:CaffeSolver,
+class Caffe2DMLModel(val numClasses:String, val sc: SparkContext, val 
solver:CaffeSolver,
     val net:CaffeNetwork, val lrPolicy:LearningRatePolicy,
     val estimator:Caffe2DML) 
   extends Model[Caffe2DMLModel] with HasMaxOuterIter with 
BaseSystemMLClassifierModel with DMLGenerator {
@@ -427,14 +540,14 @@ class Caffe2DMLModel(val mloutput: MLResults,
   // Invoked by Python, MLPipeline
   val uid:String = "caffe_model_" + (new Random).nextLong 
   def this(estimator:Caffe2DML) =  {
-    this(null, Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
+    this(Utils.numClasses(estimator.net), estimator.sc, estimator.solver,
         estimator.net,
         // new CaffeNetwork(estimator.solverParam.getNet, 
caffe.Caffe.Phase.TEST, estimator.numChannels, estimator.height, 
estimator.width), 
         estimator.lrPolicy, estimator) 
   }
       
   override def copy(extra: org.apache.spark.ml.param.ParamMap): Caffe2DMLModel 
= {
-    val that = new Caffe2DMLModel(mloutput, numClasses, sc, solver, net, 
lrPolicy, estimator)
+    val that = new Caffe2DMLModel(numClasses, sc, solver, net, lrPolicy, 
estimator)
     copyValues(that, extra)
   }
   // --------------------------------------------------------------
@@ -459,11 +572,9 @@ class Caffe2DMLModel(val mloutput: MLResults,
     assign(tabDMLScript, "X", "X_full")
     
     // Initialize the layers and solvers. Reads weights and bias if 
readWeights is true.
-    val readWeights = {
-           if(mloutput == null && estimator.inputs.containsKey("$weights")) 
true
-           else if(mloutput == null) throw new DMLRuntimeException("Cannot 
call predict/score without calling either fit or by providing weights")
-           else false
-         }
+    if(!estimator.inputs.containsKey("$weights") && estimator.mloutput == 
null) 
+      throw new DMLRuntimeException("Cannot call predict/score without calling 
either fit or by providing weights")
+    val readWeights = estimator.inputs.containsKey("$weights") || 
estimator.mloutput != null
     initWeights(net, solver, readWeights)
          
          // Donot update mean and variance in batchnorm
@@ -472,7 +583,7 @@ class Caffe2DMLModel(val mloutput: MLResults,
          val lossLayers = getLossLayers(net)
          
          assign(tabDMLScript, "Prob", matrix("0", Caffe2DML.numImages, 
numClasses))
-         estimator.solverParam.getTestAlgo.toLowerCase match {
+         estimator.getTestAlgo.toLowerCase match {
       case "minibatch" => {
         ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, 
Caffe2DML.batchSize)
         forBlock("i", "1", "num_iters") {
@@ -486,15 +597,41 @@ class Caffe2DMLModel(val mloutput: MLResults,
         net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, true))
         assign(tabDMLScript, "Prob", lossLayers(0).out)
       }
+      case "allreduce_parallel_batches" => {
+        // This setting uses the batch size provided by the user
+        if(!estimator.inputs.containsKey("$parallel_batches")) {
+          throw new RuntimeException("The parameter parallel_batches is 
required for allreduce_parallel_batches")
+        }
+        // The user specifies the number of parallel_batches
+        // This ensures that the user of generated script remembers to provide 
the commandline parameter $parallel_batches
+        assign(tabDMLScript, "parallel_batches", "$parallel_batches") 
+        assign(tabDMLScript, "group_batch_size", "parallel_batches*" + 
Caffe2DML.batchSize)
+        assign(tabDMLScript, "groups", "as.integer(ceil(" + 
Caffe2DML.numImages + "/group_batch_size))")
+        // Grab groups of mini-batches
+        forBlock("g", "1", "groups") {
+          assign(tabDMLScript, "group_beg", "((g-1) * group_batch_size) %% " + 
Caffe2DML.numImages + " + 1")
+          assign(tabDMLScript, "group_end", "min(" + Caffe2DML.numImages + ", 
group_beg + group_batch_size - 1)")
+          assign(tabDMLScript, "X_group_batch", "X_full[group_beg:group_end,]")
+          //  Run graph on each mini-batch in this group in parallel (ideally 
on multiple GPUs)
+          parForBlock("j", "1", "parallel_batches") {
+            assign(tabDMLScript, "beg", "((j-1) * " + Caffe2DML.batchSize + ") 
%% nrow(X_group_batch) + 1")
+            assign(tabDMLScript, "end", "min(nrow(X_group_batch), beg + " + 
Caffe2DML.batchSize + " - 1)")
+            assign(tabDMLScript, "Xb", "X_group_batch[beg:end,]")
+            net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, true))
+            assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
+          }
+        }
+      }
       case "allreduce" => {
-        ceilDivide(tabDMLScript(), "num_iters", Caffe2DML.numImages, 
Caffe2DML.batchSize)
-        parForBlock("i", "1", "num_iters") {
-          getTestBatch(tabDMLScript)
+        // This setting doesnot use the batch size for scoring and allows the 
parfor optimizer to select plan
+        // by minimizing the memory requirement (i.e. batch size = 1)
+        parForBlock("i", "1", Caffe2DML.numImages) {
+          assign(tabDMLScript, "Xb", "X_full[i,]")
           net.getLayers.map(layer => 
net.getCaffeLayer(layer).forward(tabDMLScript, true))
-          assign(tabDMLScript, "Prob[beg:end,]", lossLayers(0).out)
+          assign(tabDMLScript, "Prob[i,]", lossLayers(0).out)
         }
       }
-      case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
estimator.solverParam.getTestAlgo)
+      case _ => throw new DMLRuntimeException("Unsupported test algo:" + 
estimator.getTestAlgo)
     }
                
                val predictionScript = dmlScript.toString()
@@ -505,10 +642,10 @@ class Caffe2DMLModel(val mloutput: MLResults,
                updateMeanVarianceForBatchNorm(net, true)
                
          val script = dml(predictionScript).out("Prob").in(estimator.inputs)
-         if(mloutput != null) {
+         if(estimator.mloutput != null) {
            // fit was called
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != 
null).map(l => script.in(l.weight, mloutput.getMatrix(l.weight)))
-         net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l 
=> script.in(l.bias, mloutput.getMatrix(l.bias)))
+         net.getLayers.map(net.getCaffeLayer(_)).filter(_.weight != 
null).map(l => script.in(l.weight, estimator.mloutput.getMatrix(l.weight)))
+         net.getLayers.map(net.getCaffeLayer(_)).filter(_.bias != null).map(l 
=> script.in(l.bias, estimator.mloutput.getMatrix(l.bias)))
          }
          (script, "X_full")
   }

Reply via email to