This is an automated email from the ASF dual-hosted git repository.
spectrometerHBH pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d0a1973ee7 [REFACTOR] Isolate backend module creation via
ffi.Module.create.<kind> registry (#19447)
d0a1973ee7 is described below
commit d0a1973ee7fc4d97f72c32aab813cb4ef4a56617
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Apr 26 15:32:49 2026 -0400
[REFACTOR] Isolate backend module creation via ffi.Module.create.<kind>
registry (#19447)
## Summary
Introduces FFI registry indirection for all backend module-creation
functions, eliminating the hard linker dependency from
`libtvm_compiler.so` on runtime-defined creator symbols. After this PR,
`nm -u libtvm_compiler.so | grep ModuleCreate` returns nothing.
## Changes
- **Registry convention**: each backend runtime `.cc` registers
`ffi.Module.create.<kind>` in a `TVM_FFI_STATIC_INIT_BLOCK` with
FFI-compatible argument types. The static init runs when the `.so` is
loaded; until then the creator is unreachable from the registry.
- **Inline header wrappers**: `XxxModuleCreate` in each backend header
becomes `inline`, does a `static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.<kind>")` (cached after
first call), and dispatches through FFI. Compiler-side codegen callsites
are unchanged.
- **Backends registered**: `ffi.Module.create.cuda`, `.rocm`, `.metal`,
`.hexagon`, `.opencl`, `.opencl.spirv`, `.vulkan`, `.const_loader`.
- **FFI arg type conversions**: `std::string` → `ffi::String`;
`std::unordered_map<string,string>` →
`ffi::Map<ffi::String,ffi::String>` (metal smap); `SPIRVShader` →
`ffi::Bytes` (serialised as flag+data uint32 stream);
`unordered_map<string,Tensor>` → `ffi::Map<ffi::String,Tensor>`;
`unordered_map<string,vector<string>>` →
`ffi::Map<ffi::String,ffi::Array<ffi::String>>`.
- **spirv\_shader.h moved**: `src/runtime/spirv/spirv_shader.h` →
`src/runtime/vulkan/spirv_shader.h` (only consumed by vulkan + opencl
runtime). Old path becomes a redirect-include for backward
compatibility.
- **Off-build stubs updated**: cuda/rocm/metal off-stubs are now empty
(inline wrappers handle the "off" case by throwing a clear
registry-not-found error); opencl/hexagon off-stubs register fallback
creators.
- **DeviceSourceModuleCreate** left as-is — its `std::function<>`
argument is not FFI-serialisable.
## Verification
- `ninja tvm_runtime tvm_compiler cpptest` clean with
`HIDE_PRIVATE_SYMBOLS=ON`
- `./cpptest`: 128/128 tests pass
- `all-platform-minimal-test`: 54 passed, 98 skipped
- `tests/python/runtime/`: 81 passed, 2 skipped
- `tests/python/relax/test_vm_build.py`: 84 passed, 2 xfailed
- JVM: BUILD SUCCESS
- `nm -u build/lib/libtvm_compiler.so | grep ModuleCreate` → (empty)
---
src/runtime/const_loader_module.cc | 23 ++++++++++--
src/runtime/const_loader_module.h | 31 ++++++++++++++--
src/runtime/cuda/cuda_module.cc | 17 ++++++---
src/runtime/cuda/cuda_module.h | 17 +++++++--
src/runtime/hexagon/hexagon_module.cc | 19 ++++++++--
src/runtime/hexagon/hexagon_module.h | 25 +++++++++----
src/runtime/metal/metal_module.h | 22 ++++++++---
src/runtime/metal/metal_module.mm | 20 +++++++---
src/runtime/opencl/opencl_module.cc | 17 ++++++---
src/runtime/opencl/opencl_module.h | 43 +++++++++++++++++++---
src/runtime/opencl/opencl_module_spirv.cc | 24 ++++++++++--
src/runtime/rocm/rocm_module.cc | 18 ++++++---
src/runtime/rocm/rocm_module.h | 18 +++++++--
src/runtime/spirv/spirv_shader.h | 55 +++-------------------------
src/runtime/{spirv => vulkan}/spirv_shader.h | 6 +--
src/runtime/vulkan/vulkan_module.cc | 32 +++++++++++++---
src/runtime/vulkan/vulkan_module.h | 38 +++++++++++++++++--
src/target/opt/build_cuda_off.cc | 15 ++------
src/target/opt/build_hexagon_off.cc | 26 ++++++++++---
src/target/opt/build_metal_off.cc | 19 ++--------
src/target/opt/build_opencl_off.cc | 35 ++++++++++++------
src/target/opt/build_rocm_off.cc | 25 ++-----------
src/target/source/codegen_metal.cc | 5 ++-
23 files changed, 362 insertions(+), 188 deletions(-)
diff --git a/src/runtime/const_loader_module.cc
b/src/runtime/const_loader_module.cc
index ae0ea73bd0..006c1f1e1a 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -250,7 +250,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj {
std::unordered_map<std::string, std::vector<std::string>>
const_vars_by_symbol_;
};
-ffi::Module ConstLoaderModuleCreate(
+static ffi::Module ConstLoaderModuleCreateImpl(
const std::unordered_map<std::string, Tensor>& const_var_tensor,
const std::unordered_map<std::string, std::vector<std::string>>&
const_vars_by_symbol) {
auto n = ffi::make_object<ConstLoaderModuleObj>(const_var_tensor,
const_vars_by_symbol);
@@ -259,8 +259,25 @@ ffi::Module ConstLoaderModuleCreate(
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("ffi.Module.load_from_bytes.const_loader",
- ConstLoaderModuleObj::LoadFromBytes);
+ refl::GlobalDef()
+ .def("ffi.Module.load_from_bytes.const_loader",
ConstLoaderModuleObj::LoadFromBytes)
+ .def("ffi.Module.create.const_loader",
+ [](ffi::Map<ffi::String, Tensor> const_var_tensor_ffi,
+ ffi::Map<ffi::String, ffi::Array<ffi::String>>
const_vars_by_symbol_ffi) {
+ std::unordered_map<std::string, Tensor> const_var_tensor;
+ for (const auto& kv : const_var_tensor_ffi) {
+ const_var_tensor[std::string(kv.first)] = kv.second;
+ }
+ std::unordered_map<std::string, std::vector<std::string>>
const_vars_by_symbol;
+ for (const auto& kv : const_vars_by_symbol_ffi) {
+ std::vector<std::string> vars;
+ for (const auto& v : kv.second) {
+ vars.push_back(std::string(v));
+ }
+ const_vars_by_symbol[std::string(kv.first)] = vars;
+ }
+ return ConstLoaderModuleCreateImpl(const_var_tensor,
const_vars_by_symbol);
+ });
}
} // namespace runtime
diff --git a/src/runtime/const_loader_module.h
b/src/runtime/const_loader_module.h
index 3bdbc1235c..c97232016d 100644
--- a/src/runtime/const_loader_module.h
+++ b/src/runtime/const_loader_module.h
@@ -25,7 +25,10 @@
#ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_
#define TVM_RUNTIME_CONST_LOADER_MODULE_H_
+#include <tvm/ffi/container/array.h>
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/tensor.h>
@@ -39,15 +42,37 @@ namespace runtime {
/*!
* \brief Create a ConstLoader module object.
*
- * \param const_var_tensor Maps consts var name to Tensor containing data for
the var.
+ * \param const_var_tensor Maps const var name to Tensor containing data for
the var.
* \param const_vars_by_symbol Maps the name of a module init function to a
list of names of
* const vars whose data will be passed to that init function.
*
* \return The created ConstLoaderModule.
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.const_loader").
+ * The creator is always available (ConstLoaderModule is a runtime-universal
module).
*/
-TVM_RUNTIME_DLL ffi::Module ConstLoaderModuleCreate(
+inline ffi::Module ConstLoaderModuleCreate(
const std::unordered_map<std::string, Tensor>& const_var_tensor,
- const std::unordered_map<std::string, std::vector<std::string>>&
const_vars_by_symbol);
+ const std::unordered_map<std::string, std::vector<std::string>>&
const_vars_by_symbol) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.const_loader");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.const_loader is not registered in runtime. "
+ << "Ensure libtvm_runtime is loaded.";
+ // Convert to FFI-compatible types.
+ ffi::Map<ffi::String, Tensor> ffi_const_var_tensor;
+ for (const auto& kv : const_var_tensor) {
+ ffi_const_var_tensor.Set(kv.first, kv.second);
+ }
+ ffi::Map<ffi::String, ffi::Array<ffi::String>> ffi_const_vars_by_symbol;
+ for (const auto& kv : const_vars_by_symbol) {
+ ffi::Array<ffi::String> vars;
+ for (const auto& v : kv.second) {
+ vars.push_back(ffi::String(v));
+ }
+ ffi_const_vars_by_symbol.Set(kv.first, vars);
+ }
+ return (*fcreate)(ffi_const_var_tensor,
ffi_const_vars_by_symbol).cast<ffi::Module>();
+}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index d06f5a9d5c..691e44f0a0 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -308,8 +308,9 @@ ffi::Optional<ffi::Function>
CUDAModuleNode::GetFunction(const ffi::String& name
return PackFuncVoidAddr(f, info->arg_types, info->arg_extra_tags);
}
-ffi::Module CUDAModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string cuda_source) {
+static ffi::Module CUDAModuleCreateImpl(std::string data, std::string fmt,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ std::string cuda_source) {
auto n = ffi::make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source);
return ffi::Module(n);
}
@@ -322,7 +323,7 @@ ffi::Module CUDAModuleLoadFile(const std::string&
file_name, const ffi::String&
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
- return CUDAModuleCreate(data, fmt, fmap, std::string());
+ return CUDAModuleCreateImpl(data, fmt, fmap, std::string());
}
ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) {
@@ -333,7 +334,7 @@ ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes&
bytes) {
stream.Read(&fmt);
TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&data);
- return CUDAModuleCreate(data, fmt, fmap, std::string());
+ return CUDAModuleCreateImpl(data, fmt, fmap, std::string());
}
TVM_FFI_STATIC_INIT_BLOCK() {
@@ -342,7 +343,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile)
.def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile)
.def("ffi.Module.load_from_file.cubin", CUDAModuleLoadFile)
- .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes);
+ .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes)
+ .def("ffi.Module.create.cuda",
+ [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String cuda_source) {
+ return CUDAModuleCreateImpl(std::string(data), std::string(fmt),
fmap,
+ std::string(cuda_source));
+ });
}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h
index 2a2b1068d7..1bd94332ef 100644
--- a/src/runtime/cuda/cuda_module.h
+++ b/src/runtime/cuda/cuda_module.h
@@ -25,6 +25,7 @@
#define TVM_RUNTIME_CUDA_CUDA_MODULE_H_
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
#include <memory>
@@ -46,10 +47,20 @@ static constexpr const int kMaxNumGPUs = 32;
* \param fmt The format of the data, can be "ptx", "cubin"
* \param fmap The map function information map of each function.
* \param cuda_source Optional, CUDA source file
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.cuda").
+ * Requires libtvm_runtime built with USE_CUDA=ON to have registered the
creator.
*/
-TVM_RUNTIME_DLL ffi::Module CUDAModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String,
FunctionInfo> fmap,
- std::string cuda_source);
+inline ffi::Module CUDAModuleCreate(ffi::String data, ffi::String fmt,
+ ffi::Map<ffi::String, FunctionInfo> fmap,
+ ffi::String cuda_source) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.cuda");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.cuda is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_CUDA=ON.";
+ return (*fcreate)(data, fmt, fmap, cuda_source).cast<ffi::Module>();
+}
+
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_
diff --git a/src/runtime/hexagon/hexagon_module.cc
b/src/runtime/hexagon/hexagon_module.cc
index dd9d74c202..5546e31b8a 100644
--- a/src/runtime/hexagon/hexagon_module.cc
+++ b/src/runtime/hexagon/hexagon_module.cc
@@ -25,6 +25,7 @@
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/registry.h>
#include <tvm/support/io.h>
#include <string>
@@ -89,12 +90,24 @@ ffi::Bytes HexagonModuleNode::SaveToBytes() const {
return ffi::Bytes(std::move(result));
}
-ffi::Module HexagonModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string asm_str,
- std::string obj_str, std::string ir_str,
std::string bc_str) {
+static ffi::Module HexagonModuleCreateImpl(std::string data, std::string fmt,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ std::string asm_str, std::string
obj_str,
+ std::string ir_str, std::string
bc_str) {
auto n = ffi::make_object<HexagonModuleNode>(data, fmt, fmap, asm_str,
obj_str, ir_str, bc_str);
return ffi::Module(n);
}
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("ffi.Module.create.hexagon", [](ffi::String data,
ffi::String fmt,
+ ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String asm_str,
ffi::String obj_str,
+ ffi::String ir_str,
ffi::String bc_str) {
+ return HexagonModuleCreateImpl(std::string(data), std::string(fmt), fmap,
std::string(asm_str),
+ std::string(obj_str), std::string(ir_str),
std::string(bc_str));
+ });
+}
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/hexagon/hexagon_module.h
b/src/runtime/hexagon/hexagon_module.h
index df3fa7b5bb..eeae7b32b5 100644
--- a/src/runtime/hexagon/hexagon_module.h
+++ b/src/runtime/hexagon/hexagon_module.h
@@ -21,6 +21,7 @@
#define TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/logging.h>
#include <array>
@@ -38,14 +39,24 @@ namespace runtime {
* \param data The module data.
* \param fmt The format of the data, can be "obj".
* \param fmap The function information map of each function.
- * \param asm_str ffi::String with the generated assembly source.
- * \param obj_str ffi::String with the object file data.
- * \param ir_str ffi::String with the disassembled LLVM IR source.
- * \param bc_str ffi::String with the bitcode LLVM IR.
+ * \param asm_str String with the generated assembly source.
+ * \param obj_str String with the object file data.
+ * \param ir_str String with the disassembled LLVM IR source.
+ * \param bc_str String with the bitcode LLVM IR.
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.hexagon").
+ * Requires libtvm_runtime built with USE_HEXAGON=ON to have registered the
creator.
*/
-ffi::Module HexagonModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string asm_str,
- std::string obj_str, std::string ir_str,
std::string bc_str);
+inline ffi::Module HexagonModuleCreate(ffi::String data, ffi::String fmt,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ ffi::String asm_str, ffi::String
obj_str, ffi::String ir_str,
+ ffi::String bc_str) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.hexagon");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.hexagon is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_HEXAGON=ON.";
+ return (*fcreate)(data, fmt, fmap, asm_str, obj_str, ir_str,
bc_str).cast<ffi::Module>();
+}
/*!
\brief Module implementation for compiled Hexagon binaries. It is suitable
diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h
index 4534cede53..fe9454f674 100644
--- a/src/runtime/metal/metal_module.h
+++ b/src/runtime/metal/metal_module.h
@@ -24,7 +24,9 @@
#ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_
#define TVM_RUNTIME_METAL_METAL_MODULE_H_
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <memory>
#include <string>
@@ -41,14 +43,24 @@ static constexpr const int kMetalMaxNumDevice = 32;
/*!
* \brief create a metal module from data.
*
- * \param smap The map from name to each shader kernel.
+ * \param smap The map from name to each shader kernel (FFI-typed).
* \param fmap The map function information map of each function.
* \param fmt The format of the source, can be "metal" or "metallib"
- * \param source Optional, source file, concatenaed for debug dump
+ * \param source Optional, source file, concatenated for debug dump
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.metal").
+ * Requires libtvm_runtime built with USE_METAL=ON to have registered the
creator.
*/
-ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string>
smap,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string fmt,
- std::string source);
+inline ffi::Module MetalModuleCreate(ffi::Map<ffi::String, ffi::String> smap,
+ ffi::Map<ffi::String, FunctionInfo> fmap,
ffi::String fmt,
+ ffi::String source) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.metal");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.metal is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_METAL=ON.";
+ return (*fcreate)(smap, fmap, fmt, source).cast<ffi::Module>();
+}
+
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_METAL_METAL_MODULE_H_
diff --git a/src/runtime/metal/metal_module.mm
b/src/runtime/metal/metal_module.mm
index 6837404ad3..22168538e3 100644
--- a/src/runtime/metal/metal_module.mm
+++ b/src/runtime/metal/metal_module.mm
@@ -272,9 +272,9 @@ ffi::Optional<ffi::Function>
MetalModuleNode::GetFunction(const ffi::String& nam
return ret;
}
-ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string>
smap,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string fmt,
- std::string source) {
+static ffi::Module MetalModuleCreateImpl(std::unordered_map<std::string,
std::string> smap,
+ ffi::Map<ffi::String, FunctionInfo>
fmap, std::string fmt,
+ std::string source) {
ObjectPtr<MetalModuleNode> n;
AUTORELEASEPOOL { n = ffi::make_object<MetalModuleNode>(smap, fmap, fmt,
source); };
return ffi::Module(n);
@@ -295,7 +295,7 @@ TVM_FFI_STATIC_INIT_BLOCK() {
fmap.Set(kv.first.cast<ffi::String>(),
FunctionInfo(std::move(info_node)));
}
- return MetalModuleCreate(
+ return MetalModuleCreateImpl(
std::unordered_map<std::string, std::string>(smap.begin(),
smap.end()), fmap, fmt,
source);
});
@@ -315,12 +315,20 @@ ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes&
bytes) {
TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&fmt);
- return MetalModuleCreate(smap, fmap, fmt, "");
+ return MetalModuleCreateImpl(smap, fmap, fmt, "");
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
- refl::GlobalDef().def("ffi.Module.load_from_bytes.metal",
MetalModuleLoadFromBytes);
+ refl::GlobalDef()
+ .def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes)
+ .def("ffi.Module.create.metal",
+ [](ffi::Map<ffi::String, ffi::String> smap, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String fmt, ffi::String source) {
+ return MetalModuleCreateImpl(
+ std::unordered_map<std::string, std::string>(smap.begin(),
smap.end()), fmap,
+ std::string(fmt), std::string(source));
+ });
}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/opencl/opencl_module.cc
b/src/runtime/opencl/opencl_module.cc
index c7f873a021..c5afaeba7d 100644
--- a/src/runtime/opencl/opencl_module.cc
+++ b/src/runtime/opencl/opencl_module.cc
@@ -360,8 +360,9 @@ ffi::Optional<ffi::Function>
OpenCLModuleNode::GetFunction(const ffi::String& na
return OpenCLModuleNodeBase::GetFunction(name);
}
-ffi::Module OpenCLModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string source) {
+static ffi::Module OpenCLModuleCreateImpl(std::string data, std::string fmt,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ std::string source) {
auto n = ffi::make_object<OpenCLModuleNode>(data, fmt, fmap, source);
n->Init();
return ffi::Module(n);
@@ -375,7 +376,7 @@ ffi::Module OpenCLModuleLoadFile(const std::string&
file_name, const ffi::String
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
- return OpenCLModuleCreate(data, fmt, fmap, std::string());
+ return OpenCLModuleCreateImpl(data, fmt, fmap, std::string());
}
ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) {
@@ -386,7 +387,7 @@ ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes&
bytes) {
stream.Read(&fmt);
TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&data);
- return OpenCLModuleCreate(data, fmt, fmap, std::string());
+ return OpenCLModuleCreateImpl(data, fmt, fmap, std::string());
}
TVM_FFI_STATIC_INIT_BLOCK() {
@@ -394,7 +395,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
refl::GlobalDef()
.def("ffi.Module.load_from_file.cl", OpenCLModuleLoadFile)
.def("ffi.Module.load_from_file.clbin", OpenCLModuleLoadFile)
- .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes);
+ .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes)
+ .def("ffi.Module.create.opencl",
+ [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String source) {
+ return OpenCLModuleCreateImpl(std::string(data),
std::string(fmt), fmap,
+ std::string(source));
+ });
}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/opencl/opencl_module.h
b/src/runtime/opencl/opencl_module.h
index 1e9bb88c93..6697badd48 100644
--- a/src/runtime/opencl/opencl_module.h
+++ b/src/runtime/opencl/opencl_module.h
@@ -24,15 +24,18 @@
#ifndef TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
#define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
#include <memory>
#include <string>
+#include <unordered_map>
#include <vector>
+#include "../../support/bytes_io.h"
#include "../metadata.h"
-#include "../spirv/spirv_shader.h"
+#include "../vulkan/spirv_shader.h"
namespace tvm {
namespace runtime {
@@ -43,10 +46,19 @@ namespace runtime {
* \param fmt The format of the data, can be "clbin", "cl"
* \param fmap The map function information map of each function.
* \param source Generated OpenCL kernels.
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.opencl").
+ * Requires libtvm_runtime built with USE_OPENCL=ON to have registered the
creator.
*/
-TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate(std::string data, std::string
fmt,
- ffi::Map<ffi::String,
FunctionInfo> fmap,
- std::string source);
+inline ffi::Module OpenCLModuleCreate(ffi::String data, ffi::String fmt,
+ ffi::Map<ffi::String, FunctionInfo> fmap,
+ ffi::String source) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.opencl");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.opencl is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_OPENCL=ON.";
+ return (*fcreate)(data, fmt, fmap, source).cast<ffi::Module>();
+}
/*!
* \brief Create a opencl module from SPIRV.
@@ -54,10 +66,29 @@ TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate(std::string
data, std::string fmt
* \param shaders The map from function names to SPIRV binaries.
* \param spirv_text The concatenated text representation of SPIRV modules.
* \param fmap The map function information map of each function.
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.opencl.spirv").
+ * Each SPIRVShader is serialised to ffi::Bytes before crossing the FFI
boundary.
+ * Requires libtvm_runtime built with USE_OPENCL=ON and TVM_ENABLE_SPIRV to
have
+ * registered the creator.
*/
-TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate(
+inline ffi::Module OpenCLModuleCreate(
const std::unordered_map<std::string, spirv::SPIRVShader>& shaders,
- const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap);
+ const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.opencl.spirv");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.opencl.spirv is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_OPENCL=ON and
TVM_ENABLE_SPIRV.";
+ // Serialise each SPIRVShader to ffi::Bytes for the FFI boundary.
+ ffi::Map<ffi::String, ffi::Bytes> shader_bytes;
+ for (const auto& kv : shaders) {
+ std::string buf;
+ support::BytesOutStream strm(&buf);
+ strm.Write(kv.second);
+ shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
+ }
+ return (*fcreate)(shader_bytes, ffi::String(spirv_text),
fmap).cast<ffi::Module>();
+}
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_
diff --git a/src/runtime/opencl/opencl_module_spirv.cc
b/src/runtime/opencl/opencl_module_spirv.cc
index 0125c4121a..12b6df9aa1 100644
--- a/src/runtime/opencl/opencl_module_spirv.cc
+++ b/src/runtime/opencl/opencl_module_spirv.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/ffi/function.h>
+#include <tvm/ffi/reflection/registry.h>
#include <tvm/support/io.h>
#include <string>
@@ -129,13 +130,30 @@ cl_kernel
OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC
return kernel;
}
-ffi::Module OpenCLModuleCreate(const std::unordered_map<std::string,
SPIRVShader>& shaders,
- const std::string& spirv_text,
- ffi::Map<ffi::String, FunctionInfo> fmap) {
+static ffi::Module OpenCLSPIRVModuleCreateImpl(
+ const std::unordered_map<std::string, SPIRVShader>& shaders, const
std::string& spirv_text,
+ ffi::Map<ffi::String, FunctionInfo> fmap) {
auto n = ffi::make_object<OpenCLSPIRVModuleNode>(shaders, spirv_text, fmap);
n->Init();
return ffi::Module(n);
}
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def("ffi.Module.create.opencl.spirv",
+ [](ffi::Map<ffi::String, ffi::Bytes> shader_bytes,
ffi::String spirv_text,
+ ffi::Map<ffi::String, FunctionInfo> fmap) {
+ std::unordered_map<std::string, SPIRVShader> shaders;
+ for (const auto& kv : shader_bytes) {
+ support::BytesInStream stream(kv.second);
+ SPIRVShader shader;
+ TVM_FFI_ICHECK(stream.Read(&shader));
+ shaders[std::string(kv.first)] = shader;
+ }
+ return OpenCLSPIRVModuleCreateImpl(shaders,
std::string(spirv_text),
+ fmap);
+ });
+}
+
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc
index 56f929c3c2..4d8eabcf13 100644
--- a/src/runtime/rocm/rocm_module.cc
+++ b/src/runtime/rocm/rocm_module.cc
@@ -208,9 +208,9 @@ ffi::Optional<ffi::Function>
ROCMModuleNode::GetFunction(const ffi::String& name
return PackFuncPackedArgAligned(f, info->arg_types);
}
-ffi::Module ROCMModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string hip_source,
- std::string assembly) {
+static ffi::Module ROCMModuleCreateImpl(std::string data, std::string fmt,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ std::string hip_source, std::string
assembly) {
auto n = ffi::make_object<ROCMModuleNode>(data, fmt, fmap, hip_source,
assembly);
return ffi::Module(n);
}
@@ -222,7 +222,7 @@ ffi::Module ROCMModuleLoadFile(const std::string&
file_name, const std::string&
std::string meta_file = GetMetaFilePath(file_name);
LoadBinaryFromFile(file_name, &data);
LoadMetaDataFromFile(meta_file, &fmap);
- return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
+ return ROCMModuleCreateImpl(data, fmt, fmap, std::string(), std::string());
}
ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) {
@@ -233,7 +233,7 @@ ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes&
bytes) {
stream.Read(&fmt);
TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&data);
- return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string());
+ return ROCMModuleCreateImpl(data, fmt, fmap, std::string(), std::string());
}
TVM_FFI_STATIC_INIT_BLOCK() {
@@ -242,7 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def("ffi.Module.load_from_bytes.hsaco", ROCMModuleLoadFromBytes)
.def("ffi.Module.load_from_bytes.hip", ROCMModuleLoadFromBytes)
.def("ffi.Module.load_from_file.hsaco", ROCMModuleLoadFile)
- .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile);
+ .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile)
+ .def("ffi.Module.create.rocm",
+ [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String hip_source, ffi::String assembly) {
+ return ROCMModuleCreateImpl(std::string(data), std::string(fmt),
fmap,
+ std::string(hip_source),
std::string(assembly));
+ });
}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h
index 78f6d86d9c..666f73493f 100644
--- a/src/runtime/rocm/rocm_module.h
+++ b/src/runtime/rocm/rocm_module.h
@@ -25,6 +25,7 @@
#define TVM_RUNTIME_ROCM_ROCM_MODULE_H_
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <memory>
#include <string>
@@ -45,10 +46,21 @@ static constexpr const int kMaxNumGPUs = 32;
* \param fmt The format of the data, can be "hsaco"
* \param fmap The map function information map of each function.
* \param rocm_source Optional, rocm source file
+ * \param assembly Optional, GCN assembly source
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.rocm").
+ * Requires libtvm_runtime built with USE_ROCM=ON to have registered the
creator.
*/
-ffi::Module ROCMModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string rocm_source,
- std::string assembly);
+inline ffi::Module ROCMModuleCreate(ffi::String data, ffi::String fmt,
+ ffi::Map<ffi::String, FunctionInfo> fmap,
+ ffi::String rocm_source, ffi::String
assembly) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.rocm");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.rocm is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_ROCM=ON.";
+ return (*fcreate)(data, fmt, fmap, rocm_source,
assembly).cast<ffi::Module>();
+}
+
} // namespace runtime
} // namespace tvm
#endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_
diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/spirv/spirv_shader.h
index 202b85b243..11d578c6e2 100644
--- a/src/runtime/spirv/spirv_shader.h
+++ b/src/runtime/spirv/spirv_shader.h
@@ -17,57 +17,14 @@
* under the License.
*/
+/*!
+ * \file src/runtime/spirv/spirv_shader.h
+ * \brief Deprecated include — SPIRVShader has moved to
src/runtime/vulkan/spirv_shader.h.
+ * This header is kept for backward compatibility; include the new path
directly.
+ */
#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
-#include <tvm/ffi/function.h>
-#include <tvm/runtime/base.h>
-#include <tvm/runtime/device_api.h>
-#include <tvm/runtime/logging.h>
-#include <tvm/support/io.h>
-#include <tvm/support/serializer.h>
-
-#include <vector>
-
-namespace tvm {
-namespace runtime {
-namespace spirv {
-
-struct SPIRVShader {
- /*! \brief header flag */
- uint32_t flag{0};
- /*! \brief Data segment */
- std::vector<uint32_t> data;
-
- void Save(support::Stream* writer) const {
- writer->Write(flag);
- writer->Write(data);
- }
- bool Load(support::Stream* reader) {
- if (!reader->Read(&flag)) return false;
- if (!reader->Read(&data)) return false;
- return true;
- }
-};
-
-} // namespace spirv
-
-using spirv::SPIRVShader;
-} // namespace runtime
-} // namespace tvm
+#include "../vulkan/spirv_shader.h"
-namespace tvm {
-namespace support {
-template <>
-struct Serializer<::tvm::runtime::spirv::SPIRVShader> {
- static constexpr bool enabled = true;
- static void Write(Stream* strm, const ::tvm::runtime::spirv::SPIRVShader&
data) {
- data.Save(strm);
- }
- static bool Read(Stream* strm, ::tvm::runtime::spirv::SPIRVShader* data) {
- return data->Load(strm);
- }
-};
-} // namespace support
-} // namespace tvm
#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
diff --git a/src/runtime/spirv/spirv_shader.h
b/src/runtime/vulkan/spirv_shader.h
similarity index 93%
copy from src/runtime/spirv/spirv_shader.h
copy to src/runtime/vulkan/spirv_shader.h
index 202b85b243..f290d0dbd1 100644
--- a/src/runtime/spirv/spirv_shader.h
+++ b/src/runtime/vulkan/spirv_shader.h
@@ -17,8 +17,8 @@
* under the License.
*/
-#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
-#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
+#ifndef TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_
+#define TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_
#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
@@ -70,4 +70,4 @@ struct Serializer<::tvm::runtime::spirv::SPIRVShader> {
};
} // namespace support
} // namespace tvm
-#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_
+#endif // TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_
diff --git a/src/runtime/vulkan/vulkan_module.cc
b/src/runtime/vulkan/vulkan_module.cc
index 9267115351..bdb048adbe 100644
--- a/src/runtime/vulkan/vulkan_module.cc
+++ b/src/runtime/vulkan/vulkan_module.cc
@@ -25,14 +25,27 @@
#include "../../support/bytes_io.h"
#include "../file_utils.h"
+#include "spirv_shader.h"
#include "vulkan_wrapped_func.h"
namespace tvm {
namespace runtime {
namespace vulkan {
-ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShader>
smap,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string source) {
+/*!
+ * \brief Deserialize a SPIRVShader from ffi::Bytes.
+ * Format: flag (uint32_t) followed by data (vector<uint32_t>).
+ */
+static SPIRVShader DeserializeSPIRVShader(const ffi::Bytes& bytes) {
+ support::BytesInStream stream(bytes);
+ SPIRVShader shader;
+ TVM_FFI_ICHECK(stream.Read(&shader));
+ return shader;
+}
+
+static ffi::Module VulkanModuleCreateImpl(std::unordered_map<std::string,
SPIRVShader> smap,
+ ffi::Map<ffi::String, FunctionInfo>
fmap,
+ std::string source) {
auto n = ffi::make_object<VulkanModuleNode>(smap, fmap, source);
return ffi::Module(n);
}
@@ -50,7 +63,7 @@ ffi::Module VulkanModuleLoadFile(const std::string&
file_name, const ffi::String
stream.Read(&magic);
TVM_FFI_ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic
mismatch";
stream.Read(&smap);
- return VulkanModuleCreate(smap, fmap, "");
+ return VulkanModuleCreateImpl(smap, fmap, "");
}
ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) {
@@ -62,14 +75,23 @@ ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes&
bytes) {
ffi::Map<ffi::String, FunctionInfo> fmap;
TVM_FFI_ICHECK(stream.Read(&fmap));
stream.Read(&smap);
- return VulkanModuleCreate(smap, fmap, "");
+ return VulkanModuleCreateImpl(smap, fmap, "");
}
TVM_FFI_STATIC_INIT_BLOCK() {
namespace refl = tvm::ffi::reflection;
refl::GlobalDef()
.def("ffi.Module.load_from_file.vulkan", VulkanModuleLoadFile)
- .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes);
+ .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes)
+ .def("ffi.Module.create.vulkan",
+ [](ffi::Map<ffi::String, ffi::Bytes> shader_bytes,
+ ffi::Map<ffi::String, FunctionInfo> fmap, ffi::String source) {
+ std::unordered_map<std::string, SPIRVShader> smap;
+ for (const auto& kv : shader_bytes) {
+ smap[std::string(kv.first)] = DeserializeSPIRVShader(kv.second);
+ }
+ return VulkanModuleCreateImpl(smap, fmap, std::string(source));
+ });
}
} // namespace vulkan
diff --git a/src/runtime/vulkan/vulkan_module.h
b/src/runtime/vulkan/vulkan_module.h
index 2337f3cc79..87df473753 100644
--- a/src/runtime/vulkan/vulkan_module.h
+++ b/src/runtime/vulkan/vulkan_module.h
@@ -20,21 +20,51 @@
#ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
#define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
#include <string>
#include <unordered_map>
+#include "../../support/bytes_io.h"
#include "../metadata.h"
-#include "../spirv/spirv_shader.h"
+#include "spirv_shader.h"
namespace tvm {
namespace runtime {
namespace vulkan {
-TVM_RUNTIME_DLL ffi::Module VulkanModuleCreate(std::unordered_map<std::string,
SPIRVShader> smap,
- ffi::Map<ffi::String,
FunctionInfo> fmap,
- std::string source);
+
+/*!
+ * \brief Create a Vulkan module from SPIRV shaders.
+ *
+ * \param smap Map from function name to SPIRVShader.
+ * \param fmap Map from function name to FunctionInfo.
+ * \param source Optional SPIRV text (for inspection).
+ *
+ * Dispatches through the FFI registry ("ffi.Module.create.vulkan").
+ * Each SPIRVShader is serialised to ffi::Bytes before crossing the FFI
boundary
+ * and rehydrated on the runtime side.
+ * Requires libtvm_runtime built with USE_VULKAN=ON to have registered the
creator.
+ */
+inline ffi::Module VulkanModuleCreate(std::unordered_map<std::string,
SPIRVShader> smap,
+ ffi::Map<ffi::String, FunctionInfo> fmap,
+ std::string source) {
+ static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.vulkan");
+ TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
+ << "ffi.Module.create.vulkan is not registered in runtime. "
+ << "Link or load libtvm_runtime built with USE_VULKAN=ON.";
+ // Serialise each SPIRVShader to ffi::Bytes for the FFI boundary.
+ ffi::Map<ffi::String, ffi::Bytes> shader_bytes;
+ for (const auto& kv : smap) {
+ std::string buf;
+ support::BytesOutStream strm(&buf);
+ strm.Write(kv.second);
+ shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
+ }
+ return (*fcreate)(shader_bytes, fmap,
ffi::String(source)).cast<ffi::Module>();
+}
} // namespace vulkan
diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc
index bf5c5d63d4..e9e4351c89 100644
--- a/src/target/opt/build_cuda_off.cc
+++ b/src/target/opt/build_cuda_off.cc
@@ -18,16 +18,7 @@
*/
/*!
- * Optional module when build CUDA is switched to off
+ * Optional module when build CUDA is switched to off.
+ * CUDAModuleCreate is now an inline registry-lookup wrapper in cuda_module.h,
+ * so no out-of-line stub is needed here.
*/
-#include "../../runtime/cuda/cuda_module.h"
-namespace tvm {
-namespace runtime {
-
-ffi::Module CUDAModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string cuda_source) {
- TVM_FFI_THROW(InternalError) << "CUDA is not enabled";
- TVM_FFI_UNREACHABLE();
-}
-} // namespace runtime
-} // namespace tvm
diff --git a/src/target/opt/build_hexagon_off.cc
b/src/target/opt/build_hexagon_off.cc
index 7fcb2b51a4..08450cf171 100644
--- a/src/target/opt/build_hexagon_off.cc
+++ b/src/target/opt/build_hexagon_off.cc
@@ -17,16 +17,32 @@
* under the License.
*/
+/*!
+ * Optional module when Hexagon runtime is switched to off.
+ * When ffi.Module.create.hexagon is not registered, HexagonModuleCreate (the
inline
+ * wrapper) raises a clear RuntimeError. Fall back to a DeviceSourceModule
for
+ * compilation-only (source inspection) workflows instead.
+ */
+#include "../../runtime/hexagon/hexagon_module.h"
#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-ffi::Module HexagonModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string asm_str,
- std::string obj_str, std::string ir_str,
std::string bc_str) {
- LOG(WARNING) << "Hexagon runtime is not enabled, return a source module...";
- return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex");
+// Register a fallback creator so that compiler-side code that calls
+// HexagonModuleCreate() when USE_HEXAGON=OFF still gets a usable
+// DeviceSourceModule (for source inspection / serialisation) rather than a
+// registry-not-found error.
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef().def(
+ "ffi.Module.create.hexagon",
+ [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String /*asm_str*/, ffi::String /*obj_str*/, ffi::String
/*ir_str*/,
+ ffi::String /*bc_str*/) -> ffi::Module {
+ LOG(WARNING) << "Hexagon runtime is not enabled, returning a source
module...";
+ return codegen::DeviceSourceModuleCreate(std::string(data),
std::string(fmt), fmap, "hex");
+ });
}
} // namespace runtime
diff --git a/src/target/opt/build_metal_off.cc
b/src/target/opt/build_metal_off.cc
index 7f544d92f6..fae5d511c6 100644
--- a/src/target/opt/build_metal_off.cc
+++ b/src/target/opt/build_metal_off.cc
@@ -18,20 +18,7 @@
*/
/*!
- * Optional module when build metal is switched to off
+ * Optional module when build metal is switched to off.
+ * MetalModuleCreate is now an inline registry-lookup wrapper in
metal_module.h,
+ * so no out-of-line stub is needed here.
*/
-#include "../../runtime/metal/metal_module.h"
-#include "../source/codegen_source_base.h"
-
-namespace tvm {
-namespace runtime {
-
-ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string>
smap,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string fmt,
- std::string source) {
- LOG(WARNING) << "Metal runtime not enabled, return a source module...";
- return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal");
-}
-
-} // namespace runtime
-} // namespace tvm
diff --git a/src/target/opt/build_opencl_off.cc
b/src/target/opt/build_opencl_off.cc
index 1a27866a4c..e30725f442 100644
--- a/src/target/opt/build_opencl_off.cc
+++ b/src/target/opt/build_opencl_off.cc
@@ -18,24 +18,35 @@
*/
/*!
- * Optional module when build opencl is switched to off
+ * Optional module when build opencl is switched to off.
+ * Register fallback creators so that compiler-side code (codegen_opencl.cc)
+ * that calls OpenCLModuleCreate() when USE_OPENCL=OFF still gets a usable
+ * DeviceSourceModule for source inspection / serialisation workflows.
*/
-#include "../../runtime/opencl/opencl_module.h"
+#include <tvm/ffi/reflection/registry.h>
+
+#include "../../runtime/metadata.h"
#include "../source/codegen_source_base.h"
namespace tvm {
namespace runtime {
-ffi::Module OpenCLModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string source) {
- return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl");
-}
-
-ffi::Module OpenCLModuleCreate(const std::unordered_map<std::string,
SPIRVShader>& shaders,
- const std::string& spirv_text,
- ffi::Map<ffi::String, FunctionInfo> fmap) {
- TVM_FFI_THROW(InternalError) << "OpenCLModuleCreate is called but OpenCL is
not enabled.";
- TVM_FFI_UNREACHABLE();
+TVM_FFI_STATIC_INIT_BLOCK() {
+ namespace refl = tvm::ffi::reflection;
+ refl::GlobalDef()
+ .def("ffi.Module.create.opencl",
+ [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String,
FunctionInfo> fmap,
+ ffi::String /*source*/) -> ffi::Module {
+ return codegen::DeviceSourceModuleCreate(std::string(data),
std::string(fmt), fmap,
+ "opencl");
+ })
+ .def("ffi.Module.create.opencl.spirv",
+ [](ffi::Map<ffi::String, ffi::Bytes> /*shader_bytes*/, ffi::String
/*spirv_text*/,
+ ffi::Map<ffi::String, FunctionInfo> /*fmap*/) -> ffi::Module {
+ TVM_FFI_THROW(InternalError)
+ << "OpenCLModuleCreate (SPIRV) is called but OpenCL is not
enabled.";
+ TVM_FFI_UNREACHABLE();
+ });
}
} // namespace runtime
diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc
index ea1265ad29..634c8252c8 100644
--- a/src/target/opt/build_rocm_off.cc
+++ b/src/target/opt/build_rocm_off.cc
@@ -18,26 +18,7 @@
*/
/*!
- * Optional module when build rocm is switched to off
+ * Optional module when build rocm is switched to off.
+ * ROCMModuleCreate is now an inline registry-lookup wrapper in rocm_module.h,
+ * so no out-of-line stub is needed here.
*/
-#include "../../runtime/rocm/rocm_module.h"
-#include "../source/codegen_source_base.h"
-
-namespace tvm {
-namespace runtime {
-
-ffi::Module ROCMModuleCreate(std::string data, std::string fmt,
- ffi::Map<ffi::String, FunctionInfo> fmap,
std::string rocm_source,
- std::string assembly) {
- LOG(WARNING) << "ROCM runtime is not enabled, return a source module...";
- auto fget_source = [rocm_source, assembly](const std::string& format) {
- if (format.length() == 0) return assembly;
- if (format == "ll" || format == "llvm") return rocm_source;
- if (format == "asm") return assembly;
- return std::string("");
- };
- return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco",
fget_source);
-}
-
-} // namespace runtime
-} // namespace tvm
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 6831596c81..c4734c54bb 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -22,6 +22,7 @@
*/
#include "codegen_metal.h"
+#include <tvm/ffi/container/map.h>
#include <tvm/ffi/reflection/registry.h>
#include <tvm/tirx/transform.h>
@@ -447,7 +448,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) {
mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod));
std::ostringstream source_maker;
- std::unordered_map<std::string, std::string> smap;
+ ffi::Map<ffi::String, ffi::String> smap;
const auto fmetal_compile =
tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile");
std::string fmt = fmetal_compile ? "metallib" : "metal";
@@ -472,7 +473,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) {
if (fmetal_compile) {
fsource = (*fmetal_compile)(fsource, target).cast<std::string>();
}
- smap[func_name] = fsource;
+ smap.Set(func_name, fsource);
}
return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt,
source_maker.str());