This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 15e185d922 [Hexagon][QNN] Improve performance wo QNN canonicalization
(#13734)
15e185d922 is described below
commit 15e185d922b4de567aa2f74c71aedbc0b56952df
Author: ibsidorenko <[email protected]>
AuthorDate: Wed Jan 11 11:03:26 2023 +0300
[Hexagon][QNN] Improve performance wo QNN canonicalization (#13734)
This commit improves performance of different models tuned with
MetaScheduler for Hexagon target and without QNN canonicalization.
Benchmarking of several models on Snapdragon 8gen1 and tuned with MS:
shape | QNN canon enabled, ms | QNN canon disabled, ms |
speedup |
-----------------|-----------------------|------------------------|-------------|
ResNet, int8 | 50 | 48 |
+4.2% |
Inception, int8 | 103 | 106 |
-2.8% |
SRGAN, int8 | 348 | 431 |
-19.3% |
--------------------------------------------------------------------------------|
What was done:
1) Added 2 new passes: QnnLegalize and QnnCanonicalize. But this is just
wrappers for Legalize("FTVMQnnLegalize") and
Legalize("FTVMQnnCanonicalize").
2) Added ability to disable inline for specific blocks in MetaSchedule
AutoInline rule. For example, it can be done through the
T.block_attr({"meta_schedule.inline_rule": "disable"}).
3) Implemented compute, alter op and legalization functions for
qnn.conv2d operation (for Hexagon target).
---
include/tvm/relay/transform.h | 4 +
include/tvm/tir/stmt.h | 3 +
python/tvm/relay/qnn/op/_qnn.py | 10 +-
python/tvm/relay/qnn/op/legalizations.py | 70 +++++++
python/tvm/relay/qnn/strategy/hexagon.py | 13 ++
python/tvm/topi/hexagon/qnn/__init__.py | 1 +
python/tvm/topi/hexagon/qnn/conv2d_alter_op.py | 53 ++++++
python/tvm/topi/hexagon/qnn/nn.py | 208 +++++++++++++++++----
python/tvm/topi/nn/qnn.py | 19 ++
src/meta_schedule/schedule_rule/auto_inline.cc | 5 +
src/relay/qnn/op/convolution.cc | 3 +-
src/relay/qnn/op/requantize.cc | 9 +-
src/relay/qnn/pass/legalize.cc | 22 ++-
.../test_hexagon/test_wo_qnn_canonicalization.py | 121 +++++++++++-
14 files changed, 499 insertions(+), 42 deletions(-)
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index cdea8e8e3c..3227f7979d 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -710,6 +710,10 @@ TVM_DLL Function UnCPS(const Function& f);
*/
TVM_DLL Expr DeDup(const Expr& e);
+namespace legalize {
+TVM_DLL Expr Legalize(const Expr& expr, const std::string&
legalize_map_attr_name);
+} // namespace legalize
+
} // namespace relay
} // namespace tvm
diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h
index dc257b1e8a..96e03477a1 100644
--- a/include/tvm/tir/stmt.h
+++ b/include/tvm/tir/stmt.h
@@ -1613,6 +1613,9 @@ constexpr const char* meta_schedule_auto_tensorize_init =
"meta_schedule.auto_te
*/
constexpr const char* warp_execution = "warp_execution";
+/*! \brief Mark that a block is disallowed in auto inline. */
+constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
+
/*!
* \brief Check if attr_key is a pragma key extension
* \param attr_key The attr key to be compared
diff --git a/python/tvm/relay/qnn/op/_qnn.py b/python/tvm/relay/qnn/op/_qnn.py
index 64ef1ee92a..c9c4c86e8b 100644
--- a/python/tvm/relay/qnn/op/_qnn.py
+++ b/python/tvm/relay/qnn/op/_qnn.py
@@ -22,7 +22,7 @@ from tvm import topi
from .. import strategy
from ...op.op import register_compute
from ...op.op import register_injective_schedule
-from ...op.op import register_strategy, register_pattern, OpPattern
+from ...op.op import register_strategy, register_pattern,
register_alter_op_layout, OpPattern
@register_compute("qnn.simulated_quantize")
@@ -83,7 +83,13 @@ register_pattern("qnn.concatenate", OpPattern.INJECTIVE)
# qnn.conv2d
register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
-register_pattern("qnn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+@register_alter_op_layout("qnn.conv2d")
+def alter_op_layout_qnn_conv2d(attrs, inputs, tinfos, out_type):
+ """Alternate the layout of qnn.conv2d"""
+ return topi.nn.qnn_conv2d_alter_layout(attrs, inputs, tinfos, out_type)
+
# qnn.dense
register_strategy("qnn.dense", strategy.qnn_dense_strategy)
diff --git a/python/tvm/relay/qnn/op/legalizations.py
b/python/tvm/relay/qnn/op/legalizations.py
index ad016bc200..9baabf36a9 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -405,6 +405,11 @@ def is_fast_int8_on_intel():
return target_has_sse42(target.mcpu)
+# Helper function to align up given value.
+def helper_align_up(value, aligner):
+ return ((value + aligner) // aligner) * aligner
+
+
########################
# ARM CPU legalizations.
########################
@@ -483,3 +488,68 @@ def _qnn_dense_legalize_cuda(attrs, inputs, types):
# CUDA prefers both datatypes to be the int8.
return helper_change_dtypes_to_int8(attrs, inputs, types,
relay.qnn.op.dense)
return None
+
+
+########################
+# Hexagon legalizations.
+########################
+
+IN_CHANNEL_VECTOR_LENGTH = 4
+OUT_CHANNEL_VECTOR_LENGTH = 32
+
+
+@qnn_conv2d_legalize.register("hexagon")
+def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
+ """Legalize qnn.conv2d op for vrmpy tensorization.
+
+ If the inputs are signed or unsigned int8 and data/kernel layouts are
NCHW/OIHW, then the input
+ and output channels are padded to be a multiple of 4 and 32 respectively.
+ """
+ data_layout = attrs["data_layout"]
+ kernel_layout = attrs["kernel_layout"]
+
+ if data_layout != "NCHW" or kernel_layout != "OIHW":
+ return None
+
+ data_tensor, kernel_tensor = types[0], types[1]
+
+ if "int8" in data_tensor.dtype and "int8" in kernel_tensor.dtype:
+ in_channel = data_tensor.shape[1].value
+ out_channel = kernel_tensor.shape[0].value
+ ic_modified = False
+ oc_modified = False
+ data, kernel, input_zp, output_zp, input_scale, output_scale = inputs
+
+ if in_channel % IN_CHANNEL_VECTOR_LENGTH != 0:
+ new_in_channel = helper_align_up(in_channel,
IN_CHANNEL_VECTOR_LENGTH)
+ diff = new_in_channel - in_channel
+ pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
+ data = relay.nn.pad(data, pad_width=pad_width)
+ kernel = relay.nn.pad(kernel, pad_width=pad_width)
+ ic_modified = True
+
+ new_out_channel = out_channel
+ if out_channel % OUT_CHANNEL_VECTOR_LENGTH != 0:
+ new_out_channel = helper_align_up(out_channel,
OUT_CHANNEL_VECTOR_LENGTH)
+ diff = new_out_channel - out_channel
+ kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0,
0), (0, 0)))
+ oc_modified = True
+
+ if ic_modified is True or oc_modified is True:
+ new_attrs = dict(attrs)
+ if oc_modified:
+ new_attrs["channels"] = new_out_channel
+ out = relay.qnn.op.conv2d(
+ data, kernel, input_zp, output_zp, input_scale,
output_scale, **new_attrs
+ )
+ output_tensor = types[6]
+ original_out_shape = list(output_tensor.shape)
+ out = relay.strided_slice(out, begin=[0, 0, 0, 0],
end=original_out_shape)
+ else:
+ out = relay.qnn.op.conv2d(
+ data, kernel, input_zp, output_zp, input_scale,
output_scale, **new_attrs
+ )
+
+ return out
+
+ return None
diff --git a/python/tvm/relay/qnn/strategy/hexagon.py
b/python/tvm/relay/qnn/strategy/hexagon.py
index d17812e3fb..c25c96f8ed 100644
--- a/python/tvm/relay/qnn/strategy/hexagon.py
+++ b/python/tvm/relay/qnn/strategy/hexagon.py
@@ -17,12 +17,18 @@
"""Definition of Hexagon operator strategy."""
# pylint: disable=unused-argument,wildcard-import,unused-wildcard-import
+import re
+
from tvm import topi
from .generic import *
from ... import op as _op
from ...op.strategy.generic import is_depthwise_conv2d
+NCHWC_MATCHER = re.compile("^NCHW[0-9]+c$")
+OIHWIOI_MATCHER = re.compile("^OIHW[0-9]+i[0-9]+o[0-9]+i$")
+
+
@qnn_quantize_strategy.register("hexagon")
def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target):
"""qnn.quantize strategy for Hexagon"""
@@ -135,6 +141,13 @@ def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type,
target):
wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d),
name="qnn_conv2d.hexagon",
)
+ elif NCHWC_MATCHER.match(data_layout) and
OIHWIOI_MATCHER.match(kernel_layout):
+ if data.dtype == "uint8" and kernel.dtype == "int8":
+ strategy.add_implementation(
+ wrap_topi_qnn_conv2d(topi.hexagon.qnn_conv2d_NCHWc_int8),
+
wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d_NCHWc_int8),
+ name="qnn_conv2d_NCHWc_int8.hexagon",
+ )
elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape,
kernel_layout, groups):
if data_layout == "NCHW" and kernel_layout == "OIHW":
strategy.add_implementation(
diff --git a/python/tvm/topi/hexagon/qnn/__init__.py
b/python/tvm/topi/hexagon/qnn/__init__.py
index d41d8854d7..b8cdc7a26d 100644
--- a/python/tvm/topi/hexagon/qnn/__init__.py
+++ b/python/tvm/topi/hexagon/qnn/__init__.py
@@ -29,3 +29,4 @@ from .nn import *
from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute,
qdepthwise_conv2d_schedule
from .adaptive_avg_pool1d import *
from .global_avg_pool2d import *
+from .conv2d_alter_op import *
diff --git a/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
b/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
new file mode 100644
index 0000000000..867a477956
--- /dev/null
+++ b/python/tvm/topi/hexagon/qnn/conv2d_alter_op.py
@@ -0,0 +1,53 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""QNN Conv2d alter op functions for Hexagon"""
+
+from tvm import relay
+from ...nn import qnn_conv2d_alter_layout
+from ...utils import get_const_tuple
+
+
+@qnn_conv2d_alter_layout.register("hexagon")
+def _alter_qnn_conv2d_layout(attrs, inputs, tinfos, _out_type):
+ data_layout = attrs["data_layout"]
+ kernel_layout = attrs["kernel_layout"]
+ data_tensor, kernel_tensor, _, _, _, _ = tinfos
+
+ if (
+ "int8" in data_tensor.dtype
+ and "int8" in kernel_tensor.dtype
+ and data_layout == "NCHW"
+ and kernel_layout == "OIHW"
+ ):
+ out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)
+
+ if out_channel % 32 != 0 or in_channel % 4 != 0:
+ return None
+
+ n_elems = 4
+ oc_bn = 32
+ ic_bn = min(in_channel, 32)
+
+ new_attrs = dict(attrs)
+ new_attrs["channels"] = out_channel
+ new_attrs["data_layout"] = "NCHW%dc" % ic_bn
+ new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn //
n_elems, oc_bn, n_elems)
+ new_attrs["out_layout"] = "NCHW%dc" % oc_bn
+
+ return relay.qnn.op.conv2d(*inputs, **new_attrs)
+
+ return None
diff --git a/python/tvm/topi/hexagon/qnn/nn.py
b/python/tvm/topi/hexagon/qnn/nn.py
index 49220d0fd0..aabdf2a63b 100644
--- a/python/tvm/topi/hexagon/qnn/nn.py
+++ b/python/tvm/topi/hexagon/qnn/nn.py
@@ -17,14 +17,17 @@
"""Hexagon QNN operators"""
# pylint: disable=invalid-name
+import numpy as np
+
import tvm
from tvm import te, topi
-from ..utils import saturate
+from ..utils import saturate, get_fixed_point_value
from ...utils import get_const_tuple
from ...nn.utils import get_pad_tuple
from ...nn.pad import pad
from ... import tag, nn
-from ...x86.concat import concatenate
+from ..conv2d import conv2d_NCHWc_int8
+from ...transform import concatenate
def clip_cast(val, dtype):
@@ -36,7 +39,9 @@ def clip_cast(val, dtype):
# Return True if given Tensor is scalar constant value.
def is_constant(tensor: te.Tensor):
- return tensor.ndim == 0
+ return tensor.ndim == 0 and (
+ isinstance(tensor.op.body[0], (tvm.tir.expr.FloatImm,
tvm.tir.expr.IntImm))
+ )
def get_qnn_param(param, indices, axis):
@@ -65,6 +70,11 @@ def default_schedule(outs):
outs = [outs] if isinstance(outs, tvm.te.tensor.Tensor) else outs
s = tvm.te.create_schedule([x.op for x in outs])
tvm.te.schedule.AutoInlineInjective(s)
+ for x in outs:
+ fused = s[x].fuse(*x.op.axis)
+ outer, inner = s[x].split(fused, factor=128 //
np.dtype(x.dtype).itemsize)
+ s[x].vectorize(inner)
+ s[x].parallel(outer)
return s
@@ -140,30 +150,58 @@ def schedule_qnn_dequantize(outs):
return default_schedule(outs)
-def qnn_requantize(data, input_scale, input_zp, output_scale, output_zp,
axis=-1, out_dtype="int8"):
+def qnn_requantize(
+ data: te.Tensor,
+ input_scale: te.Tensor,
+ input_zp: te.Tensor,
+ output_scale: te.Tensor,
+ output_zp: te.Tensor,
+ axis=-1,
+ out_dtype="int8",
+):
"""Compute for qnn.requantize
- Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input -
zp_input))
+ If both input and output scales are constant scalars then we convert scale
to fixed point value
+ and use integer arithmetic only for performance optimization purpose.
+ But this is a tradeoff between performance and accuracy, since we use
int16 data type to
+ represent fixed point values (against QNN lowering approach where we use
int32 for that).
+
+ if input and/or output scales are not constant scalars then we use the
following formula:
+ Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input -
zp_input))
TODO: support 'rounding' and 'compute_dtype' arguments.
"""
- def _compute(*indices):
- value = data(*indices)
+ if is_constant(input_scale) and is_constant(output_scale):
+ iscale = input_scale.op.body[0].value
+ oscale = output_scale.op.body[0].value
+ scale = iscale / oscale
+ scale_fixed_point, rsh = get_fixed_point_value(scale, "int16")
+
+ def _compute(*indices):
+ value = data(*indices)
+ # Subtract input zero point:
+ sub = te.subtract(value, input_zp)
+ # Fixed point multiply + roundup delta:
+ mul = (sub * scale_fixed_point + (1 << (rsh - 1))) >> rsh
+ # Add output zero point + clip + cast:
+ return saturate(te.add(mul, output_zp),
out_dtype).astype(out_dtype)
- iscale = get_qnn_param(input_scale, indices, axis)
- oscale = get_qnn_param(output_scale, indices, axis)
+ return te.compute(data.shape, _compute)
+
+ else:
- sub = te.subtract(value, input_zp)
- mul = te.div(iscale, oscale)
- val = te.add(te.round(te.multiply(mul, sub)), output_zp)
+ def _compute(*indices):
+ value = data(*indices)
+ iscale = get_qnn_param(input_scale, indices, axis)
+ oscale = get_qnn_param(output_scale, indices, axis)
- # clip + cast:
- const_min = tvm.tir.min_value(out_dtype)
- const_max = tvm.tir.max_value(out_dtype)
- return te.max(tvm.te.min(val, const_max), const_min).astype(out_dtype)
+ sub = te.subtract(value, input_zp)
+ mul = te.div(iscale, oscale)
+ val = te.add(te.round(te.multiply(mul, sub)), output_zp)
+ return saturate(val, out_dtype).astype(out_dtype)
- return te.compute(data.shape, _compute)
+ return te.compute(data.shape, _compute)
def schedule_qnn_requantize(outs):
@@ -188,9 +226,15 @@ def compute_qnn_binary_op(
):
"""Compute for QNN binary operation
- Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input -
lhs_zp))
- _OP_ round((rhs_scale)/(output_scale) * (rhs_input -
rhs_zp))
- where _OP_ is add/subtract
+ If rhs/lhs/output scales are constant scalars then we convert scale to
fixed point value
+ and use integer arithmetic only for performance optimization purpose.
+ But this is a tradeoff between performance and accuracy, since we use
int16 data type to
+ represent fixed point values (against QNN lowering approach where we use
int32 for that).
+
+ if rhs/lhs/output scales are not constant scalars then we use the
following formula:
+ Q_output = output_zp + round((lhs_scale)/(output_scale) * (lhs_input -
lhs_zp))
+ _OP_ round((rhs_scale)/(output_scale) * (rhs_input -
rhs_zp))
+ where _OP_ is add/subtract
"""
assert lhs.dtype == rhs.dtype
dtype = lhs.dtype
@@ -200,13 +244,24 @@ def compute_qnn_binary_op(
"int32"
)
- def _compute_tensor(x: te.Tensor, iscale, input_zp):
- return te.compute(
- x.shape,
- lambda *i: te.round(
- te.multiply(te.div(iscale, output_scale), te.subtract(x(*i),
input_zp))
- ).astype("int32"),
- )
+ def _compute_tensor(x: te.Tensor, input_scale, input_zp):
+ if is_constant(input_scale) and is_constant(output_scale):
+ iscale = input_scale.op.body[0].value
+ oscale = output_scale.op.body[0].value
+ scale = iscale / oscale
+ scale_fixed_point, rsh = get_fixed_point_value(scale, "int16")
+ return te.compute(
+ x.shape,
+ lambda *i: (te.subtract(x(*i), input_zp) * scale_fixed_point +
(1 << (rsh - 1)))
+ >> rsh,
+ )
+ else:
+ return te.compute(
+ x.shape,
+ lambda *i: te.round(
+ te.multiply(te.div(input_scale, output_scale),
te.subtract(x(*i), input_zp))
+ ).astype("int32"),
+ )
if is_constant(lhs):
lhs_tensor = _compute_const(lhs, lhs_scale, lhs_zp)
@@ -391,7 +446,7 @@ def qnn_concatenate(data, axis, out_dtype):
# Requantize tensors and add them to the list.
args.append(qnn_requantize(tensor, i_scale, i_zp, o_scale, o_zp,
out_dtype=out_dtype))
- # Call x86 implementation of concatenate.
+ # Call generic implementation of concatenate.
return concatenate(args, axis)
@@ -454,6 +509,15 @@ def qnn_conv2d( # Conv2d inputs
get_const_tuple(padding), (dilated_kernel_h, dilated_kernel_w)
)
+ # Subtract zero point from weights. axis=0 in get_qnn_param means 'O'
dimension in "OIHW"
+ # weights layout.
+ weight = te.compute(
+ weight.shape,
+ lambda *indices: te.subtract(
+ weight(*indices), get_qnn_param(kernel_zero_point, indices, axis=0)
+ ),
+ )
+
# Subtract zero point from input and then do padding with 0 value
data = te.compute(data.shape, lambda *indices: te.subtract(data(*indices),
input_zero_point))
@@ -469,7 +533,6 @@ def qnn_conv2d( # Conv2d inputs
kh = te.reduce_axis((0, kernel_height), name="kh")
kw = te.reduce_axis((0, kernel_width), name="kw")
- # axis=0 in get_qnn_param means 'O' dimension in "OIHW" weights layout.
out = te.compute(
oshape,
lambda n, oc, oh, ow: te.sum(
@@ -479,9 +542,7 @@ def qnn_conv2d( # Conv2d inputs
oh * height_stride + kh * dilation_h,
ow * width_stride + kw * dilation_w,
].astype("int32")
- * te.subtract(
- weight[oc, ic, kh, kw], get_qnn_param(kernel_zero_point, (oc,
ic, kh, kw), axis=0)
- ).astype("int32"),
+ * weight[oc, ic, kh, kw].astype("int32"),
axis=[ic, kh, kw],
),
)
@@ -532,6 +593,89 @@ def schedule_qnn_conv2d(outs):
return default_schedule(outs)
+def qnn_conv2d_NCHWc_int8( # Conv2d inputs
+ data,
+ weight,
+ # Conv2d quantization params:
+ input_zero_point,
+ kernel_zero_point,
+ _input_scale,
+ _kernel_scale,
+ # bias
+ bias,
+ # Requantization params:
+ rq_input_scale,
+ rq_input_zero_point,
+ rq_output_scale,
+ rq_output_zero_point,
+ # Conv2d attributes:
+ strides,
+ padding,
+ dilation,
+ _oshape,
+ odtype,
+):
+ """Compute for qnn.conv2d with NCHWc layout."""
+ # Subtract zero point from weights. Need to disable inline of this block
+ # (meta_schedule.inline_rule = disable). Otherwise, inline prevents from
tensorization.
+ weight = te.compute(
+ weight.shape,
+ lambda *i: te.subtract(weight(*i),
kernel_zero_point).astype(weight.dtype),
+ name="weight_zp",
+ attrs={"meta_schedule.inline_rule": "disable"},
+ )
+
+ # Subtract zero point from input. Again need to disable inline of this
block
+ # (meta_schedule.inline_rule = disable). Otherwise, inline prevents from
tensorization.
+ data = te.compute(
+ data.shape,
+ lambda *i: te.subtract(data(*i), input_zero_point).astype(data.dtype),
+ name="data_zp",
+ attrs={"meta_schedule.inline_rule": "disable"},
+ )
+
+ strides = get_const_tuple(strides)
+ padding = get_const_tuple(padding)
+ dilation = get_const_tuple(dilation)
+ out = conv2d_NCHWc_int8(data, weight, strides, padding, dilation,
"NCHW32c", "NCHW32c")
+
+ # Add bias
+ if bias is not None:
+ assert len(out.shape) == len(bias.shape)
+ assert bias.shape[2] == 1 and bias.shape[3] == 1
+ out = te.compute(
+ out.shape, lambda n, c, h, w, ci: out[n, c, h, w, ci] + bias[n, c,
0, 0, ci]
+ )
+
+ # Requantize output of convolution
+ # Q_output = zp_output + round((scale_input)/(scale_output) * (Q_input -
zp_input))
+ if rq_input_scale is not None and rq_output_scale is not None:
+ # Now supported only scalar and 1D quantization parameters
+ assert len(rq_input_scale.shape) == 0 or len(rq_input_scale.shape) == 1
+ assert len(rq_output_scale.shape) == 0 or len(rq_output_scale.shape)
== 1
+ axis = -1
+ if len(rq_input_scale.shape) == 1 or len(rq_output_scale.shape) == 1:
+ axis = 1 # Axis param should correspond to 'C' dimension.
+
+ return qnn_requantize(
+ out,
+ rq_input_scale,
+ rq_input_zero_point,
+ rq_output_scale,
+ rq_output_zero_point,
+ axis,
+ odtype,
+ )
+
+ return out
+
+
+def schedule_qnn_conv2d_NCHWc_int8(outs):
+ """Schedule for qnn.conv2d with NCHWc layout."""
+
+ return default_schedule(outs)
+
+
def qnn_depthwise_conv2d( # Conv2d inputs
data,
weight,
diff --git a/python/tvm/topi/nn/qnn.py b/python/tvm/topi/nn/qnn.py
index 222f7a7c22..7a29266b08 100644
--- a/python/tvm/topi/nn/qnn.py
+++ b/python/tvm/topi/nn/qnn.py
@@ -236,3 +236,22 @@ def qnn_add_alter_layout(_attrs, _inputs, _tinfos,
_out_type):
Unlike other TOPI functions, this function operates on both graph level
and operator level.
"""
return None
+
+
[email protected]_func
+def qnn_conv2d_alter_layout(_attrs, _inputs, _tinfos, _out_type):
+ """Change qnn.conv2D layout.
+ Not to change by default
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : tvm.relay.Expr
+ Grouped input symbols
+ tinfos : list
+ Input shape and dtype
+ out_type: type
+ The output type
+ """
+ return None
diff --git a/src/meta_schedule/schedule_rule/auto_inline.cc
b/src/meta_schedule/schedule_rule/auto_inline.cc
index d2d48b9008..22e8396925 100644
--- a/src/meta_schedule/schedule_rule/auto_inline.cc
+++ b/src/meta_schedule/schedule_rule/auto_inline.cc
@@ -139,6 +139,11 @@ inline InlineType AutoInlineNode::CheckInline(const
tir::Schedule& sch,
}
}
}
+ // Cond 6. The block is disallowed for auto inline
+ if (Optional<String> ann =
+ tir::GetAnn<String>(block_sref,
tir::attr::meta_schedule_inline_rule)) {
+ if (ann.value() == "disable") return InlineType::kNoInline;
+ }
// Last cond: Check inline into the consumers or the spatial producer
tir::StmtSRef scope_block = tir::GetScopeRoot(sch->state(), block_sref,
/*require_stage_pipeline=*/false);
diff --git a/src/relay/qnn/op/convolution.cc b/src/relay/qnn/op/convolution.cc
index 2170ba76e0..f5ac6af1df 100644
--- a/src/relay/qnn/op/convolution.cc
+++ b/src/relay/qnn/op/convolution.cc
@@ -860,7 +860,8 @@ operator to understand how to scale back the int32 output
to (u)int8 or (u)int16
.add_type_rel("QnnConv2D", QnnConv2DRel)
.set_attr<TNonComputational>("TNonComputational", true)
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", QnnConv2DCanonicalize)
- .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
QnnConvInferCorrectLayout);
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout",
QnnConvInferCorrectLayout)
+ .set_attr<TOpPattern>("TOpPattern", kOutEWiseFusable);
TVM_REGISTER_GLOBAL("relay.qnn.op._make.conv2d").set_body_typed(MakeQnnConv2D);
diff --git a/src/relay/qnn/op/requantize.cc b/src/relay/qnn/op/requantize.cc
index 91df4a287c..e1d27ee536 100644
--- a/src/relay/qnn/op/requantize.cc
+++ b/src/relay/qnn/op/requantize.cc
@@ -76,10 +76,17 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const
Attrs& attrs,
if (old_dim == layout_dim) {
new_axis = tvm::Integer(axis_index);
}
- // Collect only the primal axis.
+
if (layout_axis.IsPrimal()) {
new_layout_string += layout_dim;
axis_index++;
+ } else {
+ // Propogate layout if input_zero_point and input_scale are scalar
values.
+ ICHECK_GE(old_in_types.size(), 3);
+ if (IsScalarType(old_in_types[1]) && IsScalarType(old_in_types[2])) {
+ new_layout_string +=
std::to_string(new_in_layouts[0].FactorOf(layout_axis)) + layout_dim;
+ axis_index++;
+ }
}
}
diff --git a/src/relay/qnn/pass/legalize.cc b/src/relay/qnn/pass/legalize.cc
index a5906cf5e6..fd88c4df8c 100644
--- a/src/relay/qnn/pass/legalize.cc
+++ b/src/relay/qnn/pass/legalize.cc
@@ -30,10 +30,28 @@ namespace qnn {
namespace transform {
+// QnnLegalize pass is a wrapper for relay::legalize::Legalize pass.
+Pass QnnLegalize() {
+ runtime::TypedPackedFunc<Function(Function, IRModule,
relay::transform::PassContext)> pass_func =
+ [=](Function f, IRModule m, relay::transform::PassContext pc) {
+ return Downcast<Function>(relay::legalize::Legalize(f,
"FTVMQnnLegalize"));
+ };
+ return relay::transform::CreateFunctionPass(pass_func, 1, "QnnLegalize",
{"InferType"});
+}
+
+// QnnCanonicalize pass is a wrapper for relay::legalize::Legalize pass.
+Pass QnnCanonicalize() {
+ runtime::TypedPackedFunc<Function(Function, IRModule,
relay::transform::PassContext)> pass_func =
+ [=](Function f, IRModule m, relay::transform::PassContext pc) {
+ return Downcast<Function>(relay::legalize::Legalize(f,
"FTVMQnnCanonicalize"));
+ };
+ return relay::transform::CreateFunctionPass(pass_func, 1, "QnnCanonicalize",
{"InferType"});
+}
+
Pass Legalize() {
Array<Pass> pass_seqs;
- pass_seqs.push_back(relay::transform::Legalize("FTVMQnnLegalize"));
- pass_seqs.push_back(relay::transform::Legalize("FTVMQnnCanonicalize"));
+ pass_seqs.push_back(QnnLegalize());
+ pass_seqs.push_back(QnnCanonicalize());
relay::transform::Pass seq = relay::transform::Sequential(pass_seqs,
"qnn.Legalize");
return seq;
}
diff --git a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
index 06e738d9b7..e583b1b5ea 100644
--- a/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
+++ b/tests/python/contrib/test_hexagon/test_wo_qnn_canonicalization.py
@@ -59,7 +59,7 @@ def execute(mod_executor, inputs: dict):
def build_hexagon_module(mod):
- with tvm.transform.PassContext(opt_level=3,
disabled_pass=["qnn.Legalize"]):
+ with tvm.transform.PassContext(opt_level=3,
disabled_pass=["QnnCanonicalize"]):
hexagon_lowered = tvm.relay.build(
mod,
tvm.target.Target(HEXAGON_AOT_LLVM_TARGET,
host=HEXAGON_AOT_LLVM_TARGET),
@@ -87,7 +87,7 @@ def test_qnn_conv2d_rq(hexagon_session: Session):
weight_shape = [16, 8, 3, 3]
data = relay.var("data", shape=data_shape, dtype="float32")
weight = relay.var("weight", shape=weight_shape, dtype="float32")
- op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0),
out_dtype="int8")
+ op0 = relay.qnn.op.quantize(data, relay.const(0.078), relay.const(0),
out_dtype="uint8")
op1 = relay.qnn.op.quantize(weight, relay.const(0.07), relay.const(0),
out_dtype="int8")
op2 = relay.qnn.op.conv2d(
op0,
@@ -116,7 +116,7 @@ def test_qnn_conv2d_rq(hexagon_session: Session):
# Reference compilation
llvm_lowered = build_ref_module(relay_mod)
- data_np = np.random.rand(*data_shape) - 0.5
+ data_np = np.random.rand(*data_shape)
weight_np = np.random.rand(*weight_shape) - 0.5
inputs = {"data": data_np, "weight": weight_np}
@@ -181,7 +181,8 @@ def test_qnn_dense_bias_rq(hexagon_session: Session):
llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
llvm_out = execute(llvm_m, inputs)
- np.testing.assert_equal(hexagon_output, llvm_out)
+ # Diff by 1 is Ok.
+ tvm.testing.assert_allclose(hexagon_output, llvm_out, atol=1)
class TestQnnBinaryOp:
@@ -278,5 +279,117 @@ class TestQnnBinaryOp:
tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
+class TestQnnOp:
+ """QNN op test class"""
+
+ @tvm.testing.requires_hexagon
+ def test_qnn_requantize(self, hexagon_session: Session):
+ """qnn.requantize test without QNN canonicalization."""
+ data_shape = [256]
+ data = relay.var("data", shape=data_shape, dtype="int32")
+
+ op = relay.qnn.op.requantize(
+ data,
+ input_scale=relay.const(0.156),
+ input_zero_point=relay.const(2),
+ output_scale=relay.const(0.212),
+ output_zero_point=relay.const(1),
+ out_dtype="int8",
+ )
+ mod = tvm.IRModule.from_expr(op)
+
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(mod)
+
+ # Reference compilation
+ llvm_lowered = build_ref_module(mod)
+
+ data_np = np.arange(-256, 256, 2, dtype="int32")
+ inputs = {"data": data_np}
+
+ hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+ hexagon_output = execute(hx_m, inputs)
+
+ dev = tvm.cpu(0)
+ llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+ llvm_output = execute(llvm_m, inputs)
+
+ np.testing.assert_equal(hexagon_output, llvm_output)
+
+ @tvm.testing.requires_hexagon
+ def test_qnn_concatenate(self, hexagon_session: Session):
+ """qnn.concatenate op test without QNN canonicalization."""
+ x_shape = [1, 64]
+ y_shape = [2, 64]
+ z_shape = [3, 64]
+ input_x = relay.var("x", shape=x_shape, dtype="uint8")
+ input_y = relay.var("y", shape=y_shape, dtype="uint8")
+ input_z = relay.var("z", shape=z_shape, dtype="uint8")
+
+ op = relay.qnn.op.concatenate(
+ (input_x, input_y, input_z),
+ input_scales=(relay.const(0.3), relay.const(0.7),
relay.const(1.3)),
+ input_zero_points=(relay.const(0), relay.const(1), relay.const(2)),
+ output_scale=relay.const(0.8),
+ output_zero_point=relay.const(5),
+ axis=0,
+ )
+ mod = tvm.IRModule.from_expr(op)
+
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(mod)
+
+ # Reference compilation
+ llvm_lowered = build_ref_module(mod)
+
+ x_np = np.arange(0, 64, 1, dtype="uint8").reshape(x_shape)
+ y_np = np.arange(0, 128, 1, dtype="uint8").reshape(y_shape)
+ z_np = np.arange(0, 192, 1, dtype="uint8").reshape(z_shape)
+ inputs = {"x": x_np, "y": y_np, "z": z_np}
+
+ hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+ hexagon_output = execute(hx_m, inputs)
+
+ dev = tvm.cpu(0)
+ llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+ llvm_output = execute(llvm_m, inputs)
+
+ # Diff by 1 is Ok.
+ tvm.testing.assert_allclose(hexagon_output, llvm_output, atol=1)
+
+ @tvm.testing.requires_hexagon
+ def test_qnn_tanh(self, hexagon_session: Session):
+ """qnn.tanh op test without QNN canonicalization."""
+ data_shape = [256]
+ data = relay.var("data", shape=data_shape, dtype="uint8")
+
+ op = relay.qnn.op.tanh(
+ data,
+ scale=relay.const(0.518),
+ zero_point=relay.const(137),
+ output_scale=relay.const(0.207),
+ output_zero_point=relay.const(128),
+ )
+ mod = tvm.IRModule.from_expr(op)
+
+ # Compile for Hexagon
+ hexagon_lowered = build_hexagon_module(mod)
+
+ # Reference compilation
+ llvm_lowered = build_ref_module(mod)
+
+ data_np = np.arange(0, 256, 1, dtype="uint8")
+ inputs = {"data": data_np}
+
+ hx_m = hexagon_session.get_executor_from_factory(hexagon_lowered)
+ hexagon_output = execute(hx_m, inputs)
+
+ dev = tvm.cpu(0)
+ llvm_m = tvm.runtime.executor.AotModule(llvm_lowered["default"](dev))
+ llvm_output = execute(llvm_m, inputs)
+
+ np.testing.assert_equal(hexagon_output, llvm_output)
+
+
if __name__ == "__main__":
tvm.testing.main()