This is an automated email from the ASF dual-hosted git repository. sanirudh pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 122398995d [VM][Hexagon] Cache operations when bypass mode is enabled (#16762) 122398995d is described below commit 122398995daa43d843a81ab3aaeba4b63a02d5b9 Author: Abhikrant Sharma <quic_abhik...@quicinc.com> AuthorDate: Sat Mar 23 15:37:17 2024 +0530 [VM][Hexagon] Cache operations when bypass mode is enabled (#16762) This is needed as Hexagon DMA engine expects cache maintenance by applications --- src/runtime/relax_vm/hexagon/builtin.cc | 12 +++++++++++- tests/python/contrib/test_hexagon/test_dma_builtin.py | 6 +++--- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/runtime/relax_vm/hexagon/builtin.cc b/src/runtime/relax_vm/hexagon/builtin.cc index b32d0e14aa..586984dfc0 100644 --- a/src/runtime/relax_vm/hexagon/builtin.cc +++ b/src/runtime/relax_vm/hexagon/builtin.cc @@ -44,6 +44,9 @@ TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") CHECK_EQ(GetDataSize(*dptr), GetDataSize(*sptr)); auto size = GetDataSize(*dptr); ICHECK(size > 0); + if (bypass_cache) + qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(src), size, QURT_MEM_CACHE_INVALIDATE, + QURT_MEM_DCACHE); do { ret = tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Copy( queue_id, dst, src, size, bypass_cache); @@ -52,10 +55,17 @@ TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_copy") }); TVM_REGISTER_GLOBAL("vm.builtin.hexagon.dma_wait") - .set_body_typed([](TVMArgValue vm_ptr, int queue_id, int inflight_dma, + .set_body_typed([](TVMArgValue vm_ptr, int queue_id, int inflight_dma, bool bypass_cache, [[maybe_unused]] NDArray src_arr, [[maybe_unused]] NDArray dst_arr) { ICHECK(inflight_dma >= 0); tvm::runtime::hexagon::HexagonDeviceAPI::Global()->UserDMA()->Wait(queue_id, inflight_dma); + if (bypass_cache) { + const DLTensor* dptr = dst_arr.operator->(); + void* dst = dptr->data; + auto size = GetDataSize(*dptr); + qurt_mem_cache_clean(reinterpret_cast<qurt_addr_t>(dst), size, QURT_MEM_CACHE_FLUSH, + QURT_MEM_DCACHE); + } }); } // namespace relax_vm } // namespace runtime diff --git a/tests/python/contrib/test_hexagon/test_dma_builtin.py b/tests/python/contrib/test_hexagon/test_dma_builtin.py index e1c98ac356..86be640689 100644 --- a/tests/python/contrib/test_hexagon/test_dma_builtin.py +++ b/tests/python/contrib/test_hexagon/test_dma_builtin.py @@ -107,12 +107,12 @@ class Module_1D: ) __: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.hexagon.dma_wait", - [0, 2, x, a], + [0, 2, True, x, a], sinfo_args=[], ) __: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.hexagon.dma_wait", - [1, 1, y, b], + [1, 1, True, y, b], sinfo_args=[], ) ___: R.Tuple = cls.compute_add_in_vtcm(a, b, c) @@ -132,7 +132,7 @@ class Module_1D: ) __: R.Tuple = R.call_builtin_with_ctx( "vm.builtin.hexagon.dma_wait", - [0, 1, c, ret_val], + [0, 1, True, c, ret_val], sinfo_args=[], ) _t3: R.Tuple = R.vm.kill_object(vtcm_obj)