This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f8fe3c503 [topi] Add `arm_cpu` specific pooling schedules (#14855)
9f8fe3c503 is described below

commit 9f8fe3c503b1aea30c785055bbc35c83c69c1359
Author: Jack Frankland <30410009+franklandj...@users.noreply.github.com>
AuthorDate: Mon Jul 10 21:12:13 2023 +0100

    [topi] Add `arm_cpu` specific pooling schedules (#14855)
    
    This commit:
    * Adds specialized `arm_cpu` pooling schedules for both fixed width and
      salable vectors.
    * Enables topi testing of new `arm_cpu` schedules.
---
 python/tvm/relay/op/strategy/arm_cpu.py          | 17 ++---
 python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py |  2 +
 python/tvm/topi/arm_cpu/mprofile/dsp/pool.py     | 30 +++++---
 python/tvm/topi/arm_cpu/pooling.py               | 91 +++++++++++++++++++++++-
 tests/python/topi/python/test_topi_pooling.py    |  1 +
 5 files changed, 116 insertions(+), 25 deletions(-)

diff --git a/python/tvm/relay/op/strategy/arm_cpu.py 
b/python/tvm/relay/op/strategy/arm_cpu.py
index dc3b16aa82..3a9f7e1c11 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -26,6 +26,7 @@ from tvm import relay, topi, tir
 from ....auto_scheduler import is_auto_scheduler_enabled
 from ....meta_schedule import is_meta_schedule_enabled
 from ....topi.generic import conv2d as conv2d_generic
+from ....topi.arm_cpu.mprofile import dsp
 from .. import op as _op
 from .generic import *
 
@@ -63,19 +64,11 @@ def concatenate_strategy_arm_cpu(attrs, inputs, out_type, 
target):
 def schedule_pool_arm_cpu(attrs, outs, target):
     """schedule pooling ops arm cpu"""
     layout = attrs.layout
-    avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
     with target:
-        if (
-            avg_pool
-            and target.features.has_dsp
-            and layout in ("NCW", "NCHW")
-            or not avg_pool
-            and target.features.has_dsp
-            and layout in ("NWC", "NHWC")
-        ):
-            return topi.arm_cpu.schedule_pool(outs, layout)
-        logger.warning("pool is not optimized for arm cpu.")
-        return topi.generic.schedule_pool(outs, layout)
+        if target.features.has_dsp:
+            is_avg_pool = isinstance(attrs, relay.op.op_attrs.AvgPool2DAttrs)
+            return dsp.pool.schedule_pool(outs, layout, is_avg_pool)
+        return topi.arm_cpu.schedule_pool(outs, layout)
 
 
 def _get_padding_width(padding):
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
index 13a83393a9..35e3f35a10 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/__init__.py
@@ -14,3 +14,5 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
+"""Schedule for arm_cpu targets supporting DSP"""
+from .pool import schedule_pool
diff --git a/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py 
b/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
index 4416831124..7bd24dfee7 100644
--- a/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
+++ b/python/tvm/topi/arm_cpu/mprofile/dsp/pool.py
@@ -20,18 +20,12 @@ import logging
 
 import tvm
 
-from tvm import te
+from tvm import te, topi
 from tvm.topi.utils import traverse_inline
 
-from .micro_kernel.max_pool import (
-    intrin_max,
-    max_impl,
-)
+from .micro_kernel.max_pool import intrin_max, max_impl
 
-from .micro_kernel.avg_pool import (
-    intrin_sum,
-    sum_impl,
-)
+from .micro_kernel.avg_pool import intrin_sum, sum_impl
 
 logger = logging.getLogger("topi")
 
@@ -100,8 +94,24 @@ def schedule_avgpool_2d_nchw(s, op):
     s[output].pragma(n, "import_c", sum_impl(pool_w, uniq_id))
 
 
-def pool_dsp_schedule(outs, layout):
+def schedule_pool(outs, layout, is_avg_pool):
     """Schedule function for v7e-m DSP instructions of pooling."""
+
+    if is_avg_pool and layout not in ["NCW", "NCHW"]:
+        logger.warning(
+            "avg pool not support for NCW or NCHW layouts on DSP"
+            "enabled targets, falling back on generic pool"
+            "implementation"
+        )
+        return topi.generic.schedule_pool(outs, layout)
+    elif not is_avg_pool and layout not in ["NWC", "NHWC"]:
+        logger.warning(
+            "max pool not support for NWC or NHWC layouts on DSP"
+            "enabled targets, falling back on generic pool"
+            "implementation"
+        )
+        return topi.generic.schedule_pool(outs, layout)
+
     s = te.create_schedule([x.op for x in outs])
 
     def _callback(op):
diff --git a/python/tvm/topi/arm_cpu/pooling.py 
b/python/tvm/topi/arm_cpu/pooling.py
index f09f008934..248383c08c 100644
--- a/python/tvm/topi/arm_cpu/pooling.py
+++ b/python/tvm/topi/arm_cpu/pooling.py
@@ -17,9 +17,94 @@
 # pylint: disable=invalid-name, unused-variable
 """Schedule for pooling operators"""
 
-from .mprofile.dsp.pool import pool_dsp_schedule
+import logging
+from tvm import topi, te
+from tvm.target import Target
+from .. import tag
 
 
 def schedule_pool(outs, layout):
-    """Create schedule for avgpool/maxpool with dsp"""
-    return pool_dsp_schedule(outs, layout)
+    """Create schedule for avgpool/maxpool"""
+
+    if layout != "NHWC":
+        logger = logging.getLogger("topi")
+        logger.warning(
+            """We currently only support NHWC target specific pools on arm_cpu,
+               falling back on generic pool scheduling"""
+        )
+        return topi.generic.schedule_pool(outs, layout)
+
+    return schedule_pool_2d(outs)
+
+
+def schedule_pool_2d(outs):
+    """Create arm_cpu specific 2D schedule for avgpool/maxpool"""
+
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    schedule_ops = [x.op for x in outs]
+    schedule = te.create_schedule(schedule_ops)
+    scheduled_ops = []
+
+    def traverse(op):
+        # Recursively inline any injective operation that isn't the pooling
+        # operation or hasn't already been scheduled.
+        if tag.is_injective(op.tag):
+            if op not in schedule.outputs:
+                schedule[op].compute_inline()
+            for tensor in op.input_tensors:
+                if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op 
not in scheduled_ops:
+                    traverse(tensor.op)
+        # schedule the actual pooling operation
+        elif op.tag.startswith("pool"):
+            n, height, width, channel = schedule[op].op.axis
+            # Average pool consists of two parts; a sum then a division.
+            # We can schedule the division loop to parallelize across height 
and
+            # vectorize across width.
+            enable_explicit_vectorization = not 
Target.current(allow_none=False).features.has_sve
+            if op != outs[0].op:
+                output = outs[0]
+                output_fused = schedule[output].fuse(output.op.axis[1], 
output.op.axis[2])
+                schedule[output].parallel(output_fused)
+                vectorization_factor = (
+                    8 if enable_explicit_vectorization else 
output.op.axis[3].dom.extent
+                )
+                _, inner = schedule[output].split(output.op.axis[3], 
vectorization_factor)
+                schedule[output].vectorize(inner)
+
+            padded_input = op.input_tensors[0]
+            if isinstance(padded_input.op, te.tensor.ComputeOp):
+                schedule[padded_input].compute_inline()
+
+            # For targets without SVE try explicitly vectorizing the channel
+            # loop, For SVE targets leave the loop in place for LLVM to convert
+            # into a scalable vector loop.
+            vectorization_factor = 8 if enable_explicit_vectorization else 
channel.dom.extent
+            channel_outer, channel_inner = schedule[op].split(channel, 
vectorization_factor)
+            schedule[op].vectorize(channel_inner)
+            schedule[op].parallel(height)
+            if len(schedule[op].op.reduce_axis) > 0:
+                filter_height, filter_width = schedule[op].op.reduce_axis
+                # We consider any filter of area < 10 to be small enough to
+                # unroll; 3x3 filters have shown better performance when
+                # unrolled.
+                if filter_height.dom.extent * filter_width.dom.extent <= 9:
+                    # For small filters, unrolling the filter loops allows us 
to
+                    # vectorize over channels without reordering anything.
+                    schedule[op].unroll(filter_width)
+                    schedule[op].unroll(filter_height)
+                else:
+                    # Reordering so that channels is the fastest moving axis 
allows
+                    # LLVM to vectorize across contiguous memory in the NHWC
+                    # ordering.
+                    schedule[op].reorder(
+                        n, height, width, filter_height, filter_width, 
channel_outer, channel_inner
+                    )
+            else:
+                schedule[op].reorder(n, height, width, channel_outer, 
channel_inner)
+        else:
+            raise RuntimeError("Unsupported operator: %s" % op.tag)
+
+        scheduled_ops.append(op)
+
+    traverse(outs[0].op)
+    return schedule
diff --git a/tests/python/topi/python/test_topi_pooling.py 
b/tests/python/topi/python/test_topi_pooling.py
index 5f8aebabc2..0d0ee65ad4 100644
--- a/tests/python/topi/python/test_topi_pooling.py
+++ b/tests/python/topi/python/test_topi_pooling.py
@@ -28,6 +28,7 @@ from tvm.topi.utils import get_const_tuple
 
 _pool_schedule = {
     "generic": topi.generic.schedule_pool,
+    "arm_cpu": topi.arm_cpu.schedule_pool,
     "cpu": topi.x86.schedule_pool,
     "gpu": topi.cuda.schedule_pool,
     "hls": topi.hls.schedule_pool,

Reply via email to