reminisce commented on a change in pull request #11251: [WIP] Graph partitioner 
and subgraph op
URL: https://github.com/apache/incubator-mxnet/pull/11251#discussion_r195511375
 
 

 ##########
 File path: example/subgraph_op/imagenet_inference.py
 ##########
 @@ -0,0 +1,166 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import argparse
+import logging
+import os
+import time
+import mxnet as mx
+from common import modelzoo
+from mxnet import nd
+from mxnet.contrib.quantization import *
+from mxnet.base import _LIB
+
+
+def download_dataset(dataset_url, dataset_dir, logger=None):
+    if logger is not None:
+        logger.info('Downloading dataset for inference from %s to %s' % 
(dataset_url, dataset_dir))
+    mx.test_utils.download(dataset_url, dataset_dir)
+
+
+def download_model(model_name, logger=None):
+    dir_path = os.path.dirname(os.path.realpath(__file__))
+    model_path = os.path.join(dir_path, 'model')
+    if logger is not None:
+        logger.info('Downloading model %s... into path %s' % (model_name, 
model_path))
+    return modelzoo.download_model(args.model, os.path.join(dir_path, 'model'))
+
+
+def advance_data_iter(data_iter, n):
+    assert n >= 0
+    if n == 0:
+        return data_iter
+    has_next_batch = True
+    while has_next_batch:
+        try:
+            data_iter.next()
+            n -= 1
+            if n == 0:
+                return data_iter
+        except StopIteration:
+            has_next_batch = False
+
+
+def score(sym, arg_params, aux_params, data, devs, label_name, 
max_num_examples, logger=None):
+    metrics = [mx.metric.create('acc'),
+               mx.metric.create('top_k_accuracy', top_k=5)]
+    if not isinstance(metrics, list):
+        metrics = [metrics, ]
+    mod = mx.mod.Module(symbol=sym, context=devs, label_names=[label_name, ])
+    mod.bind(for_training=False,
+             data_shapes=data.provide_data,
+             label_shapes=data.provide_label)
+    mod.set_params(arg_params, aux_params)
+
+    tic = time.time()
+    num = 0
+    for batch in data:
+        mod.forward(batch, is_train=False)
+        for m in metrics:
+            mod.update_metric(m, batch.label)
+        num += batch_size
+        if max_num_examples is not None and num >= max_num_examples:
+            break
+
+    speed = num / (time.time() - tic)
+
+    if logger is not None:
+        logger.info('Finished inference with %d images' % num)
+        logger.info('Finished with %f images per second', speed)
+        for m in metrics:
+            logger.info(m.get())
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser(description='Score a model on a dataset')
+    parser.add_argument('--model', type=str, choices=['imagenet1k-resnet-152', 
'imagenet1k-inception-bn'],
+                        help='currently only supports imagenet1k-resnet-152 or 
imagenet1k-inception-bn')
+    parser.add_argument('--batch-size', type=int, default=32)
+    parser.add_argument('--label-name', type=str, default='softmax_label')
+    parser.add_argument('--dataset', type=str, required=True, help='dataset 
path')
+    parser.add_argument('--rgb-mean', type=str, default='0,0,0')
+    parser.add_argument('--image-shape', type=str, default='3,224,224')
+    parser.add_argument('--data-nthreads', type=int, default=60, help='number 
of threads for data decoding')
+    parser.add_argument('--num-skipped-batches', type=int, default=0, 
help='skip the number of batches for inference')
+    parser.add_argument('--num-inference-batches', type=int, required=True, 
help='number of images used for inference')
+    parser.add_argument('--shuffle-dataset', action='store_true', default=True,
+                        help='shuffle the calibration dataset')
+    parser.add_argument('--shuffle-chunk-seed', type=int, default=3982304,
+                        help='shuffling chunk seed, see'
+                             ' 
https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+    parser.add_argument('--shuffle-seed', type=int, default=48564309,
+                        help='shuffling seed, see'
+                             ' 
https://mxnet.incubator.apache.org/api/python/io/io.html?highlight=imager#mxnet.io.ImageRecordIter'
+                             ' for more details')
+
+    args = parser.parse_args()
+
+    logging.basicConfig()
+    logger = logging.getLogger('logger')
+    logger.setLevel(logging.INFO)
+    data_nthreads = args.data_nthreads
+    batch_size = args.batch_size
+    logger.info('batch size = %d for inference' % batch_size)
+
+    rgb_mean = args.rgb_mean
+    logger.info('rgb_mean = %s' % rgb_mean)
+    rgb_mean = [float(i) for i in rgb_mean.split(',')]
+    mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': 
rgb_mean[2]}
+
+    label_name = args.label_name
+    logger.info('label_name = %s' % label_name)
+
+    image_shape = args.image_shape
+    data_shape = tuple([int(i) for i in image_shape.split(',')])
+    logger.info('Input data shape = %s' % str(data_shape))
+
+    dataset = args.dataset
+    download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset)
+    logger.info('Dataset for inference: %s' % dataset)
+
+    # creating data iterator
+    data = mx.io.ImageRecordIter(path_imgrec=dataset,
+                                 label_width=1,
+                                 preprocess_threads=data_nthreads,
+                                 batch_size=batch_size,
+                                 data_shape=data_shape,
+                                 label_name=label_name,
+                                 rand_crop=False,
+                                 rand_mirror=False,
+                                 shuffle=True,
+                                 shuffle_chunk_seed=3982304,
+                                 seed=48564309,
+                                 **mean_args)
+
+    # download model
+    prefix, epoch = download_model(model_name=args.model, logger=logger)
+    sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
+    op_names = ['BatchNorm', 'Convolution', 'Pooling', 'Activation']
+    out = SymbolHandle()
+    check_call(_LIB.MXPartitionGraph(sym.handle, mx_uint(len(op_names)), 
c_str_array(op_names),
 
 Review comment:
   We will hide this from users. The API design is still under discussion.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to