This is an automated email from the ASF dual-hosted git repository. lukhut pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new d22bdce2bf [Relay][Op] Connect existing arm_cpu schedule to relay strategy for concat (#14270) d22bdce2bf is described below commit d22bdce2bf4c16fab0ed54ca320f07ed48ee85d0 Author: Ashutosh Parkhi <86472128+ashutosh-...@users.noreply.github.com> AuthorDate: Tue Mar 14 14:53:04 2023 +0000 [Relay][Op] Connect existing arm_cpu schedule to relay strategy for concat (#14270) Previously used generic implementation for concatenate before cpu schedules were made the default fallback schedules. This leads to performance degradation as this blocks fusion with nearby ops. This commit adds Relay op strategy for arm_cpu implementation which makes it use arm_cpu schedule before cpu one. Reference: https://github.com/apache/tvm/pull/13775 Co-authored-by: Luke Hutton <luke.hut...@arm.com> --- python/tvm/relay/op/strategy/arm_cpu.py | 16 +++++-- .../relay/strategy/test_select_implementation.py | 56 ++++++++++++++++++++++ 2 files changed, 67 insertions(+), 5 deletions(-) diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index c8d51bc23c..6e6c1bf03b 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -46,11 +46,17 @@ def schedule_injective_arm_cpu(_, outs, target): return topi.arm_cpu.schedule_injective(outs) -@schedule_concatenate.register("arm_cpu") -def schedule_concatenate_arm_cpu(_, outs, target): - """schedule concatenate for arm cpu""" - with target: - return topi.arm_cpu.schedule_concatenate(outs) +@concatenate_strategy.register(["arm_cpu"]) +def concatenate_strategy_arm_cpu(attrs, inputs, out_type, target): + """concatenate arm_cpu strategy""" + strategy = _op.OpStrategy() + + strategy.add_implementation( + wrap_compute_concat(topi.concatenate), + wrap_topi_schedule(topi.arm_cpu.schedule_concatenate), + name="concatenate.arm_cpu", + ) + return strategy @schedule_pool.register(["arm_cpu"]) diff --git a/tests/python/relay/strategy/test_select_implementation.py b/tests/python/relay/strategy/test_select_implementation.py new file mode 100644 index 0000000000..3e63bc4751 --- /dev/null +++ b/tests/python/relay/strategy/test_select_implementation.py @@ -0,0 +1,56 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" Tests strategy selection for Relay ops """ +import pytest +import tvm +from tvm import relay +from tvm import te +from tvm.relay.testing import run_infer_type +import tvm.testing + + +@pytest.mark.parametrize( + "target, expected_implementation", + [("llvm", "concatenate.cpu"), ("llvm -device=arm_cpu", "concatenate.arm_cpu")], +) +def test_concatenate(target, expected_implementation): + target = tvm.target.Target(target) + + shape = (1, 1, 1, 3) + dtype = "float32" + axis = 1 + inputs = [] + inputs.append(relay.var("var0", shape=shape, dtype=dtype)) + inputs.append(relay.var("var1", shape=shape, dtype=dtype)) + input_tuple = relay.Tuple(inputs) + out = relay.op.concatenate(input_tuple, axis) + out = run_infer_type(out) + + impl, xx = relay.backend.te_compiler.select_implementation( + relay.op.get("concatenate"), + out.attrs, + [te.placeholder(shape)], + out.checked_type, + target, + use_autotvm=False, + ) + assert impl.name == expected_implementation + + +if __name__ == "__main__": + tvm.testing.main()