This is an automated email from the ASF dual-hosted git repository. mbaret pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 77e4fd1 [BYOC][ACL] Depthwise convolution support (#7206) 77e4fd1 is described below commit 77e4fd16bfb8175c83638870af9646d1027f0de7 Author: Dmitriy Smirnov <dmitriy.smir...@arm.com> AuthorDate: Mon Jan 11 14:26:05 2021 +0000 [BYOC][ACL] Depthwise convolution support (#7206) * [BYOC][ACL] Depthwise convolution support Added support for depthwise convolution. ACL only supports depth-wise convolution when kernel size is 3x3 and 5x5 and strides are (1, 1) or (2, 2), if this is not the case then fallback to TVM. Also rework tests to remove non-deterministic trials. *Compute Library for the Arm Architecture (ACL). *All credits to Luke Hutton @lhutton1 Change-Id: Ida1f5802a65377b84325edf14a0149242c1af857 * linter * CHECK -> ICHECK Co-authored-by: Luke Hutton <luke.hut...@arm.com> --- docs/deploy/arm_compute_lib.rst | 8 +- python/tvm/relay/op/contrib/arm_compute_lib.py | 109 +++++++- python/tvm/relay/testing/__init__.py | 6 +- .../backend/contrib/arm_compute_lib/codegen.cc | 48 ++-- src/runtime/contrib/arm_compute_lib/acl_runtime.cc | 69 ++++- src/runtime/contrib/arm_compute_lib/acl_utils.cc | 10 + src/runtime/contrib/arm_compute_lib/acl_utils.h | 9 + .../contrib/test_arm_compute_lib/infrastructure.py | 42 --- .../contrib/test_arm_compute_lib/test_conv2d.py | 283 +++++++++++++-------- .../contrib/test_arm_compute_lib/test_dense.py | 83 +++--- .../contrib/test_arm_compute_lib/test_network.py | 4 +- 11 files changed, 450 insertions(+), 221 deletions(-) diff --git a/docs/deploy/arm_compute_lib.rst b/docs/deploy/arm_compute_lib.rst index a2eaa5f..5d11241 100644 --- a/docs/deploy/arm_compute_lib.rst +++ b/docs/deploy/arm_compute_lib.rst @@ -15,7 +15,7 @@ specific language governing permissions and limitations under the License. -Relay Arm :sup:`®` Compute Library Integration +Relay Arm:sup:`®` Compute Library Integration ============================================== **Author**: `Luke Hutton <https://github.com/lhutton1>`_ @@ -195,12 +195,14 @@ Operator support | | Simple: nn.conv2d | | | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu? | | | | -| | (only groups = 1 supported) | +| | Normal and depth-wise (when kernel is 3x3 or 5x5 and strides are 1x1 | +| | or 2x2) convolution supported. Grouped convolution is not supported. | +----------------------+-------------------------------------------------------------------------+ | qnn.conv2d | uint8: | | | Composite: nn.pad?, nn.conv2d, nn.bias_add?, nn.relu?, qnn.requantize | | | | -| | (only groups = 1 supported) | +| | Normal and depth-wise (when kernel is 3x3 or 5x5 and strides are 1x1 | +| | or 2x2) convolution supported. Grouped convolution is not supported. | +----------------------+-------------------------------------------------------------------------+ | nn.dense | fp32: | | | Simple: nn.dense | diff --git a/python/tvm/relay/op/contrib/arm_compute_lib.py b/python/tvm/relay/op/contrib/arm_compute_lib.py index a78ad29..8a03cb1 100644 --- a/python/tvm/relay/op/contrib/arm_compute_lib.py +++ b/python/tvm/relay/op/contrib/arm_compute_lib.py @@ -19,12 +19,15 @@ import numpy as np import tvm +from tvm._ffi import register_func from tvm.relay.expr import const from tvm.relay import transform from tvm.relay.build_module import bind_params_by_name +from tvm.relay.testing.temp_op_attr import TempOpAttr from ...dataflow_pattern import wildcard, is_op, is_constant, is_expr from .register import register_pattern_table +from ..strategy.generic import is_depthwise_conv2d def is_arm_compute_runtime_enabled(): @@ -71,6 +74,61 @@ def partition_for_arm_compute_lib(mod, params=None): return seq(mod) +@register_func("relay.ext.arm_compute_lib.optimize") +def preprocess_module(mod): + """ + Pre-process a module containing functions ready for ACL codegen. For now we enforce OHWI + kernel layout and fold the transforms away. + + Parameters + ---------- + mod : Module + The module to run passes on. + + Returns + ------- + preprocessed_mod : The processed module. + """ + + def convert_layout_conv2d(conv2d_function): + def convert_conv(attrs, inputs, tinfos, desired_layouts): + new_attrs = dict(attrs) + data_info = tinfos[0] + weight_info = tinfos[1] + desired_data_layout, desired_kernel_layout = map(str, desired_layouts) + new_attrs["data_layout"] = desired_data_layout + new_attrs["kernel_layout"] = desired_kernel_layout + + if is_depthwise_conv2d( + data_info.shape, + attrs["data_layout"], + weight_info.shape, + attrs["kernel_layout"], + attrs["groups"], + ): + dkl = desired_kernel_layout + new_attrs["kernel_layout"] = dkl[3] + dkl[1:3] + dkl[0] + return conv2d_function(*inputs, **new_attrs) + + return convert_conv + + with TempOpAttr( + "nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d) + ), TempOpAttr( + "qnn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.qnn.op.conv2d) + ): + seq = tvm.transform.Sequential( + [ + transform.ConvertLayout( + {"nn.conv2d": ["NHWC", "OHWI"], "qnn.conv2d": ["NHWC", "OHWI"]} + ), + transform.FoldConstant(), + ] + ) + preprocessed_mod = seq(mod) + return preprocessed_mod + + @register_pattern_table("arm_compute_lib") def arm_compute_lib_pattern_table(): """Get the ACL pattern table.""" @@ -236,8 +294,6 @@ _register_external_op_helper("reshape") def conv2d(expr): """Check if the external ACL codegen for conv2d should be used.""" attrs, args = expr.attrs, expr.args - if attrs.groups != 1: - return False if attrs.data_layout != "NHWC": return False if attrs.out_dtype != "float32" and attrs.out_dtype != "": @@ -248,14 +304,25 @@ def conv2d(expr): kernel_typ = args[1].checked_type if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "float32": return False + is_depthwise = is_depthwise_conv2d( + data_typ.shape, + attrs["data_layout"], + kernel_typ.shape, + attrs["kernel_layout"], + attrs["groups"], + ) + if is_depthwise: + return depthwise_conv2d(attrs, args) + # ACL doesn't support grouped convolution + if attrs.groups != 1 and not is_depthwise: + return False return True def qnn_conv2d(expr): """Check if the external ACL codegen for qnn.conv2d should be used.""" attrs, args = expr.attrs, expr.args - if attrs.groups != 1: - return False + if attrs.data_layout != "NHWC": return False if attrs.out_dtype != "int32" and attrs.out_dtype != "": @@ -266,6 +333,40 @@ def qnn_conv2d(expr): kernel_typ = args[1].checked_type if len(kernel_typ.shape) != 4 or kernel_typ.dtype != "uint8": return False + is_depthwise = is_depthwise_conv2d( + data_typ.shape, + attrs["data_layout"], + kernel_typ.shape, + attrs["kernel_layout"], + attrs["groups"], + ) + if is_depthwise: + return depthwise_conv2d(attrs, args) + # ACL doesn't support grouped convolution + if attrs.groups != 1 and not is_depthwise: + return False + return True + + +def depthwise_conv2d(attrs, args): + """Check if the external ACL codegen for depthwise convolution should be used. + + Note + ---- + Relay does not have a depthwise conv2d operator whilst ACL does. We simply + separate the checks for depthwise for clarity. + """ + kernel_typ = args[1].checked_type + # Only supports 3x3, 5x5 depthwise + if ( + kernel_typ.shape[0] not in [3, 5] + or kernel_typ.shape[1] not in [3, 5] + or kernel_typ.shape[0] != kernel_typ.shape[1] + ): + return False + # Stride must be (1, 1) or (2, 2) + if (attrs.strides[0], attrs.strides[1]) not in [(1, 1), (2, 2)]: + return False return True diff --git a/python/tvm/relay/testing/__init__.py b/python/tvm/relay/testing/__init__.py index 0b81cb9..f0c79be 100644 --- a/python/tvm/relay/testing/__init__.py +++ b/python/tvm/relay/testing/__init__.py @@ -22,9 +22,9 @@ import numpy as np import tvm from tvm import te -import tvm.relay as relay -import tvm.relay.op as op -from tvm.relay import Prelude +from tvm import relay +from tvm.relay import op +from tvm.relay.prelude import Prelude from tvm.testing import enabled_targets from . import mlp diff --git a/src/relay/backend/contrib/arm_compute_lib/codegen.cc b/src/relay/backend/contrib/arm_compute_lib/codegen.cc index a963242..e0669ae 100644 --- a/src/relay/backend/contrib/arm_compute_lib/codegen.cc +++ b/src/relay/backend/contrib/arm_compute_lib/codegen.cc @@ -24,6 +24,7 @@ #include <tvm/ir/module.h> #include <tvm/relay/attrs/nn.h> #include <tvm/relay/type.h> +#include <tvm/tir/analysis.h> #include <memory> #include <string> @@ -126,7 +127,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { nodes.activation = current_call; current_call = current_call->args[0].as<CallNode>(); } - if (backend::IsOp(current_call, "nn.bias_add")) { + if (backend::IsOp(current_call, "add")) { nodes.bias = current_call; current_call = current_call->args[0].as<CallNode>(); } @@ -154,19 +155,32 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { */ std::shared_ptr<JSONGraphNode> CreateCompositeConvJSONNode(const CallNode* cn) { CompositeConvNode nodes = UnpackCompositeConvolution(cn); - std::string name = "nn.conv2d"; const auto* conv_attr = nodes.conv->attrs.as<Conv2DAttrs>(); ICHECK(conv_attr); - ICHECK(conv_attr->kernel_layout == "OHWI") - << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; + + std::string name; + std::string name_prefix = "nn"; + + // Distinguish between normal and depth-wise convolution + if (conv_attr->channels.defined() && + tvm::tir::ExprDeepEqual()(conv_attr->channels, conv_attr->groups) && + conv_attr->groups != 1) { + name = "depthwise_conv2d"; + ICHECK(conv_attr->kernel_layout == "IHWO") + << "Kernel layout must be IHWO, has the module been pre-processed correctly?"; + } else { + name = "conv2d"; + ICHECK(conv_attr->kernel_layout == "OHWI") + << "Kernel layout must be OHWI, has the module been pre-processed correctly?"; + } // Inputs must be added in the same order they appear in the relay graph. std::vector<JSONGraphNodeEntry> inputs; inputs.push_back(VisitExpr(cn->args[0])[0]); inputs.push_back(VisitExpr(nodes.conv->args[1])[0]); if (nodes.requantize) { - name = "qnn.conv2d"; + name_prefix = "qnn"; inputs.push_back(VisitExpr(nodes.conv->args[2])[0]); // input zero-point inputs.push_back(VisitExpr(nodes.conv->args[3])[0]); // kernel zero-point inputs.push_back(VisitExpr(nodes.conv->args[4])[0]); // input scale @@ -180,7 +194,7 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { inputs.push_back(VisitExpr(nodes.requantize->args[4])[0]); // output zero-point } - auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1); + auto json_node = std::make_shared<JSONGraphNode>(name_prefix + "." + name, "kernel", inputs, 1); SetCallNodeAttribute(json_node, nodes.conv); // Override attributes @@ -224,10 +238,11 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { nodes.requantize = current_call; current_call = current_call->args[0].as<CallNode>(); } - if (backend::IsOp(current_call, "nn.bias_add")) { + if (backend::IsOp(current_call, "add")) { nodes.bias = current_call; current_call = current_call->args[0].as<CallNode>(); } + // Enforce a dense node exists at this point during traversal if (nodes.requantize) { ICHECK(backend::IsOp(current_call, "qnn.dense")); @@ -330,25 +345,6 @@ class ACLJSONSerializer : public backend::contrib::JSONSerializer { }; /*! - * \brief Pre-process a module containing functions ready for ACL codegen. - * - * For now we enforce OHWI kernel layout and fold the transforms away. - * - * \param mod The module to be pre-processed. - * \return The processed module. - */ -IRModule PreProcessModule(const IRModule& mod) { - IRModule preprocessed_module; - tvm::Map<String, Array<String>> desired_layouts = {{"nn.conv2d", {"NHWC", "OHWI"}}, - {"qnn.conv2d", {"NHWC", "OHWI"}}}; - preprocessed_module = transform::ConvertLayout(desired_layouts)(mod); - preprocessed_module = transform::FoldConstant()(preprocessed_module); - return preprocessed_module; -} - -TVM_REGISTER_GLOBAL("relay.ext.arm_compute_lib.optimize").set_body_typed(PreProcessModule); - -/*! * \brief Create a runtime module for ACL. * * This consists of a series of "serialized functions" which each represent a diff --git a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc index 09879bd..ed8f6ad 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_runtime.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_runtime.cc @@ -32,6 +32,7 @@ #include <arm_compute/core/Types.h> #include <arm_compute/runtime/NEON/functions/NEArithmeticAddition.h> #include <arm_compute/runtime/NEON/functions/NEConvolutionLayer.h> +#include <arm_compute/runtime/NEON/functions/NEDepthwiseConvolutionLayer.h> #include <arm_compute/runtime/NEON/functions/NEElementwiseOperations.h> #include <arm_compute/runtime/NEON/functions/NEFullyConnectedLayer.h> #include <arm_compute/runtime/NEON/functions/NEPoolingLayer.h> @@ -131,6 +132,9 @@ class ACLRuntime : public JSONRuntimeBase { if ("nn.conv2d" == op_name || "qnn.conv2d" == op_name) { CreateConvolution2DLayer(&layer_, node, mm); num_pools++; + } else if ("nn.depthwise_conv2d" == op_name || "qnn.depthwise_conv2d" == op_name) { + CreateDepthwiseConvolution2DLayer(&layer_, node, mm); + num_pools++; } else if ("nn.dense" == op_name || "qnn.dense" == op_name) { CreateFullyConnectedLayer(&layer_, node, mm); num_pools++; @@ -227,12 +231,7 @@ class ACLRuntime : public JSONRuntimeBase { arm_compute::ActivationLayerInfo act_info; if (node.HasAttr("activation_type")) { std::string activation_type = node.GetAttr<std::vector<std::string>>("activation_type")[0]; - if (activation_type == "relu") { - act_info = arm_compute::ActivationLayerInfo( - arm_compute::ActivationLayerInfo::ActivationFunction::RELU); - } else { - LOG(FATAL) << "Unsupported activation function"; - } + act_info = MakeACLActivationInfo(activation_type); } arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1])); @@ -270,6 +269,64 @@ class ACLRuntime : public JSONRuntimeBase { } /*! + * \brief Create a 2D depthwise convolution layer. + * + * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. + * \param node The JSON representation of the operator. + * \param mm The ACL conv2d layer can request auxiliary memory from TVM. + */ + void CreateDepthwiseConvolution2DLayer( + CachedLayer* layer, const JSONGraphNode& node, + const std::shared_ptr<arm_compute::MemoryManagerOnDemand>& mm) { + std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding"); + std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides"); + std::vector<std::string> dilation = node.GetAttr<std::vector<std::string>>("dilation"); + arm_compute::PadStrideInfo pad_stride_info = MakeACLPadStride(padding, strides); + + arm_compute::ActivationLayerInfo act_info; + if (node.HasAttr("activation_type")) { + std::string activation_type = node.GetAttr<std::vector<std::string>>("activation_type")[0]; + act_info = MakeACLActivationInfo(activation_type); + } + + arm_compute::Size2D dilation_2d(std::stoi(dilation[0]), std::stoi(dilation[1])); + + // Collect inputs and outputs, handling both nn.conv2d and qnn.conv2d cases. + std::vector<JSONGraphNodeEntry> inputs = node.GetInputs(); + size_t num_inputs = inputs.size(); + bool has_bias; + if (node.GetOpName() == "qnn.depthwise_conv2d") { + ICHECK(num_inputs >= 8U && num_inputs <= 9U) + << "Quantized convolution requires 9 inputs with a bias, 8 inputs without."; + has_bias = num_inputs == 9; + layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[0], &inputs[4], &inputs[2])); + layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[1], &inputs[5], &inputs[3])); + if (has_bias) { + layer->inputs.push_back(MakeACLTensorFromJSONEntry(inputs[6])); + } + layer->outputs.push_back( + MakeACLTensorFromJSONNode(node, &inputs[6 + has_bias], &inputs[7 + has_bias])); + } else { + ICHECK(num_inputs >= 2U && num_inputs <= 3U) + << "Convolution requires 3 inputs with a bias, 2 inputs without."; + has_bias = num_inputs == 3; + for (const auto& i : inputs) { + layer->inputs.push_back(MakeACLTensorFromJSONEntry(i)); + } + layer->outputs.push_back(MakeACLTensorFromJSONNode(node)); + } + + // Depth multiplier is the final dimension in acl weights tensor (IWH*M*) + int depth_multiplier = layer->inputs[1].info()->tensor_shape()[3]; + + auto function = std::make_shared<arm_compute::NEDepthwiseConvolutionLayer>(mm); + function->configure(&layer->inputs[0], &layer->inputs[1], + has_bias ? &layer->inputs[2] : nullptr, &layer->outputs[0], pad_stride_info, + depth_multiplier, act_info, dilation_2d); + layer->function = function; + } + + /*! * \brief Create a fully connected (dense) layer. * * \param layer The ACL layer to build. Containing inputs, outputs and the ACL function. diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.cc b/src/runtime/contrib/arm_compute_lib/acl_utils.cc index 604c619..3b26209 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.cc +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.cc @@ -134,6 +134,16 @@ arm_compute::DataType MakeACLDataType(const DLDataType& data_type) { } } +arm_compute::ActivationLayerInfo MakeACLActivationInfo(const std::string& activation_type) { + auto act_func = arm_compute::ActivationLayerInfo::ActivationFunction::IDENTITY; + if (activation_type == "relu") { + act_func = arm_compute::ActivationLayerInfo::ActivationFunction::RELU; + } else { + LOG(FATAL) << "Activation " << activation_type << " unsupported by ACL runtime"; + } + return {act_func}; +} + template <typename T> std::vector<T> GetVectorFromDLTensor(const DLTensor* tensor) { ICHECK(tensor) << "Cannot convert a nullptr"; diff --git a/src/runtime/contrib/arm_compute_lib/acl_utils.h b/src/runtime/contrib/arm_compute_lib/acl_utils.h index 576ed91..dbb006f 100644 --- a/src/runtime/contrib/arm_compute_lib/acl_utils.h +++ b/src/runtime/contrib/arm_compute_lib/acl_utils.h @@ -109,6 +109,15 @@ arm_compute::PadStrideInfo MakeACLPadStride(const std::vector<std::string>& pad, arm_compute::DataType MakeACLDataType(const DLDataType& data_type); /*! + * \brief Convert string to arm_compute::ActivationLayerInfo + * + * \param activation_type A string representing activation function. + * Currently supports the following options: "relu". + * \return arm_compute::ActivationLayerInfo. + */ +arm_compute::ActivationLayerInfo MakeACLActivationInfo(const std::string& activation_type); + +/*! * \brief Get a vector from DLTensor data. * \note Performs a copy of data. * diff --git a/tests/python/contrib/test_arm_compute_lib/infrastructure.py b/tests/python/contrib/test_arm_compute_lib/infrastructure.py index c5d711d..80cd584 100644 --- a/tests/python/contrib/test_arm_compute_lib/infrastructure.py +++ b/tests/python/contrib/test_arm_compute_lib/infrastructure.py @@ -303,45 +303,3 @@ def verify_codegen( f"Actual={codegen_str} \n" f"Expected={known_good_codegen_str}" ) - - -def generate_trials(space, r_factor=3): - """Generates a series of trials. - - This algorithm generates a series of non-deterministic trials given a - space of options to test. A trial is generated by pulling a value from - each option in the space. On some occasions the values are shuffled to - ensure a different trial on each r_factor iteration. The algorithm ensures - that each value from an option is used at least once. The total number of - trials is determined by the r_factor * the option with the largest number - of values. - - Parameters - ---------- - space: List[List[Any]] - A list of different options with varying values to test. - r_factor: (optional) int - The repeat factor. - - Returns - ------- - A list of trials specifying values for each option. - - """ - np.random.seed(0) - max_len = 1 - for option in space: - max_len = max(max_len, len(option)) - - num_trials = r_factor * max_len - trials = [] - for i in range(num_trials): - trial = [] - for option in space: - if i % len(option) == 0: - np.random.shuffle(option) - trial.append(option[i % len(option)]) - - trials.append(trial) - - return trials diff --git a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py index 4496a2a..cc5bbfe 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_conv2d.py +++ b/tests/python/contrib/test_arm_compute_lib/test_conv2d.py @@ -21,15 +21,14 @@ import numpy as np import tvm from tvm import relay -from .infrastructure import ( +from test_arm_compute_lib.infrastructure import ( skip_runtime_test, skip_codegen_test, build_and_run, verify, verify_codegen, - generate_trials, ) -from .infrastructure import Device +from test_arm_compute_lib.infrastructure import Device def _get_model( @@ -57,7 +56,12 @@ def _get_model( if len(padding) == 2: padding = (padding[0], padding[1], padding[0], padding[1]) shape = (shape[0], shape[1] + padding[0] * 2, shape[2] + padding[1] * 2, shape[3]) - weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels) + is_depthwise = shape[3] == channels == groups + weight_format = "HWOI" if is_depthwise else "HWIO" + if weight_format == "HWIO": + weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels) + else: + weight_shape = (kernel_h, kernel_w, channels, shape[3] // groups) w = tvm.nd.array(np.random.uniform(-128, 127, weight_shape).astype(dtype)) weights = relay.const(w, dtype) out = relay.nn.conv2d( @@ -65,7 +69,7 @@ def _get_model( weights, kernel_size=(kernel_h, kernel_w), data_layout="NHWC", - kernel_layout="HWIO", + kernel_layout=weight_format, dilation=dilation, strides=strides, padding=padding, @@ -75,7 +79,8 @@ def _get_model( ) params = {"w": w} if has_bias: - b = tvm.nd.array(np.random.uniform(-128, 127, weight_shape[3]).astype(dtype)) + bias_shape = weight_shape[2] if is_depthwise else weight_shape[3] + b = tvm.nd.array(np.random.uniform(-128, 127, bias_shape).astype(dtype)) biasc = relay.const(b, dtype) out = relay.nn.bias_add(out, biasc, axis=3) params["b"] = b @@ -134,7 +139,12 @@ def _get_qnn_model( if len(padding) == 2: padding = (padding[0], padding[1], padding[0], padding[1]) shape = (shape[0], shape[1] + padding[0] * 2, shape[2] + padding[1] * 2, shape[3]) - weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels) + is_depthwise = shape[3] == channels == groups + weight_format = "HWOI" if is_depthwise else "HWIO" + if weight_format == "HWIO": + weight_shape = (kernel_h, kernel_w, shape[3] // groups, channels) + else: + weight_shape = (kernel_h, kernel_w, channels, shape[3] // groups) w = tvm.nd.array(np.random.uniform(0, 255, weight_shape).astype(dtype)) weights = relay.const(w, dtype) out = relay.qnn.op.conv2d( @@ -146,7 +156,7 @@ def _get_qnn_model( kernel_scale=relay.const(kernel_sc, "float32"), kernel_size=(kernel_h, kernel_w), data_layout="NHWC", - kernel_layout="HWIO", + kernel_layout=weight_format, dilation=dilation, strides=strides, padding=padding, @@ -156,7 +166,8 @@ def _get_qnn_model( ) params = {"w": w} if has_bias: - b = tvm.nd.array(np.random.uniform(0, 255, weight_shape[3]).astype("int32")) + bias_shape = weight_shape[2] if is_depthwise else weight_shape[3] + b = tvm.nd.array(np.random.uniform(-128, 127, bias_shape).astype("int32")) biasc = relay.const(b, "int32") out = relay.nn.bias_add(out, biasc, axis=3) params["b"] = b @@ -188,21 +199,30 @@ def _get_expected_codegen( ): if len(padding) == 2: padding = (padding[0], padding[1], padding[0], padding[1]) - weight_shape = (channels, kernel_h, kernel_w, shape[3] // groups) output_height = ((shape[1] - kernel_h + padding[0] + padding[2]) / strides[0]) + 1 output_width = ((shape[2] - kernel_w + padding[1] + padding[3]) / strides[1]) + 1 output_shape = (1, int(output_height), int(output_width), channels) out_dtype = "int32" if dtype == "uint8" else "float32" + is_depthwise = shape[3] == channels == groups + weight_format = "IHWO" if is_depthwise else "OHWI" + if weight_format == "IHWO": + weight_shape = (shape[3] // groups, kernel_h, kernel_w, channels) + else: + weight_shape = (channels, kernel_h, kernel_w, shape[3] // groups) + if is_depthwise: + name = "nn.depthwise_conv2d" + else: + name = "nn.conv2d" node = { "op": "kernel", - "name": "nn.conv2d", + "name": name, "inputs": [], "attrs": { - "groups": [["1"]], + "groups": [[str(groups)]], "num_outputs": "1", "data_layout": [["NHWC"]], - "kernel_layout": [["OHWI"]], + "kernel_layout": [[weight_format]], "channels": [[str(channels)]], "dilation": [[str(dilation[0]), str(dilation[1])]], "out_layout": [[""]], @@ -229,7 +249,7 @@ def _get_expected_codegen( # qnn.conv2d params, input and kernel if dtype == "uint8": - node["name"] = "qnn.conv2d" + node["name"] = "qnn." + node["name"].split(".")[1] for param_dtype in ["int32", "float32"]: for _ in range(2): inputs.append( @@ -246,7 +266,10 @@ def _get_expected_codegen( { "op": "const", "name": "", - "attrs": {"shape": [[[weight_shape[0]]]], "dtype": [[bias_dtype]]}, + "attrs": { + "shape": [[[1, 1, 1, weight_shape[3] if is_depthwise else weight_shape[0]]]], + "dtype": [[bias_dtype]], + }, } ) @@ -275,29 +298,43 @@ def test_conv2d(): device = Device() np.random.seed(0) - kernel_hs = [1, 2, 3, 5] - kernel_ws = [1, 2, 3, 5] - pad = [(1, 1), (2, 2), (2, 1)] - strides = [(1, 1), (2, 2)] - dilation = [(1, 1)] - out_channels = [4, 7, 16] - input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] - # composite operator (pad, bias, activation) - composite = [ - (False, False, False), - (False, True, False), - (False, False, True), - (False, True, True), - (True, False, False), - ] dtype = "float32" - trials = generate_trials( - [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3 - ) + trials = [ + # Normal convolution + [2, 2, (1, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (12, 15, 16), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (12, 15, 16), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (10, 10, 14), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (10, 10, 14), (True, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (10, 10, 14), (False, True, True), True], + ] - for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials: - groups = 1 - shape = (1, *input_shapes) + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + is_depthwise, + ) in trials: + shape = (1, *shape) + if is_depthwise: + groups = shape[3] + else: + groups = 1 outputs = [] inputs = { "a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype)), @@ -338,31 +375,43 @@ def test_codegen_conv2d(): if skip_codegen_test(): return - np.random.seed(0) - - kernel_hs = [1, 2, 3, 5] - kernel_ws = [1, 2, 3, 5] - pad = [(1, 1), (2, 2), (2, 1)] - strides = [(1, 1), (2, 2)] - dilation = [(1, 1)] - out_channels = [4, 7, 16] - input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] - # composite operator (pad, bias, activation) - composite = [ - (False, False, False), - (False, True, False), - (False, False, True), - (False, True, True), - (True, False, False), - ] dtype = "float32" - trials = generate_trials( - [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3 - ) + trials = [ + # Normal convolution + [2, 2, (1, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (12, 15, 16), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (12, 15, 16), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (10, 10, 14), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (10, 10, 14), (True, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (10, 10, 14), (False, True, True), True], + ] - for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials: - groups = 1 - shape = (1, *input_shapes) + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + is_depthwise, + ) in trials: + shape = (1, *shape) + if is_depthwise: + groups = shape[3] + else: + groups = 1 inputs = {"a"} args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels) @@ -389,29 +438,43 @@ def test_qnn_conv2d(): device = Device() np.random.seed(0) - kernel_hs = [1, 2, 3, 5] - kernel_ws = [1, 2, 3, 5] - pad = [(1, 1), (2, 2)] - strides = [(1, 1), (2, 2)] - dilation = [(1, 1)] - out_channels = [4, 7, 16] - input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] - # composite operator (pad, bias, activation) - composite = [ - (False, False, False), - (False, True, False), - (False, False, True), - (False, True, True), - (True, False, False), - ] dtype = "uint8" - trials = generate_trials( - [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3 - ) + trials = [ + # Normal convolution + [2, 2, (1, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (12, 15, 16), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (12, 15, 16), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (10, 10, 14), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (10, 10, 14), (True, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (10, 10, 14), (False, True, True), True], + ] - for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials: - groups = 1 - shape = (1, *input_shapes) + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + is_depthwise, + ) in trials: + shape = (1, *shape) + if is_depthwise: + groups = shape[3] + else: + groups = 1 outputs = [] inputs = {"a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))} @@ -463,36 +526,52 @@ def test_qnn_conv2d(): "output scale": output_sc, "output zero point": output_zp, } - verify(outputs, atol=1, rtol=0, config=config, verify_saturation=True) + + atol = 2 if is_depthwise else 1 + verify(outputs, atol=atol, rtol=0, config=config, verify_saturation=True) def test_codegen_qnn_conv2d(): if skip_codegen_test(): return - kernel_hs = [1, 2, 3, 5] - kernel_ws = [1, 2, 3, 5] - pad = [(1, 1), (2, 2), (2, 1)] - strides = [(1, 1), (2, 2)] - dilation = [(1, 1)] - out_channels = [4, 7, 16] - input_shapes = [(10, 10, 14), (12, 15, 16), (20, 20, 20)] - # composite operator (pad, bias, activation) - composite = [ - (False, False, False), - (False, True, False), - (False, False, True), - (False, True, True), - (True, False, False), - ] dtype = "uint8" - trials = generate_trials( - [kernel_hs, kernel_ws, pad, strides, dilation, out_channels, input_shapes, composite], 3 - ) + trials = [ + # Normal convolution + [2, 2, (1, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, False, False), False], + [2, 1, (2, 2), (1, 1), (1, 1), 7, (12, 15, 16), (False, False, True), False], + [3, 3, (2, 1), (1, 1), (1, 1), 4, (10, 10, 14), (False, True, False), False], + [3, 3, (1, 1), (1, 1), (1, 1), 16, (12, 15, 16), (False, False, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False], + [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False], + [5, 5, (1, 1), (2, 2), (1, 1), 4, (10, 10, 14), (True, False, False), False], + [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False], + [3, 3, (1, 1), (2, 2), (1, 1), 16, (10, 10, 14), (False, True, True), False], + # Depth-wise convolution + [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True], + [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True], + [3, 3, (2, 2), (2, 2), (1, 1), 14, (10, 10, 14), (True, False, False), True], + [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True], + [3, 3, (1, 1), (2, 2), (1, 1), 14, (10, 10, 14), (False, True, True), True], + ] - for kernel_h, kernel_w, pad, stride, dilation, out_channels, input_shapes, composite in trials: - groups = 1 - shape = (1, *input_shapes) + for ( + kernel_h, + kernel_w, + pad, + stride, + dilation, + out_channels, + shape, + composite, + is_depthwise, + ) in trials: + shape = (1, *shape) + if is_depthwise: + groups = shape[3] + else: + groups = 1 inputs = {"a"} input_zp = 100 diff --git a/tests/python/contrib/test_arm_compute_lib/test_dense.py b/tests/python/contrib/test_arm_compute_lib/test_dense.py index 0279aa7..dba7be6 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_dense.py +++ b/tests/python/contrib/test_arm_compute_lib/test_dense.py @@ -28,7 +28,6 @@ from test_arm_compute_lib.infrastructure import ( build_and_run, verify, verify_codegen, - generate_trials, ) @@ -184,17 +183,19 @@ def test_dense(): device = Device() np.random.seed(0) - dtype = ["float32"] - shape = [ - (1, (1, 128), (16, 128), 16), - (1, (32, 32), (32, 32), 32), - (0, (1, 64), (1, 64), 1), - (0, (11, 2), (2, 2), 2), + dtype = "float32" + trials = [ + [(1, 128), (16, 128), 16, True, 1], + [(1, 128), (16, 128), 16, False, 1], + [(32, 32), (32, 32), 32, True, 1], + [(32, 32), (32, 32), 32, False, 1], + [(1, 64), (1, 64), 1, True, 0], + [(1, 64), (1, 64), 1, False, 0], + [(11, 2), (2, 2), 2, True, 0], + [(11, 2), (2, 2), 2, False, 0], ] - composite = [False, True] - trials = generate_trials([dtype, shape, composite], 3) - for dtype, (acl_partitions, shape, weight_shape, units), composite in trials: + for shape, weight_shape, units, composite, acl_partitions in trials: outputs = [] inputs = {"a": tvm.nd.array(np.random.uniform(-128, 127, shape).astype(dtype))} func, params = _get_model( @@ -230,19 +231,26 @@ def test_codegen_dense(): np.random.seed(0) - dtype = ["float32"] - shape = [(1, (1, 128), (16, 128), 16), (1, (32, 32), (32, 32), 32), (0, (1, 64), (1, 64), 1)] - composite = [False, True] - trials = generate_trials([dtype, shape, composite], 3) + dtype = "float32" + trials = [ + [(1, 128), (16, 128), 16, True, 1], + [(1, 128), (16, 128), 16, False, 1], + [(32, 32), (32, 32), 32, True, 1], + [(32, 32), (32, 32), 32, False, 1], + [(1, 64), (1, 64), 1, True, 0], + [(1, 64), (1, 64), 1, False, 0], + ] - for dtype, (acl_partitions, shape, weight_shape, units), composite in trials: + for shape, weight_shape, units, composite, acl_partitions in trials: inputs = {"a"} args = (shape, weight_shape, units, dtype) func, params = _get_model(*args, var_names=iter(inputs), has_bias=composite) exp_codegen = _get_expected_codegen(*args, has_bias=composite) - verify_codegen(func, exp_codegen, acl_partitions, 1 - acl_partitions) + verify_codegen( + func, exp_codegen, acl_partitions, (1 - acl_partitions) * (2 - int(not composite)) + ) def test_qnn_dense(): @@ -254,19 +262,21 @@ def test_qnn_dense(): device = Device() np.random.seed(0) - dtype = ["uint8"] - shape = [ - (0, (4, 4), (4, 4), 4), - (1, (16, 16), (4, 16), 4), - (1, (1, 128), (16, 128), 16), - (1, (32, 32), (32, 32), 32), - (0, (1, 64), (1, 64), 1), + dtype = "uint8" + trials = [ + [(4, 4), (4, 4), 4, True, 0], + [(4, 4), (4, 4), 4, False, 0], + [(16, 16), (4, 16), 4, True, 1], + [(16, 16), (4, 16), 4, False, 1], + [(1, 128), (16, 128), 16, True, 1], + [(1, 128), (16, 128), 16, False, 1], + [(32, 32), (32, 32), 32, True, 1], + [(32, 32), (32, 32), 32, False, 1], + [(1, 64), (1, 64), 1, True, 0], + [(1, 64), (1, 64), 1, False, 0], ] - composite = [False, True] - trials = generate_trials([dtype, shape, composite], 3) - - for dtype, (acl_partitions, shape, weight_shape, units), composite in trials: + for shape, weight_shape, units, composite, acl_partitions in trials: outputs = [] inputs = {"a": tvm.nd.array(np.random.uniform(0, 255, shape).astype(dtype))} input_zp = 100 @@ -328,12 +338,17 @@ def test_codegen_qnn_dense(): np.random.seed(0) - dtype = ["uint8"] - shape = [(1, (1, 128), (16, 128), 16), (1, (32, 32), (32, 32), 32), (0, (1, 64), (1, 64), 1)] - composite = [False, True] - trials = generate_trials([dtype, shape, composite], 3) + dtype = "uint8" + trials = [ + [(1, 128), (16, 128), 16, True, 1], + [(1, 128), (16, 128), 16, False, 1], + [(32, 32), (32, 32), 32, True, 1], + [(32, 32), (32, 32), 32, False, 1], + [(1, 64), (1, 64), 1, True, 0], + [(1, 64), (1, 64), 1, False, 0], + ] - for dtype, (acl_partitions, shape, weight_shape, units), composite in trials: + for shape, weight_shape, units, composite, acl_partitions in trials: inputs = {"a"} args = (shape, weight_shape, units, dtype) @@ -357,7 +372,9 @@ def test_codegen_qnn_dense(): has_bias=composite, ) exp_codegen = _get_expected_codegen(*args, has_bias=composite) - verify_codegen(func, exp_codegen, acl_partitions, 2 - 2 * acl_partitions) + verify_codegen( + func, exp_codegen, acl_partitions, (1 - acl_partitions) * (3 - int(not composite)) + ) if __name__ == "__main__": diff --git a/tests/python/contrib/test_arm_compute_lib/test_network.py b/tests/python/contrib/test_arm_compute_lib/test_network.py index 898446b..462df14 100644 --- a/tests/python/contrib/test_arm_compute_lib/test_network.py +++ b/tests/python/contrib/test_arm_compute_lib/test_network.py @@ -123,7 +123,7 @@ def test_mobilenet(): return mod, params, inputs _build_and_run_network( - *get_model(), device=device, tvm_ops=73, acl_partitions=18, atol=0.002, rtol=0.01 + *get_model(), device=device, tvm_ops=56, acl_partitions=31, atol=0.002, rtol=0.01 ) @@ -148,7 +148,7 @@ def test_quantized_mobilenet(): return mod, params, inputs _build_and_run_network( - *get_model(), device=device, tvm_ops=42, acl_partitions=17, atol=8, rtol=0 + *get_model(), device=device, tvm_ops=3, acl_partitions=30, atol=9, rtol=0 )