tkonolige commented on a change in pull request #10310:
URL: https://github.com/apache/tvm/pull/10310#discussion_r812335125
##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -365,3 +439,164 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
return None
+
+
+@conv2d_legalize.register("arm_cpu")
+def _conv2d_legalize(attrs, inputs, arg_types):
+ """Legalizes Conv2D op.
+
+ Parameters
+ ----------
+ attrs : tvm.ir.Attrs
+ Attributes of current convolution
+ inputs : list of tvm.relay.Expr
+ The args of the Relay expr to be legalized
+ types : list of types
+ List of input and output types
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The legalized expr
+ """
+
+ # Dilation not supported yet. Return None if dilation is not (1, 1)
+ dilation = attrs.get_int_tuple("dilation")
+ if not (dilation[0] == 1 and dilation[1] == 1):
+ return None
+
+ # No legalization for depthwise convolutions yet.
+ groups = attrs.get_int("groups")
+ if groups != 1:
+ return None
+
+ # Collect the input tensors.
+ data_tensor, kernel_tensor = arg_types[0], arg_types[1]
+ data_dtype = data_tensor.dtype
+ kernel_dtype = kernel_tensor.dtype
+
+ # Collect the output tensor.
+ output_tensor = arg_types[2]
+
+ # Collect the input exprs.
+ data, kernel = inputs
+
+ # Get the conv attrs
+ new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+ is_int8_inputs = False
+ # ARM intrinsics need the datatypes of data and kernel to be the same
+ # Original --> C = A (conv) B
+ # A and B are int8
+ # C = (A + 128 - 128) (conv) B
+ # C = (A' conv B) - 128 (conv) B
+ # where A' = A + 128
+ # and 128 (conv) B is basically a reduce on CRS axis for weights.
+ # C = (A - 128 + 128) (conv) B
+ # C = (A' conv B) + 128 (conv) B
+ # where A' = A - 128
+ if (
+ data_tensor.dtype == "uint8"
+ and kernel_tensor.dtype == "int8"
+ or data_tensor.dtype == "int8"
+ and kernel_tensor.dtype == "uint8"
+ ):
+ if data_tensor.dtype == "uint8" and kernel_tensor.dtype == "int8":
+ # shift data to int8
+ before_shift = relay.add
+ after_shift = relay.subtract
+ data_dtype = "int8"
+ else:
+ # shift data to uint8
+ before_shift = relay.subtract
+ after_shift = relay.add
+ data_dtype = "uint8"
+ is_int8_inputs = True
+ padding = attrs.get_int_tuple("padding")
+ kh, kw = attrs.get_int_tuple("kernel_size")
+ pt, pl, pb, pr = get_pad_tuple(padding, (kh, kw))
+
+ if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
+ adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
axis=(0, 1, 2))
+ pad_width = ((0, 0), (pt, pb), (pl, pr), (0, 0))
+ elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] ==
"OIHW":
+ pad_width = ((0, 0), (0, 0), (pt, pb), (pl, pr))
+ adjust_shift = relay.sum(relay.cast(kernel, dtype="int32"),
axis=(1, 2, 3))
+ adjust_shift = relay.expand_dims(adjust_shift, axis=1,
num_newaxis=2)
+ else:
+ return None
+
+ data = relay.cast(data, "int32")
+ data = before_shift(data, relay.const(128, "int32"))
+ data = relay.cast(data, data_dtype)
+
+ # Do external padding as pad value has to be 128.
+ if any(padding):
+ data = relay.nn.pad(data, pad_width=pad_width, pad_value=128)
+ new_attrs["padding"] = (0, 0)
+
+ # Multiply 128 to adjust shift.
+ adjust_shift = relay.multiply(adjust_shift, relay.const(128, "int32"))
+
+ # Legalize if the datatypes are suitable for fast Int8 instructions. Int8
instructions require
+ # input channel to be a multiple of 8 and output channels to be a multiple
of 8. For input
+ # channels, we pad both the inputs and weights input channels. For output
channels, we pad the
+ # weight and stride_slice the output.
+ if is_int8_hw_support(data_dtype, kernel_dtype):
+ # Flags to remember if the expr is modified
+ ic_modified = False
+ oc_modified = False
+
+ # Find the value of input and output channel.
+ in_channel = -1
+ out_channel = -1
+ if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] == "HWIO":
+ in_channel = data_tensor.shape[3].value
+ out_channel = kernel_tensor.shape[3].value
+ elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] ==
"OIHW":
+ in_channel = data_tensor.shape[1].value
+ out_channel = kernel_tensor.shape[0].value
+ else:
+ return None
+
+ if in_channel % 8 != 0:
+ new_in_channel = ((in_channel + 8) // 8) * 8
+ diff = new_in_channel - in_channel
+ if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] ==
"HWIO":
+ data = relay.nn.pad(data, pad_width=((0, 0), (0, 0), (0, 0),
(0, diff)))
+ kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0,
diff), (0, 0)))
+ ic_modified = True
+ elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] ==
"OIHW":
+ pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
+ data = relay.nn.pad(data, pad_width=pad_width)
+ kernel = relay.nn.pad(kernel, pad_width=pad_width)
+ ic_modified = True
+ else:
+ return None
+
+ new_out_channel = out_channel
+ if out_channel % 8 != 0:
+ new_out_channel = ((out_channel + 8) // 8) * 8
+ diff = new_out_channel - out_channel
+ if attrs["data_layout"] == "NHWC" and attrs["kernel_layout"] ==
"HWIO":
+ kernel = relay.nn.pad(kernel, pad_width=((0, 0), (0, 0), (0,
0), (0, diff)))
+ oc_modified = True
+ elif attrs["data_layout"] == "NCHW" and attrs["kernel_layout"] ==
"OIHW":
+ kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0),
(0, 0), (0, 0)))
+ oc_modified = True
+ else:
+ return None
+
+ if oc_modified:
+ new_attrs["channels"] = new_out_channel
+ out = tvm.relay.nn.conv2d(data, kernel, **new_attrs)
+ original_out_shape = [x.value for x in output_tensor.shape]
+ out = relay.strided_slice(out, begin=[0, 0, 0, 0],
end=original_out_shape)
+ else:
+ out = relay.nn.conv2d(data, kernel, **new_attrs)
+
+ if is_int8_inputs:
+ out = after_shift(out, adjust_shift)
+
+ return out
+ return None
Review comment:
I moved the common code into `topi/generic/conv2d.py`
##########
File path: python/tvm/topi/arm_cpu/conv2d_alter_op.py
##########
@@ -333,6 +343,70 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
+
+ if topi_tmpl == "conv2d_NCHWc_int8.arm_cpu":
+ assert data_layout == "NCHW" and kernel_layout == "OIHW"
+ if cfg.is_fallback:
+ _get_default_config_int8(
+ cfg,
+ data_tensor,
+ kernel_tensor,
+ strides,
+ padding,
+ dilation,
+ out_dtype,
+ False,
+ data_layout,
+ )
+
+ batch_size, in_channel, height, width =
get_const_tuple(data_tensor.shape)
+ out_channel, channel_multiplier, kh, kw =
get_const_tuple(kernel_tensor.shape)
+ ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
+ n_elems = 8
+
+ # convert kernel data layout from 4D to 7D
+ data_expr, kernel_expr = inputs
+ kernel_IHWO = relay.transpose(kernel_expr, axes=(1, 2, 3, 0))
+ kernel_IHWOo = relay.reshape(kernel_IHWO, (in_channel, kh, kw,
out_channel // oc_bn, oc_bn))
+ kernel_OHWoI = relay.transpose(kernel_IHWOo, axes=(3, 1, 2, 4, 0))
+ kernel_OHWoIi = relay.reshape(
+ kernel_OHWoI, (out_channel // oc_bn, kh, kw, oc_bn, in_channel //
ic_bn, ic_bn)
+ )
+ kernel_OHWoIie = relay.reshape(
+ kernel_OHWoIi,
+ (out_channel // oc_bn, kh, kw, oc_bn, in_channel // ic_bn, ic_bn
// n_elems, n_elems),
+ )
+ kernel_OIHWioe = relay.transpose(kernel_OHWoIie, axes=(0, 4, 1, 2, 5,
3, 6))
Review comment:
done
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]