This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch refactor-s2 in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 520ffd650eba7c1ab6d295e4b9e9c0aa4a14a5d2 Author: tqchen <[email protected]> AuthorDate: Tue Apr 22 12:39:39 2025 -0400 fix structural hash and relax tir codegen --- include/tvm/runtime/container/shape_tuple.h | 4 +++- src/contrib/msc/core/printer/msc_base_printer.cc | 3 +-- src/node/structural_hash.cc | 5 ++--- src/tir/transforms/arg_binder.cc | 2 +- src/tir/transforms/make_packed_api.cc | 11 ++++++----- tests/python/relax/test_vm_codegen_tir.py | 2 +- 6 files changed, 14 insertions(+), 13 deletions(-) diff --git a/include/tvm/runtime/container/shape_tuple.h b/include/tvm/runtime/container/shape_tuple.h index 6a0497049f..61f44d30be 100644 --- a/include/tvm/runtime/container/shape_tuple.h +++ b/include/tvm/runtime/container/shape_tuple.h @@ -24,10 +24,12 @@ #ifndef TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ #define TVM_RUNTIME_CONTAINER_SHAPE_TUPLE_H_ +#include <tvm/ffi/container/shape.h> + #include <ostream> #include <utility> #include <vector> -#include <tvm/ffi/container/shape.h> + #include "./base.h" namespace tvm { diff --git a/src/contrib/msc/core/printer/msc_base_printer.cc b/src/contrib/msc/core/printer/msc_base_printer.cc index dac732aaa6..838d284d13 100644 --- a/src/contrib/msc/core/printer/msc_base_printer.cc +++ b/src/contrib/msc/core/printer/msc_base_printer.cc @@ -114,8 +114,7 @@ void MSCBasePrinter::PrintTypedDoc(const LiteralDoc& doc) { output_ << float_imm->value; } } else if (const auto* string_obj = value.as<StringObj>()) { - output_ << "\"" << tvm::support::StrEscape(string_obj->data, string_obj->size) - << "\""; + output_ << "\"" << tvm::support::StrEscape(string_obj->data, string_obj->size) << "\""; } else { LOG(FATAL) << "TypeError: Unsupported literal value type: " << value.GetTypeKey(); } diff --git a/src/node/structural_hash.cc b/src/node/structural_hash.cc index 4c864d6613..50b8cb4e9b 100644 --- a/src/node/structural_hash.cc +++ b/src/node/structural_hash.cc @@ -316,8 +316,7 @@ struct StringObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static void SHashReduce(const runtime::StringObj* key, SHashReducer hash_reduce) { - hash_reduce->SHashReduceHashedValue( - ffi::details::StableHashBytes(key->data, key->size)); + hash_reduce->SHashReduceHashedValue(ffi::details::StableHashBytes(key->data, key->size)); } static bool SEqualReduce(const runtime::StringObj* lhs, const runtime::StringObj* rhs, @@ -497,7 +496,7 @@ struct ShapeTupleObjTrait { static constexpr const std::nullptr_t VisitAttrs = nullptr; static void SHashReduce(const ShapeTupleObj* self, SHashReducer hash_reduce) { - hash_reduce(self->size); + hash_reduce(static_cast<uint64_t>(self->size)); for (uint32_t i = 0; i < self->size; ++i) { hash_reduce(self->data[i]); } diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc index 9270a14df9..5b9e005b7e 100644 --- a/src/tir/transforms/arg_binder.cc +++ b/src/tir/transforms/arg_binder.cc @@ -151,7 +151,7 @@ inline PrimExpr TVMArrayGet(DataType t, Var arr, builtin::TVMStructFieldKind kin void ArgBinder::BindDLTensor(const Buffer& buffer, const PrimExpr& device_type, const PrimExpr& device_id, const Var& handle, - const std::string& arg_name) { + const std::string& arg_name) { const DataType tvm_shape_type = DataType::ShapeIndex(); const DataType tvm_ndim_type = DataType::Int(32); const Stmt nop = Evaluate(0); diff --git a/src/tir/transforms/make_packed_api.cc b/src/tir/transforms/make_packed_api.cc index d1931ebced..d241d43c19 100644 --- a/src/tir/transforms/make_packed_api.cc +++ b/src/tir/transforms/make_packed_api.cc @@ -293,9 +293,11 @@ PrimFunc MakePackedAPI(PrimFunc func) { // if type_index is NDArray, we need to add the offset of the DLTensor header // which always equals 16 bytes, this ensures that T.handle always shows up as a DLTensor* arg_value = f_load_arg_value(param.dtype(), i); - PrimExpr handle_from_ndarray = Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), - {arg_value, IntImm(DataType::Int(32), 16)}); - arg_value = Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); + PrimExpr handle_from_ndarray = + Call(DataType::Handle(), tir::builtin::handle_add_byte_offset(), + {arg_value, IntImm(DataType::Int(32), 16)}); + arg_value = + Select(type_index == ffi::TypeIndex::kTVMFFINDArray, handle_from_ndarray, arg_value); } else if (dtype.is_bool()) { std::ostringstream msg; msg << name_hint << ": Expect arg[" << i << "] to be boolean"; @@ -348,8 +350,7 @@ PrimFunc MakePackedAPI(PrimFunc func) { } for (const auto& [var, buffer] : buffer_def) { - binder.BindDLTensor(buffer, device_type, device_id, var, - name_hint + "." + var->name_hint); + binder.BindDLTensor(buffer, device_type, device_id, var, name_hint + "." + var->name_hint); arg_buffer_declarations.push_back(DeclBuffer(buffer, nop)); } diff --git a/tests/python/relax/test_vm_codegen_tir.py b/tests/python/relax/test_vm_codegen_tir.py index 60f096585d..41f8e81735 100644 --- a/tests/python/relax/test_vm_codegen_tir.py +++ b/tests/python/relax/test_vm_codegen_tir.py @@ -89,7 +89,7 @@ def test_tir_call(): def __vmtir__foo(ctx_ptr: T.handle, r: T.handle, c: T.handle, f: T.handle): T.func_attr({"global_symbol": "__vmtir__foo"}) T.call_cpacked( - "shape_func", T.anylist_getitem(r, T.int32(0)), T.reinterpret("handle", T.uint64(0)) + "shape_func", T.anylist_getitem(r, T.int32(0)) ) T.anylist_setitem_call_packed( r, T.int32(1), "vm.builtin.copy", T.anylist_getitem(r, T.int32(0))
