This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 015da7c185 [TOPI][ADRENO] Add conv2d transpose nchw texture schedule 
(#15786)
015da7c185 is described below

commit 015da7c185199add4af943970b6ee3d0a0661ec4
Author: krishnaraj36 <[email protected]>
AuthorDate: Wed Nov 15 11:37:40 2023 +0530

    [TOPI][ADRENO] Add conv2d transpose nchw texture schedule (#15786)
    
    * [TOPI][ADRENO] Add conv2d transpose nchw texture schedule
    
    Added the conv2d transpose strategy for adreno target and enable the
    optimized schedule.
    
    * Fix the whitespace lint error
    
    * Fix lint errors
    
    * Fix whitespace lint error
    
    * Removed unused variables
    
    * Add more conv2dTranspose testcases
    
    * empty update
    
    empty update for retrigger ci
    
    * Update test_conv2d_transpose_nchw_texture.py
    
    * Added more testcase to check memory scopes
    
    * Device specific alter_op_layout for conv2d_transpose
    
    * Fix in virtual device setup and added test case with scope check
    
    * Add the comment conv2d algo
    
    * Add the comment conv2d algo
    
    * Removed fp16 test case from texture
    
    It is failing for few gpu devices.
    
    * remove opencl config change for mainline confilct
    
    * Add the test case for 3 channel input which run with cuda schecule
    
    * Fix in op strategy for out channel 3
    
    * Comment in test case for memory scope
    
    ---------
    
    Co-authored-by: Siva <[email protected]>
---
 python/tvm/relay/op/nn/_nn.py                      |   6 +
 python/tvm/relay/op/strategy/adreno.py             |  52 +++
 python/tvm/topi/adreno/__init__.py                 |   2 +
 .../tvm/topi/adreno/conv2d_transpose_alter_op.py   | 121 ++++++
 python/tvm/topi/adreno/conv2d_transpose_nchw.py    | 412 +++++++++++++++++++++
 python/tvm/topi/adreno/utils.py                    |  23 ++
 python/tvm/topi/nn/conv2d.py                       |  23 ++
 src/relay/transforms/annotate_texture_storage.cc   |   4 +
 .../test_conv2d_transpose_nchw_texture.py          | 325 ++++++++++++++++
 .../relay/opencl_texture/utils/adreno_utils.py     |   5 +-
 10 files changed, 972 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index c68685f0ae..6acaf43fe7 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -335,6 +335,12 @@ def legalize_conv2d_transpose(attrs, inputs, types):
     return topi.nn.conv2d_transpose_legalize(attrs, inputs, types)
 
 
[email protected]_alter_op_layout("nn.conv2d_transpose")
+def alter_op_layout_conv2d_transpose(attrs, inputs, tinfos, out_type):
+    """Alternate the layout of conv2d_transpose"""
+    return topi.nn.conv2d_transpose_alter_layout(attrs, inputs, tinfos, 
out_type)
+
+
 @reg.register_convert_op_layout("nn.conv2d_transpose")
 def convert_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
     """Convert Layout pass registration for conv2d_transpose op.
diff --git a/python/tvm/relay/op/strategy/adreno.py 
b/python/tvm/relay/op/strategy/adreno.py
index c180eeec74..bacace9ad4 100644
--- a/python/tvm/relay/op/strategy/adreno.py
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -215,6 +215,58 @@ def 
conv2d_winograd_without_weight_transform_strategy_adreno(attrs, inputs, out_
     return strategy
 
 
+@conv2d_transpose_strategy.register("adreno")
+def conv2d_transpose_strategy_adreno(attrs, inputs, out_type, target):
+    """conv2d_transpose adreno strategy"""
+    strategy = _op.OpStrategy()
+    _, kernel = inputs
+    dilation = attrs.get_int_tuple("dilation")
+    groups = attrs.groups
+    data_layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    assert dilation == (1, 1), "not support dilate now"
+
+    if (groups == 1) and (
+        (data_layout == "NCHW" and kernel_layout == "IOHW")
+        or (data_layout == "NCHW4c" and kernel_layout == "IOHW4o")
+        or (data_layout == "NCHW" and kernel_layout == "IOHW4o")
+    ):
+        if len(kernel.shape) == 4:
+            _, oc, _, _ = get_const_tuple(kernel.shape)
+        else:
+            _, oc, _, _, _ = get_const_tuple(kernel.shape)
+        # We cannot use textures for case than number of channels is less than 
4.
+        # So, we use compute functions from cuda.
+        if len(kernel.shape) == 4 and oc < 4:
+            strategy.add_implementation(
+                wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw),
+                wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
+                name="conv2d_transpose_nchw.cuda",
+            )
+            return strategy
+        strategy.add_implementation(
+            wrap_compute_conv2d_transpose(topi.adreno.conv2d_transpose_nchwc),
+            wrap_topi_schedule(topi.adreno.schedule_conv2d_transpose_nchwc),
+            name="conv2d_transpose_nchwc.image2d",
+            plevel=10,
+        )
+    elif data_layout == "NCHW":
+        strategy.add_implementation(
+            wrap_compute_conv2d_transpose(topi.cuda.conv2d_transpose_nchw, 
has_groups=True),
+            wrap_topi_schedule(topi.cuda.schedule_conv2d_transpose_nchw),
+            name="conv2d_transpose_nchw.cuda",
+        )
+    else:
+        raise RuntimeError(
+            "Layout not supported: ("
+            + data_layout
+            + ", "
+            + kernel_layout
+            + ") - only support NCHW, NCHW4c / IOHW4o layouts for 
conv2d_transpose"
+        )
+    return strategy
+
+
 @schedule_pool.register("adreno")
 def schedule_pool_adreno(attrs, outs, target):
     """schedule pooling ops for adreno"""
diff --git a/python/tvm/topi/adreno/__init__.py 
b/python/tvm/topi/adreno/__init__.py
index 55bfbee2a8..cd42848b29 100644
--- a/python/tvm/topi/adreno/__init__.py
+++ b/python/tvm/topi/adreno/__init__.py
@@ -23,7 +23,9 @@ from .conv2d_nhwc import *
 from .depthwise_conv2d_nhwc import *
 from .pooling import *
 from .conv2d_alter_op import *
+from .conv2d_transpose_alter_op import *
 from .conv2d_nchw_winograd import *
 from .conv2d_nhwc_winograd import *
 from .injective import schedule_injective
 from .reduction import *
+from .conv2d_transpose_nchw import *
diff --git a/python/tvm/topi/adreno/conv2d_transpose_alter_op.py 
b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py
new file mode 100644
index 0000000000..c68e5cb7a5
--- /dev/null
+++ b/python/tvm/topi/adreno/conv2d_transpose_alter_op.py
@@ -0,0 +1,121 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D Transpose alter op for Qualcomm Adreno GPU"""
+
+import logging
+
+import re
+import tvm
+from tvm import te
+from tvm import relay
+from tvm import autotvm
+from ..utils import get_const_tuple
+from ..nn import conv2d_transpose_alter_layout
+
+logger = logging.getLogger("topi")
+
+# Number of wildcards for matching of supported layouts to be transformed
+_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
+_IOHWo_matcher = re.compile("^IOHW[0-9]+o$")
+
+
+@conv2d_transpose_alter_layout.register("adreno")
+def _alter_conv2d_transpose_layout(attrs, inputs, tinfos, out_type):
+    """
+    Prepare of the new conv2d_transpose with proper target blocked layout 
attributes
+    OpenCL Textures supports 1d/2d/3d/4d tetures but read happens always only 
for 4 elements
+    in a line. Thus way we are supporting for now only 4d conversions on the 
end
+    NCHW -> NCHW4c & IOHW ->IOHW4o
+    """
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    # Parse the attributes.
+    padding = attrs.get_int_tuple("padding")
+    strides = attrs.get_int_tuple("strides")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data_tensor, kernel_tensor = tinfos
+    data_dtype = data_tensor.dtype
+    out_dtype = out_type.dtype
+
+    if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest):
+        cfg = dispatch_ctx.query(target, None)
+        workload = cfg.workload
+    else:
+        impl, outs = relay.backend.te_compiler.select_implementation(
+            relay.op.get("nn.conv2d_transpose"), attrs, tinfos, out_type, 
target
+        )
+        workload = autotvm.task.get_workload(outs)
+        cfg = dispatch_ctx.query(target, workload)
+
+    topi_tmpl = workload[0]
+
+    if "conv2d_transpose_nchwc" in topi_tmpl:  # covers conv2d_transpose_nchwc
+        if data_layout == "NCHW" and kernel_layout == "IOHW":
+            batch, in_channels, in_height, in_width = data_tensor.shape
+            _, out_channles, kernel_h, kernel_w = kernel_tensor.shape
+            in_channel_block = in_channels % 4
+            if in_channel_block == 0:
+                in_channel_block = 4
+            num_filter_block = out_channles % 4
+            if num_filter_block == 0:
+                num_filter_block = 4
+
+            # no support yet for tensors that cannot be divisible by factor 4
+            if num_filter_block != 4:
+                return None
+
+            batch_size, in_channel, height, width = 
get_const_tuple(data_tensor.shape)
+            in_filter_channel, out_channel, kh, kw = 
get_const_tuple(kernel_tensor.shape)
+
+            # update new attrs
+            new_attrs["channels"] = out_channel
+            if in_channel_block == 4:
+                new_attrs["data_layout"] = f"NCHW{in_channel_block}c"
+            else:
+                new_attrs["data_layout"] = "NCHW"
+            # (oc, ic, h, w) -> (ic, OC, h, w, oc)
+            new_attrs["kernel_layout"] = f"IOHW{num_filter_block}o"
+            new_attrs["out_layout"] = f"NCHW{num_filter_block}c"
+
+            # Store altered operator's config for applying of tuned AutoTVM 
statistics
+            if in_channel_block == 4:
+                new_data = te.placeholder(
+                    (batch_size, in_channel // in_channel_block, height, 
width, in_channel_block),
+                    dtype=data_dtype,
+                )
+            else:
+                new_data = data_tensor
+            new_kernel = te.placeholder(
+                (in_filter_channel, out_channel // num_filter_block, kh, kw, 
num_filter_block),
+                dtype=kernel_tensor.dtype,
+            )
+            new_workload = autotvm.task.args_to_workload(
+                [new_data, new_kernel, strides, padding, dilation, out_dtype],
+                topi_tmpl,  # "conv2d_transpose_nchwc.image2d",
+            )
+            dispatch_ctx.update(target, new_workload, cfg)
+        else:
+            assert _NCHWc_matcher.match(data_layout)
+            assert _IOHWo_matcher.match(kernel_layout)
+        return relay.nn.conv2d_transpose(*inputs, **new_attrs)
+
+    return None
diff --git a/python/tvm/topi/adreno/conv2d_transpose_nchw.py 
b/python/tvm/topi/adreno/conv2d_transpose_nchw.py
new file mode 100644
index 0000000000..ad8c7b88ef
--- /dev/null
+++ b/python/tvm/topi/adreno/conv2d_transpose_nchw.py
@@ -0,0 +1,412 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""conv2d_transpose nchw schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+from .. import nn
+
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    bind_data_copy,
+    get_default_conv2d_config,
+    get_texture_storage,
+)
+
+
[email protected]_topi_compute("conv2d_transpose_nchwc.image2d")
+def conv2d_transpose_nchwc(
+    cfg, Input, Filter, stride, padding, out_dtype, output_padding, groups=1
+):
+    """
+    Transposed Convolution operator in NCHWc layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such 
conversion
+         will be absent for real blocked convolution, no sense to include into 
tuning
+      2. Expand spatial dimensions to have width and height be dividable by 
factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU 
much better
+      3. Add paddings. This happens even if we do not need pad originaly. This 
is useful
+         due to work arounding of the gaps of texture annotation between 
Primary Functions
+         and limited support of textures in schedules. Later on this pad will 
be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    outpad_height, outpad_width = output_padding
+    assert outpad_height < stride_h and outpad_width < stride_w
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_channels, in_height, in_width = Input.shape
+        in_channel_chunks, in_channel_block, in_channel_tail = 
split_to_chunks(in_channels, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_channel_chunks, in_height, in_width, 
in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, 
name="data_placeholder")
+        else:
+            Input = pack_input(
+                Input,
+                "NCHW",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+    else:
+        batch, in_channel_chunks, in_height, in_width, in_channel_block = 
Input.shape
+
+    if len(Filter.shape) == 4:
+        in_filter_channels, out_channels, kernel_h, kernel_w = Filter.shape
+        out_channel_chunks, out_channel_block, out_channel_tail = 
split_to_chunks(out_channels, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            kshape = (in_filter_channels, out_channel_chunks, kernel_h, 
kernel_w, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, 
name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Filter = pack_filter(
+                Filter,
+                "IOHW",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+    else:
+        in_filter_channels, out_channel_chunks, kernel_h, kernel_w, 
out_channel_block = Filter.shape
+
+    cfg.stride = stride
+
+    pad_top, pad_left, pad_bottom, pad_right = nn.get_pad_tuple(padding, 
(kernel_h, kernel_w))
+
+    out_width_orig = out_width = (
+        (in_width - 1) * stride_w + kernel_w - pad_left - pad_right + 
outpad_width
+    )
+    pad_left = kernel_w - 1 - pad_left
+    pad_right = kernel_w - 1 - pad_right + outpad_width
+    dilated_width = stride_w * (in_width - 1) + 1
+
+    out_height_orig = out_height = (
+        (in_height - 1) * stride_h + kernel_h - pad_top - pad_bottom + 
outpad_height
+    )
+    pad_top = kernel_h - 1 - pad_top
+    pad_bottom = kernel_h - 1 - pad_bottom + outpad_height
+    dilated_height = stride_h * (in_height - 1) + 1
+
+    if out_height % 2 != 0:
+        out_height += 1
+    if out_width % 2 != 0:
+        out_width += 1
+
+    if out_height % 4 != 0:
+        out_height += 2
+    if out_width % 4 != 0:
+        out_width += 2
+
+    # compute pad
+    temp = te.compute(
+        (
+            batch,
+            in_channel_chunks,
+            pad_top + dilated_height + pad_bottom,
+            pad_left + dilated_width + pad_right,
+            in_channel_block,
+        ),
+        lambda n, c, y, x, cb: tvm.tir.if_then_else(
+            tvm.tir.all(
+                x >= pad_left,
+                x < pad_left + dilated_width,
+                tvm.tir.indexmod(x - pad_left, stride_w).equal(0),
+                y >= pad_top,
+                y < pad_top + dilated_height,
+                tvm.tir.indexmod(y - pad_top, stride_h).equal(0),
+            ),
+            Input[
+                n,
+                c,
+                tvm.tir.indexdiv(y - pad_top, stride_h),
+                tvm.tir.indexdiv(x - pad_left, stride_w),
+                cb,
+            ],
+            tvm.tir.const(0.0, Input.dtype),
+        ),
+        name="pad_temp",
+    )
+
+    # compute transposed conv
+    dcc = te.reduce_axis((0, in_channel_chunks), name="dcc")
+    dcb = te.reduce_axis((0, in_channel_block), name="dcb")
+    dh = te.reduce_axis((0, kernel_h), name="dh")
+    dw = te.reduce_axis((0, kernel_w), name="dw")
+    conv = te.compute(
+        (batch, out_channel_chunks, out_height, out_width, out_channel_block),
+        lambda b, c, h, w, cb: te.sum(
+            temp[
+                b, c // out_channel_chunks * (in_channel_chunks) + dcc, h + 
dh, w + dw, dcb
+            ].astype(out_dtype)
+            * Filter[
+                dcc * in_channel_block + dcb,
+                c % out_channel_chunks,
+                kernel_h - 1 - dh,
+                kernel_w - 1 - dw,
+                cb,
+            ].astype(out_dtype),
+            axis=[dcc, dcb, dh, dw],
+        ),
+        tag="conv2d_transpose_nchwc",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
+            lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_channels, out_height_orig, out_width_orig),
+            lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % 
out_channel_block],
+            tag="adreno_conv2d_transpose_latest_op",
+        )
+    else:
+        return te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, 
out_channel_block),
+            lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, 
ffb].astype(out_dtype),
+            tag="adreno_conv2d_transpose_latest_op",
+        )
+
+
[email protected]_topi_schedule("conv2d_transpose_nchwc.image2d")
+def schedule_conv2d_transpose_nchwc(cfg, outs):
+    """Create the schedule for conv2d_nchw"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == "adreno_conv2d_transpose_latest_op":
+            schedule_conv2d_transpose_NCHWc(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def schedule_conv2d_transpose_NCHWc(cfg, s, output):
+    """
+    schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local 
worksize.
+       The limitations for tuning includes heuristics from some tuned networks 
to limit
+       search space and not pay much time for useles configurations.
+    2. In case of 4d convolution schedule copying of the input (and filter) 
into
+      5d tensors
+    4. pad should be scheduled separately to create independent opencl kernel. 
If pad is
+       inlined into convolution, this gives 1.5x performance drop
+    5. We are using cache_read for intermediate tensors to produce texture and 
guarantee
+       the best performance on the next stage.
+       The weights are managed through static texture planning mechanism and 
guarantied come
+       in texture memory scope.
+       Thus way we are calling cache_read only for data tensor
+    6. For 5d convolution we schedule the latest op with binding 5d axis and 
vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. 
conversion
+       of data type
+    7. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    pad_data, kernel = s[conv].op.input_tensors
+    filter_pack_rt = bool(
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in 
kernel.op.tag
+    )
+
+    if "pad_temp" in pad_data.op.name:
+        input_pad_temp = pad_data.op.input_tensors[0]
+    else:
+        input_pad_temp = pad_data
+
+    input_pack_rt = bool(
+        isinstance(input_pad_temp.op, tvm.te.ComputeOp) and "input_pack" in 
input_pad_temp.op.tag
+    )
+
+    ##### space definition begin #####
+    n, fc, y, x, fb = s[conv].op.axis
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+
+    if conv.shape[1] % 2 == 0:
+        min_threads_div = 2
+    else:
+        min_threads_div = 1
+    cfg.define_split(
+        "tile_fc",
+        fc,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8
+        and entity.size[2] >= min_threads_div
+        and entity.size[2] < 256,
+    )
+    cfg.define_split(
+        "tile_y",
+        y,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+    cfg.define_split(
+        "tile_x",
+        x,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+
+    cfg.define_split("tile_rcc", rcc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 64])
+    cfg.define_knob("unroll_explicit", [0, 1])
+    cfg.multi_filter(
+        filter=lambda entity: (  # pylint: disable=chained-comparison
+            entity["tile_fc"].size[1] * entity["tile_y"].size[1] * 
entity["tile_x"].size[1]
+        )
+        <= 24
+        and 32
+        <= (entity["tile_fc"].size[2] * entity["tile_y"].size[2] * 
entity["tile_x"].size[2])
+        < 1024
+    )
+    if cfg.is_fallback:
+        get_default_conv2d_config(cfg, conv.shape[1], conv.shape[2], 
conv.shape[3])
+    ##### space definition end #####
+
+    pad_data, kernel = s[conv].op.input_tensors
+    # There are several conditions that have to be handled:
+    # 1. If we are in the tuning, we always add cache read for data to main 
conv kernel
+    #    to get texture in tuning opencl kernel
+    # 2. If we are repacking input in runtime, we should always explicit 
schedule this one more
+    #    stage of data copy from 4d to 5d (referred as pack_data).
+    # 3. If we have pad (independently if we have runtime repack or not) we 
should inline it in the
+    #    cache_read("texture")
+    if autotvm.GLOBAL_SCOPE.in_tuning or input_pack_rt:
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            if "pad_temp" in pad_data.op.name:
+                s[pad_data].compute_inline()
+        else:
+            if "pad_temp" in pad_data.op.name:
+                pack_data = pad_data.op.input_tensors[0]
+                bind_data_copy(s[pack_data])
+                s[pad_data].compute_inline()
+            else:
+                pack_data = pad_data
+                bind_data_copy(s[pack_data])
+
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
+        bind_data_copy(s[AT])
+    elif "pad_temp" in pad_data.op.name:
+        s[pad_data].compute_inline()
+        # create cache stage
+        AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), 
[conv])
+        bind_data_copy(s[AT])
+
+    if autotvm.GLOBAL_SCOPE.in_tuning or filter_pack_rt:
+        if not autotvm.GLOBAL_SCOPE.in_tuning:
+            bind_data_copy(s[kernel])
+        if kernel.shape[2] == 1 and kernel.shape[3] == 1:
+            WT = s.cache_read(kernel, get_texture_storage(kernel.shape), 
[conv])
+            bind_data_copy(s[WT])
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # tile and bind spatial axes
+    n, fc, y, x, fb = s[latest_blocked].op.axis
+
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    bf = s[latest_blocked].fuse(n, bf)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, fc, y, x, fb = s[conv].op.axis
+
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+    rco, rci = cfg["tile_rcc"].apply(s, conv, rcc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+
+    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb)
+    s[conv].vectorize(fb)
+    s[conv].unroll(rcb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", 
cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", 
cfg["unroll_explicit"].val)
+
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape)
+    _, IC, KH, KW, _ = get_const_tuple(kernel.shape)
+    ICKHKW = IC * KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW)
diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py
index 698a306514..a42cbeeb77 100644
--- a/python/tvm/topi/adreno/utils.py
+++ b/python/tvm/topi/adreno/utils.py
@@ -281,6 +281,22 @@ def pack_filter(
             Filter[indices[0], indices[1], indices[2], indices[3] * out_block 
+ indices[4]],
         )
 
+    def _reorder_weights_iohw(*indices):
+        conditionA = []
+        conditionA.append(indices[1] == out_chunks - 1)
+        conditionA.append(indices[4] >= out_original_tail)
+        conditionAT = tvm.tir.all(*conditionA)
+
+        conditionO = []
+        conditionO.append(conditionAT)
+        conditionO.append(indices[0] >= in_chunks * in_block + 
in_original_tail)
+        conditionOT = tvm.tir.any(*conditionO)
+        return tvm.tir.if_then_else(
+            conditionOT,
+            pad_value,
+            Filter[indices[0], indices[1] * out_block + indices[4], 
indices[2], indices[3]],
+        )
+
     if in_filter_channels == 1:
         if layout == "OIHW":
             reordered_filter = te.compute(
@@ -313,6 +329,13 @@ def pack_filter(
                 name="filter_pack",
                 tag="filter_pack",
             )
+        elif layout == "IOHW":
+            reordered_filter = te.compute(
+                [in_filter_channels, out_chunks, kernel_h, kernel_w, 
out_block],
+                _reorder_weights_iohw,
+                name="filter_pack",
+                tag="filter_pack",
+            )
         elif layout == "HWIO":
             reordered_filter = te.compute(
                 [kernel_h, kernel_w, in_filter_channels, out_chunks, 
out_block],
diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py
index 8fdcb0dc1a..75f72ee93d 100644
--- a/python/tvm/topi/nn/conv2d.py
+++ b/python/tvm/topi/nn/conv2d.py
@@ -143,6 +143,29 @@ def conv2d_alter_layout(attrs, inputs, tinfos, out_type):
     return None
 
 
[email protected]_func
+def conv2d_transpose_alter_layout(attrs, inputs, tinfos, out_type):
+    """Change Conv2D_Transpose layout.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution
+    inputs : tvm.relay.Expr
+        Grouped input symbols
+    tinfos : list
+        Input shape and dtype
+    out_type: type
+        The output type
+
+    Note
+    ----
+    Unlike other TOPI functions, this function operates on both graph level 
and operator level.
+    """
+    # not to change by default
+    return None
+
+
 @tvm.target.generic_func
 def conv2d_infer_layout(workload, cfg):
     """Infer input/output shapes and layouts from a workload and cfg.
diff --git a/src/relay/transforms/annotate_texture_storage.cc 
b/src/relay/transforms/annotate_texture_storage.cc
index 01d47b6953..9ccb2171d8 100644
--- a/src/relay/transforms/annotate_texture_storage.cc
+++ b/src/relay/transforms/annotate_texture_storage.cc
@@ -392,6 +392,10 @@ class StorageInfo : private 
transform::DeviceAwareExprVisitor {
           (attrs->kernel_layout == "OIHW4o" || attrs->kernel_layout == 
"HWIO4o")) {
         supports_texture_storage = true;
       }
+    } else if (auto attrs = call->attrs.as<Conv2DTransposeAttrs>()) {
+      if (attrs->data_layout == "NCHW4c" && attrs->kernel_layout == "IOHW4o") {
+        supports_texture_storage = true;
+      }
     } else if (auto attrs = call->attrs.as<GlobalPool2DAttrs>()) {
       if (attrs->layout == "NCHW4c") {
         supports_texture_storage = true;
diff --git 
a/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py 
b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py
new file mode 100644
index 0000000000..d110c8329f
--- /dev/null
+++ b/tests/python/relay/opencl_texture/test_conv2d_transpose_nchw_texture.py
@@ -0,0 +1,325 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import re
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from tvm.contrib import utils
+from utils.adreno_utils import gpu_preprocess, build_run_compare, 
build_run_compare_vm
+import pytest
+
+
+executor_type = tvm.testing.parameter("ge", "vm")
+dtype = tvm.testing.parameter("float32")
+
+
[email protected]_opencl
[email protected]_targets("opencl -device=adreno")
+def test_conv2d_transpose_adreno(remote, target, executor_type, dtype):
+    # Conv2d transpose test cases lists
+    trials = [
+        [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False), 
gpu_preprocess],
+        [4, 4, (0, 0), (2, 2), (1, 1), 256, (32, 64, 64), (False, False), 
None],
+        [3, 3, (0, 0), (2, 2), (1, 1), 64, (256, 100, 100), (True, True), 
None],
+        [4, 4, (1, 1), (1, 1), (1, 1), 512, (16, 100, 100), (False, False), 
gpu_preprocess],
+        [5, 5, (2, 2), (2, 2), (1, 1), 4, (16, 100, 100), (True, False), 
gpu_preprocess],
+        [7, 7, (3, 3), (2, 2), (1, 1), 8, (4, 100, 100), (False, True), None],
+        [7, 7, (3, 3), (2, 2), (1, 1), 64, (3, 100, 100), (True, True), None],
+        [3, 3, (1, 1), (1, 1), (1, 1), 3, (16, 8, 8), (True, True), None],
+    ]
+    # Tensors memory scope with graph executor build
+    ge_texture_scopes = [
+        ["", "global.texture", "global.texture-weight", "", ""],
+        ["", "global.texture", "global.texture-weight", "", ""],
+        ["", "global.texture", "global.texture-weight", 
"global.texture-weight", "", ""],
+        ["", "global.texture", "global.texture-weight", "", ""],
+        ["", "global.texture", "global.texture-weight", 
"global.texture-weight", "", ""],
+        ["", "global.texture", "global.texture-nhwc", "", ""],
+        [],
+        [],
+    ]
+    # Tensors memory scope with vm executor build
+    vm_texture_scopes = [
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[4]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[4]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-nhwc
+        """,
+        [],
+        [],
+    ]
+
+    for i, (
+        kernel_h,
+        kernel_w,
+        pad,
+        stride,
+        dilation,
+        out_channels,
+        shape,
+        composite,
+        _gpu_preprocess,
+    ) in enumerate(trials):
+        shape = (1, *shape)
+        has_bias = composite[0]
+        has_activation = composite[1]
+        input_shape = shape
+        filter_shape = (shape[1], out_channels, kernel_w, kernel_h)
+        x = relay.var("data", shape=input_shape, dtype=dtype)
+        w = relay.var("weight", shape=filter_shape, dtype=dtype)
+        inputs = [x, w]
+        y = relay.nn.conv2d_transpose(
+            x,
+            w,
+            channels=out_channels,
+            kernel_size=(kernel_w, kernel_h),
+            strides=stride,
+            padding=pad,
+            kernel_layout="IOHW",
+            data_layout="NCHW",
+            dilation=dilation,
+        )
+
+        np.random.seed(0)
+        initializer = relay.testing.init.Xavier()
+        filter_data = np.zeros(filter_shape).astype(dtype)
+        initializer("weight", filter_data)
+        params1 = {
+            "weight": tvm.nd.array(filter_data),
+        }
+
+        if has_bias:
+            b = relay.var("bias", shape=(out_channels,), dtype=dtype)
+            y = relay.nn.bias_add(y, b, axis=1)
+            inputs.append(b)
+            bias_data = np.zeros((out_channels,)).astype(dtype)
+            initializer("bias", bias_data)
+            params1["bias"] = tvm.nd.array(bias_data)
+        if has_activation:
+            y = relay.nn.relu(y)
+
+        mod = relay.Function(inputs, y)
+        if executor_type == "ge":
+            build_run_compare(
+                remote,
+                mod,
+                params1,
+                {"data": input_shape},
+                {"data": dtype},
+                target,
+                ge_texture_scopes[i],
+                _gpu_preprocess,
+            )
+        else:
+            build_run_compare_vm(
+                remote,
+                mod,
+                params1,
+                {"data": input_shape},
+                {"data": dtype},
+                target,
+                vm_texture_scopes[i],
+                _gpu_preprocess,
+            )
+
+
[email protected]_opencl
[email protected]_targets("opencl -device=adreno")
+def test_conv2d_transpose_three_layer_block(remote, target, executor_type, 
dtype):
+    # Conv2d transpose test cases lists
+    trials = [
+        [4, 4, (1, 1), (2, 2), (1, 1), 64, (256, 100, 100), (False, False), 
None],
+        [3, 3, (0, 0), (1, 1), (1, 1), 64, (256, 12, 12), (True, True), 
gpu_preprocess],
+    ]
+    ge_texture_scopes = [
+        [
+            "",
+            "global.texture",
+            "global.texture-weight",
+            "global.texture",
+            "global.texture-weight",
+            "global.texture",
+            "global.texture-weight",
+            "",
+            "",
+        ],
+        [
+            "",
+            "global.texture-nhwc",
+            "global.texture-weight",
+            "global.texture-nhwc",
+            "global.texture-weight",
+            "global.texture-weight",
+            "global.texture-nhwc",
+            "global.texture-weight",
+            "",
+            "",
+        ],
+    ]
+    vm_texture_scopes = [
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[4]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[5]: device type 4, id 0 and mem_scope global.texture
+        VM VirtualDevice[6]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[7]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+        """
+        VM VirtualDevice[0]: device type 1, id 0 and mem_scope
+        VM VirtualDevice[1]: device type 4, id 0 and mem_scope
+        VM VirtualDevice[2]: device type 4, id 0 and mem_scope 
global.texture-nhwc
+        VM VirtualDevice[3]: device type 4, id 0 and mem_scope 
global.texture-nhwc
+        VM VirtualDevice[4]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[5]: device type 4, id 0 and mem_scope 
global.texture-nhwc
+        VM VirtualDevice[6]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[7]: device type 4, id 0 and mem_scope 
global.texture-weight
+        VM VirtualDevice[8]: device type 4, id 0 and mem_scope 
global.texture-weight
+        """,
+    ]
+
+    for i, (
+        kernel_h,
+        kernel_w,
+        pad,
+        stride,
+        dilation,
+        out_channels,
+        shape,
+        composite,
+        _gpu_preprocess,
+    ) in enumerate(trials):
+        shape = (1, *shape)
+        has_bias = composite[0]
+        has_activation = composite[1]
+        input_shape = shape
+        filter_shape = (shape[1], out_channels, kernel_w, kernel_h)
+        x = relay.var("data", shape=input_shape, dtype=dtype)
+        w = relay.var("weight", shape=filter_shape, dtype=dtype)
+        inputs = [x, w]
+        W1 = relay.var("weight1", shape=(shape[1], shape[1], 1, 1), 
dtype=dtype)
+        conv = relay.nn.conv2d(x, W1, padding=[0, 0, 0, 0], channels=shape[1], 
kernel_size=(1, 1))
+        inputs.append(W1)
+        conv = relay.op.nn.relu(conv)
+        y = relay.nn.conv2d_transpose(
+            conv,
+            w,
+            channels=out_channels,
+            kernel_size=(kernel_w, kernel_h),
+            strides=stride,
+            padding=pad,
+            kernel_layout="IOHW",
+            data_layout="NCHW",
+            dilation=dilation,
+        )
+
+        if has_bias:
+            b = relay.var("bias", shape=(out_channels,), dtype=dtype)
+            y = relay.nn.bias_add(y, b, axis=1)
+            inputs.append(b)
+
+        if has_activation:
+            y = relay.nn.relu(y)
+        W2 = relay.var("weight2", shape=(out_channels, out_channels, 1, 1), 
dtype=dtype)
+        out = relay.nn.conv2d(
+            y, W2, padding=[0, 0, 0, 0], channels=out_channels, 
kernel_size=(1, 1)
+        )
+        out = relay.op.nn.relu(out)
+        np.random.seed(0)
+        inputs.append(W2)
+        initializer = relay.testing.init.Xavier()
+        filter_data = np.zeros(filter_shape).astype(dtype)
+        initializer("weight", filter_data)
+        filter_data1 = np.zeros((shape[1], shape[1], 1, 1)).astype(dtype)
+        initializer("weight", filter_data1)
+        filter_data2 = np.zeros((out_channels, out_channels, 1, 
1)).astype(dtype)
+        initializer("weight", filter_data2)
+        params1 = {
+            "weight": tvm.nd.array(filter_data),
+            "weight1": tvm.nd.array(filter_data1),
+            "weight2": tvm.nd.array(filter_data2),
+        }
+        if has_bias:
+            bias_data = np.zeros((out_channels,)).astype(dtype)
+            initializer("bias", bias_data)
+            params1["bias"] = tvm.nd.array(bias_data)
+
+        mod = relay.Function(inputs, out)
+
+        if executor_type == "ge":
+            build_run_compare(
+                remote,
+                mod,
+                params1,
+                {"data": input_shape},
+                {"data": dtype},
+                target,
+                ge_texture_scopes[i],
+                _gpu_preprocess,
+            )
+        else:
+            build_run_compare_vm(
+                remote,
+                mod,
+                params1,
+                {"data": input_shape},
+                {"data": dtype},
+                target,
+                vm_texture_scopes[i],
+                _gpu_preprocess,
+            )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relay/opencl_texture/utils/adreno_utils.py 
b/tests/python/relay/opencl_texture/utils/adreno_utils.py
index d9e52f8847..21bdfbdee3 100644
--- a/tests/python/relay/opencl_texture/utils/adreno_utils.py
+++ b/tests/python/relay/opencl_texture/utils/adreno_utils.py
@@ -200,7 +200,10 @@ def build_run_compare_vm(
 
 def gpu_preprocess(tvm_mod):
     layout_config = relay.transform.LayoutConfig()
-    desired_layouts = {"nn.conv2d": ["NCHW4c", "OIHW4o"]}
+    desired_layouts = {
+        "nn.conv2d": ["NCHW4c", "OIHW4o"],
+        "nn.conv2d_transpose": ["NCHW4c", "IOHW4o"],
+    }
     with layout_config:
         seq = 
tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
         with tvm.transform.PassContext(opt_level=3):


Reply via email to