kparzysz-quic commented on code in PR #15057:
URL: https://github.com/apache/tvm/pull/15057#discussion_r1227260589


##########
python/tvm/topi/hexagon/slice_ops/avg_pool2d.py:
##########
@@ -16,118 +16,206 @@
 # under the License.
 # pylint: disable=invalid-name, unused-variable, unused-argument, 
too-many-locals, pointless-exception-statement
 
-""" Compute and schedule for avg_pool2d slice op
-
-Please note the following assumptions made by the implementation:
-
-1) The input must be padded in advance to account for 'padding'. In addition,
-   both input and output must be padded as per the physical buffer layout.
-2) The current implementation assumes 'count_include_pad' to be 'True'. It can 
be
-   modified to support 'False' case but the element count for the pooling 
window
-   must be pre-computed and provided as an input to reduce the run-time 
overhead.
-3) 'padding' is ignored. It must be handled outside of the sliced op.
-4) Please note that this implementation will not work if the output includes 
any
-   physical layout related padding as it can result into out-of-bound access
-   for the input.
-"""
+""" Compute and schedule for avg_pool2d slice op """
 
 from tvm import te
 from tvm import tir
 from ..utils import get_layout_transform_fn
+from ...utils import get_const_tuple
+from ...nn.utils import get_pad_tuple
+from ...nn.pad import pad
+from ..compute_poolarea import compute_PoolArea
 
 
-def validate_out_shape(out_shape, in_shape, kernel, stride, dilation):
-    """Validate output shape"""
-    _, oh, ow, _ = out_shape
-    _, ih, iw, _ = in_shape
+def avg_pool2d_NCHW(
+    data, kernel, stride, padding, dilation, count_include_pad, oshape, 
odtype="float16"
+):
+    """avg_pool2d compute"""
+    if odtype != "float16":
+        raise RuntimeError(f"Unsupported output dtype '{odtype}'")
     kh, kw = kernel
+    rh = te.reduce_axis((0, kh), name="rh")
+    rw = te.reduce_axis((0, kw), name="rw")
     sh, sw = stride
     dh, dw = dilation
-    if ih < (oh - 1) * sh + dh * (kh - 1) + 1:
-        raise RuntimeError("Output height is too large")
-    if iw < (ow - 1) * sw + dw * (kw - 1) + 1:
-        raise RuntimeError("Output width is too large")
 
+    dilated_kh = (kh - 1) * dh + 1
+    dilated_kw = (kw - 1) * dw + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        get_const_tuple(padding), (dilated_kh, dilated_kw)
+    )
+
+    # DOPAD
 
-def avg_pool2d_compute(A, kernel, stride, dilation, oshape, odtype="float16"):
+    if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0:
+        pad_before = (0, 0, pad_top, pad_left)
+        pad_after = (0, 0, pad_down, pad_right)
+        data_pad = pad(data, pad_before, pad_after, name="data_pad")
+    else:
+        # By definition when True, zero-padding will be included in the 
averaging calculation
+        # This is equivalent to PoolArea = (kh * kw)
+        count_include_pad = True
+        data_pad = data
+
+    Sum = te.compute(
+        oshape,
+        lambda b, c, h, w: te.sum(
+            data_pad[b, c, h * sh + dh * rh, w * sw + dw * 
rw].astype("float32"), axis=[rh, rw]
+        ),
+        name="pool_sum",
+    )
+
+    if not count_include_pad:
+        # Compute PoolArea using unpadded input tensor
+        _, _, oh, ow = oshape
+        _, _, ih, iw = data.shape
+
+        PoolArea = te.compute(
+            (oh, ow),
+            lambda i, j: compute_PoolArea(i, j, ih, iw, kh, kw, sh, sw, dh, 
dw, pad_top, pad_left),
+            name="pool_area",
+        )
+
+        InvArea = te.compute(
+            (oh, ow),
+            lambda i, j: tir.if_then_else(
+                tir.all(PoolArea[i, j] > 0), (float(1) / PoolArea[i, j]), 0
+            ),
+            name="inverse_area",
+        )
+
+        Avg = te.compute(
+            oshape,
+            lambda b, c, h, w: (Sum[b, c, h, w] * InvArea[h, 
w]).astype(odtype),
+            name="pool_avg",
+        )
+    else:
+        InvArea = float(1) / (kh * kw)
+        Avg = te.compute(
+            oshape, lambda b, c, h, w: (Sum[b, c, h, w] * 
InvArea).astype(odtype), name="pool_avg"
+        )
+
+    return Avg
+
+
+def avg_pool2d_NHWC(
+    data, kernel, stride, padding, dilation, count_include_pad, oshape, 
odtype="float16"
+):
     """avg_pool2d compute"""
     if odtype != "float16":
-        RuntimeError(f"Unsupported output dtype '{odtype}'")
+        raise RuntimeError(f"Unsupported output dtype '{odtype}'")
     kh, kw = kernel
     rh = te.reduce_axis((0, kh), name="rh")
     rw = te.reduce_axis((0, kw), name="rw")
-    ob, oh, ow, oc = oshape
-    if isinstance(ob, int):
-        validate_out_shape(oshape, A.shape, kernel, stride, dilation)
 
     sh, sw = stride
     dh, dw = dilation
     InvArea = float(1) / (kh * kw)
 
+    dilated_kh = (kh - 1) * dh + 1
+    dilated_kw = (kw - 1) * dw + 1
+
+    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
+        get_const_tuple(padding), (dilated_kh, dilated_kw)
+    )
+
+    # DOPAD
+    if pad_top != 0 or pad_down != 0 or pad_left != 0 or pad_right != 0:
+        pad_before = (0, pad_top, pad_left, 0)
+        pad_after = (0, pad_down, pad_right, 0)
+        data_pad = pad(data, pad_before, pad_after, name="data_pad")
+    else:
+        # By definition when True, zero-padding will be included in the 
averaging calculation
+        # This is equivalent to PoolArea = (kh * kw)
+        count_include_pad = True
+        data_pad = data
+
     Sum = te.compute(
         oshape,
         lambda b, h, w, c: te.sum(
-            A[b, h * sh + dh * rh, w * sw + dw * rw, c].astype("float32"), 
axis=[rh, rw]
+            data_pad[b, h * sh + dh * rh, w * sw + dw * rw, 
c].astype("float32"), axis=[rh, rw]
         ),
-        name="sum",
-    )
-    Avg = te.compute(
-        oshape, lambda b, h, w, c: (Sum[b, h, w, c] * 
InvArea).astype(A.dtype), name="avg"
+        name="pool_sum",
     )
+
+    if not count_include_pad:
+        # Compute PoolArea using unpadded input tensor
+        _, oh, ow, _ = oshape
+        _, ih, iw, _ = data.shape
+
+        PoolArea = te.compute(
+            (oh, ow),
+            lambda i, j: compute_PoolArea(i, j, ih, iw, kh, kw, sh, sw, dh, 
dw, pad_top, pad_left),
+            name="pool_area",
+        )
+
+        InvArea = te.compute(
+            (oh, ow),
+            lambda i, j: tir.if_then_else(
+                tir.all(PoolArea[i, j] > 0), (float(1) / PoolArea[i, j]), 0
+            ),
+            name="inverse_area",
+        )
+
+        Avg = te.compute(
+            oshape,
+            lambda b, h, w, c: (Sum[b, h, w, c] * InvArea[h, 
w]).astype(odtype),
+            name="pool_avg",
+        )
+    else:
+        InvArea = float(1) / (kh * kw)
+        Avg = te.compute(
+            oshape, lambda b, h, w, c: (Sum[b, h, w, c] * 
InvArea).astype(odtype), name="pool_avg"
+        )
+
     return Avg
 
 
-def schedule_nhwc_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
-    """Schedule for input and output layout nhwc-8h2w32c2w"""
+def schedule_8h2w32c2w(outs, ins, output_layout: str, input_layout: str):
+    """Schedule for input and output layout 8h2w32c2w"""
     func = te.create_prim_func([ins, outs])
+    print(func)
     s = tir.Schedule(func)
-    Sum = s.get_block("sum")
-    Avg = s.get_block("avg")
+    Sum = s.get_block("pool_sum")
+    Avg = s.get_block("pool_avg")
 
+    mem_scope = "global.vtcm"
+    sum_read = s.cache_read(Sum, 0, mem_scope)
+    avg_write = s.cache_write(Avg, 0, mem_scope)
     input_transform_fn = get_layout_transform_fn(input_layout)
     output_transform_fn = get_layout_transform_fn(output_layout)
-    s.transform_layout(Sum, ("read", 0), input_transform_fn)
-    s.transform_layout(Avg, ("write", 0), output_transform_fn)
-
-    # Schedule 'Avg'
-    n, h, w, c = s.get_loops(Avg)
-    ho, hi = s.split(h, [None, 8])
-    wo, wi = s.split(w, [None, 4])
-    wio, wii = s.split(wi, [None, 2])
-    co, ci = s.split(c, [None, 32])
-    s.reorder(n, ho, wo, co, hi, wio, ci, wii)
-    ci_wii = s.fuse(ci, wii)
-    s.vectorize(ci_wii)
-
-    # Schedule 'Sum'
-    s.compute_at(Sum, wio)
-    Sum_axis = s.get_loops(Sum)
-    s.reorder(Sum_axis[-2], Sum_axis[-1], Sum_axis[-4], Sum_axis[-3])
-    ci_wii = s.fuse(Sum_axis[-4], Sum_axis[-3])
-    # s.vectorize(ci_wii) # Doesn't work
+    s.transform_layout(Sum, ("read", 0), input_transform_fn, pad_value=0.0)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn, pad_value=0.0)
     return s
 
 
-def schedule_n11c_1024c(outs, ins, output_layout: str, input_layout: str):
-    """Schedule for output layout: n11c-1024c, input layout: nhwc-8h2w32c2w"""
+def schedule_1024c(outs, ins, output_layout: str, input_layout: str):
+    """Schedule for output layout: 1024c, input layout: 8h2w32c2w"""
     func = te.create_prim_func([ins, outs])
     s = tir.Schedule(func)
-    Sum = s.get_block("sum")
-    Avg = s.get_block("avg")
+    Sum = s.get_block("pool_sum")
+    Avg = s.get_block("pool_avg")
 
+    mem_scope = "global.vtcm"
+    sum_read = s.cache_read(Sum, 0, mem_scope)
+    avg_write = s.cache_write(Avg, 0, mem_scope)
     input_transform_fn = get_layout_transform_fn(input_layout)
     output_transform_fn = get_layout_transform_fn(output_layout)
-    s.transform_layout(Sum, ("read", 0), input_transform_fn)
-    s.transform_layout(Avg, ("write", 0), output_transform_fn)
+    s.transform_layout(Sum, ("read", 0), input_transform_fn, pad_value=0.0)
+    s.transform_layout(Avg, ("write", 0), output_transform_fn, pad_value=0.0)
 
     # Schedule 'Avg'
-    n, h, w, c = s.get_loops(Avg)
-    co, ci = s.split(c, [None, 1024])
+    if output_layout == "n11c-1024c-2d":
+        n, h, w, c = s.get_loops(Avg)
+    else:
+        n, c, h, w = s.get_loops(Avg)
+    _, ci = s.split(c, [None, 1024])
     cio, cii = s.split(ci, [None, 64])
     s.vectorize(cii)
 
     # Schedule 'Sum'
-    s.compute_at(Sum, cio)
+    # s.compute_at(Sum, cio)

Review Comment:
   Done.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to