This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 517c420d7b [TOPI][ADRENO] Add Group Conv2d texture schedule (#17274)
517c420d7b is described below

commit 517c420d7b89029638926f10bbe9bed27f23bb5f
Author: krishnaraj36 <quic_kvegi...@quicinc.com>
AuthorDate: Mon Aug 19 18:22:45 2024 +0530

    [TOPI][ADRENO] Add Group Conv2d texture schedule (#17274)
    
    * Added Support for Adreno Texture Based Group Convolution
    
    * Added Few Testcases and Fixed Compute
    
    * Limited Support for Group Convolution
    
    * Removed Dead Code, Fixed Minor Issues
    
    ---------
    
    Co-authored-by: Sanjay Shankar Krishnaa <sa...@qti.qualcomm.com>
---
 python/tvm/relay/op/strategy/adreno.py             |  31 +-
 python/tvm/topi/adreno/__init__.py                 |   1 +
 python/tvm/topi/adreno/group_conv2d_nchw.py        | 386 +++++++++++++++++++++
 .../test_group_conv2d_nchw_texture.py              | 208 +++++++++++
 4 files changed, 625 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/strategy/adreno.py 
b/python/tvm/relay/op/strategy/adreno.py
index bacace9ad4..99e4d0a405 100644
--- a/python/tvm/relay/op/strategy/adreno.py
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -182,8 +182,37 @@ def conv2d_strategy_adreno(attrs, inputs, out_type, 
target):
                     + kernel_layout
                     + ") - only support NCHW4c / OIHW4o and NHWC / HWOI 
layouts for conv2d"
                 )
+        elif (data_layout == "NCHW4c" or data_layout == "NCHW") and (
+            kernel_layout == "OIHW" or kernel_layout == "OIHW4o"
+        ):
+            pad_in_chunks = (len(data.shape) == 5 and data.shape[1] % groups 
!= 0) or (
+                len(data.shape) == 4 and data.shape[1] % (groups * 4) != 0
+            )
+            pad_out_chunks = (len(kernel.shape) == 5 and kernel.shape[0] % 
groups != 0) or (
+                len(kernel.shape) == 4 and kernel.shape[0] % (groups * 4) != 0
+            )
+
+            if not (pad_in_chunks or pad_out_chunks):
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.adreno.group_conv2d_nchwc),
+                    
wrap_topi_schedule(topi.adreno.schedule_group_conv2d_nchwc),
+                    name="group_conv2d_nchwc.image2d",
+                    plevel=10,
+                )
+            elif len(data.shape) == 4 and len(kernel.shape) == 4:
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, 
has_groups=True),
+                    wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw),
+                    name="group_conv2d_nchw.cuda",
+                )
+            else:
+                raise RuntimeError(
+                    "General group convolution is not currently supported for 
NCHWc layouts"
+                )
         else:
-            raise RuntimeError("General group convolution is not currently 
supported")
+            raise RuntimeError(
+                "General group convolution has limited support for NCHW(4c) 
layouts..."
+            )
     return strategy
 
 
diff --git a/python/tvm/topi/adreno/__init__.py 
b/python/tvm/topi/adreno/__init__.py
index cd42848b29..2c0ed20f10 100644
--- a/python/tvm/topi/adreno/__init__.py
+++ b/python/tvm/topi/adreno/__init__.py
@@ -20,6 +20,7 @@
 from .conv2d_nchw import *
 from .depthwise_conv2d_nchw import *
 from .conv2d_nhwc import *
+from .group_conv2d_nchw import *
 from .depthwise_conv2d_nhwc import *
 from .pooling import *
 from .conv2d_alter_op import *
diff --git a/python/tvm/topi/adreno/group_conv2d_nchw.py 
b/python/tvm/topi/adreno/group_conv2d_nchw.py
new file mode 100644
index 0000000000..f1ab7fcf0e
--- /dev/null
+++ b/python/tvm/topi/adreno/group_conv2d_nchw.py
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+
+"""Group Conv2d NCHW Operator wt Schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    expand_spatial_dimensions,
+    add_pad,
+    bind_data_copy,
+    get_default_conv2d_config,
+    get_texture_storage,
+)
+
+
+@autotvm.register_topi_schedule("group_conv2d_nchwc.image2d")
+def schedule_group_conv2d_nchwc(cfg, outs):
+    """Create the schedule for group_conv2d_nchw"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "adreno_group_conv2d_latest_op":
+            schedule_group_conv2d_NCHWc_KCRSk(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+@autotvm.register_topi_compute("group_conv2d_nchwc.image2d")
+def group_conv2d_nchwc(cfg, Input, Filter, stride, padding, dilation, 
out_dtype):
+    """
+    Group Convolution Operator in NCHWc layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such 
conversion
+         will be absent for real blocked convolution, no sense to include into 
tuning
+      2. Expand spatial dimensions to have width and height be dividable by 
factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU 
much better
+      3. Add paddings. This happens even if we do not need pad originaly. This 
is useful
+         due to work surrounding the gaps of texture annotation between 
Primary Functions
+         and limited support of textures in schedules. Later on this pad will 
be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+
+    if out_dtype is None:
+        out_dtype = Input.dtype
+
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_channels, in_height, in_width = Input.shape
+        in_channel_chunks, in_channel_block, in_channel_tail = 
split_to_chunks(in_channels, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_channel_chunks, in_height, in_width, 
in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, 
name="data_placeholder")
+        else:
+            Input = pack_input(
+                Input,
+                "NCHW",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+    else:
+        batch, in_channel_chunks, in_height, in_width, in_channel_block = 
Input.shape
+        in_channels = in_channel_chunks * in_channel_block
+
+    if len(Filter.shape) == 4:
+        out_channels, in_filter_channels, kernel_h, kernel_w = Filter.shape
+        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channels, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kshape = (out_channel_chunks, in_filter_channels, kernel_h, 
kernel_w, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Filter = pack_filter(
+                Filter,
+                "OIHW",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+    else:
+        out_channel_chunks, in_filter_channels, kernel_h, kernel_w, 
out_channel_block = Filter.shape
+        out_channels = out_channel_chunks * out_channel_block
+
+    assert in_channels % in_filter_channels == 0
+    groups = in_channels // in_filter_channels
+
+    # Compute Constraints...
+    assert out_channel_chunks % groups == 0
+    assert in_channel_chunks % groups == 0
+
+    out_height_orig, out_height, out_width_orig, out_width = 
expand_spatial_dimensions(
+        in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, 
padding, stride_h, stride_w
+    )
+
+    temp = add_pad(
+        Input,
+        "NCHW",
+        out_height_orig,
+        out_width_orig,
+        kernel_h,
+        kernel_w,
+        dilation_h,
+        dilation_w,
+        padding,
+        stride_h,
+        stride_w,
+    )
+
+    in_group_channel_chunks = in_channel_chunks // groups
+    in_group_channel_block = in_channel_block
+    out_group_channel_chunks = out_channel_chunks // groups
+    rcc = te.reduce_axis((0, in_group_channel_chunks), name="rcc")
+    rcb = te.reduce_axis((0, in_group_channel_block), name="rcb")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+
+    conv = te.compute(
+        (batch, out_channel_chunks, out_height, out_width, out_channel_block),
+        lambda nn, occ, yy, xx, obb: te.sum(
+            (
+                temp[
+                    nn,
+                    occ // out_group_channel_chunks * in_group_channel_chunks 
+ rcc,
+                    yy * stride_h + ry * dilation_h,
+                    xx * stride_w + rx * dilation_w,
+                    rcb,
+                ]
+                * Filter[occ, rcc * in_group_channel_block + rcb, ry, rx, obb]
+            ).astype(out_dtype),
+            axis=[rcc, rcb, ry, rx],
+        ),
+        tag="conv2d_nchwc_group",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
+            lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_channels, out_height_orig, out_width_orig),
+            lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % 
out_channel_block],
+            tag="adreno_group_conv2d_latest_op",
+        )
+    else:
+        return te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
+            lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, 
ffb].astype(out_dtype),
+            tag="adreno_group_conv2d_latest_op",
+        )
+
+
+def schedule_group_conv2d_NCHWc_KCRSk(cfg, s, output):
+    """
+    Schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local 
worksize.
+       The limitations for tuning includes heuristics from some tuned networks 
to limit
+       search space and not pay much time for useles configurations.
+    2. In case of 4d convolution schedule copying of the input (and filter) 
into
+      5d tensors
+    4. pad should be scheduled separately to create independent opencl kernel. 
If pad is
+       inlined into convolution, this gives 1.5x performance drop
+    5. We are using cache_read for intermediate tensors to produce texture and 
guarantee
+       the best performance on the next stage.
+       The weights are managed through static texture planning mechanism and 
guarantied come
+       in texture memory scope.
+       Thus way we are calling cache_read only for data tensor
+    6. For 5d convolution we schedule the latest op with binding 5d axis and 
vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. 
conversion
+       of data type
+    7. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    pad_data, kernel = s[conv].op.input_tensors
+    filter_pack_rt = bool(
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
+    )
+
+    if "pad_temp" in pad_data.op.name:
+        input_pad_temp = pad_data.op.input_tensors[0]
+    else:
+        input_pad_temp = pad_data
+
+    input_pack_rt = bool(
+        isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in 
input_pad_temp.op.tag
+    )
+
+    ##### space definition begin #####
+    n, fc, y, x, fb = s[conv].op.axis
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+
+    if conv.shape[1] % 2 == 0:
+        min_threads_div = 2
+    else:
+        min_threads_div = 1
+    cfg.define_split(
+        "tile_fc",
+        fc,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8
+        and entity.size[2] >= min_threads_div
+        and entity.size[2] < 256,
+    )
+    cfg.define_split(
+        "tile_y",
+        y,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+    cfg.define_split(
+        "tile_x",
+        x,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+
+    cfg.define_split("tile_rcc", rcc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    cfg.define_knob("unroll_explicit", [0, 1])
+    cfg.multi_filter(
+        filter=lambda entity: (  # pylint: disable=chained-comparison
+            entity["tile_fc"].size[1] * entity["tile_y"].size[1] * 
entity["tile_x"].size[1]
+        )
+        <= 24
+        and 32
+        <= (entity["tile_fc"].size[2] * entity["tile_y"].size[2] * 
entity["tile_x"].size[2])
+        < 1024
+    )
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], 
conv.shape[3])
+    ##### space definition end #####
+
+    pad_data, kernel = s[conv].op.input_tensors
+    # There are several conditions that have to be handled:
+    # 1. If we are in the tuning, we always add cache read for data to main 
conv kernel
+    #    to get texture in tuning opencl kernel
+    # 2. If we are repacking input in runtime, we should always explicit 
schedule this one more
+    #    stage of data copy from 4d to 5d (referred as pack_data).
+    # 3. If we have pad (independently if we have runtime repack or not) we 
should inline it in the
+    #    cache_read("texture")
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            if "pad_temp" in pad_data.op.name:
+                s[pad_data].compute_inline()
+        else:
+            if "pad_temp" in pad_data.op.name:
+                pack_data = pad_data.op.input_tensors[0]
+                bind_data_copy(s[pack_data])
+                s[pad_data].compute_inline()
+            else:
+                pack_data = pad_data
+                bind_data_copy(s[pack_data])
+
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
+        bind_data_copy(s[AT])
+    elif "pad_temp" in pad_data.op.name:
+        s[pad_data].compute_inline()
+        # create cache stage
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
+        bind_data_copy(s[AT])
+
+    if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
+        if not autotvm.GLOBAL_SCOPE.in_tuning:
+            bind_data_copy(s[kernel])
+        if kernel.shape[2] == 1 and kernel.shape[3] == 1:
+            WT = s.cache_read(kernel, get_texture_storage(kernel.shape), 
[conv])
+            bind_data_copy(s[WT])
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # tile and bind spatial axes
+    n, fc, y, x, fb = s[latest_blocked].op.axis
+
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    bf = s[latest_blocked].fuse(n, bf)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, fc, y, x, fb = s[conv].op.axis
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+
+    rco, rci = cfg["tile_rcc"].apply(s, conv, rcc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb)
+    s[conv].unroll(rcb)
+    s[conv].vectorize(fb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", 
cfg["unroll_explicit"].val)
+
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape)
+    _, IC, KH, KW, _ = get_const_tuple(kernel.shape)
+    ICKHKW = IC * KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW)
diff --git 
a/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py 
b/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py
new file mode 100644
index 0000000000..bd05610e92
--- /dev/null
+++ b/tests/python/relay/opencl_texture/test_group_conv2d_nchw_texture.py
@@ -0,0 +1,208 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import re
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from utils.adreno_utils import build_run_compare, build_run_compare_vm
+
+executor_type = tvm.testing.parameter("ge", "vm")
+dtype = tvm.testing.parameter("float32")
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_group_conv2d_nchwc_adreno_encoder1(remote, target, executor_type, 
dtype):
+    input_shape = (1, 512, 56, 100)
+    filter_shape = (512, 64, 3, 3)
+    bias_shape = (1, 512, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=512,
+        groups=8,
+        dilation=1,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_group_conv2d_nchwc_adreno_encoder2(remote, target, executor_type, 
dtype):
+    input_shape = (1, 1024, 56, 100)
+    filter_shape = (512, 128, 3, 3)
+    bias_shape = (1, 512, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[3, 3, 3, 3],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=512,
+        groups=8,
+        dilation=2,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_group_conv2d_nchwc_adreno_nontrivial(remote, target, executor_type, 
dtype):
+    input_shape = (1, 56, 56, 100)
+    filter_shape = (112, 8, 7, 3)
+    bias_shape = (1, 112, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[3, 3, 3, 3],
+        strides=[1, 2],
+        out_dtype=dtype,
+        channels=112,
+        groups=7,
+        dilation=2,
+        kernel_size=(7, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+
+
+@tvm.testing.requires_opencl
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_group_conv2d_nchwc_default(remote, target, executor_type, dtype):
+    input_shape = (1, 49, 56, 100)
+    filter_shape = (343, 7, 3, 3)
+    bias_shape = (1, 343, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=343,
+        groups=7,
+        dilation=1,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    if executor_type == "ge":
+        build_run_compare(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+    else:
+        build_run_compare_vm(remote, mod, params1, {"data": input_shape}, 
{"data": dtype}, target)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to