This is an automated email from the ASF dual-hosted git repository.

kparzysz pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d98da2736 [Hexagon] Implement avg_pool2d slice op (#11417)
9d98da2736 is described below

commit 9d98da27361429cb558930032f074172bc99b7c3
Author: Jyotsna Verma <73191103+jverma-q...@users.noreply.github.com>
AuthorDate: Wed Jun 15 12:40:37 2022 -0500

    [Hexagon] Implement avg_pool2d slice op (#11417)
    
    * Implement avg_pool2d slice op
    
    * Address review comments and fix the STIR schedule
    
    * Fix formatting issues
    
    * Address pylint errors
    
    * Additional formatting issues
    
    * more pylint fixes
    
    * Changed arch version to v68 for now
    
    * Changing arch version back to v69
    
    * Move the test to tests/python/contrib/test_hexagon/topi
---
 python/tvm/topi/hexagon/slice_ops/__init__.py      |  22 ++
 python/tvm/topi/hexagon/slice_ops/avg_pool2d.py    | 141 ++++++++
 python/tvm/topi/hexagon/utils.py                   |  52 +++
 .../python/contrib/test_hexagon/infrastructure.py  |  20 ++
 .../test_hexagon/topi/test_avg_pool2d_slice.py     | 369 +++++++++++++++++++++
 5 files changed, 604 insertions(+)

diff --git a/python/tvm/topi/hexagon/slice_ops/__init__.py 
b/python/tvm/topi/hexagon/slice_ops/__init__.py
new file mode 100644
index 0000000000..b52d410676
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/__init__.py
@@ -0,0 +1,22 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computes and Schedules for Hexagon slice ops. """
+
+# pylint: disable=wildcard-import
+
+from .avg_pool2d import avg_pool2d_compute, avg_pool2d_STIR_schedule
diff --git a/python/tvm/topi/hexagon/slice_ops/avg_pool2d.py 
b/python/tvm/topi/hexagon/slice_ops/avg_pool2d.py
new file mode 100644
index 0000000000..306be543d8
--- /dev/null
+++ b/python/tvm/topi/hexagon/slice_ops/avg_pool2d.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, unused-argument, 
too-many-locals
+
+""" Compute and schedule for avg_pool2d slice op
+
+Please note the following assumptions made by the implementation:
+
+1) The input must be padded in advance to account for 'padding'. In addition,
+   both input and output must be padded as per the physical buffer layout.
+2) The current implementation assumes 'count_include_pad' to be 'True'. It can 
be
+   modified to support 'False' case but the element count for the pooling 
window
+   must be pre-computed and provided as an input to reduce the run-time 
overhead.
+3) 'padding' is ignored. It must be handled outside of the sliced op.
+4) Please note that this implementation will not work if the output includes 
any
+   physical layout related padding as it can result into out-of-bound access
+   for the input.
+"""
+
+from tvm import te
+from tvm import tir
+from ..utils import get_layout_transform_fn
+
+
+def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
+    """Validate output shape"""
+    _, oh, ow, _ = out_shape
+    _, ih, iw, _ = in_shape
+    kh, kw = kernel
+    sh, sw = stride
+    dh, dw = dilation
+    if ih < (oh - 1) * sh + dh * (kh - 1) + 1:
+        raise RuntimeError("Output height is too large")
+    if iw < (ow - 1) * sw + dw * (kw - 1) + 1:
+        raise RuntimeError("Output width is too large")
+
+
+def avg_pool2d_compute(A, out_shape, kernel, stride, dilation):
+    """avg_pool2d compute"""
+    kh, kw = kernel
+    rh = te.reduce_axis((0, kh), name="rh")
+    rw = te.reduce_axis((0, kw), name="rw")
+    ob, oh, ow, oc = out_shape
+    if isinstance(ob, int):
+        validate_out_shape(out_shape, A.shape, kernel, stride, dilation)
+
+    sh, sw = stride
+    dh, dw = dilation
+    InvArea = float(1) / (kh * kw)
+
+    Sum = te.compute(
+        out_shape,
+        lambda b, h, w, c: te.sum(
+            A[b, h * sh + dh * rh, w * sw + dw * rw, c].astype("float32"), 
axis=[rh, rw]
+        ),
+        name="sum",
+    )
+    Avg = te.compute(
+        out_shape, lambda b, h, w, c: (Sum[b, h, w, c] * 
InvArea).astype(A.dtype), name="avg"
+    )
+    return Avg
+
+
+def STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: 
str):
+    """Schedule for input and output layout nhwc-8h2w32c2w"""
+    func = te.create_prim_func([ins, outs])
+    s = tir.Schedule(func)
+    Sum = s.get_block("sum")
+    Avg = s.get_block("avg")
+
+    input_transform_fn = get_layout_transform_fn(input_layout)
+    output_transform_fn = get_layout_transform_fn(output_layout)
+    s.transform_layout(Sum, ("read", 0), input_transform_fn)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn)
+
+    # Schedule 'Avg'
+    n, h, w, c = s.get_loops(Avg)
+    ho, hi = s.split(h, [None, 8])
+    wo, wi = s.split(w, [None, 4])
+    wio, wii = s.split(wi, [None, 2])
+    co, ci = s.split(c, [None, 32])
+    s.reorder(n, ho, wo, co, hi, wio, ci, wii)
+    ci_wii = s.fuse(ci, wii)
+    s.vectorize(ci_wii)
+
+    # Schedule 'Sum'
+    s.compute_at(Sum, wio)
+    Sum_axis = s.get_loops(Sum)
+    s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3])
+    ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3])
+    # s.vectorize(ci_wii) # Doesn't work
+    return s
+
+
+def STIR_schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
+    """Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""
+    func = te.create_prim_func([ins, outs])
+    s = tir.Schedule(func)
+    Sum = s.get_block("sum")
+    Avg = s.get_block("avg")
+
+    input_transform_fn = get_layout_transform_fn(input_layout)
+    output_transform_fn = get_layout_transform_fn(output_layout)
+    s.transform_layout(Sum, ("read", 0), input_transform_fn)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn)
+
+    # Schedule 'Avg'
+    n, h, w, c = s.get_loops(Avg)
+    co, ci = s.split(c, [None, 1024])
+    cio, cii = s.split(ci, [None, 64])
+    s.vectorize(cii)
+
+    # Schedule 'Sum'
+    s.compute_at(Sum, cio)
+    Sum_axis = s.get_loops(Sum)
+    s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-3])
+    # s.vectorize(Sum_axis[-3]) # Doesn't work
+    return s
+
+
+def avg_pool2d_STIR_schedule(outs, ins, output_layout: str, input_layout: str):
+    """STIR based schedule"""
+    if output_layout == "nhwc-8h2w32c2w-2d":
+        return STIR_schedule_nhwc_8h2w32c2w(outs, ins, output_layout, 
input_layout)
+    if output_layout == "n11c-1024c-2d":
+        return STIR_schedule_n11c_1024c(outs, ins, output_layout, input_layout)
+    raise RuntimeError(f"Unexpected layout '{output_layout}'")
diff --git a/python/tvm/topi/hexagon/utils.py b/python/tvm/topi/hexagon/utils.py
new file mode 100644
index 0000000000..af6e3de9c3
--- /dev/null
+++ b/python/tvm/topi/hexagon/utils.py
@@ -0,0 +1,52 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Common hexagon specific utilities"""
+from tvm import te
+
+
+def n11c_1024c_2d(n, h, w, c):
+    """Return index map for n11c_1024 2d layout"""
+    return [n, h, w, c // 1024, te.AXIS_SEPARATOR, c % 1024]
+
+
+def n11c_1024c_1d(n, h, w, c):
+    """Return index map for n11c_1024 1d layout"""
+    return [n, h, w, c // 1024, c % 1024]
+
+
+def nhwc_8h2w32c2w_2d(n, h, w, c):
+    """Return index map for nhwc_8h2w32c2w 2d layout"""
+    return [n, h // 8, w // 4, c // 32, te.AXIS_SEPARATOR, h % 8, (w % 4) // 
2, c % 32, w % 2]
+
+
+def nhwc_8h2w32c2w_1d(n, h, w, c):
+    """Return index map for nhwc_8h2w32c2w 1d layout"""
+    return [n, h // 8, w // 4, c // 32, h % 8, (w % 4) // 2, c % 32, w % 2]
+
+
+def get_layout_transform_fn(layout):
+    """Return index map function as per the layout string"""
+    if layout == "nhwc-8h2w32c2w-2d":
+        return nhwc_8h2w32c2w_2d
+    if layout == "nhwc-8h2w32c2w-1d":
+        return nhwc_8h2w32c2w_1d
+    if layout == "n11c-1024c-2d":
+        return n11c_1024c_2d
+    if layout == "n11c-1024c-1d":
+        return n11c_1024c_1d
+    raise RuntimeError(f"Unexpected layout '{layout}'")
diff --git a/tests/python/contrib/test_hexagon/infrastructure.py 
b/tests/python/contrib/test_hexagon/infrastructure.py
index 01eef86e6b..57a9dff8b4 100644
--- a/tests/python/contrib/test_hexagon/infrastructure.py
+++ b/tests/python/contrib/test_hexagon/infrastructure.py
@@ -14,6 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+# pylint: disable=invalid-name
 
 """ Hexagon testing infrastructure """
 
@@ -228,3 +229,22 @@ def conv2d_compute(X, filt, pad, stride, dilation):
         )
 
     return output_shape, compute
+
+
+def transform_numpy(arr_np, current_layout: str, new_layout: str):
+    """Reshape and transpose numpy array according to the specified layout"""
+    if current_layout == "nhwc":
+        if new_layout == "nhwc":
+            return arr_np
+        if new_layout in ["nhwc-8h2w32c2w-2d", "nhwc-8h2w32c2w-1d"]:
+            n, h, w, c = arr_np.shape
+            return arr_np.reshape([n, h // 8, 8, w // 4, 2, 2, c // 32, 
32]).transpose(
+                0, 1, 3, 6, 2, 4, 7, 5
+            )
+        if new_layout in ["n11c-1024c-2d", "n11c-1024c-1d"]:
+            n, h, w, c = arr_np.shape
+            assert h == 1 and w == 1, "The size of h and w must be 1"
+            return arr_np.reshape([n, 1, 1, c // 1024, 1024])
+
+        raise RuntimeError(f"Unexpected new_layout '{new_layout}'")
+    raise RuntimeError(f"Unexpected current_layout '{current_layout}'")
diff --git a/tests/python/contrib/test_hexagon/topi/test_avg_pool2d_slice.py 
b/tests/python/contrib/test_hexagon/topi/test_avg_pool2d_slice.py
new file mode 100644
index 0000000000..6cbd84b7ee
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_avg_pool2d_slice.py
@@ -0,0 +1,369 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+import numpy as np
+
+from tvm import te, topi
+
+import tvm.testing
+from tvm.topi import testing
+from tvm.contrib.hexagon.build import HexagonLauncher
+import tvm.topi.hexagon.slice_ops as sl
+from ..infrastructure import allocate_hexagon_array, transform_numpy
+
+
+input_layout = tvm.testing.parameter(
+    "nhwc-8h2w32c2w-2d",
+)
+
+
+@tvm.testing.fixture
+def input_np(input_shape, dtype):
+    return np.random.random(input_shape).astype(dtype)
+
+
+@tvm.testing.fixture
+def transformed_expected_output_np(expected_output_np, output_layout):
+    return transform_numpy(expected_output_np, "nhwc", output_layout)
+
+
+@tvm.testing.fixture
+def transformed_input_np_padded(input_np_padded, input_layout):
+    return transform_numpy(input_np_padded, "nhwc", input_layout)
+
+
+class TestAvgPool2dSlice:
+    # NOTE: input_layout is always assumed to be "nhwc-8h2w32c2w-2d"
+    (
+        output_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        ceil_mode,
+        count_include_pad,
+        output_layout,
+        dtype,
+    ) = tvm.testing.parameters(
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [1, 1],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 16, 16, 32],
+            [3, 3],
+            [1, 1],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [8, 8],
+            [1, 1],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        # Test non-one stride and dilation
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [2, 3],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [2, 2],
+            [2, 2],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [2, 2],
+            [2, 3],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        # Test non-zero padding
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [1, 1],
+            [1, 1],
+            [1, 1, 1, 1],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [1, 1],
+            [1, 1],
+            [1, 2, 3, 4],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [1, 1],
+            [1, 1],
+            [1, 2, 3, 4],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        (
+            [1, 8, 8, 32],
+            [3, 3],
+            [3, 2],
+            [2, 3],
+            [1, 2, 3, 4],
+            False,
+            True,
+            "nhwc-8h2w32c2w-2d",
+            "float16",
+        ),
+        # Test n11c-1024c-2d layout which will require input and output to 
have different layout
+        (
+            [1, 1, 1, 2048],
+            [8, 8],
+            [1, 1],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "n11c-1024c-2d",
+            "float16",
+        ),
+        (
+            [1, 1, 1, 2048],
+            [6, 6],
+            [1, 1],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "n11c-1024c-2d",
+            "float16",
+        ),
+        (
+            [1, 1, 1, 2048],
+            [3, 3],
+            [2, 2],
+            [1, 1],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "n11c-1024c-2d",
+            "float16",
+        ),
+        (
+            [1, 1, 1, 2048],
+            [4, 4],
+            [2, 2],
+            [2, 3],
+            [0, 0, 0, 0],
+            False,
+            True,
+            "n11c-1024c-2d",
+            "float16",
+        ),
+    )
+
+    @tvm.testing.fixture
+    def expected_output_np(
+        self,
+        input_np,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        ceil_mode,
+        count_include_pad,
+    ):
+        pad_before = padding[:2]
+        pad_after = padding[2:]
+        ref_np = tvm.topi.testing.poolnd_python(
+            input_np,
+            kernel,
+            stride,
+            dilation,
+            pad_before,
+            pad_after,
+            "avg",  # pool_type
+            count_include_pad,
+            False,  # ceil_mode,
+            layout="NHWC",
+        )
+        return ref_np
+
+    @tvm.testing.fixture
+    def input_shape(self, output_shape, kernel, padding, stride, dilation, 
output_layout):
+        # Input shape without any padding; 'ceil' is being ignored from 
calculation:
+        o_b, o_h, o_w, o_c = output_shape
+        d_h, d_w = dilation
+        s_h, s_w = stride
+        k_h, k_w = kernel
+        pad_before_h, pad_before_w = padding[:2]
+        pad_after_h, pad_after_w = padding[2:]
+
+        if output_layout == "n11c-1024c-2d":
+            assert (
+                pad_before_w == 0 and pad_after_w == 0 and pad_before_h == 0 
and pad_after_h == 0
+            ), "Padding must be zero for n11c-1024c-2d layout"
+            assert o_h == 1 and o_w == 1, "Output height and width must be 1"
+
+        in_h = (o_h - 1) * s_h + d_h * (k_h - 1) + 1 - pad_before_h - 
pad_after_h
+        in_w = (o_w - 1) * s_w + d_w * (k_w - 1) + 1 - pad_before_w - 
pad_after_w
+
+        return [o_b, in_h, in_w, o_c]
+
+    @tvm.testing.fixture
+    def input_shape_padded(self, input_shape, padding, output_layout):
+        # Input shape is adjusted to account for 'padding'. Also, due to the 
physical
+        # layout of the buffer, height and width are adjusted so that they are 
a
+        # multiple of 8 and 4 respectively.
+        # NOTE: Input layout is always assumed to be nhwc-8h2w32c2w-2d.
+        pad_before_h, pad_before_w = padding[:2]
+        pad_after_h, pad_after_w = padding[2:]
+        padded_input_height = ((input_shape[1] + pad_before_h + pad_after_h + 
7) // 8) * 8
+        padded_input_width = ((input_shape[2] + pad_before_w + pad_after_w + 
3) // 4) * 4
+        return [input_shape[0], padded_input_height, padded_input_width, 
input_shape[3]]
+
+    @tvm.testing.fixture
+    def input_np_padded(self, input_np, input_shape, input_shape_padded, 
padding):
+        pad_before_h, pad_before_w = padding[:2]
+        pad_after_h = input_shape_padded[1] - input_shape[1] - pad_before_h
+        pad_after_w = input_shape_padded[2] - input_shape[2] - pad_before_w
+        input_padded = np.pad(
+            input_np,
+            ((0, 0), (pad_before_h, pad_after_h), (pad_before_w, pad_after_w), 
(0, 0)),
+            "constant",
+        )
+        return input_padded
+
+    @tvm.testing.requires_hexagon
+    def test_avg_pool2d_slice(
+        self,
+        stride,
+        kernel,
+        dtype,
+        dilation,
+        padding,
+        count_include_pad,
+        input_layout,
+        output_layout,
+        output_shape,
+        input_shape,
+        input_shape_padded,
+        input_np,
+        input_np_padded,
+        transformed_input_np_padded,
+        transformed_expected_output_np,
+        expected_output_np,
+        hexagon_session,
+    ):
+
+        target_hexagon = tvm.target.hexagon("v69")
+        A = te.placeholder(input_shape_padded, name="A", dtype=dtype)
+
+        M = sl.avg_pool2d_compute(A, output_shape, kernel, stride, dilation)
+
+        # tir schedule
+        tir_schedule = sl.avg_pool2d_STIR_schedule(M, A, output_layout, 
input_layout)
+        sch = tir_schedule.mod
+
+        input_axis_separator = [4]
+        if output_layout == "nhwc-8h2w32c2w-2d":
+            output_axis_separator = [4]
+        elif output_layout == "n11c-1024c-2d":
+            output_axis_separator = [4]
+        else:
+            raise RuntimeError(f"Unexpected layout '{output_layout}'")
+
+        with tvm.transform.PassContext(opt_level=3):
+            func = tvm.build(
+                sch,
+                [A, M],
+                tvm.target.Target(target_hexagon, host=target_hexagon),
+                name="avg_pool2d",
+            )
+
+        input_arr = allocate_hexagon_array(
+            hexagon_session.device,
+            data=transformed_input_np_padded,
+            axis_separators=input_axis_separator,
+            mem_scope="global.vtcm",
+        )
+        output_arr = allocate_hexagon_array(
+            hexagon_session.device,
+            transformed_expected_output_np.shape,
+            dtype,
+            axis_separators=output_axis_separator,
+            mem_scope="global.vtcm",
+        )
+
+        mod = hexagon_session.load_module(func)
+        mod(input_arr, output_arr)
+        b, h, w, c = output_shape
+        if output_layout == "nhwc-8h2w32c2w-2d":
+            output_np = output_arr.numpy().reshape([b, h // 8, w // 4, c // 
32, 8, 2, 32, 2])
+        elif output_layout == "n11c-1024c-2d":
+            output_np = output_arr.numpy().reshape([b, 1, 1, c // 1024, 1024])
+        else:
+            raise RuntimeError(f"Unexpected layout '{output_layout}'")
+
+        np.testing.assert_allclose(output_np, transformed_expected_output_np, 
rtol=1e-3, atol=1e-3)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))

Reply via email to