This is an automated email from the ASF dual-hosted git repository.

thomasdelteil pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 4700b40  Gluon end to end tutorial (#13411)
4700b40 is described below

commit 4700b40a9d9f2bcd915d959381203db5a8b84f89
Author: Lai Wei <roywei...@gmail.com>
AuthorDate: Thu Jan 24 11:35:14 2019 -0800

    Gluon end to end tutorial (#13411)
    
    * initial draft gluon tutorial
    
    * add reference
    
    * add cpp inference
    
    * improve wording
    
    * address pr comments
    
    * add util functions on dataset
    
    * move util file
    
    * update link
    
    * fix typo, add test
    
    * allow download
    
    * update wording
    
    * update links
    
    * address comments
    
    * use lr scheduler with optimizer
    
    * separate into 2 tutorials
    
    * add c++ tutorial to test whitelist
---
 .../data/oxford_102_flower_dataset.py              | 219 ++++++++++++++
 docs/tutorials/c++/mxnet_cpp_inference_tutorial.md | 267 ++++++++++++++++
 .../gluon/gluon_from_experiment_to_deployment.md   | 334 +++++++++++++++++++++
 tests/tutorials/test_sanity_tutorials.py           |   1 +
 tests/tutorials/test_tutorials.py                  |   3 +
 5 files changed, 824 insertions(+)

diff --git a/docs/tutorial_utils/data/oxford_102_flower_dataset.py 
b/docs/tutorial_utils/data/oxford_102_flower_dataset.py
new file mode 100644
index 0000000..0dcae22
--- /dev/null
+++ b/docs/tutorial_utils/data/oxford_102_flower_dataset.py
@@ -0,0 +1,219 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+"""
+This scripts downloads and prepares the Oxford 102 Category Flower Dataset for 
training
+Dataset is from: http://www.robots.ox.ac.uk/~vgg/data/flowers/102/index.html
+Script is modified from: 
https://github.com/Arsey/keras-transfer-learning-for-oxford102
+"""
+
+import glob
+import os
+import tarfile
+from shutil import copyfile
+
+import numpy as np
+from mxnet import gluon
+from scipy.io import loadmat
+
+label_names = [
+    'pink primrose',
+    'hard-leaved pocket orchid',
+    'canterbury bells',
+    'sweet pea',
+    'english marigold',
+    'tiger lily',
+    'moon orchid',
+    'bird of paradise',
+    'monkshood',
+    'globe thistle',
+    'snapdragon',
+    "colt's foot",
+    'king protea',
+    'spear thistle',
+    'yellow iris',
+    'globe-flower',
+    'purple coneflower',
+    'peruvian lily',
+    'balloon flower',
+    'giant white arum lily',
+    'fire lily',
+    'pincushion flower',
+    'fritillary',
+    'red ginger',
+    'grape hyacinth',
+    'corn poppy',
+    'prince of wales feathers',
+    'stemless gentian',
+    'artichoke',
+    'sweet william',
+    'carnation',
+    'garden phlox',
+    'love in the mist',
+    'mexican aster',
+    'alpine sea holly',
+    'ruby-lipped cattleya',
+    'cape flower',
+    'great masterwort',
+    'siam tulip',
+    'lenten rose',
+    'barbeton daisy',
+    'daffodil',
+    'sword lily',
+    'poinsettia',
+    'bolero deep blue',
+    'wallflower',
+    'marigold',
+    'buttercup',
+    'oxeye daisy',
+    'common dandelion',
+    'petunia',
+    'wild pansy',
+    'primula',
+    'sunflower',
+    'pelargonium',
+    'bishop of llandaff',
+    'gaura',
+    'geranium',
+    'orange dahlia',
+    'pink-yellow dahlia?',
+    'cautleya spicata',
+    'japanese anemone',
+    'black-eyed susan',
+    'silverbush',
+    'californian poppy',
+    'osteospermum',
+    'spring crocus',
+    'bearded iris',
+    'windflower',
+    'tree poppy',
+    'gazania',
+    'azalea',
+    'water lily',
+    'rose',
+    'thorn apple',
+    'morning glory',
+    'passion flower',
+    'lotus',
+    'toad lily',
+    'anthurium',
+    'frangipani',
+    'clematis',
+    'hibiscus',
+    'columbine',
+    'desert-rose',
+    'tree mallow',
+    'magnolia',
+    'cyclamen',
+    'watercress',
+    'canna lily',
+    'hippeastrum ',
+    'bee balm',
+    'ball moss',
+    'foxglove',
+    'bougainvillea',
+    'camellia',
+    'mallow',
+    'mexican petunia',
+    'bromelia',
+    'blanket flower',
+    'trumpet creeper',
+    'blackberry lily'
+]
+
+def download_data():
+    data_url = 'http://www.robots.ox.ac.uk/~vgg/data/flowers/102/'
+    image_file_name = '102flowers.tgz'
+    label_file_name = 'imagelabels.mat'
+    setid_file_name = 'setid.mat'
+
+    global data_path, image_path, label_path, setid_path
+    image_path = os.path.join(data_path, image_file_name)
+    label_path = os.path.join(data_path, label_file_name)
+    setid_path = os.path.join(data_path, setid_file_name)
+    # download the dataset into current directory
+    if not os.path.exists(data_path):
+        os.mkdir(data_path)
+    if not os.path.isfile(image_path):
+        gluon.utils.download(url=data_url + image_file_name, path=data_path)
+    if not os.path.exists(os.path.join(data_path, 'jpg')):
+        print("Extracting downloaded dataset...")
+        tarfile.open(image_path).extractall(path=data_path)
+    if not os.path.isfile(label_path):
+        gluon.utils.download(url=data_url + label_file_name, path=data_path)
+    if not os.path.isfile(setid_path):
+        gluon.utils.download(url=data_url + setid_file_name, path=data_path)
+
+
+def prepare_data():
+    # Read .mat file containing training, testing, and validation sets.
+    global data_path, image_path, label_path, setid_path, label_names
+    setid = loadmat(setid_path)
+
+    idx_train = setid['trnid'][0] - 1
+    idx_test = setid['tstid'][0] - 1
+    idx_valid = setid['valid'][0] - 1
+
+    # Read .mat file containing image labels.
+    image_labels = loadmat(label_path)['labels'][0]
+
+    # Subtract one to get 0-based labels
+    image_labels -= 1
+
+    # convert label from number to flower names
+    image_labels = [label_names[i] for i in image_labels]
+    # extracted images are stored in folder 'jpg'
+    files = sorted(glob.glob(os.path.join(data_path, 'jpg', '*.jpg')))
+    file_label_pairs = np.array([i for i in zip(files, image_labels)])
+
+    # move files from extracted folder to train, test, valid
+    move_files('train', file_label_pairs[idx_test, :])
+    move_files('test', file_label_pairs[idx_train, :])
+    move_files('valid', file_label_pairs[idx_valid, :])
+
+
+def move_files(dir_name, file_label_pairs):
+    data_segment_dir = os.path.join(data_path, dir_name)
+    if not os.path.exists(data_segment_dir):
+        os.mkdir(data_segment_dir)
+
+    for label in label_names:
+        class_dir = os.path.join(data_segment_dir, label)
+        if not os.path.exists(class_dir):
+            os.mkdir(class_dir)
+
+    for file, label in file_label_pairs:
+        src = str(file)
+        dst = os.path.join(data_path, dir_name, label, src.split(os.sep)[-1])
+        copyfile(src, dst)
+
+
+def generate_synset():
+    with open('synset.txt', 'w') as f:
+        # Gluon Dataset API will load synset in sorted order
+        for label in sorted(label_names):
+            f.write(label.strip() + '\n')
+        f.close()
+
+
+def get_data(dir_name):
+    global data_path
+    data_path = dir_name
+    download_data()
+    prepare_data()
+    generate_synset()
diff --git a/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md 
b/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md
new file mode 100644
index 0000000..e55e7c9
--- /dev/null
+++ b/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md
@@ -0,0 +1,267 @@
+# MXNet C++ API inference tutorial
+
+## Overview
+MXNet provides various useful tools and interfaces for deploying your model 
for inference. For example, you can use [MXNet Model 
Server](https://github.com/awslabs/mxnet-model-server) to start a service and 
host your trained model easily.
+Besides that, you can also use MXNet's different language APIs to integrate 
your model with your existing service. We provide 
[Python](https://mxnet.incubator.apache.org/api/python/module/module.html),    
[Java](https://mxnet.incubator.apache.org/api/java/index.html), 
[Scala](https://mxnet.incubator.apache.org/api/scala/index.html), and 
[C++](https://mxnet.incubator.apache.org/api/c++/index.html) APIs.
+
+This tutorial is a continuation of the [Gluon end to end 
tutorial](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/gluon/gluon_from_experiment_to_deployment.md),
 we will focus on the MXNet C++ API. We have slightly modified the code in [C++ 
Inference 
Example](https://github.com/apache/incubator-mxnet/tree/master/cpp-package/example/inference)
 for our use case.
+
+## Prerequisites
+
+To complete this tutorial, you need:
+- Complete the training part of [Gluon end to end 
tutorial](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/gluon/end_to_end_tutorial_training.md)
+- Learn the basics about [MXNet C++ 
API](https://github.com/apache/incubator-mxnet/tree/master/cpp-package)
+
+
+## Setup the MXNet C++ API
+To use the C++ API in MXNet, you need to build MXNet from source with C++ 
package. Please follow the [built from source 
guide](https://mxnet.incubator.apache.org/install/ubuntu_setup.html), and [C++ 
Package 
documentation](https://github.com/apache/incubator-mxnet/tree/master/cpp-package)
+to enable the C++ API.
+The summary of those two documents is that you need to build MXNet from source 
with `USE_CPP_PACKAGE` flag set to 1. For example: `make -j USE_CPP_PACKAGE=1`.
+
+## Load the model and run inference
+
+After you complete [the previous 
tutorial](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/gluon/end_to_end_tutorial_training.md),
 you will get the following output files:
+1. Model Architecture stored in `flower-recognition-symbol.json`
+2. Model parameter values stored in `flower-recognition-0040.params` (`0040` 
is for 40 epochs we ran)
+3. Label names stored in `synset.txt`
+4. Mean and standard deviation values stored in `mean_std_224` for image 
normalization.
+
+
+Now we need to write the C++ code to load them and run prediction on a test 
image.
+The full code is available in the [C++ Inference 
Example](https://github.com/apache/incubator-mxnet/tree/master/cpp-package/example/inference),
 we will walk you through it and point out the necessary changes to make for 
our use case.
+
+
+
+### Write a predictor using the MXNet C++ API
+
+In general, the C++ inference code should follow the 4 steps below. We can do 
that using a Predictor class.
+1. Load the pre-trained model
+2. Load the parameters of pre-trained model
+3. Load the image to be classified in to NDArray and apply image 
transformation we did in training
+4. Run the forward pass and predict the class of the input image
+
+```cpp
+class Predictor {
+ public:
+    Predictor() {}
+    Predictor(const std::string& model_json_file,
+              const std::string& model_params_file,
+              const Shape& input_shape,
+              bool gpu_context_type = false,
+              const std::string& synset_file = "",
+              const std::string& mean_image_file = "");
+    void PredictImage(const std::string& image_file);
+    ~Predictor();
+
+ private:
+    void LoadModel(const std::string& model_json_file);
+    void LoadParameters(const std::string& model_parameters_file);
+    void LoadSynset(const std::string& synset_file);
+    NDArray LoadInputImage(const std::string& image_file);
+    void LoadMeanImageData();
+    void LoadDefaultMeanImageData();
+    void NormalizeInput(const std::string& mean_image_file);
+    inline bool FileExists(const std::string& name) {
+        struct stat buffer;
+        return (stat(name.c_str(), &buffer) == 0);
+    }
+    NDArray mean_img;
+    std::map<std::string, NDArray> args_map;
+    std::map<std::string, NDArray> aux_map;
+    std::vector<std::string> output_labels;
+    Symbol net;
+    Executor *executor;
+    Shape input_shape;
+    NDArray mean_image_data;
+    NDArray std_dev_image_data;
+    Context global_ctx = Context::cpu();
+    std::string mean_image_file;
+};
+```
+
+### Load the model, synset file, and normalization values
+
+In the Predictor constructor, you need to provide paths to saved json and 
param files. After that, add the following methods `LoadModel` and 
`LoadParameters` to load the network and its parameters. This part is the same 
as [the 
example](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/inception_inference.cpp).
+
+Next, we need to load synset file, and normalization values. We have made the 
following change since our synset file contains flower names and we used both 
mean and standard deviation for image normalization.
+
+```c++
+/*
+ * The following function loads the synset file.
+ * This information will be used later to report the label of input image.
+ */
+void Predictor::LoadSynset(const std::string& synset_file) {
+  if (!FileExists(synset_file)) {
+    LG << "Synset file " << synset_file << " does not exist";
+    throw std::runtime_error("Synset file does not exist");
+  }
+  LG << "Loading the synset file.";
+  std::ifstream fi(synset_file.c_str());
+  if (!fi.is_open()) {
+    std::cerr << "Error opening synset file " << synset_file << std::endl;
+    throw std::runtime_error("Error in opening the synset file.");
+  }
+  std::string lemma;
+  while (getline(fi, lemma)) {
+    output_labels.push_back(lemma);
+  }
+  fi.close();
+}
+
+/*
+ * The following function loads the mean and standard deviation values.
+ * This data will be used for normalizing the image before running the forward
+ * pass.
+ * The output data has the same shape as that of the input image data.
+ */
+void Predictor::LoadMeanImageData() {
+  LG << "Load the mean image data that will be used to normalize "
+     << "the image before running forward pass.";
+  mean_image_data = NDArray(input_shape, global_ctx, false);
+  mean_image_data.SyncCopyFromCPU(
+        NDArray::LoadToMap(mean_image_file)["mean_img"].GetData(),
+        input_shape.Size());
+  NDArray::WaitAll();
+   std_dev_image_data = NDArray(input_shape, global_ctx, false);
+   std_dev_image_data.SyncCopyFromCPU(
+       NDArray::LoadToMap(mean_image_file)["std_img"].GetData(),
+       input_shape.Size());
+    NDArray::WaitAll();
+}
+```
+
+
+
+### Load input image
+
+Now let's add a method to load the input image we want to predict and converts 
it to NDArray for prediction.
+```cpp
+NDArray Predictor::LoadInputImage(const std::string& image_file) {
+  if (!FileExists(image_file)) {
+    LG << "Image file " << image_file << " does not exist";
+    throw std::runtime_error("Image file does not exist");
+  }
+  LG << "Loading the image " << image_file << std::endl;
+  std::vector<float> array;
+  cv::Mat mat = cv::imread(image_file);
+  /*resize pictures to (224, 224) according to the pretrained model*/
+  int height = input_shape[2];
+  int width = input_shape[3];
+  int channels = input_shape[1];
+  cv::resize(mat, mat, cv::Size(height, width));
+  for (int c = 0; c < channels; ++c) {
+    for (int i = 0; i < height; ++i) {
+      for (int j = 0; j < width; ++j) {
+        array.push_back(static_cast<float>(mat.data[(i * height + j) * 3 + 
c]));
+      }
+    }
+  }
+  NDArray image_data = NDArray(input_shape, global_ctx, false);
+  image_data.SyncCopyFromCPU(array.data(), input_shape.Size());
+  NDArray::WaitAll();
+  return image_data;
+}
+```
+
+### Predict the image
+
+Finally, let's run the inference. It's basically using MXNet executor to do a 
forward pass. To run predictions on multiple images, you can load the images in 
a list of NDArrays and run prediction in batches. Note that the Predictor class 
may not be thread safe. Calling it in multi-threaded environments was not 
tested. To utilize multi-threaded prediction, you need to use the C predict 
API. Please follow the [C predict 
example](https://github.com/apache/incubator-mxnet/tree/master/example [...]
+
+An additional step is to normalize the image NDArrays values to `(0, 1)` and 
apply mean and standard deviation we just loaded. 
+
+```cpp
+/*
+ * The following function runs the forward pass on the model.
+ * The executor is created in the constructor.
+ *
+ */
+void Predictor::PredictImage(const std::string& image_file) {
+  // Load the input image
+  NDArray image_data = LoadInputImage(image_file);
+
+  // Normalize the image
+  image_data.Slice(0, 1) /= 255.0;
+  image_data -= mean_image_data;
+  image_data /= std_dev_image_data;
+
+  LG << "Running the forward pass on model to predict the image";
+  /*
+   * The executor->arg_arrays represent the arguments to the model.
+   *
+   * Copying the image_data that contains the NDArray of input image
+   * to the arg map of the executor. The input is stored with the key "data" 
in the map.
+   *
+   */
+  image_data.CopyTo(&(executor->arg_dict()["data"]));
+  NDArray::WaitAll();
+
+  // Run the forward pass.
+  executor->Forward(false);
+
+  // The output is available in executor->outputs.
+  auto array = executor->outputs[0].Copy(global_ctx);
+  NDArray::WaitAll();
+
+  /*
+   * Find out the maximum accuracy and the index associated with that accuracy.
+   * This is done by using the argmax operator on NDArray.
+   */
+  auto predicted = array.ArgmaxChannel();
+  NDArray::WaitAll();
+
+  int best_idx = predicted.At(0, 0);
+  float best_accuracy = array.At(0, best_idx);
+
+  if (output_labels.empty()) {
+    LG << "The model predicts the highest accuracy of " << best_accuracy << " 
at index "
+       << best_idx;
+  } else {
+    LG << "The model predicts the input image to be a [" << 
output_labels[best_idx]
+       << " ] with Accuracy = " << best_accuracy << std::endl;
+  }
+}
+```
+
+### Compile and run the inference code
+
+You can find the [full code for the inference 
example](https://github.com/apache/incubator-mxnet/tree/master/cpp-package/example/inference)
 in the `cpp-package` folder of the project
+, and to compile it use this 
[Makefile](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/Makefile).
+
+Make a copy of the example code, rename it to `flower_inference` and apply the 
changes we mentioned above. Now you will be able to compile and run inference. 
Run `make all`. Once this is complete, run inference with the following 
parameters. Remember to set your `LD_LIBRARY_PATH` to point to MXNet library if 
you have not done so.
+
+```bash
+make all
+export LD_LIBRARY_PATH=$LD_LIBRARY_PATH=:path/to/incubator-mxnet/lib
+./flower_inference --symbol flower-recognition-symbol.json --params 
flower-recognition-0040.params --synset synset.txt --mean mean_std_224.nd 
--image ./data/test/lotus/image_01832.jpg
+```
+
+Then it will predict your image:
+
+```bash
+[17:38:51] resnet.cpp:150: Loading the model from 
flower-recognition-symbol.json
+
+[17:38:51] resnet.cpp:163: Loading the model parameters from 
flower-recognition-0040.params
+
+[17:38:52] resnet.cpp:190: Loading the synset file.
+[17:38:52] resnet.cpp:211: Load the mean image data that will be used to 
normalize the image before running forward pass.
+[17:38:52] resnet.cpp:263: Loading the image ./data/test/lotus/image_01832.jpg
+
+[17:38:52] resnet.cpp:299: Running the forward pass on model to predict the 
image
+[17:38:52] resnet.cpp:331: The model predicts the input image to be a [lotus ] 
with Accuracy = 8.63046
+```
+
+
+
+## What's next
+
+Now you can explore more ways to run inference and deploy your models:
+1. [Java Inference 
examples](https://github.com/apache/incubator-mxnet/tree/master/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer)
+2. [Scala Inference 
examples](https://mxnet.incubator.apache.org/tutorials/scala/)
+3. [ONNX model inference 
examples](https://mxnet.incubator.apache.org/tutorials/onnx/inference_on_onnx_model.html)
+4. [MXNet Model Server 
Examples](https://github.com/awslabs/mxnet-model-server/tree/master/examples)
+
+## References
+
+1. [Gluon end to end 
tutorial](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/gluon/end_to_end_tutorial_training.md)
+2. [Gluon C++ inference 
example](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/)
+3. [Gluon C++ 
package](https://github.com/apache/incubator-mxnet/tree/master/cpp-package)
\ No newline at end of file
diff --git a/docs/tutorials/gluon/gluon_from_experiment_to_deployment.md 
b/docs/tutorials/gluon/gluon_from_experiment_to_deployment.md
new file mode 100644
index 0000000..87e6f24
--- /dev/null
+++ b/docs/tutorials/gluon/gluon_from_experiment_to_deployment.md
@@ -0,0 +1,334 @@
+
+# Gluon: from experiment to deployment, an end to end tutorial
+
+## Overview
+MXNet Gluon API comes with a lot of great features, and it can provide you 
everything you need: from experimentation to deploying the model. In this 
tutorial, we will walk you through a common use case on how to build a model 
using gluon, train it on your data, and deploy it for inference.
+This tutorial covers training and inference in Python, please continue to [C++ 
inference 
part](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md)
 after you finish.
+
+Let's say you need to build a service that provides flower species 
recognition. A common problem is that you don't have enough data to train a 
good model. In such cases, a technique called Transfer Learning can be used to 
make a more robust model.
+In Transfer Learning we make use of a pre-trained model that solves a related 
task, and was trained on a very large standard dataset, such as ImageNet. 
ImageNet is from a different domain, but we can utilize the knowledge in this 
pre-trained model to perform the new task at hand.
+
+Gluon provides State of the Art models for many of the standard tasks such as 
Classification, Object Detection, Segmentation, etc. In this tutorial we will 
use the pre-trained model [ResNet50 V2](https://arxiv.org/abs/1603.05027) 
trained on ImageNet dataset. This model achieves 77.11% top-1 accuracy on 
ImageNet. We seek to transfer as much knowledge as possible for our task of 
recognizing different species of flowers.
+
+
+
+
+## Prerequisites
+
+To complete this tutorial, you need:
+
+- [Build MXNet from 
source](https://mxnet.incubator.apache.org/install/ubuntu_setup.html#build-mxnet-from-source)
 with Python(Gluon) and C++ Packages
+- Learn the basics about Gluon with [A 60-minute Gluon Crash 
Course](https://gluon-crash-course.mxnet.io/)
+
+
+## The Data
+
+We will use the [Oxford 102 Category Flower 
Dataset](http://www.robots.ox.ac.uk/~vgg/data/flowers/102/) as an example to 
show you the steps.
+We have prepared a utility file to help you download and organize your data 
into train, test, and validation sets. Run the following Python code to 
download and prepare the data:
+
+
+```python
+import mxnet as mx
+data_util_file = "oxford_102_flower_dataset.py"
+base_url = 
"https://raw.githubusercontent.com/roywei/incubator-mxnet/gluon_tutorial/docs/tutorial_utils/data/{}?raw=true";
+mx.test_utils.download(base_url.format(data_util_file), fname=data_util_file)
+import oxford_102_flower_dataset
+
+# download and move data to train, test, valid folders
+path = './data'
+oxford_102_flower_dataset.get_data(path)
+```
+
+Now your data will be organized into the following format, all the images 
belong to the same category will be put together in the following pattern:
+```bash
+data
+|--train
+|   |-- class0
+|   |   |-- image_06736.jpg
+|   |   |-- image_06741.jpg
+...
+|   |-- class1
+|   |   |-- image_06755.jpg
+|   |   |-- image_06899.jpg
+...
+|-- test
+|   |-- class0
+|   |   |-- image_00731.jpg
+|   |   |-- image_0002.jpg
+...
+|   |-- class1
+|   |   |-- image_00036.jpg
+|   |   |-- image_05011.jpg
+
+```
+
+## Training using Gluon
+
+### Define Hyper-parameters
+
+Now let's first import necessary packages:
+
+
+```python
+import math
+import os
+import time
+from multiprocessing import cpu_count
+
+from mxnet import autograd
+from mxnet import gluon, init
+from mxnet.gluon import nn
+from mxnet.gluon.data.vision import transforms
+from mxnet.gluon.model_zoo.vision import resnet50_v2
+```
+
+Next, we define the hyper-parameters that we will use for fine-tuning. We will 
use the [MXNet learning rate 
scheduler](https://mxnet.incubator.apache.org/tutorials/gluon/learning_rate_schedules.html)
 to adjust learning rates during training.
+
+
+```python
+classes = 102
+epochs = 40
+lr = 0.001
+per_device_batch_size = 32
+momentum = 0.9
+wd = 0.0001
+
+lr_factor = 0.75
+# learning rate change at following epochs
+lr_epochs = [10, 20, 30]
+
+num_gpus = mx.context.num_gpus()
+num_workers = cpu_count()
+ctx = [mx.gpu(i) for i in range(num_gpus)] if num_gpus > 0 else [mx.cpu()]
+batch_size = per_device_batch_size * max(num_gpus, 1)
+```
+
+Now we will apply data augmentations on training images. This makes minor 
alterations on the training images, and our model will consider them as 
distinct images. This can be very useful for fine-tuning on a relatively small 
dataset, and it will help improve the model. We can use the Gluon [DataSet 
API](https://mxnet.incubator.apache.org/tutorials/gluon/datasets.html), 
[DataLoader 
API](https://mxnet.incubator.apache.org/tutorials/gluon/datasets.html), and 
[Transform API](https://mxnet.in [...]
+1. Randomly crop the image and resize it to 224x224
+2. Randomly flip the image horizontally
+3. Randomly jitter color and add noise
+4. Transpose the data from `[height, width, num_channels]` to `[num_channels, 
height, width]`, and map values from [0, 255] to [0, 1]
+5. Normalize with the mean and standard deviation from the ImageNet dataset.
+
+For validation and inference, we only need to apply step 1, 4, and 5. We also 
need to save the mean and standard deviation values for [inference using 
C++](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md).
+
+```python
+jitter_param = 0.4
+lighting_param = 0.1
+
+# mean and std for normalizing image value in range (0,1)
+mean = [0.485, 0.456, 0.406]
+std = [0.229, 0.224, 0.225]
+
+training_transformer = transforms.Compose([
+    transforms.RandomResizedCrop(224),
+    transforms.RandomFlipLeftRight(),
+    transforms.RandomColorJitter(brightness=jitter_param, 
contrast=jitter_param,
+                                 saturation=jitter_param),
+    transforms.RandomLighting(lighting_param),
+    transforms.ToTensor(),
+    transforms.Normalize(mean, std)
+])
+
+validation_transformer = transforms.Compose([
+    transforms.Resize(256),
+    transforms.CenterCrop(224),
+    transforms.ToTensor(),
+    transforms.Normalize(mean, std)
+])
+
+# save mean and std NDArray values for inference
+mean_img = mx.nd.stack(*[mx.nd.full((224, 224), m) for m in mean])
+std_img = mx.nd.stack(*[mx.nd.full((224, 224), s) for s in std])
+mx.nd.save('mean_std_224.nd', {"mean_img": mean_img, "std_img": std_img})
+
+train_path = os.path.join(path, 'train')
+val_path = os.path.join(path, 'valid')
+test_path = os.path.join(path, 'test')
+
+# loading the data and apply pre-processing(transforms) on images
+train_data = gluon.data.DataLoader(
+    
gluon.data.vision.ImageFolderDataset(train_path).transform_first(training_transformer),
+    batch_size=batch_size, shuffle=True, num_workers=num_workers)
+
+val_data = gluon.data.DataLoader(
+    
gluon.data.vision.ImageFolderDataset(val_path).transform_first(validation_transformer),
+    batch_size=batch_size, shuffle=False, num_workers=num_workers)
+
+test_data = gluon.data.DataLoader(
+    
gluon.data.vision.ImageFolderDataset(test_path).transform_first(validation_transformer),
+    batch_size=batch_size, shuffle=False, num_workers=num_workers)
+```
+
+### Loading pre-trained model
+
+
+We will use pre-trained ResNet50_v2 model which was pre-trained on the 
[ImageNet Dataset](http://www.image-net.org/) with 1000 classes. To match the 
classes in the Flower dataset, we must redefine the last softmax (output) layer 
to be 102, then initialize the parameters.
+
+Before we go to training, one unique Gluon feature you should be aware of is 
hybridization. It allows you to convert your imperative code to a static 
symbolic graph, which is much more efficient to execute. There are two main 
benefits of hybridizing your model: better performance and easier serialization 
for deployment. The best part is that it's as simple as just calling 
`net.hybridize()`. To know more about Gluon hybridization, please follow the 
[hybridization tutorial](https://mxnet.i [...]
+
+
+
+```python
+# load pre-trained resnet50_v2 from model zoo
+finetune_net = resnet50_v2(pretrained=True, ctx=ctx)
+
+# change last softmax layer since number of classes are different
+with finetune_net.name_scope():
+    finetune_net.output = nn.Dense(classes)
+finetune_net.output.initialize(init.Xavier(), ctx=ctx)
+# hybridize for better performance
+finetune_net.hybridize()
+
+num_batch = len(train_data)
+
+# setup learning rate scheduler
+iterations_per_epoch = math.ceil(num_batch)
+# learning rate change at following steps
+lr_steps = [epoch * iterations_per_epoch for epoch in lr_epochs]
+schedule = mx.lr_scheduler.MultiFactorScheduler(step=lr_steps, 
factor=lr_factor, base_lr=lr)
+
+# setup optimizer with learning rate scheduler, metric, and loss function
+sgd_optimizer = mx.optimizer.SGD(learning_rate=lr, lr_scheduler=schedule, 
momentum=momentum, wd=wd)
+metric = mx.metric.Accuracy()
+softmax_cross_entropy = gluon.loss.SoftmaxCrossEntropyLoss()
+```
+
+### Fine-tuning model on your custom dataset
+
+Now let's define the test metrics and start fine-tuning.
+
+
+
+```python
+def test(net, val_data, ctx):
+    metric = mx.metric.Accuracy()
+    for i, (data, label) in enumerate(val_data):
+        data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)
+        label = gluon.utils.split_and_load(label, ctx_list=ctx, 
even_split=False)
+        outputs = [net(x) for x in data]
+        metric.update(label, outputs)
+    return metric.get()
+
+trainer = gluon.Trainer(finetune_net.collect_params(), optimizer=sgd_optimizer)
+
+# start with epoch 1 for easier learning rate calculation
+for epoch in range(1, epochs + 1):
+
+    tic = time.time()
+    train_loss = 0
+    metric.reset()
+
+    for i, (data, label) in enumerate(train_data):
+        # get the images and labels
+        data = gluon.utils.split_and_load(data, ctx_list=ctx, even_split=False)
+        label = gluon.utils.split_and_load(label, ctx_list=ctx, 
even_split=False)
+        with autograd.record():
+            outputs = [finetune_net(x) for x in data]
+            loss = [softmax_cross_entropy(yhat, y) for yhat, y in zip(outputs, 
label)]
+        for l in loss:
+            l.backward()
+
+        trainer.step(batch_size)
+        train_loss += sum([l.mean().asscalar() for l in loss]) / len(loss)
+        metric.update(label, outputs)
+
+    _, train_acc = metric.get()
+    train_loss /= num_batch
+    _, val_acc = test(finetune_net, val_data, ctx)
+
+    print('[Epoch %d] Train-acc: %.3f, loss: %.3f | Val-acc: %.3f | 
learning-rate: %.3E | time: %.1f' %
+          (epoch, train_acc, train_loss, val_acc, trainer.learning_rate, 
time.time() - tic))
+
+_, test_acc = test(finetune_net, test_data, ctx)
+print('[Finished] Test-acc: %.3f' % (test_acc))
+```
+
+Following is the training result:
+```bash
+[Epoch 40] Train-acc: 0.945, loss: 0.354 | Val-acc: 0.955 | learning-rate: 
4.219E-04 | time: 17.8
+[Finished] Test-acc: 0.952
+```
+In the previous example output, we trained the model using an [AWS p3.8xlarge 
instance](https://aws.amazon.com/ec2/instance-types/p3/) with 4 Tesla V100 
GPUs. We were able to reach a test accuracy of 95.5% with 40 epochs in around 
12 minutes. This was really fast because our model was pre-trained on a much 
larger dataset, ImageNet, with around 1.3 million images. It worked really well 
to capture features on our small dataset.
+
+
+### Save the fine-tuned model
+
+
+We now have a trained our custom model. This can be serialized into model 
files using the export function. The export function will export the model 
architecture into a `.json` file and model parameters into a `.params` file.
+
+
+
+```python
+finetune_net.export("flower-recognition", epoch=epochs)
+
+```
+
+`export` creates `flower-recognition-symbol.json` and 
`flower-recognition-0040.params` (`0040` is for 40 epochs we ran) in the 
current directory. These files can be used for model deployment in the next 
section.
+
+## Load the model and run inference using the MXNet Module API
+
+MXNet provides various useful tools and interfaces for deploying your model 
for inference. For example, you can use [MXNet Model 
Server](https://github.com/awslabs/mxnet-model-server) to start a service and 
host your trained model easily.
+Besides that, you can also use MXNet's different language APIs to integrate 
your model with your existing service. We provide 
[Python](https://mxnet.incubator.apache.org/api/python/module/module.html),    
[Java](https://mxnet.incubator.apache.org/api/java/index.html), 
[Scala](https://mxnet.incubator.apache.org/api/scala/index.html), and 
[C++](https://mxnet.incubator.apache.org/api/c++/index.html) APIs.
+
+Here we will briefly introduce how to run inference using Module API in 
Python. There is more detailed explanation available in the [Predict Image 
Tutorial](https://mxnet.incubator.apache.org/tutorials/python/predict_image.html).
+In general, prediction consists of the following steps:
+1. Load the model architecture (symbol file) and trained parameter values 
(params file)
+2. Load the synset file for label names
+3. Load the image and apply the same transformation we did on validation 
dataset during training
+4. Run a forward pass on the image data
+5. Convert output probabilities to predicted label name
+
+```python
+import numpy as np
+from collections import namedtuple
+
+ctx = mx.cpu()
+# load model symbol and params
+sym, arg_params, aux_params = mx.model.load_checkpoint('flower-recognition', 
epochs)
+mod = mx.mod.Module(symbol=sym, context=ctx, label_names=None)
+mod.bind(for_training=False, data_shapes=[('data', (1, 3, 224, 224))], 
label_shapes=mod._label_shapes)
+mod.set_params(arg_params, aux_params, allow_missing=True)
+
+# load synset for label names
+with open('synset.txt', 'r') as f:
+    labels = [l.rstrip() for l in f]
+
+# load an image for prediction
+img = mx.image.imread('./data/test/lotus/image_01832.jpg')
+# apply transform we did during training
+img = validation_transformer(img)
+# batchify
+img = img.expand_dims(axis=0)
+Batch = namedtuple('Batch', ['data'])
+mod.forward(Batch([img]))
+prob = mod.get_outputs()[0].asnumpy()
+prob = np.squeeze(prob)
+idx = np.argmax(prob)
+print('probability=%f, class=%s' % (prob[idx], labels[idx]))
+```
+
+Following is the output, you can see the image has been classified as lotus 
correctly.
+```bash
+probability=9.798435, class=lotus
+```
+
+## What's next
+
+You can continue to the [next 
tutorial](https://github.com/apache/incubator-mxnet/tree/master/docs/tutorials/c++/mxnet_cpp_inference_tutorial.md)
 on how to load the model we just trained and run inference using MXNet C++ API.
+
+You can also find more ways to run inference and deploy your models here:
+1. [Java Inference 
examples](https://github.com/apache/incubator-mxnet/tree/master/scala-package/examples/src/main/java/org/apache/mxnetexamples/javaapi/infer)
+2. [Scala Inference 
examples](https://mxnet.incubator.apache.org/tutorials/scala/)
+3. [ONNX model inference 
examples](https://mxnet.incubator.apache.org/tutorials/onnx/inference_on_onnx_model.html)
+4. [MXNet Model Server 
Examples](https://github.com/awslabs/mxnet-model-server/tree/master/examples)
+
+## References
+
+1. [Transfer Learning for Oxford102 Flower 
Dataset](https://github.com/Arsey/keras-transfer-learning-for-oxford102)
+2. [Gluon book on 
fine-tuning](https://gluon.mxnet.io/chapter08_computer-vision/fine-tuning.html)
+3. [Gluon CV transfer learning 
tutorial](https://gluon-cv.mxnet.io/build/examples_classification/transfer_learning_minc.html)
+4. [Gluon crash course](https://gluon-crash-course.mxnet.io/)
+5. [Gluon CPP inference 
example](https://github.com/apache/incubator-mxnet/blob/master/cpp-package/example/inference/)
+
+<!-- INSERT SOURCE DOWNLOAD BUTTONS -->
\ No newline at end of file
diff --git a/tests/tutorials/test_sanity_tutorials.py 
b/tests/tutorials/test_sanity_tutorials.py
index 644a611..429527d 100644
--- a/tests/tutorials/test_sanity_tutorials.py
+++ b/tests/tutorials/test_sanity_tutorials.py
@@ -28,6 +28,7 @@ whitelist = ['basic/index.md',
              'c++/basics.md',
              'c++/index.md',
              'c++/subgraphAPI.md',
+             'c++/mxnet_cpp_inference_tutorial.md',
              'control_flow/index.md',
              'embedded/index.md',
              'embedded/wine_detector.md',
diff --git a/tests/tutorials/test_tutorials.py 
b/tests/tutorials/test_tutorials.py
index 8d8ef39..37ba991 100644
--- a/tests/tutorials/test_tutorials.py
+++ b/tests/tutorials/test_tutorials.py
@@ -151,6 +151,9 @@ def test_python_logistic_regression() :
 def test_python_numpy_gotchas() :
     assert _test_tutorial_nb('gluon/gotchas_numpy_in_mxnet')
 
+def test_gluon_end_to_end():
+    assert _test_tutorial_nb('gluon/gluon_from_experiment_to_deployment')
+
 def test_python_mnist():
     assert _test_tutorial_nb('python/mnist')
 

Reply via email to