This is an automated email from the ASF dual-hosted git repository.

patriczhao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 7983ef1  fix quantize graph pass (#14605)
7983ef1 is described below

commit 7983ef1cd4fd8e4895c158f35f57e200d71817c4
Author: Xinyu Chen <xinyu1.c...@intel.com>
AuthorDate: Wed Apr 3 19:58:22 2019 -0700

    fix quantize graph pass (#14605)
---
 src/operator/quantization/quantize_graph_pass.cc | 12 ++-
 tests/python/mkl/test_subgraph.py                | 53 ++++++++++++-
 tests/python/quantization/test_quantization.py   | 95 ++++++++++++++++++++++++
 3 files changed, 154 insertions(+), 6 deletions(-)

diff --git a/src/operator/quantization/quantize_graph_pass.cc 
b/src/operator/quantization/quantize_graph_pass.cc
index af53397..5bd9e8a 100644
--- a/src/operator/quantization/quantize_graph_pass.cc
+++ b/src/operator/quantization/quantize_graph_pass.cc
@@ -265,7 +265,7 @@ Graph QuantizeGraph(Graph &&src) {
             (mirror_node->op() != Op::Get("_contrib_dequantize"))) {
           // here we calculate the output number (exclude min/max, in order to
           // calculate min/max index from mirror node) based on assumption that
-          // there is only 1min and 1max output from mirror node (which is
+          // there is only 1 min and 1 max output from mirror node (which is
           // currently true)
           size_t num_outputs = mirror_node->num_outputs() - 2;
           uint32_t min_index = num_outputs + 2 * e.index;
@@ -297,9 +297,13 @@ Graph QuantizeGraph(Graph &&src) {
       // Only insert dequantize for those Ops supports quantize and not 
excluded.
       NodePtr mirror_node = mirror_map.at(e.node.get());
       NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version};
-      size_t num_inputs = e.node->num_inputs();
-      uint32_t min_index = num_inputs + 2 * e.index;
-      uint32_t max_index = num_inputs + 2 * e.index + 1;
+      // here we calculate the output number (exclude min/max, in order to
+      // calculate min/max index from mirror node) based on assumption that
+      // there is only 1 min and 1 max output from mirror node (which is
+      // currently true)
+      size_t num_outputs = e.node->num_outputs();
+      uint32_t min_index = num_outputs + 2 * e.index;
+      uint32_t max_index = num_outputs + 2 * e.index + 1;
 
       NodePtr dequantize_node = CreateNode("_contrib_dequantize",
           e.node->attrs.name + "_dequantize");
diff --git a/tests/python/mkl/test_subgraph.py 
b/tests/python/mkl/test_subgraph.py
index c8cf79e..761eb47 100644
--- a/tests/python/mkl/test_subgraph.py
+++ b/tests/python/mkl/test_subgraph.py
@@ -85,7 +85,7 @@ def check_qsym_scale_align(qsym):
 
 
 def check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, 
label_shape):
-  mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+  mod = Module(symbol=qsym, context=mx.current_context())
   mod.bind(for_training=False,
            data_shapes=[('data', data_shape)],
            label_shapes=[('softmax_label', label_shape)])
@@ -96,7 +96,7 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, batch, 
data_shape, label_
   return mod.get_outputs()
 
 def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape):
-  mod = mx.mod.Module(symbol=qsym, context=mx.current_context())
+  mod = Module(symbol=qsym, context=mx.current_context())
   mod.bind(for_training=False,
            data_shapes=[('data', data_shape)],
            label_shapes=[('softmax_label', label_shape)])
@@ -185,6 +185,55 @@ def check_quantize(sym, data_shape, out_type, name='conv',
       assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), 
atol = 1)
 
 @with_seed()
+def check_quantize_whole_model_with_forward():
+  def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
+    mod = Module(symbol=qsym, label_names=None, context=mx.current_context())
+    mod.bind(for_training=False,
+             data_shapes=[('data', data_shape)])
+    mod.set_params(qarg_params, qaux_params)
+    data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in 
mod.data_shapes]
+    batch = mx.io.DataBatch(data, [])
+    mod.forward(batch, is_train=False)
+    for output in mod.get_outputs():
+        output.wait_to_read()
+
+  def check_quantize_whole_model(out_type):
+    batch_size = 4
+    data_shape = (batch_size, 4, 10, 10)
+    data = mx.sym.Variable('data')
+    conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, 
name='conv0')
+    sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1')
+    sym_sg = sym.get_backend_symbol('MKLDNN')
+    mod = Module(symbol=sym, label_names=[])
+    mod.bind(for_training=False,
+             data_shapes=[('data', data_shape)])
+
+    mod.init_params(mx.init.Normal(0.5))
+    arg_params, aux_params = mod.get_params()
+
+    excluded_sym_names = []
+
+    calib_data = mx.nd.random.uniform(shape=data_shape)
+    calib_data = NDArrayIter(data=calib_data)
+    calib_data = DummyIter(calib_data)
+    calib_layer = lambda name: name.endswith('_output')
+    qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym_sg,
+                                                                     
arg_params=arg_params,
+                                                                     
aux_params=aux_params,
+                                                                     
ctx=mx.current_context(),
+                                                                     
excluded_sym_names=excluded_sym_names,
+                                                                     
quantized_dtype=out_type,
+                                                                     
calib_mode='naive',
+                                                                     
calib_data=calib_data,
+                                                                     
calib_layer=calib_layer,
+                                                                     
num_calib_examples=5)
+    qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
+    check_qsym_forward(qsym, qarg_params, qaux_params, data_shape)
+
+  for qdtype in ['uint8', 'int8', 'auto']:
+    check_quantize_whole_model(qdtype)
+
+@with_seed()
 def check_fusion(sym, data_shape, attrs_op, name='conv', 
check_quantization=True):
   op_name = config[name][OP_NAME]
   sg_pass_name = config[name][SG_PASS_NAME]
diff --git a/tests/python/quantization/test_quantization.py 
b/tests/python/quantization/test_quantization.py
index eedc867..757df81 100644
--- a/tests/python/quantization/test_quantization.py
+++ b/tests/python/quantization/test_quantization.py
@@ -678,6 +678,101 @@ def test_quantize_model_with_forward():
         check_quantize_model(qdtype)
 
 @with_seed()
+def test_quantize_conv_with_forward():
+    def check_quantize_model(qdtype):
+        if is_test_for_native_cpu():
+            print('skipped testing test_quantize_model_with_forward for native 
cpu since it is not supported yet')
+            return
+        elif qdtype == 'int8' and is_test_for_mkldnn():
+            print('skipped testing test_quantize_model_with_forward for mkldnn 
cpu int8 since it is not supported yet')
+            return
+        elif qdtype == 'uint8' and is_test_for_gpu():
+            print('skipped testing test_quantize_model_with_forward for gpu 
uint8 since it is not supported yet')
+            return
+
+        def check_params(params, qparams, qsym=None):
+            if qsym is None:
+                assert len(params) == len(qparams)
+                for k, v in params.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+            else:
+                qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, 
params, th_dict = {})
+                assert len(qparams) == len(qparams_ground_truth)
+                for k, v in qparams_ground_truth.items():
+                    assert k in qparams
+                    assert same(v.asnumpy(), qparams[k].asnumpy())
+
+        def check_qsym_calibrated(qsym):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('requantize_') != -1:
+                    assert 'min_calib_range' in v
+                    assert 'max_calib_range' in v
+
+        def check_qsym_qdtype(qsym, qdtype):
+            attrs = qsym.attr_dict()
+            for k, v in attrs.items():
+                if k.find('_quantize') != -1:
+                    assert 'out_type' in v
+                    assert v['out_type'] == qdtype
+
+        def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape):
+            mod = mx.mod.Module(symbol=qsym, label_names=None, 
context=mx.current_context())
+            mod.bind(for_training=False,
+                     data_shapes=[('data', data_shape)])
+            mod.set_params(qarg_params, qaux_params)
+            data = [mx.random.uniform(-1.0, 1.0, shape=shape) for _, shape in 
mod.data_shapes]
+            batch = mx.io.DataBatch(data, [])
+            mod.forward(batch, is_train=False)
+            for output in mod.get_outputs():
+                output.wait_to_read()
+
+        batch_size = 4
+        dshape = (batch_size, 4, 10, 10)
+        data = mx.sym.Variable('data')
+        sym = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, 
name='conv0')
+
+        mod = Module(symbol=sym, label_names=None)
+        mod.bind(data_shapes=[('data', dshape)])
+
+        mod.init_params()
+        arg_params, aux_params = mod.get_params()
+        excluded_sym_names = []
+
+        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
+                                                                            
arg_params=arg_params,
+                                                                            
aux_params=aux_params,
+                                                                            
excluded_sym_names=excluded_sym_names,
+                                                                            
ctx=mx.current_context(),
+                                                                            
quantized_dtype=qdtype,
+                                                                            
calib_mode='none')
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_forward(qsym, qarg_params, qaux_params, dshape)
+
+        calib_data = mx.nd.random.uniform(shape=dshape)
+        calib_data = NDArrayIter(data=calib_data, batch_size=batch_size)
+        calib_data = DummyIter(calib_data)
+        qsym, qarg_params, qaux_params = 
mx.contrib.quant.quantize_model(sym=sym,
+                                                                            
arg_params=arg_params,
+                                                                            
aux_params=aux_params,
+                                                                            
excluded_sym_names=excluded_sym_names,
+                                                                            
ctx=mx.current_context(),
+                                                                            
quantized_dtype=qdtype,
+                                                                            
calib_mode='naive',
+                                                                            
calib_data=calib_data,
+                                                                            
num_calib_examples=20)
+        check_params(arg_params, qarg_params, qsym)
+        check_params(aux_params, qaux_params)
+        check_qsym_calibrated(qsym)
+        check_qsym_qdtype(qsym, qdtype)
+        check_qsym_forward(qsym, qarg_params, qaux_params, dshape)
+
+    for qdtype in ['uint8', 'int8']:
+        check_quantize_model(qdtype)
+
+@with_seed()
 def test_quantize_sym_with_calib():
     sym = get_fp32_sym()
     offline_params = [name for name in sym.list_arguments()

Reply via email to