[SYSTEMML-1524] Graduate `nn` library to `scripts/nn`

This graduates the SystemML `nn` deep learning library from the staging
directory to the top-level `scripts` directory.  The aim is to have the
library ready for full release by the 1.0 release, alongside Caffe2DML,
GPU support, and native BLAS.

Closes #472.


Project: http://git-wip-us.apache.org/repos/asf/incubator-systemml/repo
Commit: 
http://git-wip-us.apache.org/repos/asf/incubator-systemml/commit/43c321d1
Tree: http://git-wip-us.apache.org/repos/asf/incubator-systemml/tree/43c321d1
Diff: http://git-wip-us.apache.org/repos/asf/incubator-systemml/diff/43c321d1

Branch: refs/heads/master
Commit: 43c321d18675d9b76483e0d1d8b156196172efdb
Parents: 1f5cf69
Author: Mike Dusenberry <mwdus...@us.ibm.com>
Authored: Wed Apr 26 14:40:46 2017 -0700
Committer: Mike Dusenberry <mwdus...@us.ibm.com>
Committed: Wed Apr 26 14:40:46 2017 -0700

----------------------------------------------------------------------
 scripts/nn/README.md                            |  183 ++
 scripts/nn/examples/Example - MNIST LeNet.ipynb |  189 ++
 .../Example - MNIST Softmax Classifier.ipynb    |  179 ++
 scripts/nn/examples/README.md                   |   74 +
 scripts/nn/examples/get_mnist_data.sh           |   28 +
 scripts/nn/examples/mnist_lenet-predict.dml     |   91 +
 scripts/nn/examples/mnist_lenet-train.dml       |  123 ++
 scripts/nn/examples/mnist_lenet.dml             |  331 ++++
 scripts/nn/examples/mnist_softmax-predict.dml   |   77 +
 scripts/nn/examples/mnist_softmax-train.dml     |  110 ++
 scripts/nn/examples/mnist_softmax.dml           |  178 ++
 scripts/nn/layers/affine.dml                    |   92 +
 scripts/nn/layers/batch_norm1d.dml              |  210 +++
 scripts/nn/layers/batch_norm2d.dml              |  238 +++
 scripts/nn/layers/conv2d.dml                    |  194 ++
 scripts/nn/layers/conv2d_builtin.dml            |  160 ++
 scripts/nn/layers/cross_entropy_loss.dml        |   78 +
 scripts/nn/layers/dropout.dml                   |   76 +
 scripts/nn/layers/l1_loss.dml                   |   72 +
 scripts/nn/layers/l1_reg.dml                    |   56 +
 scripts/nn/layers/l2_loss.dml                   |   72 +
 scripts/nn/layers/l2_reg.dml                    |   56 +
 scripts/nn/layers/log_loss.dml                  |   76 +
 scripts/nn/layers/lstm.dml                      |  260 +++
 scripts/nn/layers/max_pool2d.dml                |  159 ++
 scripts/nn/layers/max_pool2d_builtin.dml        |  103 +
 scripts/nn/layers/relu.dml                      |   59 +
 scripts/nn/layers/rnn.dml                       |  183 ++
 scripts/nn/layers/scale_shift1d.dml             |   95 +
 scripts/nn/layers/scale_shift2d.dml             |  107 ++
 scripts/nn/layers/sigmoid.dml                   |   62 +
 scripts/nn/layers/softmax.dml                   |   87 +
 scripts/nn/layers/tanh.dml                      |   65 +
 scripts/nn/optim/adagrad.dml                    |   77 +
 scripts/nn/optim/adam.dml                       |   97 +
 scripts/nn/optim/rmsprop.dml                    |   79 +
 scripts/nn/optim/sgd.dml                        |   42 +
 scripts/nn/optim/sgd_momentum.dml               |   71 +
 scripts/nn/optim/sgd_nesterov.dml               |   81 +
 scripts/nn/test/README.md                       |   32 +
 scripts/nn/test/conv2d_simple.dml               |  213 +++
 scripts/nn/test/grad_check.dml                  | 1769 ++++++++++++++++++
 scripts/nn/test/max_pool2d_simple.dml           |  172 ++
 scripts/nn/test/run_tests.dml                   |   90 +
 scripts/nn/test/test.dml                        |  549 ++++++
 scripts/nn/test/util.dml                        |  155 ++
 scripts/nn/util.dml                             |  202 ++
 scripts/staging/SystemML-NN/README.md           |  183 --
 .../nn/examples/Example - MNIST LeNet.ipynb     |  189 --
 .../Example - MNIST Softmax Classifier.ipynb    |  179 --
 .../staging/SystemML-NN/nn/examples/README.md   |   74 -
 .../SystemML-NN/nn/examples/get_mnist_data.sh   |   28 -
 .../nn/examples/mnist_lenet-predict.dml         |   91 -
 .../nn/examples/mnist_lenet-train.dml           |  123 --
 .../SystemML-NN/nn/examples/mnist_lenet.dml     |  331 ----
 .../nn/examples/mnist_softmax-predict.dml       |   77 -
 .../nn/examples/mnist_softmax-train.dml         |  110 --
 .../SystemML-NN/nn/examples/mnist_softmax.dml   |  178 --
 .../staging/SystemML-NN/nn/layers/affine.dml    |   92 -
 .../SystemML-NN/nn/layers/batch_norm1d.dml      |  210 ---
 .../SystemML-NN/nn/layers/batch_norm2d.dml      |  238 ---
 .../staging/SystemML-NN/nn/layers/conv2d.dml    |  194 --
 .../SystemML-NN/nn/layers/conv2d_builtin.dml    |  160 --
 .../nn/layers/cross_entropy_loss.dml            |   78 -
 .../staging/SystemML-NN/nn/layers/dropout.dml   |   76 -
 .../staging/SystemML-NN/nn/layers/l1_loss.dml   |   72 -
 .../staging/SystemML-NN/nn/layers/l1_reg.dml    |   56 -
 .../staging/SystemML-NN/nn/layers/l2_loss.dml   |   72 -
 .../staging/SystemML-NN/nn/layers/l2_reg.dml    |   56 -
 .../staging/SystemML-NN/nn/layers/log_loss.dml  |   76 -
 scripts/staging/SystemML-NN/nn/layers/lstm.dml  |  260 ---
 .../SystemML-NN/nn/layers/max_pool2d.dml        |  159 --
 .../nn/layers/max_pool2d_builtin.dml            |  103 -
 scripts/staging/SystemML-NN/nn/layers/relu.dml  |   59 -
 scripts/staging/SystemML-NN/nn/layers/rnn.dml   |  183 --
 .../SystemML-NN/nn/layers/scale_shift1d.dml     |   95 -
 .../SystemML-NN/nn/layers/scale_shift2d.dml     |  107 --
 .../staging/SystemML-NN/nn/layers/sigmoid.dml   |   62 -
 .../staging/SystemML-NN/nn/layers/softmax.dml   |   87 -
 scripts/staging/SystemML-NN/nn/layers/tanh.dml  |   65 -
 .../staging/SystemML-NN/nn/optim/adagrad.dml    |   77 -
 scripts/staging/SystemML-NN/nn/optim/adam.dml   |   97 -
 .../staging/SystemML-NN/nn/optim/rmsprop.dml    |   79 -
 scripts/staging/SystemML-NN/nn/optim/sgd.dml    |   42 -
 .../SystemML-NN/nn/optim/sgd_momentum.dml       |   71 -
 .../SystemML-NN/nn/optim/sgd_nesterov.dml       |   81 -
 scripts/staging/SystemML-NN/nn/test/README.md   |   32 -
 .../SystemML-NN/nn/test/conv2d_simple.dml       |  213 ---
 .../staging/SystemML-NN/nn/test/grad_check.dml  | 1769 ------------------
 .../SystemML-NN/nn/test/max_pool2d_simple.dml   |  172 --
 .../staging/SystemML-NN/nn/test/run_tests.dml   |   90 -
 scripts/staging/SystemML-NN/nn/test/test.dml    |  549 ------
 scripts/staging/SystemML-NN/nn/test/util.dml    |  155 --
 scripts/staging/SystemML-NN/nn/util.dml         |  202 --
 94 files changed, 7752 insertions(+), 7752 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/README.md
----------------------------------------------------------------------
diff --git a/scripts/nn/README.md b/scripts/nn/README.md
new file mode 100644
index 0000000..b80f2c6
--- /dev/null
+++ b/scripts/nn/README.md
@@ -0,0 +1,183 @@
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% endcomment %}
+-->
+
+# SystemML-NN
+
+### A deep learning library for [Apache 
SystemML](https://github.com/apache/incubator-systemml).
+
+## Examples:
+#### Please see the [`examples`](nn/examples) folder for more detailed 
examples, or view the following two quick examples.
+### Neural net for regression with vanilla SGD:
+```python
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/l2_loss.dml") as l2_loss
+source("nn/layers/relu.dml") as relu
+source("nn/optim/sgd.dml") as sgd
+
+# Generate input data
+N = 1024 # num examples
+D = 100 # num features
+t = 1 # num targets
+X = rand(rows=N, cols=D, pdf="normal")
+y = rand(rows=N, cols=t)
+
+# Create 2-layer network:
+## affine1 -> relu1 -> affine2
+M = 64 # number of neurons
+[W1, b1] = affine::init(D, M)
+[W2, b2] = affine::init(M, t)
+
+# Initialize optimizer
+lr = 0.05  # learning rate
+mu = 0.9  # momentum
+decay = 0.99  # learning rate decay constant
+
+# Optimize
+print("Starting optimization")
+batch_size = 32
+epochs = 5
+iters = 1024 / batch_size
+for (e in 1:epochs) {
+  for(i in 1:iters) {
+    # Get next batch
+    X_batch = X[i:i+batch_size-1,]
+    y_batch = y[i:i+batch_size-1,]
+
+    # Compute forward pass
+    out1 = affine::forward(X_batch, W1, b1)
+    outr1 = relu::forward(out1)
+    out2 = affine::forward(outr1, W2, b2)
+
+    # Compute loss
+    loss = l2_loss::forward(out2, y_batch)
+    print("L2 loss: " + loss)
+
+    # Compute backward pass
+    dout2 = l2_loss::backward(out2, y_batch)
+    [doutr1, dW2, db2] = affine::backward(dout2, outr1, W2, b2)
+    dout1 = relu::backward(doutr1, out1)
+    [dX_batch, dW1, db1] = affine::backward(dout1, X_batch, W1, b1)
+
+    # Optimize with vanilla SGD
+    W1 = sgd::update(W1, dW1, lr)
+    b1 = sgd::update(b1, db1, lr)
+    W2 = sgd::update(W2, dW2, lr)
+    b2 = sgd::update(b2, db2, lr)
+  }
+  # Decay learning rate
+  lr = lr * decay
+}
+```
+
+### Neural net for multi-class classification with dropout and SGD w/ Nesterov 
momentum:
+```python
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+# Generate input data
+N = 1024 # num examples
+D = 100 # num features
+t = 5 # num targets
+X = rand(rows=N, cols=D, pdf="normal")
+classes = round(rand(rows=N, cols=1, min=1, max=t, pdf="uniform"))
+y = matrix(0, rows=N, cols=t)
+parfor (i in 1:N) {
+  y[i, as.scalar(classes[i,1])] = 1  # one-hot encoding
+}
+
+# Create network:
+# affine1 -> relu1 -> dropout1 -> affine2 -> relu2 -> dropout2 -> affine3 -> 
softmax
+H1 = 64 # number of neurons in 1st hidden layer
+H2 = 64 # number of neurons in 2nd hidden layer
+p = 0.5  # dropout probability
+[W1, b1] = affine::init(D, H1)
+[W2, b2] = affine::init(H1, H2)
+[W3, b3] = affine::init(H2, t)
+
+# Initialize SGD w/ Nesterov momentum optimizer
+lr = 0.05  # learning rate
+mu = 0.5  # momentum
+decay = 0.99  # learning rate decay constant
+vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+
+# Optimize
+print("Starting optimization")
+batch_size = 64
+epochs = 10
+iters = 1024 / batch_size
+for (e in 1:epochs) {
+  for(i in 1:iters) {
+    # Get next batch
+    X_batch = X[i:i+batch_size-1,]
+    y_batch = y[i:i+batch_size-1,]
+
+    # Compute forward pass
+    ## layer 1:
+    out1 = affine::forward(X_batch, W1, b1)
+    outr1 = relu::forward(out1)
+    [outd1, maskd1] = dropout::forward(outr1, p, -1)
+    ## layer 2:
+    out2 = affine::forward(outd1, W2, b2)
+    outr2 = relu::forward(out2)
+    [outd2, maskd2] = dropout::forward(outr2, p, -1)
+    ## layer 3:
+    out3 = affine::forward(outd2, W3, b3)
+    probs = softmax::forward(out3)
+
+    # Compute loss
+    loss = cross_entropy_loss::forward(probs, y_batch)
+    print("Cross entropy loss: " + loss)
+
+    # Compute backward pass
+    ## loss:
+    dprobs = cross_entropy_loss::backward(probs, y_batch)
+    ## layer 3:
+    dout3 = softmax::backward(dprobs, out3)
+    [doutd2, dW3, db3] = affine::backward(dout3, outd2, W3, b3)
+    ## layer 2:
+    doutr2 = dropout::backward(doutd2, outr2, p, maskd2)
+    dout2 = relu::backward(doutr2, out2)
+    [doutd1, dW2, db2] = affine::backward(dout2, outd1, W2, b2)
+    ## layer 1:
+    doutr1 = dropout::backward(doutd1, outr1, p, maskd1)
+    dout1 = relu::backward(doutr1, out1)
+    [dX_batch, dW1, db1] = affine::backward(dout1, X_batch, W1, b1)
+
+    # Optimize with SGD w/ Nesterov momentum
+    [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+    [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+    [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+    [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+    [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+    [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+  }
+  # Anneal momentum towards 0.999
+  mu = mu + (0.999 - mu)/(1+epochs-e)
+  # Decay learning rate
+  lr = lr * decay
+}
+```

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/Example
 - MNIST LeNet.ipynb
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/Example - MNIST LeNet.ipynb 
b/scripts/nn/examples/Example - MNIST LeNet.ipynb
new file mode 100644
index 0000000..0423269
--- /dev/null
+++ b/scripts/nn/examples/Example - MNIST LeNet.ipynb   
@@ -0,0 +1,189 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Quick Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# Create a SystemML MLContext object\n",
+    "from systemml import MLContext, dml\n",
+    "ml = MLContext(sc)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Data - MNIST"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The MNIST dataset contains labeled images of handwritten digits, where 
each example is a 28x28 pixel image of grayscale values in the range [0,255] 
stretched out as 784 pixels, and each label is one of 10 possible digits in 
[0,9].  Here, we download 60,000 training examples, and 10,000 test examples, 
where the format is \"label, pixel_1, pixel_2, ..., pixel_n\"."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "%%sh\n",
+    "mkdir -p data/mnist/\n",
+    "cd data/mnist/\n",
+    "curl -O https://pjreddie.com/media/files/mnist_train.csv\n";,
+    "curl -O https://pjreddie.com/media/files/mnist_test.csv";
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## SystemML \"LeNet\" Neural Network"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. Train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "script_string = \"\"\"\n",
+    "source(\"nn/examples/mnist_lenet.dml\") as mnist_lenet\n",
+    "\n",
+    "# Read training data\n",
+    "data = read($data, format=\"csv\")\n",
+    "n = nrow(data)\n",
+    "\n",
+    "# Extract images and labels\n",
+    "images = data[,2:ncol(data)]\n",
+    "labels = data[,1]\n",
+    "\n",
+    "# Scale images to [-1,1], and one-hot encode the labels\n",
+    "images = (images / 255.0) * 2 - 1\n",
+    "labels = table(seq(1, n), labels+1, n, 10)\n",
+    "\n",
+    "# Split into training (55,000 examples) and validation (5,000 
examples)\n",
+    "X = images[5001:nrow(images),]\n",
+    "X_val = images[1:5000,]\n",
+    "y = labels[5001:nrow(images),]\n",
+    "y_val = labels[1:5000,]\n",
+    "\n",
+    "# Train\n",
+    "epochs = 10\n",
+    "[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, y, X_val, y_val, 
C, Hin, Win, epochs)\n",
+    "\"\"\"\n",
+    "script = (dml(script_string).input(\"$data\", 
\"data/mnist/mnist_train.csv\")\n",
+    "                            .input(C=1, Hin=28, Win=28)\n",
+    "                            .output(\"W1\", \"b1\", \"W2\", \"b2\", 
\"W3\", \"b3\", \"W4\", \"b4\"))\n",
+    "W1, b1, W2, b2, W3, b3, W4, b4 = (ml.execute(script)\n",
+    "                                    .get(\"W1\", \"b1\", \"W2\", \"b2\", 
\"W3\", \"b3\", \"W4\", \"b4\"))"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. Compute Test Accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "script_string = \"\"\"\n",
+    "source(\"nn/examples/mnist_lenet.dml\") as mnist_lenet\n",
+    "\n",
+    "# Read test data\n",
+    "data = read($data, format=\"csv\")\n",
+    "n = nrow(data)\n",
+    "\n",
+    "# Extract images and labels\n",
+    "X_test = data[,2:ncol(data)]\n",
+    "y_test = data[,1]\n",
+    "\n",
+    "# Scale images to [-1,1], and one-hot encode the labels\n",
+    "X_test = (X_test / 255.0) * 2 - 1\n",
+    "y_test = table(seq(1, n), y_test+1, n, 10)\n",
+    "\n",
+    "# Eval on test set\n",
+    "probs = mnist_lenet::predict(X_test, C, Hin, Win, W1, b1, W2, b2, W3, b3, 
W4, b4)\n",
+    "[loss, accuracy] = mnist_lenet::eval(probs, y_test)\n",
+    "\n",
+    "print(\"Test Accuracy: \" + accuracy)\n",
+    "\"\"\"\n",
+    "script = dml(script_string).input(**{\"$data\": 
\"data/mnist/mnist_train.csv\",\n",
+    "                                     \"C\": 1, \"Hin\": 28, \"Win\": 
28,\n",
+    "                                     \"W1\": W1, \"b1\": b1,\n",
+    "                                     \"W2\": W2, \"b2\": b2,\n",
+    "                                     \"W3\": W3, \"b3\": b3,\n",
+    "                                     \"W4\": W4, \"b4\": b4})\n",
+    "ml.execute(script)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 3. Extract Model Into Spark DataFrames For Future Use"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "W1_df = W1.toDF()\n",
+    "b1_df = b1.toDF()\n",
+    "W2_df = W2.toDF()\n",
+    "b2_df = b2.toDF()\n",
+    "W3_df = W3.toDF()\n",
+    "b3_df = b3.toDF()\n",
+    "W4_df = W4.toDF()\n",
+    "b4_df = b4.toDF()\n",
+    "W1_df, b1_df, W2_df, b2_df, W3_df, b3_df, W4_df, b4_df"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 + Spark 2.x + SystemML",
+   "language": "python",
+   "name": "pyspark3_2.x"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.1"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/Example
 - MNIST Softmax Classifier.ipynb
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb 
b/scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb
new file mode 100644
index 0000000..5e7182a
--- /dev/null
+++ b/scripts/nn/examples/Example - MNIST Softmax Classifier.ipynb      
@@ -0,0 +1,179 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Quick Setup"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "scrolled": false
+   },
+   "outputs": [],
+   "source": [
+    "# Create a SystemML MLContext object\n",
+    "from systemml import MLContext, dml\n",
+    "ml = MLContext(sc)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## Download Data - MNIST"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "The MNIST dataset contains labeled images of handwritten digits, where 
each example is a 28x28 pixel image of grayscale values in the range [0,255] 
stretched out as 784 pixels, and each label is one of 10 possible digits in 
[0,9].  Here, we download 60,000 training examples, and 10,000 test examples, 
where the format is \"label, pixel_1, pixel_2, ..., pixel_n\"."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {
+    "collapsed": true
+   },
+   "outputs": [],
+   "source": [
+    "%%sh\n",
+    "mkdir -p data/mnist/\n",
+    "cd data/mnist/\n",
+    "curl -O https://pjreddie.com/media/files/mnist_train.csv\n";,
+    "curl -O https://pjreddie.com/media/files/mnist_test.csv";
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "## SystemML Softmax Model"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 1. Train"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "training = \"\"\"\n",
+    "source(\"nn/examples/mnist_softmax.dml\") as mnist_softmax\n",
+    "\n",
+    "# Read training data\n",
+    "data = read($data, format=\"csv\")\n",
+    "n = nrow(data)\n",
+    "\n",
+    "# Extract images and labels\n",
+    "images = data[,2:ncol(data)]\n",
+    "labels = data[,1]\n",
+    "\n",
+    "# Scale images to [0,1], and one-hot encode the labels\n",
+    "images = images / 255.0\n",
+    "labels = table(seq(1, n), labels+1, n, 10)\n",
+    "\n",
+    "# Split into training (55,000 examples) and validation (5,000 
examples)\n",
+    "X = images[5001:nrow(images),]\n",
+    "X_val = images[1:5000,]\n",
+    "y = labels[5001:nrow(images),]\n",
+    "y_val = labels[1:5000,]\n",
+    "\n",
+    "# Train\n",
+    "epochs = 1\n",
+    "[W, b] = mnist_softmax::train(X, y, X_val, y_val, epochs)\n",
+    "\"\"\"\n",
+    "script = dml(training).input(\"$data\", 
\"data/mnist/mnist_train.csv\").output(\"W\", \"b\")\n",
+    "W, b = ml.execute(script).get(\"W\", \"b\")"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 2. Compute Test Accuracy"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "testing = \"\"\"\n",
+    "source(\"nn/examples/mnist_softmax.dml\") as mnist_softmax\n",
+    "\n",
+    "# Read test data\n",
+    "data = read($data, format=\"csv\")\n",
+    "n = nrow(data)\n",
+    "\n",
+    "# Extract images and labels\n",
+    "X_test = data[,2:ncol(data)]\n",
+    "y_test = data[,1]\n",
+    "\n",
+    "# Scale images to [0,1], and one-hot encode the labels\n",
+    "X_test = X_test / 255.0\n",
+    "y_test = table(seq(1, n), y_test+1, n, 10)\n",
+    "\n",
+    "# Eval on test set\n",
+    "probs = mnist_softmax::predict(X_test, W, b)\n",
+    "[loss, accuracy] = mnist_softmax::eval(probs, y_test)\n",
+    "\n",
+    "print(\"Test Accuracy: \" + accuracy)\n",
+    "\"\"\"\n",
+    "script = dml(testing).input(\"$data\", \"data/mnist/mnist_test.csv\", 
W=W, b=b)\n",
+    "ml.execute(script)"
+   ]
+  },
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "### 3. Extract Model Into Spark DataFrames For Future Use"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "W_df = W.toDF()\n",
+    "b_df = b.toDF()\n",
+    "W_df, b_df"
+   ]
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.6.1"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 1
+}

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/README.md
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/README.md b/scripts/nn/examples/README.md
new file mode 100644
index 0000000..d5e9d04
--- /dev/null
+++ b/scripts/nn/examples/README.md
@@ -0,0 +1,74 @@
+<!--
+{% comment %}
+Licensed to the Apache Software Foundation (ASF) under one or more
+contributor license agreements.  See the NOTICE file distributed with
+this work for additional information regarding copyright ownership.
+The ASF licenses this file to you under the Apache License, Version 2.0
+(the "License"); you may not use this file except in compliance with
+the License.  You may obtain a copy of the License at
+
+http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+{% endcomment %}
+-->
+
+# SystemML-NN Examples
+
+#### This folder contains scripts and PySpark Jupyter notebooks serving as 
examples of using the *SystemML-NN* (`nn`) deep learning library.
+
+---
+
+# Examples
+### MNIST Softmax Classifier
+
+* This example trains a softmax classifier, which is essentially a multi-class 
logistic regression model, on the MNIST data.  The model will be trained on the 
*training* images, validated on the *validation* images, and tested for final 
performance metrics on the *test* images.
+* Notebook: `Example - MNIST Softmax Classifier.ipynb`.
+* DML Functions: `mnist_softmax.dml`
+* Training script: `mnist_softmax-train.dml`
+* Prediction script: `mnist_softmax-predict.dml`
+
+### MNIST "LeNet" Neural Net
+
+* This example trains a neural network on the MNIST data using a ["LeNet" 
architecture](http://yann.lecun.com/exdb/publis/pdf/lecun-98.pdf). The model 
will be trained on the *training* images, validated on the *validation* images, 
and tested for final performance metrics on the *test* images.
+* Notebook: `Example - MNIST LeNet.ipynb`.
+* DML Functions: `mnist_lenet.dml`
+* Training script: `mnist_lenet-train.dml`
+* Prediction script: `mnist_lenet-predict.dml`
+
+---
+
+# Setup
+## Code
+* To run the examples, please first download and unzip the project via GitHub 
using the "Clone or download" button on the [homepage of the 
project](https://github.com/dusenberrymw/systemml-nn), *or* via the following 
commands:
+
+  ```
+  git clone https://github.com/dusenberrymw/systemml-nn.git
+  ```
+
+* Then, move into the `systemml-nn` folder via:
+  ```
+  cd systemml-nn
+  ```
+
+## Data
+* These examples use the classic [MNIST](http://yann.lecun.com/exdb/mnist/) 
dataset, which contains labeled 28x28 pixel images of handwritten digits in the 
range of 0-9.  There are 60,000 training images, and 10,000 testing images.  Of 
the 60,000 training images, 5,000 will be used as validation images.
+* **Download**:
+  * **Notebooks**: The data will be automatically downloaded as a step in 
either of the example notebooks.
+  * **Training scripts**: Please run `get_mnist_data.sh` to download the data 
separately.
+
+## Execution
+* These examples contain scripts written in SystemML's R-like language 
(`*.dml`), as well as PySpark Jupyter notebooks (`*.ipynb`).  The scripts 
contain the math for the algorithms, enclosed in functions, and the notebooks 
serve as full, end-to-end examples of reading in data, training models using 
the functions within the scripts, and evaluating final performance.
+* **Notebooks**: To run the notebook examples, please install the SystemML 
Python package with `pip install systemml`, and then startup Jupyter in the 
following manner from this directory (or for more information, please see [this 
great blog 
post](http://spark.tc/0-to-life-changing-application-with-apache-systemml/)):
+
+  ```
+  PYSPARK_DRIVER_PYTHON=jupyter PYSPARK_DRIVER_PYTHON_OPTS="notebook" pyspark 
--master local[*] --driver-memory 3G --driver-class-path SystemML.jar --jars 
SystemML.jar
+  ```
+
+  Note that all printed output, such as training statistics, from the SystemML 
scripts will be sent to the terminal in which Jupyter was started (for now...).
+
+* **Scripts**: To run the scripts from the command line using `spark-submit`, 
please see the comments located at the top of the `-train` and `-predict` 
scripts.

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/get_mnist_data.sh
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/get_mnist_data.sh 
b/scripts/nn/examples/get_mnist_data.sh
new file mode 100755
index 0000000..deb0c40
--- /dev/null
+++ b/scripts/nn/examples/get_mnist_data.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+DIR="$(cd "$(dirname "$0")" && pwd)"
+mkdir -p $DIR/data/mnist/
+cd $DIR/data/mnist/
+curl -O https://pjreddie.com/media/files/mnist_train.csv
+curl -O https://pjreddie.com/media/files/mnist_test.csv
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_lenet-predict.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_lenet-predict.dml 
b/scripts/nn/examples/mnist_lenet-predict.dml
new file mode 100644
index 0000000..85a5307
--- /dev/null
+++ b/scripts/nn/examples/mnist_lenet-predict.dml
@@ -0,0 +1,91 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# MNIST LeNet - Predict
+#
+# This script computes the class probability predictions of a
+# trained convolutional net using the "LeNet" architecture on
+# images of handwritten digits.
+#
+# Inputs:
+#  - X: File containing training images.
+#     The format is "pixel_1, pixel_2, ..., pixel_n".
+#  - C: Number of color chanels in the images.
+#  - Hin: Input image height.
+#  - Win: Input image width.
+#  - model_dir: Directory containing the trained weights and biases
+#     of the model.
+#  - out_dir: Directory to store class probability predictions for
+#     each image.
+#  - fmt: [DEFAULT: "csv"] File format of `X` and output predictions.
+#     Options include: "csv", "mm", "text", and "binary".
+#
+# Outputs:
+#  - probs: File containing class probability predictions for each
+#     image.
+#
+# Data:
+# The X file should contain images of handwritten digits,
+# where each example is a 28x28 pixel image of grayscale values in
+# the range [0,255] stretched out as 784 pixels.
+#
+# Sample Invocation (running from outside the `nn` folder):
+# 1. Download images.
+#
+#   For example, save images to `nn/examples/data/mnist/images.csv`.
+#
+# 2. Execute using Spark
+#   ```
+#   spark-submit --master local[*] --driver-memory 5G
+#   --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128
+#   $SYSTEMML_HOME/target/SystemML.jar -f nn/examples/mnist_lenet-predict.dml
+#   -nvargs X=nn/examples/data/mnist/images.csv C=1 Hin=28 Win=28
+#   model_dir=nn/examples/model/mnist_lenet out_dir=nn/examples/data/mnist
+#   ```
+#
+source("nn/examples/mnist_lenet.dml") as mnist_lenet
+
+# Read training data
+fmt = ifdef($fmt, "csv")
+X = read($X, format=fmt)
+C = $C
+Hin = $Hin
+Win = $Win
+
+# Scale images to [-1,1]
+X = (X / 255.0) * 2 - 1
+
+# Read model coefficients
+W1 = read($model_dir+"/W1")
+b1 = read($model_dir+"/b1")
+W2 = read($model_dir+"/W2")
+b2 = read($model_dir+"/b2")
+W3 = read($model_dir+"/W3")
+b3 = read($model_dir+"/b3")
+W4 = read($model_dir+"/W4")
+b4 = read($model_dir+"/b4")
+
+# Predict classes
+probs = mnist_lenet::predict(X, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4)
+
+# Output results
+write(probs, $out_dir+"/probs."+fmt, format=fmt)
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_lenet-train.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_lenet-train.dml 
b/scripts/nn/examples/mnist_lenet-train.dml
new file mode 100644
index 0000000..0fc733e
--- /dev/null
+++ b/scripts/nn/examples/mnist_lenet-train.dml
@@ -0,0 +1,123 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# MNIST LeNet - Train
+#
+# This script trains a convolutional net using the "LeNet" architecture
+# on images of handwritten digits.
+#
+# Inputs:
+#  - train: File containing labeled MNIST training images.
+#     The format is "label, pixel_1, pixel_2, ..., pixel_n".
+#  - test: File containing labeled MNIST test images.
+#     The format is "label, pixel_1, pixel_2, ..., pixel_n".
+#  - C: Number of color chanels in the images.
+#  - Hin: Input image height.
+#  - Win: Input image width.
+#  - epochs: [DEFAULT: 10] Total number of full training loops over
+#     the full data set.
+#  - out_dir: [DEFAULT: "."] Directory to store weights and bias
+#     matrices of trained model, as well as final test accuracy.
+#  - fmt: [DEFAULT: "csv"] File format of `train` and `test` data.
+#     Options include: "csv", "mm", "text", and "binary".
+#
+# Outputs:
+#  - W1, W2, W3, W4: Files containing the trained weights of the model.
+#  - b1, b2, b3, b4: Files containing the trained biases of the model.
+#  - accuracy: File containing the final accuracy on the test data.
+#
+# Data:
+# The MNIST dataset contains labeled images of handwritten digits,
+# where each example is a 28x28 pixel image of grayscale values in
+# the range [0,255] stretched out as 784 pixels, and each label is
+# one of 10 possible digits in [0,9].
+#
+# Sample Invocation (running from outside the `nn` folder):
+# 1. Download data (60,000 training examples, and 10,000 test examples)
+#   ```
+#   nn/examples/get_mnist_data.sh
+#   ```
+#
+# 2. Execute using Spark
+#   ```
+#   spark-submit --master local[*] --driver-memory 10G
+#   --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128
+#   $SYSTEMML_HOME/target/SystemML.jar -f nn/examples/mnist_lenet-train.dml
+#   -nvargs train=nn/examples/data/mnist/mnist_train.csv 
test=nn/examples/data/mnist/mnist_test.csv
+#   C=1 Hin=28 Win=28 epochs=10 out_dir=nn/examples/model/mnist_lenet
+#   ```
+#
+source("nn/examples/mnist_lenet.dml") as mnist_lenet
+
+# Read training data & settings
+fmt = ifdef($fmt, "csv")
+train = read($train, format=fmt)
+test = read($test, format=fmt)
+C = $C
+Hin = $Hin
+Win = $Win
+epochs = ifdef($epochs, 10)
+out_dir = ifdef($out_dir, ".")
+
+# Extract images and labels
+images = train[,2:ncol(train)]
+labels = train[,1]
+X_test = test[,2:ncol(test)]
+y_test = test[,1]
+
+# Scale images to [-1,1], and one-hot encode the labels
+n = nrow(train)
+n_test = nrow(test)
+images = (images / 255.0) * 2 - 1
+labels = table(seq(1, n), labels+1, n, 10)
+X_test = (X_test / 255.0) * 2 - 1
+y_test = table(seq(1, n_test), y_test+1, n_test, 10)
+
+# Split into training (55,000 examples) and validation (5,000 examples)
+X = images[5001:nrow(images),]
+X_val = images[1:5000,]
+y = labels[5001:nrow(images),]
+y_val = labels[1:5000,]
+
+# Train
+[W1, b1, W2, b2, W3, b3, W4, b4] = mnist_lenet::train(X, y, X_val, y_val, C, 
Hin, Win, epochs)
+
+# Write model out
+write(W1, out_dir+"/W1")
+write(b1, out_dir+"/b1")
+write(W2, out_dir+"/W2")
+write(b2, out_dir+"/b2")
+write(W3, out_dir+"/W3")
+write(b3, out_dir+"/b3")
+write(W4, out_dir+"/W4")
+write(b4, out_dir+"/b4")
+
+# Eval on test set
+probs = mnist_lenet::predict(X_test, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, 
b4)
+[loss, accuracy] = mnist_lenet::eval(probs, y_test)
+
+# Output results
+print("Test Accuracy: " + accuracy)
+write(accuracy, out_dir+"/accuracy")
+
+print("")
+print("")
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_lenet.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_lenet.dml 
b/scripts/nn/examples/mnist_lenet.dml
new file mode 100644
index 0000000..e5755c4
--- /dev/null
+++ b/scripts/nn/examples/mnist_lenet.dml
@@ -0,0 +1,331 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST LeNet Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/conv2d_builtin.dml") as conv2d
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/dropout.dml") as dropout
+source("nn/layers/l2_reg.dml") as l2_reg
+source("nn/layers/max_pool2d_builtin.dml") as max_pool2d
+source("nn/layers/relu.dml") as relu
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int C, int Hin, int Win, int epochs)
+    return (matrix[double] W1, matrix[double] b1,
+            matrix[double] W2, matrix[double] b2,
+            matrix[double] W3, matrix[double] b3,
+            matrix[double] W4, matrix[double] b4) {
+  /*
+   * Trains a convolutional net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.  The targets, y, have K
+   * classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - y_val: Target validation matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   */
+  N = nrow(X)
+  K = ncol(y)
+
+  # Create network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = 32  # num conv filters in conv1
+  F2 = 64  # num conv filters in conv2
+  N3 = 512  # num nodes in affine3
+  # Note: affine4 has K nodes, which is equal to the number of target 
dimensions (num classes)
+
+  [W1, b1] = conv2d::init(F1, C, Hf, Wf)  # inputs: (N, C*Hin*Win)
+  [W2, b2] = conv2d::init(F2, F1, Hf, Wf)  # inputs: (N, F1*(Hin/2)*(Win/2))
+  [W3, b3] = affine::init(F2*(Hin/2/2)*(Win/2/2), N3)  # inputs: (N, 
F2*(Hin/2/2)*(Win/2/2))
+  [W4, b4] = affine::init(N3, K)  # inputs: (N, N3)
+  W4 = W4 / sqrt(2)  # different initialization, since being fed into softmax, 
instead of relu
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.01  # learning rate
+  mu = 0.9  #0.5  # momentum
+  decay = 0.95  # learning rate decay constant
+  vW1 = sgd_nesterov::init(W1); vb1 = sgd_nesterov::init(b1)
+  vW2 = sgd_nesterov::init(W2); vb2 = sgd_nesterov::init(b2)
+  vW3 = sgd_nesterov::init(W3); vb3 = sgd_nesterov::init(b3)
+  vW4 = sgd_nesterov::init(W4); vb4 = sgd_nesterov::init(b4)
+
+  # Regularization
+  lambda = 5e-04
+
+  # Optimize
+  print("Starting optimization")
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      # Compute forward pass
+      ## layer 1: conv1 -> relu1 -> pool1
+      [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                                pad, pad)
+      outr1 = relu::forward(outc1)
+      [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                    strideh=2, stridew=2, 
pad=0, pad=0)
+      ## layer 2: conv2 -> relu2 -> pool2
+      [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                                stride, stride, pad, pad)
+      outr2 = relu::forward(outc2)
+      [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                    strideh=2, stridew=2, 
pad=0, pad=0)
+      ## layer 3:  affine3 -> relu3 -> dropout
+      outa3 = affine::forward(outp2, W3, b3)
+      outr3 = relu::forward(outa3)
+      [outd3, maskd3] = dropout::forward(outr3, 0.5, -1)
+      ## layer 4:  affine4 -> softmax
+      outa4 = affine::forward(outd3, W4, b4)
+      probs = softmax::forward(outa4)
+
+      # Compute loss & accuracy for training & validation data every 100 
iterations.
+      if (i %% 100 == 0) {
+        # Compute training loss & accuracy
+        loss_data = cross_entropy_loss::forward(probs, y_batch)
+        loss_reg_W1 = l2_reg::forward(W1, lambda)
+        loss_reg_W2 = l2_reg::forward(W2, lambda)
+        loss_reg_W3 = l2_reg::forward(W3, lambda)
+        loss_reg_W4 = l2_reg::forward(W4, lambda)
+        loss = loss_data + loss_reg_W1 + loss_reg_W2 + loss_reg_W3 + 
loss_reg_W4
+        accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch))
+
+        # Compute validation loss & accuracy
+        probs_val = predict(X_val, C, Hin, Win, W1, b1, W2, b2, W3, b3, W4, b4)
+        loss_val = cross_entropy_loss::forward(probs_val, y_val)
+        accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(y_val))
+
+        # Output results
+        print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", 
Train Accuracy: "
+              + accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + 
accuracy_val)
+      }
+
+      # Compute data backward pass
+      ## loss:
+      dprobs = cross_entropy_loss::backward(probs, y_batch)
+      ## layer 4:  affine4 -> softmax
+      douta4 = softmax::backward(dprobs, outa4)
+      [doutd3, dW4, db4] = affine::backward(douta4, outr3, W4, b4)
+      ## layer 3:  affine3 -> relu3 -> dropout
+      doutr3 = dropout::backward(doutd3, outr3, 0.5, maskd3)
+      douta3 = relu::backward(doutr3, outa3)
+      [doutp2, dW3, db3] = affine::backward(douta3, outp2, W3, b3)
+      ## layer 2: conv2 -> relu2 -> pool2
+      doutr2 = max_pool2d::backward(doutp2, Houtp2, Woutp2, outr2, F2, Houtc2, 
Woutc2, Hf=2, Wf=2,
+                                    strideh=2, stridew=2, pad=0, pad=0)
+      doutc2 = relu::backward(doutr2, outc2)
+      [doutp1, dW2, db2] = conv2d::backward(doutc2, Houtc2, Woutc2, outp1, W2, 
b2, F1,
+                                            Houtp1, Woutp1, Hf, Wf, stride, 
stride, pad, pad)
+      ## layer 1: conv1 -> relu1 -> pool1
+      doutr1 = max_pool2d::backward(doutp1, Houtp1, Woutp1, outr1, F1, Houtc1, 
Woutc1, Hf=2, Wf=2,
+                                    strideh=2, stridew=2, pad=0, pad=0)
+      doutc1 = relu::backward(doutr1, outc1)
+      [dX_batch, dW1, db1] = conv2d::backward(doutc1, Houtc1, Woutc1, X_batch, 
W1, b1, C, Hin, Win,
+                                              Hf, Wf, stride, stride, pad, pad)
+
+      # Compute regularization backward pass
+      dW1_reg = l2_reg::backward(W1, lambda)
+      dW2_reg = l2_reg::backward(W2, lambda)
+      dW3_reg = l2_reg::backward(W3, lambda)
+      dW4_reg = l2_reg::backward(W4, lambda)
+      dW1 = dW1 + dW1_reg
+      dW2 = dW2 + dW2_reg
+      dW3 = dW3 + dW3_reg
+      dW4 = dW4 + dW4_reg
+
+      # Optimize with SGD w/ Nesterov momentum
+      [W1, vW1] = sgd_nesterov::update(W1, dW1, lr, mu, vW1)
+      [b1, vb1] = sgd_nesterov::update(b1, db1, lr, mu, vb1)
+      [W2, vW2] = sgd_nesterov::update(W2, dW2, lr, mu, vW2)
+      [b2, vb2] = sgd_nesterov::update(b2, db2, lr, mu, vb2)
+      [W3, vW3] = sgd_nesterov::update(W3, dW3, lr, mu, vW3)
+      [b3, vb3] = sgd_nesterov::update(b3, db3, lr, mu, vb3)
+      [W4, vW4] = sgd_nesterov::update(W4, dW4, lr, mu, vW4)
+      [b4, vb4] = sgd_nesterov::update(b4, db4, lr, mu, vb4)
+    }
+    # Anneal momentum towards 0.999
+    #mu = mu + (0.999 - mu)/(1+epochs-e)
+    # Decay learning rate
+    lr = lr * decay
+  }
+}
+
+predict = function(matrix[double] X, int C, int Hin, int Win,
+                   matrix[double] W1, matrix[double] b1,
+                   matrix[double] W2, matrix[double] b2,
+                   matrix[double] W3, matrix[double] b3,
+                   matrix[double] W4, matrix[double] b4)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a convolutional
+   * net using the "LeNet" architecture.
+   *
+   * The input matrix, X, has N examples, each represented as a 3D
+   * volume unrolled into a single vector.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, C*Hin*Win).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   *  - W1: 1st layer weights (parameters) matrix, of shape (F1, C*Hf*Wf).
+   *  - b1: 1st layer biases vector, of shape (F1, 1).
+   *  - W2: 2nd layer weights (parameters) matrix, of shape (F2, F1*Hf*Wf).
+   *  - b2: 2nd layer biases vector, of shape (F2, 1).
+   *  - W3: 3rd layer weights (parameters) matrix, of shape 
(F2*(Hin/4)*(Win/4), N3).
+   *  - b3: 3rd layer biases vector, of shape (1, N3).
+   *  - W4: 4th layer weights (parameters) matrix, of shape (N3, K).
+   *  - b4: 4th layer biases vector, of shape (1, K).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  N = nrow(X)
+
+  # Network:
+  # conv1 -> relu1 -> pool1 -> conv2 -> relu2 -> pool2 -> affine3 -> relu3 -> 
affine4 -> softmax
+  Hf = 5  # filter height
+  Wf = 5  # filter width
+  stride = 1
+  pad = 2  # For same dimensions, (Hf - stride) / 2
+
+  F1 = nrow(W1)  # num conv filters in conv1
+  F2 = nrow(W2)  # num conv filters in conv2
+  N3 = ncol(W3)  # num nodes in affine3
+  K = ncol(W4)  # num nodes in affine4, equal to number of target dimensions 
(num classes)
+
+  # Compute predictions over mini-batches
+  probs = matrix(0, rows=N, cols=K)
+  batch_size = 64
+  iters = ceil(N / batch_size)
+  for(i in 1:iters) {
+    # Get next batch
+    beg = ((i-1) * batch_size) %% N + 1
+    end = min(N, beg + batch_size - 1)
+    X_batch = X[beg:end,]
+
+    # Compute forward pass
+    ## layer 1: conv1 -> relu1 -> pool1
+    [outc1, Houtc1, Woutc1] = conv2d::forward(X_batch, W1, b1, C, Hin, Win, 
Hf, Wf, stride, stride,
+                                              pad, pad)
+    outr1 = relu::forward(outc1)
+    [outp1, Houtp1, Woutp1] = max_pool2d::forward(outr1, F1, Houtc1, Woutc1, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 2: conv2 -> relu2 -> pool2
+    [outc2, Houtc2, Woutc2] = conv2d::forward(outp1, W2, b2, F1, Houtp1, 
Woutp1, Hf, Wf,
+                                              stride, stride, pad, pad)
+    outr2 = relu::forward(outc2)
+    [outp2, Houtp2, Woutp2] = max_pool2d::forward(outr2, F2, Houtc2, Woutc2, 
Hf=2, Wf=2,
+                                                  strideh=2, stridew=2, pad=0, 
pad=0)
+    ## layer 3:  affine3 -> relu3
+    outa3 = affine::forward(outp2, W3, b3)
+    outr3 = relu::forward(outa3)
+    ## layer 4:  affine4 -> softmax
+    outa4 = affine::forward(outr3, W4, b4)
+    probs_batch = softmax::forward(outa4)
+
+    # Store predictions
+    probs[beg:end,] = probs_batch
+  }
+}
+
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a convolutional net using the "LeNet" architecture.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  K = 10  # num target classes
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=K, pdf="uniform"))
+  y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_softmax-predict.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_softmax-predict.dml 
b/scripts/nn/examples/mnist_softmax-predict.dml
new file mode 100644
index 0000000..4c8c434
--- /dev/null
+++ b/scripts/nn/examples/mnist_softmax-predict.dml
@@ -0,0 +1,77 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# MNIST Softmax - Predict
+#
+# This script computes the class probability predictions of a
+# trained softmax classifier on images of handwritten digits.
+#
+# Inputs:
+#  - X: File containing training images.
+#     The format is "pixel_1, pixel_2, ..., pixel_n".
+#  - model_dir: Directory containing the trained weights and biases
+#     of the model.
+#  - out_dir: Directory to store class probability predictions for
+#     each image.
+#  - fmt: [DEFAULT: "csv"] File format of `X` and output predictions.
+#     Options include: "csv", "mm", "text", and "binary".
+#
+# Outputs:
+#  - probs: File containing class probability predictions for each
+#     image.
+#
+# Data:
+# The X file should contain images of handwritten digits,
+# where each example is a 28x28 pixel image of grayscale values in
+# the range [0,255] stretched out as 784 pixels.
+#
+# Sample Invocation:
+# 1. Download images.
+#
+#   For example, save images to `nn/examples/data/mnist/images.csv`.
+#
+# 2. Execute using Spark
+#   ```
+#   spark-submit --master local[*] --driver-memory 5G
+#   --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128
+#   $SYSTEMML_HOME/target/SystemML.jar -f nn/examples/mnist_softmax-predict.dml
+#   -nvargs X=nn/examples/data/mnist/images.csv
+#   model_dir=nn/examples/model/mnist_softmax out_dir=nn/examples/data/mnist
+#
+source("nn/examples/mnist_softmax.dml") as mnist_softmax
+
+# Read training data
+fmt = ifdef($fmt, "csv")
+X = read($X, format=fmt)
+
+# Scale images to [0,1], and one-hot encode the labels
+X = X / 255.0
+
+# Read model coefficients
+W = read($model_dir+"/W")
+b = read($model_dir+"/b")
+
+# Predict classes
+probs = mnist_softmax::predict(X, W, b)
+
+# Output results
+write(probs, $out_dir+"/probs."+fmt, format=fmt)
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_softmax-train.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_softmax-train.dml 
b/scripts/nn/examples/mnist_softmax-train.dml
new file mode 100644
index 0000000..09970f0
--- /dev/null
+++ b/scripts/nn/examples/mnist_softmax-train.dml
@@ -0,0 +1,110 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+# MNIST Softmax - Train
+#
+# This script trains a softmax classifier on images of handwritten
+# digits.
+#
+# Inputs:
+#  - train: File containing labeled MNIST training images.
+#     The format is "label, pixel_1, pixel_2, ..., pixel_n".
+#  - test: File containing labeled MNIST test images.
+#     The format is "label, pixel_1, pixel_2, ..., pixel_n".
+#  - out_dir: Directory to store weights and bias matrices of
+#     trained model, as well as final test accuracy.
+#  - fmt: [DEFAULT: "csv"] File format of `train` and `test` data.
+#     Options include: "csv", "mm", "text", and "binary".
+#
+# Outputs:
+#  - W: File containing the trained weights of the model.
+#  - b: File containing the trained biases of the model.
+#  - accuracy: File containing the final accuracy on the test data.
+#
+# Data:
+# The MNIST dataset contains labeled images of handwritten digits,
+# where each example is a 28x28 pixel image of grayscale values in
+# the range [0,255] stretched out as 784 pixels, and each label is
+# one of 10 possible digits in [0,9].
+#
+# Sample Invocation (running from wihtin the `examples` folder):
+# 1. Download data (60,000 training examples, and 10,000 test examples)
+#   ```
+#   nn/examples/get_mnist_data.sh
+#   ```
+#
+# 2. Execute using Spark
+#   ```
+#   spark-submit --master local[*] --driver-memory 10G
+#   --conf spark.driver.maxResultSize=0 --conf spark.rpc.message.maxSize=128
+#   $SYSTEMML_HOME/target/SystemML.jar -f nn/examples/mnist_softmax-train.dml
+#   -nvargs train=nn/examples/data/mnist/mnist_train.csv 
test=nn/examples/data/mnist/mnist_test.csv
+#   epochs=1 out_dir=nn/examples/model/mnist_softmax
+#   ```
+#
+source("nn/examples/mnist_softmax.dml") as mnist_softmax
+
+# Read training data
+fmt = ifdef($fmt, "csv")
+train = read($train, format=fmt)
+test = read($test, format=fmt)
+epochs = ifdef($epochs, 1)
+out_dir = ifdef($out_dir, ".")
+
+# Extract images and labels
+images = train[,2:ncol(train)]
+labels = train[,1]
+X_test = test[,2:ncol(test)]
+y_test = test[,1]
+
+# Scale images to [0,1], and one-hot encode the labels
+n = nrow(train)
+n_test = nrow(test)
+classes = 10
+images = images / 255.0
+labels = table(seq(1, n), labels+1, n, classes)
+X_test = X_test / 255.0
+y_test = table(seq(1, n_test), y_test+1, n_test, classes)
+
+# Split into training (55,000 examples) and validation (5,000 examples)
+X = images[5001:nrow(images),]
+X_val = images[1:5000,]
+y = labels[5001:nrow(images),]
+y_val = labels[1:5000,]
+
+# Train
+[W, b] = mnist_softmax::train(X, y, X_val, y_val, epochs)
+
+# Write model out
+write(W, out_dir+"/W")
+write(b, out_dir+"/b")
+
+# Eval on test set
+probs = mnist_softmax::predict(X_test, W, b)
+[loss, accuracy] = mnist_softmax::eval(probs, y_test)
+
+# Output results
+print("Test Accuracy: " + accuracy)
+write(accuracy, out_dir+"/accuracy")
+
+print("")
+print("")
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/examples/mnist_softmax.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/examples/mnist_softmax.dml 
b/scripts/nn/examples/mnist_softmax.dml
new file mode 100644
index 0000000..a529a12
--- /dev/null
+++ b/scripts/nn/examples/mnist_softmax.dml
@@ -0,0 +1,178 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * MNIST Softmax Example
+ */
+# Imports
+source("nn/layers/affine.dml") as affine
+source("nn/layers/cross_entropy_loss.dml") as cross_entropy_loss
+source("nn/layers/softmax.dml") as softmax
+source("nn/optim/sgd_nesterov.dml") as sgd_nesterov
+
+train = function(matrix[double] X, matrix[double] y,
+                 matrix[double] X_val, matrix[double] y_val,
+                 int epochs)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Trains a softmax classifier.
+   *
+   * The input matrix, X, has N examples, each with D features.
+   * The targets, y, have K classes, and are one-hot encoded.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - y: Target matrix, of shape (N, K).
+   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).
+   *  - y_val: Target validation matrix, of shape (N, K).
+   *  - epochs: Total number of full training loops over the full data set.
+   *
+   * Outputs:
+   *  - W: Weights (parameters) matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   */
+  N = nrow(X)  # num examples
+  D = ncol(X)  # num features
+  K = ncol(y)  # num classes
+
+  # Create softmax classifier:
+  # affine -> softmax
+  [W, b] = affine::init(D, K)
+  W = W / sqrt(2.0/(D)) * sqrt(1/(D))
+
+  # Initialize SGD w/ Nesterov momentum optimizer
+  lr = 0.2  # learning rate
+  mu = 0  # momentum
+  decay = 0.99  # learning rate decay constant
+  vW = sgd_nesterov::init(W)  # optimizer momentum state for W
+  vb = sgd_nesterov::init(b)  # optimizer momentum state for b
+
+  # Optimize
+  print("Starting optimization")
+  batch_size = 50
+  iters = 1000 #ceil(N / batch_size)
+  for (e in 1:epochs) {
+    for(i in 1:iters) {
+      # Get next batch
+      beg = ((i-1) * batch_size) %% N + 1
+      end = min(N, beg + batch_size - 1)
+      X_batch = X[beg:end,]
+      y_batch = y[beg:end,]
+
+      # Compute forward pass
+      ## affine & softmax:
+      out = affine::forward(X_batch, W, b)
+      probs = softmax::forward(out)
+
+      # Compute loss & accuracy for training & validation data
+      loss = cross_entropy_loss::forward(probs, y_batch)
+      accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch))
+      probs_val = predict(X_val, W, b)
+      loss_val = cross_entropy_loss::forward(probs_val, y_val)
+      accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(y_val))
+      print("Epoch: " + e + ", Iter: " + i + ", Train Loss: " + loss + ", 
Train Accuracy: " +
+            accuracy + ", Val Loss: " + loss_val + ", Val Accuracy: " + 
accuracy_val)
+
+      # Compute backward pass
+      ## loss:
+      dprobs = cross_entropy_loss::backward(probs, y_batch)
+      ## affine & softmax:
+      dout = softmax::backward(dprobs, out)
+      [dX_batch, dW, db] = affine::backward(dout, X_batch, W, b)
+
+      # Optimize with SGD w/ Nesterov momentum
+      [W, vW] = sgd_nesterov::update(W, dW, lr, mu, vW)
+      [b, vb] = sgd_nesterov::update(b, db, lr, mu, vb)
+    }
+    # Anneal momentum towards 0.999
+    mu = mu + (0.999 - mu)/(1+epochs-e)
+    # Decay learning rate
+    lr = lr * decay
+  }
+}
+
+predict = function(matrix[double] X, matrix[double] W, matrix[double] b)
+    return (matrix[double] probs) {
+  /*
+   * Computes the class probability predictions of a softmax classifier.
+   *
+   * The input matrix, X, has N examples, each with D features.
+   *
+   * Inputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - W: Weights (parameters) matrix, of shape (D, M).
+   *  - b: Biases vector, of shape (1, M).
+   *
+   * Outputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   */
+  # Compute forward pass
+  ## affine & softmax:
+  out = affine::forward(X, W, b)
+  probs = softmax::forward(out)
+}
+
+eval = function(matrix[double] probs, matrix[double] y)
+    return (double loss, double accuracy) {
+  /*
+   * Evaluates a softmax classifier.
+   *
+   * The probs matrix contains the class probability predictions
+   * of K classes over N examples.  The targets, y, have K classes,
+   * and are one-hot encoded.
+   *
+   * Inputs:
+   *  - probs: Class probabilities, of shape (N, K).
+   *  - y: Target matrix, of shape (N, K).
+   *
+   * Outputs:
+   *  - loss: Scalar loss, of shape (1).
+   *  - accuracy: Scalar accuracy, of shape (1).
+   */
+  # Compute loss & accuracy
+  loss = cross_entropy_loss::forward(probs, y)
+  correct_pred = rowIndexMax(probs) == rowIndexMax(y)
+  accuracy = mean(correct_pred)
+}
+
+generate_dummy_data = function()
+    return (matrix[double] X, matrix[double] y, int C, int Hin, int Win) {
+  /*
+   * Generate a dummy dataset similar to the MNIST dataset.
+   *
+   * Outputs:
+   *  - X: Input data matrix, of shape (N, D).
+   *  - y: Target matrix, of shape (N, K).
+   *  - C: Number of input channels (dimensionality of input depth).
+   *  - Hin: Input height.
+   *  - Win: Input width.
+   */
+  # Generate dummy input data
+  N = 1024  # num examples
+  C = 1  # num input channels
+  Hin = 28  # input height
+  Win = 28  # input width
+  T = 10  # num targets
+  X = rand(rows=N, cols=C*Hin*Win, pdf="normal")
+  classes = round(rand(rows=N, cols=1, min=1, max=T, pdf="uniform"))
+  y = table(seq(1, N), classes)  # one-hot encoding
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/layers/affine.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/affine.dml b/scripts/nn/layers/affine.dml
new file mode 100644
index 0000000..c9a740b
--- /dev/null
+++ b/scripts/nn/layers/affine.dml
@@ -0,0 +1,92 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * Affine (fully-connected) layer.
+ */
+
+forward = function(matrix[double] X, matrix[double] W, matrix[double] b)
+    return (matrix[double] out) {
+  /*
+   * Computes the forward pass for an affine (fully-connected) layer
+   * with M neurons.  The input data has N examples, each with D
+   * features.
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, D).
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, M).
+   */
+  out = X %*% W + b
+}
+
+backward = function(matrix[double] dout, matrix[double] X,
+                    matrix[double] W, matrix[double] b)
+    return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
+  /*
+   * Computes the backward pass for a fully-connected (affine) layer
+   * with M neurons.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, M).
+   *  - X: Inputs, of shape (N, D).
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, D).
+   *  - dW: Gradient wrt `W`, of shape (D, M).
+   *  - db: Gradient wrt `b`, of shape (1, M).
+   */
+  dX = dout %*% t(W)
+  dW = t(X) %*% dout
+  db = colSums(dout)
+}
+
+init = function(int D, int M)
+    return (matrix[double] W, matrix[double] b) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * We use the heuristic by He et al., which limits the magnification
+   * of inputs/gradients during forward/backward passes by scaling
+   * unit-Gaussian weights by a factor of sqrt(2/n), under the
+   * assumption of relu neurons.
+   *  - http://arxiv.org/abs/1502.01852
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features (number of features).
+   *  - M: Number of neurons in this layer.
+   *
+   * Outputs:
+   *  - W: Weights, of shape (D, M).
+   *  - b: Biases, of shape (1, M).
+   */
+  W = rand(rows=D, cols=M, pdf="normal") * sqrt(2.0/D)
+  b = matrix(0, rows=1, cols=M)
+}
+

http://git-wip-us.apache.org/repos/asf/incubator-systemml/blob/43c321d1/scripts/nn/layers/batch_norm1d.dml
----------------------------------------------------------------------
diff --git a/scripts/nn/layers/batch_norm1d.dml 
b/scripts/nn/layers/batch_norm1d.dml
new file mode 100644
index 0000000..2ccffdb
--- /dev/null
+++ b/scripts/nn/layers/batch_norm1d.dml
@@ -0,0 +1,210 @@
+#-------------------------------------------------------------
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+#-------------------------------------------------------------
+
+/*
+ * 1D Batch Normalization layer.
+ */
+
+forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta,
+                   string mode, matrix[double] ema_mean, matrix[double] 
ema_var,
+                   double mu, double epsilon)
+    return (matrix[double] out, matrix[double] ema_mean_upd, matrix[double] 
ema_var_upd,
+            matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm) {
+  /*
+   * Computes the forward pass for a 1D batch normalization layer.
+   * The input data has N examples, each with D features.
+   *
+   * A batch normalization layer uses the per-feature sample mean and
+   * per-feature uncorrected sample variance during training to
+   * normalize each feature of the input data.  Additionally, it
+   * introduces learnable parameters (gamma, beta) to control the
+   * amount of normalization.
+   *
+   *   `y = ((x-mean) / sqrt(var+eps)) * gamma + beta`
+   *
+   * This implementation maintains exponential moving averages of the
+   * mean and variance during training for use during testing.
+   *
+   * Reference:
+   *  - Batch Normalization: Accelerating Deep Network Training by
+   *    Reducing Internal Covariate Shift, S. Ioffe & C. Szegedy, 2015
+   *    - https://arxiv.org/abs/1502.03167
+   *
+   * Inputs:
+   *  - X: Inputs, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - out: Outputs, of shape (N, D).
+   *  - ema_mean_upd: Updated exponential moving average of the mean,
+   *      of shape (1, D).
+   *  - ema_var_upd: Updated exponential moving average of the variance,
+   *      of shape (1, D).
+   *  - cache_mean: Cache of the batch mean, of shape (1, D).
+   *      Note: This is used for performance during training.
+   *  - cache_var: Cache of the batch variance, of shape (1, D).
+   *      Note: This is used for performance during training.
+   *  - cache_norm: Cache of the normalized inputs, of shape (N, D).
+   *      Note: This is used for performance during training.
+   */
+  N = nrow(X)
+
+  if (mode == 'train') {
+    # Compute feature-wise mean and variance
+    mean = colMeans(X)  # shape (1, D)
+    # var = (1/N) * colSums((X-mean)^2)
+    var = colVars(X) * ((N-1)/N)  # compute uncorrected variance, of shape (1, 
D)
+    # Update moving averages
+    ema_mean_upd = mu*ema_mean + (1-mu)*mean
+    ema_var_upd = mu*ema_var + (1-mu)*var
+  }
+  else {
+    # Use moving averages of mean and variance during testing
+    mean = ema_mean
+    var = ema_var
+    ema_mean_upd = ema_mean
+    ema_var_upd = ema_var
+  }
+
+  # Normalize, shift, and scale
+  # norm = (X-mean)*(var+epsilon)^(-1/2)
+  norm = (X-mean) / sqrt(var+epsilon)  # shape (N, D)
+  out = norm*gamma + beta  # shape (N, D)
+
+  # Save variable for backward pass
+  cache_mean = mean
+  cache_var = var
+  cache_norm = norm
+}
+
+backward = function(matrix[double] dout, matrix[double] out,
+                    matrix[double] ema_mean_upd, matrix[double] ema_var_upd,
+                    matrix[double] cache_mean, matrix[double] cache_var, 
matrix[double] cache_norm,
+                    matrix[double] X, matrix[double] gamma, matrix[double] 
beta,
+                    string mode, matrix[double] ema_mean, matrix[double] 
ema_var,
+                    double mu, double epsilon)
+      return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
+  /*
+   * Computes the backward pass for a 1D batch normalization layer.
+   *
+   * Inputs:
+   *  - dout: Gradient wrt `out` from upstream, of shape (N, D).
+   *  - out: Outputs from the forward pass, of shape (N, D).
+   *  - ema_mean_upd: Updated exponential moving average of the mean
+   *      from the forward pass, of shape (1, D).
+   *  - ema_var_upd: Updated exponential moving average of the variance
+   *      from the forward pass, of shape (1, D).
+   *  - cache_mean: Cache of the batch mean from the forward pass, of
+   *      shape (1, D).  Note: This is used for performance during
+   *      training.
+   *  - cache_var: Cache of the batch variance from the forward pass,
+   *      of shape (1, D).  Note: This is used for performance during
+   *      training.
+   *  - cache_norm: Cache of the normalized inputs from the forward
+   *      pass, of shape (N, D).  Note: This is used for performance
+   *      during training.
+   *  - X: Inputs, of shape (N, D).
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - mode: 'train' or 'test' to indicate if the model is currently
+   *      being trained or tested.  During training, the current batch
+   *      mean and variance will be used to normalize the inputs, while
+   *      during testing, the exponential average of the mean and
+   *      variance over all previous batches will be used.
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   *  - mu: Momentum value for moving averages.
+   *      Typical values are in the range of [0.9, 0.999].
+   *  - epsilon: Smoothing term to avoid divide by zero errors.
+   *      Typical values are in the range of [1e-5, 1e-3].
+   *
+   * Outputs:
+   *  - dX: Gradient wrt `X`, of shape (N, D).
+   *  - dgamma: Gradient wrt `W`, of shape (1, D).
+   *  - dbeta: Gradient wrt `b`, of shape (1, D).
+   *
+   */
+  N = nrow(X)
+  mean = cache_mean
+  var = cache_var
+  norm = cache_norm
+  centered = X-mean
+
+  if (mode == 'train') {
+    # Compute gradients during training
+    dgamma = colSums(dout*norm)  # shape (1, D)
+    dbeta = colSums(dout)  # shape (1, D)
+    dnorm = dout * gamma  # shape (N, D)
+    dvar = (-1/2) * colSums(centered * (var+epsilon)^(-3/2) * dnorm)  # shape 
(1, D)
+    dmean = colSums((-dnorm/sqrt(var+epsilon)) + ((-2/N)*centered*dvar))  # 
shape (1, D)
+    dX = (dnorm/sqrt(var+epsilon)) + ((2/N)*centered*dvar) + ((1/N)*dmean)  # 
shape (N, D)
+  }
+  else {
+    # Compute gradients during testing
+    dgamma = colSums(dout*norm)  # shape (1, D)
+    dbeta = colSums(dout)  # shape (1, D)
+    dnorm = dout * gamma  # shape (N, D)
+    dX = dnorm / sqrt(var+epsilon)  # shape (N, D)
+  }
+}
+
+init = function(int D)
+    return (matrix[double] gamma, matrix[double] beta,
+            matrix[double] ema_mean, matrix[double] ema_var) {
+  /*
+   * Initialize the parameters of this layer.
+   *
+   * Note: This is just a convenience function, and parameters
+   * may be initialized manually if needed.
+   *
+   * Inputs:
+   *  - D: Dimensionality of the input features (number of features).
+   *
+   * Outputs:
+   *  - gamma: Scale parameters, of shape (1, D).
+   *  - beta: Shift parameters, of shape (1, D).
+   *  - ema_mean: Exponential moving average of the mean, of
+   *      shape (1, D).
+   *  - ema_var: Exponential moving average of the variance, of
+   *      shape (1, D).
+   */
+   gamma = matrix(1, rows=1, cols=D)
+   beta = matrix(0, rows=1, cols=D)
+   ema_mean = matrix(0, rows=1, cols=D)
+   ema_var = matrix(1, rows=1, cols=D)
+}
+

Reply via email to