rahul003 commented on a change in pull request #10283: [MXNET-242][Tutorial] 
Fine-tuning ONNX model in Gluon
URL: https://github.com/apache/incubator-mxnet/pull/10283#discussion_r178376584
 
 

 ##########
 File path: docs/tutorials/onnx/fine_tuning_gluon.md
 ##########
 @@ -0,0 +1,441 @@
+
+# Fine-tuning an ONNX model with MXNet/Gluon
+
+Fine-tuning is a common practice in Transfer Learning. One can take advantage 
of the pre-trained weights of a network, and use them as an initializer for 
their own task. Indeed, quite often it is difficult to gather a dataset large 
enough that it would allow training from scratch deep and complex networks such 
as ResNet152 or VGG16. For example in an image classification task, using a 
network trained on a large dataset like ImageNet gives a good base from which 
the weights can be slightly updated, or fine-tuned, to predict accurately the 
new classes. We will see in this tutorial that this can be achieved even with a 
relatively small number of new training examples.
+
+
+[Open Neural Network Exchange (ONNX)](https://github.com/onnx/onnx) provides 
an open source format for AI models. It defines an extensible computation graph 
model, as well as definitions of built-in operators and standard data types.
+
+In this tutorial we will:
+    
+- learn how to pick a specific layer from a pre-trained .onnx model file
+- learn how to load this model in Gluon and fine-tune it on a different dataset
+
+## Pre-requisite
+
+To run the tutorial you will need to have installed the following python 
modules:
+- [MXNet](http://mxnet.incubator.apache.org/install/index.html)
+- [onnx](https://github.com/onnx/onnx)
+- matplotlib
+- wget
+
+We recommend that you have done this tutorial:
+- [Inference using an ONNX model on MXNet 
Gluon](https://mxnet.incubator.apache.org/tutorials/onnx/inference_on_onnx_model.html)
+
+
+```python
+import numpy as np
+import mxnet as mx
+from mxnet import gluon, nd, autograd
+from mxnet.gluon.data.vision.datasets import ImageFolderDataset
+from mxnet.gluon.data import DataLoader
+import mxnet.contrib.onnx as onnx_mxnet
+%matplotlib inline
+import matplotlib.pyplot as plt
+import tarfile, os
+import wget
+import json
+import multiprocessing
+```
+
+
+### Downloading supporting files
+These are images and a vizualisation script
+
+
+```python
+image_folder = "images"
+utils_file = "utils.py" # contain utils function to plot nice visualization
+images = ['wrench', 'dolphin', 'lotus']
+base_url = 
"https://raw.githubusercontent.com/dmlc/web-data/master/mxnet/doc/tutorials/onnx/{}?raw=true";
+
+if not os.path.isdir(image_folder):
+    os.makedirs(image_folder)
+    for image in images:
+        wget.download(base_url.format("{}/{}.jpg".format(image_folder, 
image)), image_folder)
+if not os.path.isfile(utils_file):
+    wget.download(base_url.format(utils_file))
+```
+
+
+```python
+from utils import *
+```
+
+## Downloading a model from the ONNX model zoo
+
+We download a pre-trained model, in our case the 
[vgg16](https://arxiv.org/abs/1409.1556) model, trained on 
[ImageNet](http://www.image-net.org/) from the [ONNX model 
zoo](https://github.com/onnx/models). The model comes packaged in an archive 
`tar.gz` file containing an `model.onnx` model file and some sample 
input/output data.
+
+
+```python
+base_url = "https://s3.amazonaws.com/download.onnx/models/"; 
+current_model = "vgg16"
+model_folder = "model"
+archive_file = "{}.tar.gz".format(current_model)
+archive_path = os.path.join(model_folder, archive_file)
+url = "{}{}".format(base_url, archive_file)
+onnx_path = os.path.join(model_folder, current_model, 'model.onnx')
+
+# Create the model folder and download the zipped model
+if not os.path.isdir(model_folder):
+    os.makedirs(model_folder)
+if not os.path.isfile(archive_path):
+    print('Downloading the {} model to {}...'.format(current_model, 
archive_path))
+    wget.download(url, model_folder)
+    print('{} downloaded'.format(current_model))
+
+# Extract the model
+if not os.path.isdir(os.path.join(model_folder, current_model)):
+    print('Extracting {} in {}...'.format(archive_path, model_folder))
+    tar = tarfile.open(archive_path, "r:gz")
+    tar.extractall(model_folder)
+    tar.close()
+    print('Model extracted.')
+```
+
+## Downloading the Caltech101 dataset
+
+The [Caltech101 
dataset](http://www.vision.caltech.edu/Image_Datasets/Caltech101/) is made of 
pictures of objects belonging to 101 categories. About 40 to 800 images per 
category. Most categories have about 50 images.
+
+*L. Fei-Fei, R. Fergus and P. Perona. Learning generative visual models from 
few training examples: an incremental Bayesian approach tested on 101 object 
categories. IEEE. CVPR 2004, Workshop on Generative-Model
+Based Vision. 2004*
+
+
+```python
+data_folder = "data"
+dataset_name = "101_ObjectCategories"
+archive_file = "{}.tar.gz".format(dataset_name)
+archive_path = os.path.join(data_folder, archive_file)
+data_url = "https://s3.us-east-2.amazonaws.com/mxnet-public/";
+if not os.path.isdir(data_folder):
+    os.makedirs(data_folder)
+if not os.path.isfile(archive_path):
+    print('Downloading {} in {}...'.format(archive_file, data_folder))
+    wget.download("{}{}".format(data_url, archive_file), data_folder)
+    print('Extracting {} in {}...'.format(archive_file, data_folder))
+    tar = tarfile.open(archive_path, "r:gz")
+    tar.extractall(data_folder)
+    tar.close()
+    print('Data extracted.')
+```
 
 Review comment:
   Also, It might be better to rename the folder to say Caltech101 and not just 
101_ObjectCategories in the s3 bucket?

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to