This is an automated email from the ASF dual-hosted git repository. masahi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new 073304dadb [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232) 073304dadb is described below commit 073304dadb91ce70b3198cab8b3ae98ee4061b26 Author: Yaoda Zhou <judap...@sjtu.edu.cn> AuthorDate: Wed Aug 17 16:33:37 2022 +0800 [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean tensor support (#12232) * first commit * rename * cmake * deprecated * newline * config * config * typo * skip tvm_class * rename * delete ptr * delete ptr * save progress * boolean support * cmake file * polish code * compile config * improving the codes * format * doc&errormsg * zero-cost copy * one step * to ndarray * extra output * delete extra codes * update test * boolean support * strong test * decrease memory copy * polish * reformat * polish * remove redundant import Co-authored-by: juda <yz...@octoml.ai> --- apps/pt_tvmdsoop/tests/test_as_torch.py | 7 +- apps/pt_tvmdsoop/tests/test_boolean_tensor.py | 129 ++++++++++ cmake/modules/contrib/PT_TVMDSOOP.cmake | 68 ++++-- python/tvm/contrib/torch/__init__.py | 25 +- python/tvm/contrib/torch/module.py | 17 ++ python/tvm/contrib/torch/pytorch_tvm.py | 21 ++ .../torch/pt_call_tvm/RuntimeModuleWrapper.cc | 259 -------------------- .../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc | 266 +++++++++++++++++++++ .../RuntimeModuleWrapperTorch.cc | 215 +++++++++++++++++ .../torch/tvm_module_wrapper/runtime_bridge.h | 116 +++++++++ 10 files changed, 844 insertions(+), 279 deletions(-) diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py b/apps/pt_tvmdsoop/tests/test_as_torch.py index 2c454e9454..a13d669e7f 100644 --- a/apps/pt_tvmdsoop/tests/test_as_torch.py +++ b/apps/pt_tvmdsoop/tests/test_as_torch.py @@ -17,6 +17,8 @@ # specific language governing permissions and limitations # under the License. """Test script for tvm torch module""" +import tempfile + import numpy as np import torch @@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu(): q1 = torch.arange(8, device=cuda0).type(torch.float32) q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0) - ModuleGPU(q1, q2) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(ModuleGPU, tmp.name) + loaded_mod = torch.load(tmp.name) + loaded_mod(q1, q2) tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), atol=1e-5, rtol=1e-5) diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py new file mode 100644 index 0000000000..4718b40439 --- /dev/null +++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test script for boolean tensor support""" +import tempfile + +import torch + +import tvm +import tvm.testing +from tvm.contrib.torch import as_torch, optimize_torch +from tvm.script import tir as T + + +def negate(x): + return x.logical_not() + + +def sum_up_tensor(x): + return x.size(dim=0) - torch.sum(x.int()) + + +def tensor_boolean_operation(x): + arr1 = (x + 0.3).floor().bool() + arr2 = (~((x + 0.7).int().bool())).bool() + ret = ((arr1 & arr2).byte() + 0.5).half() + return ~(ret.bool()) + + +def test_bool_tensor_negate(): + input = torch.ones(1, dtype=torch.bool) + optimized_negate = optimize_torch( + negate, + input, + ) + with tempfile.NamedTemporaryFile(suffix=".pt") as tmp: + torch.save(optimized_negate, tmp.name) + loaded_mod = torch.load(tmp.name) + output = loaded_mod(negate(input)) + tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, rtol=1e-5) + + +def test_sum_up_tensor(): + x = torch.randint(0, 2, (16,)) + y = x.bool() + optimized_func = optimize_torch( + sum_up_tensor, + (y,), + ) + ret1 = (x[x == 0]).size(dim=0) + ret2 = optimized_func(y).numpy() + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + + +def test_tensor_boolean_operation(): + input = torch.rand(200) + model = optimize_torch( + tensor_boolean_operation, + input, + ) + ret1 = tensor_boolean_operation(input) + ret2 = model(input) + tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5) + + +@as_torch +@T.prim_func +def negate_tvmscript( + X: T.Buffer[(8, 8), "bool"], + Y: T.Buffer[(8, 8), "float32"], + Z: T.Buffer[(8, 8), "bool"], + U: T.Buffer[(8, 8), "float32"], +) -> None: + for i, j in T.grid(8, 8): + with T.block(): + if Y[i, j] > 0.0: + Z[i, j] = X[i, j] + U[i, j] = Y[i, j] + else: + Z[i, j] = not X[i, j] + U[i, j] = 0.0 - Y[i, j] + + +def negate_vanila(x, y): + z = torch.zeros(8, 8).bool() + for i in range(8): + for j in range(8): + if y[i, j] > 0: + z[i, j] = x[i, j] + else: + z[i, j] = ~x[i, j] + return z + + +def test_tvmscript_torch_decorator(): + q1 = (torch.rand(8, 8) + 0.5).int().bool() + q2 = torch.rand(8, 8) - 0.5 + q3 = torch.zeros(8, 8).bool() + q4 = torch.zeros(8, 8) + + std1 = negate_vanila(q1, q2) + std2 = torch.abs(q2) + + negate_tvmscript(q1, q2, q3, q4) + + tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5) + tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5) + + +if __name__ == "__main__": + test_tvmscript_torch_decorator() + test_bool_tensor_negate() + test_sum_up_tensor() + test_tensor_boolean_operation() diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake b/cmake/modules/contrib/PT_TVMDSOOP.cmake index 3bad3fd966..a73d3f38e9 100644 --- a/cmake/modules/contrib/PT_TVMDSOOP.cmake +++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake @@ -6,7 +6,7 @@ # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # -# http://www.apache.org/licenses/LICENSE-2.0 +# http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an @@ -17,42 +17,80 @@ if(NOT USE_PT_TVMDSOOP STREQUAL "OFF") find_package(PythonInterp REQUIRED) - execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; print(torch.__path__[0].strip())" OUTPUT_VARIABLE PT_PATH RESULT_VARIABLE PT_STATUS) - if (NOT ${PT_STATUS} EQUAL 0) + + if(NOT ${PT_STATUS} EQUAL 0) message(FATAL_ERROR "Fail to get pytorch path") endif() string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}") message(STATUS "PyTorch path: ${PT_PATH}") - set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0") + execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch;print(torch.compiled_with_cxx11_abi())" + OUTPUT_VARIABLE PT_CXX_FLAG + RESULT_VARIABLE PT_STATUS) + + string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}") + message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ") + + if(${PT_CXX_FLAG} STREQUAL "False") + set(CXX_ABI_ENABLED 0) + else() + set(CXX_ABI_ENABLED 1) + endif() + + set_property( + SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc + APPEND PROPERTY + COMPILE_OPTIONS + "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}" + "-I${PT_PATH}/include" + ) + + set_property( + SOURCE + ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc + APPEND PROPERTY + COMPILE_OPTIONS + "-I${PT_PATH}/include" + ) + set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so -l:libtorch_python.so") if(NOT USE_CUDA STREQUAL "OFF") add_definitions(-DPT_TVMDSOOP_ENABLE_GPU) endif() - string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} ${PT_LINK_FLAGS}") - separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR}) + separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND) separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR}) + # This old version is depereated and will be removed after tvm 0.11 + set(LIBRARY_OLD_NAME pt_tvmdsoop) - set(LIBRARY_NAME pt_tvmdsoop) - tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc) - add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS}) + # This new library is set for pytorch integration, which solves the c++ abi imcompability issue + set(LIBRARY_NEW_NAME pt_tvmdsoop_new) + tvm_file_glob(GLOB_RECURSE PTTVM_TORCH ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc) + + tvm_file_glob(GLOB_RECURSE PTTVM_SRCS ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc) + + add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS}) + add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH}) set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR}) - if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") - add_dependencies(${LIBRARY_NAME} tvm) + if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON") + add_dependencies(${LIBRARY_OLD_NAME} tvm) + add_dependencies(${LIBRARY_NEW_NAME} tvm) endif() - target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) - target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) - target_compile_definitions(${LIBRARY_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>) + target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>) + target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} ${PT_COMPILE_FLAGS}) + target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} ${PT_LINK_FLAGS}) + target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>) endif() - diff --git a/python/tvm/contrib/torch/__init__.py b/python/tvm/contrib/torch/__init__.py index 340f9cef9e..c3dd34d470 100644 --- a/python/tvm/contrib/torch/__init__.py +++ b/python/tvm/contrib/torch/__init__.py @@ -18,11 +18,12 @@ """Module container of Pytorch custom class""" import os import platform +import warnings import torch from tvm._ffi import libinfo -def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): +def _load_platform_specific_library(lib_name): system = platform.system() if system == "Darwin": lib_file_name = lib_name + ".dylib" @@ -33,11 +34,27 @@ def _load_platform_specific_library(lib_name="libpt_tvmdsoop"): lib_path = libinfo.find_lib_path()[0] lib_dir = os.path.dirname(lib_path) lib_file_path = os.path.join(lib_dir, lib_file_name) - torch.classes.load_library(lib_file_path) + try: + torch.classes.load_library(lib_file_path) + except OSError as err: + errmsg = str(err) + if errmsg.find("undefined symbol") != -1: + reason = " ".join( + ( + "Got undefined symbol error,", + "which might be due to the CXXABI incompatibility.", + ) + ) + else: + reason = errmsg + warnings.warn( + f"The library {lib_name} is not built successfully. {reason}", + RuntimeWarning, + ) -_load_platform_specific_library() - +_load_platform_specific_library("libpt_tvmdsoop") +_load_platform_specific_library("libpt_tvmdsoop_new") from . import module diff --git a/python/tvm/contrib/torch/module.py b/python/tvm/contrib/torch/module.py index 3da9c6f591..cfa3ad264c 100644 --- a/python/tvm/contrib/torch/module.py +++ b/python/tvm/contrib/torch/module.py @@ -16,7 +16,9 @@ # under the License. # pylint: disable=invalid-name """Module container of PyTorch custom class""" +import warnings from typing import List + import torch @@ -29,6 +31,11 @@ class GraphModule(torch.nn.Module): return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) def __init__(self, num_inputs, num_outputs, device=None): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.dummy_param = torch.nn.Parameter(torch.empty(0)) self.engine = None @@ -67,6 +74,11 @@ class VMModule(torch.nn.Module): return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes) def __init__(self, num_inputs, num_outputs, device=None): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.dummy_param = torch.nn.Parameter(torch.empty(0)) self.engine = None @@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module): """ def __init__(self, tvm_module): + warnings.warn( + "This module will be removed at TVM version 0.11", + DeprecationWarning, + stacklevel=2, + ) super().__init__() self.tvm_module = tvm_module diff --git a/python/tvm/contrib/torch/pytorch_tvm.py b/python/tvm/contrib/torch/pytorch_tvm.py index 1e50c98ab8..ffab4fa0d2 100644 --- a/python/tvm/contrib/torch/pytorch_tvm.py +++ b/python/tvm/contrib/torch/pytorch_tvm.py @@ -19,6 +19,7 @@ # pylint: disable=redefined-builtin """`compile` api that convert torch module to torch tvm module""" import os +import warnings import tvm import tvm.testing from tvm import relay, autotvm @@ -183,6 +184,16 @@ class PyTorchTVMModule: def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None): """Build pytorch module containing TVM Graph Module""" + warnings.warn( + " ".join( + ( + "This function will be removed at TVM version 0.11,", + "we suggest users to use `optimized_torch` for tuning Torch modules instead.", + ) + ), + DeprecationWarning, + stacklevel=2, + ) assert self.export_dir, "you must build_tvm or load_tvm before" input_infos = input_infos or self.input_infos assert input_infos @@ -224,6 +235,16 @@ def compile(script_module, option): pytorch_tvm_module = compile(script_module, option) pytorch_tvm_module("model_tvm.pt") """ + warnings.warn( + " ".join( + ( + "This function will be removed at TVM version 0.11,", + "we suggest users to use `optimized_torch` for tuning Torch modules instead.", + ) + ), + DeprecationWarning, + stacklevel=2, + ) input_infos = option["input_infos"] default_dtype = option.get("default_dtype", "float32") export_dir = option.get("export_dir", "pytorch_compiled") diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc deleted file mode 100644 index 12c1017bea..0000000000 --- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc +++ /dev/null @@ -1,259 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -#include <ATen/DLConvertor.h> -#include <dlpack/dlpack.h> -#include <dmlc/memory_io.h> -#include <torch/custom_class.h> -#include <torch/script.h> -#include <tvm/runtime/module.h> -#include <tvm/runtime/registry.h> -#include <tvm/target/codegen.h> -#include <tvm/target/target.h> - -#include <cstdio> -#include <map> -#include <string> -#include <vector> - -#include "../../../runtime/graph_executor/graph_executor_factory.h" -#include "../base64.h" - -namespace tvm { -namespace contrib { - -/** - * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize such TVM objects - */ -struct ThreadLocalStore { - tvm::runtime::Module mod; - static ThreadLocalStore* ThreadLocal() { - thread_local ThreadLocalStore tls; - return &tls; - } -}; - -using SerializationType = std::string; // base64 stream - -SerializationType serialize(tvm::runtime::Module module) { - static const runtime::PackedFunc* f_to_str = - runtime::Registry::Get("script_torch.save_to_base64"); - ICHECK(f_to_str) << "IndexError: Cannot find the packed function " - "`script_torch.save_to_base64` in the global registry"; - return (*f_to_str)(module); -} - -struct Deleter { // deleter - explicit Deleter(std::string file_name) { this->file_name = file_name; } - void operator()(FILE* p) const { - fclose(p); - ICHECK(remove(file_name.c_str()) == 0) - << "Failed to remove temporary file (" << file_name << ")"; - } - std::string file_name; -}; - -tvm::runtime::Module deserialize(SerializationType state) { - auto length = tvm::support::b64strlen(state); - - std::vector<u_char> bytes(length); - tvm::support::b64decode(state, bytes.data()); - - const std::string name = tmpnam(NULL); - auto file_name = name + ".so"; - std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); - fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); - fflush(pFile.get()); - - std::string load_f_name = "runtime.module.loadfile_so"; - const PackedFunc* f = runtime::Registry::Get(load_f_name); - ICHECK(f != nullptr) << "Loader for `.so` files is not registered," - << " resolved to (" << load_f_name << ") in the global registry." - << "Ensure that you have loaded the correct runtime code, and" - << "that you are on the correct hardware architecture."; - - tvm::runtime::Module ret = (*f)(file_name, ""); - - return ret; -} - -/** - * @brief A Torch's module which wraps TVM's OperatorModule Class. - * The basic forward function calling TVM's runtime is provided. - * The TVM module can be serialized/deserialized as a Torch module. - */ -class OperatorModuleWrapper : public torch::jit::CustomClassHolder { - public: - OperatorModuleWrapper() { runtime_module = ThreadLocalStore::ThreadLocal()->mod; } - - void forward(const c10::List<at::Tensor>& inputs) { - int input_length = inputs.size(); - - std::vector<DLManagedTensor*> tensors; - - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); - - tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__"); - - std::vector<TVMValue> tvm_values(input_length); - std::vector<int> tvm_type_codes(input_length); - tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); - for (int k = 0; k < input_length; ++k) { - setter(k, &tensors[k]->dl_tensor); - } - - run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_length), - nullptr); - - for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); - } - } - - SerializationType Serialize() { return serialize(runtime_module); } - - explicit OperatorModuleWrapper(SerializationType state) { runtime_module = deserialize(state); } - - private: - tvm::runtime::Module runtime_module; -}; - -tvm::Device getDevice(const at::Tensor& tensor) { - tvm::Device dev; - dev.device_id = tensor.get_device(); - switch (tensor.device().type()) { - case at::DeviceType::CPU: - dev.device_type = DLDeviceType::kDLCPU; - if (dev.device_id == -1) { - /* - * In PyTorch the device ID for cpu is -1, sometimes causing error during tuning - * Thus we manually set the device ID as 0 for avoiding potentially error of index out of - * bounds - */ - dev.device_id = 0; - } - break; - case at::DeviceType::CUDA: - dev.device_type = DLDeviceType::kDLCUDA; - break; - default: - TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + tensor.device().str()); - } - return dev; -} - -/** - * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. - * The basic forward function calling TVM's runtime is provided. - * The TVM module can be serialized/deserialized as a Torch module. - */ -class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { - public: - explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory) - : executor_factory_(executor_factory) { - CHECK(executor_factory_->IsInstance<runtime::GraphExecutorFactory>()) - << "module is not an instance of GraphExecutorFactory"; - } - - GraphExecutorFactoryWrapper() - : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {} - - c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) { - int input_length = inputs.size(); - - if (!executor_.defined()) { - TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); - DLDevice input_device = getDevice(inputs.get(0)); - - auto tmp = executor_factory_.GetFunction("default"); - - executor_ = tmp(input_device); - } - - std::vector<DLManagedTensor*> tensors; - - for (int i = 0; i < input_length; ++i) tensors.push_back(toDLPack(inputs[i])); - - tvm::runtime::PackedFunc run = executor_.GetFunction("run"); - tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input"); - tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output"); - tvm::runtime::PackedFunc get_num_outputs = executor_.GetFunction("get_num_outputs"); - - for (int k = 0; k < input_length; ++k) { - set_input(k, &tensors[k]->dl_tensor); - } - - run(); - - int64_t output_length = get_num_outputs(); - - c10::List<at::Tensor> outputs; - outputs.reserve(output_length); - - for (int k = 0; k < output_length; ++k) { - tvm::runtime::NDArray results = get_output(k); - at::Tensor atTensor = at::fromDLPack(results.ToDLPack()); - outputs.emplace_back(atTensor); - } - - for (int k = 0; k < input_length; ++k) { - tensors[k]->deleter(tensors[k]); - } - return outputs; - } - - SerializationType Serialize() { return serialize(executor_factory_); } - - explicit GraphExecutorFactoryWrapper(SerializationType state) { - executor_factory_ = deserialize(state); - } - - private: - tvm::runtime::Module executor_factory_; - tvm::runtime::Module executor_; -}; - -TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { - ThreadLocalStore::ThreadLocal()->mod = mod; -}); - -TORCH_LIBRARY(tvm_torch, m) { - m.class_<OperatorModuleWrapper>("OperatorModuleWrapper") - .def(torch::init<>()) - .def("forward", &OperatorModuleWrapper::forward) - .def_pickle( - [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> SerializationType { - return self->Serialize(); - }, - [](SerializationType state) { - return c10::make_intrusive<OperatorModuleWrapper>(state); - }); - m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper") - .def(torch::init<>()) - .def("forward", &GraphExecutorFactoryWrapper::forward) - .def_pickle( - [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> SerializationType { - return self->Serialize(); - }, - [](SerializationType state) { - return c10::make_intrusive<GraphExecutorFactoryWrapper>(state); - }); -} - -} // namespace contrib -} // namespace tvm diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc new file mode 100644 index 0000000000..fb570c163f --- /dev/null +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc @@ -0,0 +1,266 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <dlpack/dlpack.h> +#include <dmlc/memory_io.h> +#include <tvm/runtime/module.h> +#include <tvm/runtime/registry.h> +#include <tvm/target/codegen.h> +#include <tvm/target/target.h> + +#include <cstdio> +#include <map> +#include <string> +#include <vector> + +#include "../../../runtime/graph_executor/graph_executor_factory.h" +#include "../base64.h" +#include "runtime_bridge.h" + +namespace tvm { +namespace contrib { + +/* + * TVM's FFI for passing module from python to C++ + */ +struct ThreadLocalStore { + tvm::runtime::Module mod; + static ThreadLocalStore* ThreadLocal() { + thread_local ThreadLocalStore tls; + return &tls; + } +}; + +/* + * Encode TVM runtime module to base64 stream + */ +std::string serialize(tvm::runtime::Module module) { + static const runtime::PackedFunc* f_to_str = + runtime::Registry::Get("script_torch.save_to_base64"); + ICHECK(f_to_str) << "IndexError: Cannot find the packed function " + "`script_torch.save_to_base64` in the global registry"; + return (*f_to_str)(module); +} + +struct Deleter { // deleter + explicit Deleter(std::string file_name) { this->file_name = file_name; } + void operator()(FILE* p) const { + fclose(p); + ICHECK(remove(file_name.c_str()) == 0) + << "remove temporary file (" << file_name << ") unsuccessfully"; + } + std::string file_name; +}; + +/* + * Decode TVM runtime module from base64 stream + */ +tvm::runtime::Module deserialize(std::string state) { + auto length = tvm::support::b64strlen(state); + + std::vector<u_char> bytes(length); // bytes stream + tvm::support::b64decode(state, bytes.data()); + + const std::string name = tmpnam(NULL); + auto file_name = name + ".so"; + std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), Deleter(file_name)); + fwrite(bytes.data(), sizeof(u_char), length, pFile.get()); + fflush(pFile.get()); + + std::string load_f_name = "runtime.module.loadfile_so"; + const PackedFunc* f = runtime::Registry::Get(load_f_name); + ICHECK(f != nullptr) << "Loader for `.so` files is not registered," + << " resolved to (" << load_f_name << ") in the global registry." + << "Ensure that you have loaded the correct runtime code, and" + << "that you are on the correct hardware architecture."; + + tvm::runtime::Module ret = (*f)(file_name, ""); + + return ret; +} + +TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module mod) { + ThreadLocalStore::ThreadLocal()->mod = mod; +}); + +/* + * Convert NDArray to DLPack extend tensor. It should be zero-cost. + * @param src Pointer to NDArray + * @return DLPack extended tensor + */ +DLPackTensorExt CreateDLpackTensorExt(tvm::runtime::NDArray* src) { + auto is_bool = src->DataType().is_bool(); + DLManagedTensor* tensor; + if (is_bool) { + // If we change DLDataType{kDLInt, 8, 1} to DataType::Bool() + // we will get `RuntimeError: Unsupported kUInt bits 1` + auto tmp = src->CreateView(src->Shape(), DLDataType{kDLInt, 8, 1}); + tensor = tmp.ToDLPack(); + } else { + tensor = src->ToDLPack(); + } + DLPackTensorExt ret{tensor, is_bool}; + return ret; +} + +/* + * Create an NDArray with boolean type. (One memory copy) + * @param src DLpack extended tensor + * @return a new NDArray + */ +tvm::runtime::NDArray CreateBoolNDarray(DLPackTensorExt* src) { + auto& tensor = src->dl_managed_tensor->dl_tensor; + std::vector<int64_t> shape; + for (int64_t i = 0; i < tensor.ndim; i++) { + shape.push_back(tensor.shape[i]); + } + auto ret = tvm::runtime::NDArray::Empty(shape, DataType::Bool(), tensor.device); + ret.CopyFrom(&src->dl_managed_tensor->dl_tensor); + return std::move(ret); +} + +bool IsZeroCopy(DLPackTensorExt* src) { + auto& dl_tensor = src->dl_managed_tensor->dl_tensor; + return tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, dl_tensor.device); +} + +/* + * Create an NDArray from DLpack extended tensor. + * @param src DLpack extended tensor + * @return a new NDArray + */ +tvm::runtime::NDArray NDarrayFromDLpack(DLPackTensorExt* src) { + using tvm::runtime::NDArray; + + NDArray array; + auto& dl_tensor = src->dl_managed_tensor->dl_tensor; + if (src->is_bool) { + // one memory copy + // the code is similar to NewFromDLTensor except for the type + array = CreateBoolNDarray(src); + } else if (IsZeroCopy(src)) { + array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor); + } else { + // one memory copy + array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device); + } + return array; +} + +} // namespace contrib +} // namespace tvm + +extern "C" { + +struct TVMContribTorchRuntimeModule { + tvm::runtime::Module mod; + + explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) {} +}; + +bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt* src) { + return (!src->is_bool) && (tvm::contrib::IsZeroCopy(src)); +} + +TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module() { + return new TVMContribTorchRuntimeModule(tvm::contrib::ThreadLocalStore::ThreadLocal()->mod); +} + +void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, + DLPackTensorExt* inputs, size_t input_size) { + tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("__tvm_main__"); + + std::vector<TVMValue> tvm_values(input_size); + std::vector<int> tvm_type_codes(input_size); + tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data()); + + std::vector<tvm::runtime::NDArray> input_cache(input_size); + + for (size_t k = 0; k < input_size; ++k) { + auto datum = tvm::contrib::NDarrayFromDLpack(&inputs[k]); // could have one memory copy + input_cache[k] = datum; // we keep the datum in a vector for future use, otherwise the datum + // will be freed after the loop + setter(k, datum); + } + + run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), tvm_type_codes.data(), input_size), + nullptr); + + for (size_t k = 0; k < input_size; ++k) { + if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k])) + input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor); + } +} + +TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( + TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example) { + tvm::runtime::PackedFunc built_module = graph_executor_factory->mod.GetFunction("default"); + tvm::Device device_info = input_example->dl_tensor.device; + tvm::runtime::Module runtime_module = built_module(device_info); + return new TVMContribTorchRuntimeModule(runtime_module); +} + +size_t tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* runtime_module, + DLPackTensorExt* inputs, size_t input_size, + DLPackTensorExt** outputs) { + tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("run"); + tvm::runtime::PackedFunc set_input = runtime_module->mod.GetFunction("set_input"); + tvm::runtime::PackedFunc get_output = runtime_module->mod.GetFunction("get_output"); + tvm::runtime::PackedFunc get_num_outputs = runtime_module->mod.GetFunction("get_num_outputs"); + + for (size_t k = 0; k < input_size; ++k) { + set_input(k, &inputs[k].dl_managed_tensor->dl_tensor); + } + + run(); + + int64_t output_length = get_num_outputs(); + + DLPackTensorExt* outputs_ptr = new DLPackTensorExt[output_length]; + *outputs = outputs_ptr; + + for (int64_t k = 0; k < output_length; ++k) { + tvm::runtime::NDArray results = get_output(k); + outputs_ptr[k] = tvm::contrib::CreateDLpackTensorExt(&results); + } + + return output_length; +} + +char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) { + std::string std = tvm::contrib::serialize(runtime_module->mod); + char* ret = new char[std.length() + 1]; + snprintf(ret, std.length() + 1, "%s", std.c_str()); + return ret; +} + +TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) { + tvm::runtime::Module ret = tvm::contrib::deserialize(state); + return new TVMContribTorchRuntimeModule(ret); +} + +void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr) { + delete module_ptr; +} + +void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt* dlpack_ptr) { + delete[] dlpack_ptr; +} + +void tvm_contrib_torch_free_encoding(char* encoding) { delete[] encoding; } +} diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc new file mode 100644 index 0000000000..3159438d72 --- /dev/null +++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc @@ -0,0 +1,215 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include <ATen/DLConvertor.h> +#include <torch/custom_class.h> +#include <torch/script.h> + +#include <iostream> + +#include "runtime_bridge.h" + +namespace tvm { +namespace contrib { + +/* + * Convert Torch tensor to DLPack extended tensor. + * The boolean Torch tensor will convert to DLtensor with `is_bool=True` flag. + * @param src Torch tensor + * @return DLPack extended tensor + */ +DLPackTensorExt ToDLPackExt(const at::Tensor& src) { + if (!src.is_contiguous()) { + return ToDLPackExt(src.contiguous()); + } + DLPackTensorExt ret; + if (src.dtype().isScalarType(torch::kBool)) { + auto temp = src.toType(torch::kUInt8); + ret.dl_managed_tensor = at::toDLPack(temp); + ret.is_bool = true; + } else { + ret.dl_managed_tensor = at::toDLPack(src); + ret.is_bool = false; + } + + return ret; +} + +/* + * Convert DLPack extended tensor to Torch tensor. + * @param src DLPack extended tensor + * @return Torch tensor + */ +at::Tensor FromDLPackExt(const DLPackTensorExt& src) { + if (src.is_bool) { + return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool); + } else { + return at::fromDLPack(src.dl_managed_tensor); + } +} + +/** + * @brief A Torch's module which wraps TVM's OperatorModule Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class OperatorModuleWrapper : public torch::jit::CustomClassHolder { + public: + OperatorModuleWrapper() { runtime_module_ = tvm_contrib_torch_get_last_saved_runtime_module(); } + ~OperatorModuleWrapper() { tvm_contrib_torch_free_runtime_module(runtime_module_); } + + void forward(const c10::List<at::Tensor>& inputs) { + int input_length = inputs.size(); + + std::vector<DLPackTensorExt> tensors; + + // Torch tensor supports boolean type while DLpack does not, + // we convert Torch tensor to an extension of DLPack tensor + for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i])); + tvm_contrib_torch_operator_module_forward(this->runtime_module_, tensors.data(), + tensors.size()); + + for (int k = 0; k < input_length; ++k) { + if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) { + // We need to free memory manually + tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + } else { + // Ownership transferred + inputs[k].copy_(FromDLPackExt(tensors[k])); + } + } + } + + std::string Serialize() { + auto encoding = tvm_contrib_torch_encode(runtime_module_); + auto ret = std::string(encoding); + tvm_contrib_torch_free_encoding(encoding); + return ret; + } + + explicit OperatorModuleWrapper(std::string state) { + runtime_module_ = tvm_contrib_torch_decode(state.c_str()); + } + + private: + /* + * TVM runtime module wrapper + */ + TVMContribTorchRuntimeModule* runtime_module_; +}; + +/** + * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class. + * The basic forward function calling TVM's runtime is provided. + * The TVM module can be serialized/deserialized as a Torch module. + */ +class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder { + public: + explicit GraphExecutorFactoryWrapper(TVMContribTorchRuntimeModule* executor_factory) + : executor_factory_(executor_factory), executor_factory_runtime_(nullptr) {} + + ~GraphExecutorFactoryWrapper() { + tvm_contrib_torch_free_runtime_module(executor_factory_); + tvm_contrib_torch_free_runtime_module(executor_factory_runtime_); + } + + GraphExecutorFactoryWrapper() + : GraphExecutorFactoryWrapper(tvm_contrib_torch_get_last_saved_runtime_module()) {} + + std::string Serialize() { + auto encoding = tvm_contrib_torch_encode(executor_factory_); + auto ret = std::string(encoding); + tvm_contrib_torch_free_encoding(encoding); + return ret; + } + + explicit GraphExecutorFactoryWrapper(std::string state) { + executor_factory_ = tvm_contrib_torch_decode(state.c_str()); + executor_factory_runtime_ = nullptr; + } + + c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) { + int input_length = inputs.size(); + + TORCH_CHECK(input_length > 0, "Receive empty list of input tensors"); + + std::vector<DLPackTensorExt> tensors; + + // Torch tensor supports boolean type while DLpack does not, + // we convert Torch tensor to an extension of DLPack tensor + for (int i = 0; i < input_length; ++i) tensors.push_back(ToDLPackExt(inputs[i])); + + DLPackTensorExt* outputs; + if (executor_factory_runtime_ == nullptr) { + executor_factory_runtime_ = tvm_contrib_torch_create_graph_runtime_module( + this->executor_factory_, tensors[0].dl_managed_tensor); + } + auto num_outputs = tvm_contrib_torch_graph_executor_module_forward( + executor_factory_runtime_, tensors.data(), tensors.size(), &outputs); + + c10::List<at::Tensor> ret; + ret.reserve(num_outputs); + + for (size_t k = 0; k < num_outputs; ++k) { + at::Tensor atTensor = FromDLPackExt(outputs[k]); + ret.emplace_back(atTensor); + } + + for (int k = 0; k < input_length; ++k) { + tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor); + } + tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs); + + return ret; + } + + private: + /* + * TVM Graph Executor Factory module wrapper + */ + TVMContribTorchRuntimeModule* executor_factory_; + + /* + * TVM runtime module wrapper + */ + TVMContribTorchRuntimeModule* executor_factory_runtime_; +}; + +TORCH_LIBRARY(tvm_torch, m) { + m.class_<OperatorModuleWrapper>("OperatorModuleWrapper") + .def(torch::init<>()) + .def("forward", &OperatorModuleWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> std::string { + return self->Serialize(); + }, + [](std::string state) { return c10::make_intrusive<OperatorModuleWrapper>(state); }); + m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper") + .def(torch::init<>()) + .def("forward", &GraphExecutorFactoryWrapper::forward) + .def_pickle( + [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> std::string { + return self->Serialize(); + }, + [](std::string state) { + return c10::make_intrusive<GraphExecutorFactoryWrapper>(state); + }); +} + +} // namespace contrib +} // namespace tvm diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h new file mode 100644 index 0000000000..58cd53a284 --- /dev/null +++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +/*! + * \file runtime_bridge.h + * \brief Util functions for pytorch tvm interaction. + */ +#ifndef TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ +#define TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_ + +extern "C" { + +/* + * DLPack data structure extend with `is_bool` flag. + * DLPack haven't support boolean tensor + * (https://github.com/pytorch/pytorch/blob/4618371da56c887195e2e1d16dad2b9686302800/aten/src/ATen/DLConvertor.cpp#L42), + * thus a boolean tensor will be regarded as a UInt8 tensor + * (https://github.com/apache/tvm/blob/de124862714e747764aa8b7f41a90bcb25f3c6a8/python/tvm/_ffi/runtime_ctypes.py#L91). + */ +struct DLPackTensorExt { + DLManagedTensor* dl_managed_tensor; + bool is_bool; +}; + +/* + * A wrapper pointing to TVM runtime module. + */ +struct TVMContribTorchRuntimeModule; + +/* + * Obtain a saved runtime module passed by TVM FFI. + * @return A TVM runtime module wrapper. + */ +TVMContribTorchRuntimeModule* tvm_contrib_torch_get_last_saved_runtime_module(); + +/* + * Delete TVMContribTorchRuntimeModule pointer. + */ +void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* module_ptr); + +/* + * Obtain ExecutorFactory runtime module from ExecutorFactory class. + * @param graph_executor_factory ExecutorFactory class + * @param input_example For obtaining device information + * @return ExecutorFactory TVM runtime module wrapper + */ +TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module( + TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* input_example); + +/* + * Forward method for OperatorModuleWrapper. + * @param runtime_module TVM runtime module wrapper + * @param inputs Array pointer of the input tensors + * @param input_size The number of input tensors + */ +void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* runtime_module, + DLPackTensorExt* inputs, size_t input_size); + +/* + * Forward method for GraphExecutorFactoryWrapper. + * @param graph_executor_factory TVM runtime module wrapper + * @param inputs Array pointer of the input tensors + * @param input_size The number of input tensors + * @param outputs The resulting output tensors pointer + * @return The number of output tensors + */ +size_t tvm_contrib_torch_graph_executor_module_forward( + TVMContribTorchRuntimeModule* graph_executor_factory, DLPackTensorExt* inputs, + size_t input_size, DLPackTensorExt** outputs); + +/* + * Encode TVM runtime module. + * @param runtime_module TVM runtime module wrapper + * @return The encoding stream (char array) + */ +char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module); + +/* + * Decode TVM runtime module. + * @param state The encoding stream (char array) of TVM runtime module + * @return TVM runtime module wrapper + */ +TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state); + +/* + * Delete DLPackTensorExt pointer. + */ +void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*); + +/* + * Delete char array pointer. + */ +void tvm_contrib_torch_free_encoding(char* encoding); + +/* + * Checking if a DLPackTensorExt is boolean or cannot be copied in zero cost. + */ +bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt*); +} + +#endif // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_