This is an automated email from the ASF dual-hosted git repository.
liuyizhi pushed a commit to branch v0.6
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/v0.6 by this push:
new ab76831 [BACKPORT-0.6][Quantization] Fix annotation for multiply op
(#4458) (#5850)
ab76831 is described below
commit ab76831126e21d224fbd7d04c5f26f3ad3628c2e
Author: masahi <[email protected]>
AuthorDate: Fri Jun 19 15:31:25 2020 +0900
[BACKPORT-0.6][Quantization] Fix annotation for multiply op (#4458) (#5850)
* fix mul rewrite
* register Realize Rewrite for global avg pool and add test
* remove unnecessary check
* improve the test case
---
python/tvm/relay/quantize/_annotate.py | 6 ++--
src/relay/pass/quantize/realize.cc | 7 ++--
tests/python/relay/test_pass_auto_quantize.py | 49 +++++++++++++++++++++++++++
3 files changed, 56 insertions(+), 6 deletions(-)
diff --git a/python/tvm/relay/quantize/_annotate.py
b/python/tvm/relay/quantize/_annotate.py
index 9d679d2..ab98f3c 100644
--- a/python/tvm/relay/quantize/_annotate.py
+++ b/python/tvm/relay/quantize/_annotate.py
@@ -214,8 +214,10 @@ def multiply_rewrite(ref_call, new_args, ctx):
# quantize lhs to INPUT field
if lhs_kind == QAnnotateKind.ACTIVATION:
lhs_expr = attach_simulated_quantize(lhs_expr, QAnnotateKind.INPUT)
- # quantize rhs to WEIGHT field
- rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.WEIGHT)
+ if _analysis.check_constant(rhs_expr):
+ rhs_expr = attach_simulated_quantize(rhs_expr,
QAnnotateKind.WEIGHT)
+ else:
+ rhs_expr = attach_simulated_quantize(rhs_expr, QAnnotateKind.INPUT)
expr = _forward_op(ref_call, [lhs_expr, rhs_expr])
return QAnnotateExpr(expr, QAnnotateKind.ACTIVATION)
diff --git a/src/relay/pass/quantize/realize.cc
b/src/relay/pass/quantize/realize.cc
index 4cf84f4..773551a 100644
--- a/src/relay/pass/quantize/realize.cc
+++ b/src/relay/pass/quantize/realize.cc
@@ -278,13 +278,9 @@ Expr MulRealize(const Call& ref_call,
DataType dtype = cfg->dtype_activation;
if (lhs->dtype != dtype) {
ldata = Cast(ldata, dtype);
- } else {
- CHECK_EQ(lhs->dtype, dtype);
}
if (rhs->dtype != dtype) {
rdata = Cast(rdata, dtype);
- } else {
- CHECK_EQ(rhs->dtype, dtype);
}
Expr ret = ForwardOp(ref_call, {ldata, rdata});
@@ -499,6 +495,9 @@ Expr AvgPoolRealize(const Call& ref_call,
RELAY_REGISTER_OP("nn.avg_pool2d")
.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+RELAY_REGISTER_OP("nn.global_avg_pool2d")
+.set_attr<FForwardRewrite>("FQRealizeRewrite", AvgPoolRealize);
+
Expr CastHintRealize(const Call& ref_call,
const Array<Expr>& new_args,
const NodeRef& ctx) {
diff --git a/tests/python/relay/test_pass_auto_quantize.py
b/tests/python/relay/test_pass_auto_quantize.py
new file mode 100644
index 0000000..e4aa36b
--- /dev/null
+++ b/tests/python/relay/test_pass_auto_quantize.py
@@ -0,0 +1,49 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+from tvm import relay
+from tvm.relay import testing
+
+
+def quantize_and_build(out):
+ f = relay.Function(relay.analysis.free_vars(out), out)
+ mod, params = testing.create_workload(f)
+
+ with relay.quantize.qconfig(skip_conv_layers=[]):
+ qmod = relay.quantize.quantize(mod, params)
+
+ relay.build(qmod, "llvm", params=params)
+
+
+def test_mul_rewrite():
+ """a test case where rhs of mul is not constant"""
+ data = relay.var("data", shape=(1, 16, 64, 64))
+ multiplier = relay.sigmoid(relay.var("data", shape=(1, 16, 1, 1)))
+ conv = relay.nn.conv2d(data, relay.var("weight"),
+ kernel_size=(3, 3),
+ padding=(1, 1),
+ channels=16)
+ act = relay.nn.relu(data=conv)
+
+ quantize_and_build(act * multiplier)
+
+ pool = relay.nn.global_avg_pool2d(data=act)
+
+ quantize_and_build(act * pool)
+
+if __name__ == "__main__":
+ test_mul_rewrite()