This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 073304dadb [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility 
& boolean tensor support (#12232)
073304dadb is described below

commit 073304dadb91ce70b3198cab8b3ae98ee4061b26
Author: Yaoda Zhou <judap...@sjtu.edu.cn>
AuthorDate: Wed Aug 17 16:33:37 2022 +0800

    [TVM PyTorch Integration] libstdc++ CXX11 ABI Compatibility & boolean 
tensor support (#12232)
    
    * first commit
    
    * rename
    
    * cmake
    
    * deprecated
    
    * newline
    
    * config
    
    * config
    
    * typo
    
    * skip tvm_class
    
    * rename
    
    * delete ptr
    
    * delete ptr
    
    * save progress
    
    * boolean support
    
    * cmake file
    
    * polish code
    
    * compile config
    
    * improving the codes
    
    * format
    
    * doc&errormsg
    
    * zero-cost copy
    
    * one step
    
    * to ndarray
    
    * extra output
    
    * delete extra codes
    
    * update test
    
    * boolean support
    
    * strong test
    
    * decrease memory copy
    
    * polish
    
    * reformat
    
    * polish
    
    * remove redundant import
    
    Co-authored-by: juda <yz...@octoml.ai>
---
 apps/pt_tvmdsoop/tests/test_as_torch.py            |   7 +-
 apps/pt_tvmdsoop/tests/test_boolean_tensor.py      | 129 ++++++++++
 cmake/modules/contrib/PT_TVMDSOOP.cmake            |  68 ++++--
 python/tvm/contrib/torch/__init__.py               |  25 +-
 python/tvm/contrib/torch/module.py                 |  17 ++
 python/tvm/contrib/torch/pytorch_tvm.py            |  21 ++
 .../torch/pt_call_tvm/RuntimeModuleWrapper.cc      | 259 --------------------
 .../tvm_module_wrapper/RuntimeModuleWrapperTVM.cc  | 266 +++++++++++++++++++++
 .../RuntimeModuleWrapperTorch.cc                   | 215 +++++++++++++++++
 .../torch/tvm_module_wrapper/runtime_bridge.h      | 116 +++++++++
 10 files changed, 844 insertions(+), 279 deletions(-)

diff --git a/apps/pt_tvmdsoop/tests/test_as_torch.py 
b/apps/pt_tvmdsoop/tests/test_as_torch.py
index 2c454e9454..a13d669e7f 100644
--- a/apps/pt_tvmdsoop/tests/test_as_torch.py
+++ b/apps/pt_tvmdsoop/tests/test_as_torch.py
@@ -17,6 +17,8 @@
 # specific language governing permissions and limitations
 # under the License.
 """Test script for tvm torch module"""
+import tempfile
+
 import numpy as np
 
 import torch
@@ -190,7 +192,10 @@ def test_tvmscript_torch_gpu():
     q1 = torch.arange(8, device=cuda0).type(torch.float32)
     q2 = torch.zeros((8,), dtype=torch.float32, device=cuda0)
 
-    ModuleGPU(q1, q2)
+    with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+        torch.save(ModuleGPU, tmp.name)
+        loaded_mod = torch.load(tmp.name)
+        loaded_mod(q1, q2)
 
     tvm.testing.assert_allclose(q2.cpu().numpy(), (q1 + 1).cpu().numpy(), 
atol=1e-5, rtol=1e-5)
 
diff --git a/apps/pt_tvmdsoop/tests/test_boolean_tensor.py 
b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py
new file mode 100644
index 0000000000..4718b40439
--- /dev/null
+++ b/apps/pt_tvmdsoop/tests/test_boolean_tensor.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test script for boolean tensor support"""
+import tempfile
+
+import torch
+
+import tvm
+import tvm.testing
+from tvm.contrib.torch import as_torch, optimize_torch
+from tvm.script import tir as T
+
+
+def negate(x):
+    return x.logical_not()
+
+
+def sum_up_tensor(x):
+    return x.size(dim=0) - torch.sum(x.int())
+
+
+def tensor_boolean_operation(x):
+    arr1 = (x + 0.3).floor().bool()
+    arr2 = (~((x + 0.7).int().bool())).bool()
+    ret = ((arr1 & arr2).byte() + 0.5).half()
+    return ~(ret.bool())
+
+
+def test_bool_tensor_negate():
+    input = torch.ones(1, dtype=torch.bool)
+    optimized_negate = optimize_torch(
+        negate,
+        input,
+    )
+    with tempfile.NamedTemporaryFile(suffix=".pt") as tmp:
+        torch.save(optimized_negate, tmp.name)
+        loaded_mod = torch.load(tmp.name)
+        output = loaded_mod(negate(input))
+    tvm.testing.assert_allclose(input.numpy(), output.numpy(), atol=1e-5, 
rtol=1e-5)
+
+
+def test_sum_up_tensor():
+    x = torch.randint(0, 2, (16,))
+    y = x.bool()
+    optimized_func = optimize_torch(
+        sum_up_tensor,
+        (y,),
+    )
+    ret1 = (x[x == 0]).size(dim=0)
+    ret2 = optimized_func(y).numpy()
+    tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+def test_tensor_boolean_operation():
+    input = torch.rand(200)
+    model = optimize_torch(
+        tensor_boolean_operation,
+        input,
+    )
+    ret1 = tensor_boolean_operation(input)
+    ret2 = model(input)
+    tvm.testing.assert_allclose(ret1, ret2, atol=1e-5, rtol=1e-5)
+
+
+@as_torch
+@T.prim_func
+def negate_tvmscript(
+    X: T.Buffer[(8, 8), "bool"],
+    Y: T.Buffer[(8, 8), "float32"],
+    Z: T.Buffer[(8, 8), "bool"],
+    U: T.Buffer[(8, 8), "float32"],
+) -> None:
+    for i, j in T.grid(8, 8):
+        with T.block():
+            if Y[i, j] > 0.0:
+                Z[i, j] = X[i, j]
+                U[i, j] = Y[i, j]
+            else:
+                Z[i, j] = not X[i, j]
+                U[i, j] = 0.0 - Y[i, j]
+
+
+def negate_vanila(x, y):
+    z = torch.zeros(8, 8).bool()
+    for i in range(8):
+        for j in range(8):
+            if y[i, j] > 0:
+                z[i, j] = x[i, j]
+            else:
+                z[i, j] = ~x[i, j]
+    return z
+
+
+def test_tvmscript_torch_decorator():
+    q1 = (torch.rand(8, 8) + 0.5).int().bool()
+    q2 = torch.rand(8, 8) - 0.5
+    q3 = torch.zeros(8, 8).bool()
+    q4 = torch.zeros(8, 8)
+
+    std1 = negate_vanila(q1, q2)
+    std2 = torch.abs(q2)
+
+    negate_tvmscript(q1, q2, q3, q4)
+
+    tvm.testing.assert_allclose(std1.numpy(), q3.numpy(), atol=1e-5, rtol=1e-5)
+    tvm.testing.assert_allclose(std2.numpy(), q4.numpy(), atol=1e-5, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    test_tvmscript_torch_decorator()
+    test_bool_tensor_negate()
+    test_sum_up_tensor()
+    test_tensor_boolean_operation()
diff --git a/cmake/modules/contrib/PT_TVMDSOOP.cmake 
b/cmake/modules/contrib/PT_TVMDSOOP.cmake
index 3bad3fd966..a73d3f38e9 100644
--- a/cmake/modules/contrib/PT_TVMDSOOP.cmake
+++ b/cmake/modules/contrib/PT_TVMDSOOP.cmake
@@ -6,7 +6,7 @@
 # "License"); you may not use this file except in compliance
 # with the License.  You may obtain a copy of the License at
 #
-#   http://www.apache.org/licenses/LICENSE-2.0
+# http://www.apache.org/licenses/LICENSE-2.0
 #
 # Unless required by applicable law or agreed to in writing,
 # software distributed under the License is distributed on an
@@ -17,42 +17,80 @@
 
 if(NOT USE_PT_TVMDSOOP STREQUAL "OFF")
   find_package(PythonInterp REQUIRED)
-
   execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import torch; 
print(torch.__path__[0].strip())"
     OUTPUT_VARIABLE PT_PATH
     RESULT_VARIABLE PT_STATUS)
-  if (NOT ${PT_STATUS} EQUAL 0)
+
+  if(NOT ${PT_STATUS} EQUAL 0)
     message(FATAL_ERROR "Fail to get pytorch path")
   endif()
 
   string(REGEX REPLACE "\n" "" PT_PATH "${PT_PATH}")
   message(STATUS "PyTorch path: ${PT_PATH}")
 
-  set(PT_COMPILE_FLAGS_STR "-I${PT_PATH}/include -D_GLIBCXX_USE_CXX11_ABI=0")
+  execute_process(COMMAND ${PYTHON_EXECUTABLE} -c "import 
torch;print(torch.compiled_with_cxx11_abi())"
+    OUTPUT_VARIABLE PT_CXX_FLAG
+    RESULT_VARIABLE PT_STATUS)
+
+  string(REGEX REPLACE "\n" "" PT_CXX_FLAG "${PT_CXX_FLAG}")
+  message(STATUS "Found TORCH_BUILT_WITH_CXX_ABI=${PT_CXX_FLAG} ")
+
+  if(${PT_CXX_FLAG} STREQUAL "False")
+    set(CXX_ABI_ENABLED 0)
+  else()
+    set(CXX_ABI_ENABLED 1)
+  endif()
+
+  set_property(
+    SOURCE
+    
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
+    APPEND PROPERTY
+    COMPILE_OPTIONS
+    "-D_GLIBCXX_USE_CXX11_ABI=${CXX_ABI_ENABLED}"
+    "-I${PT_PATH}/include"
+  )
+
+  set_property(
+    SOURCE
+    ${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/tvm_class.cc
+    APPEND PROPERTY
+    COMPILE_OPTIONS
+    "-I${PT_PATH}/include"
+  )
+
   set(PT_LINK_FLAGS_STR "-L${PT_PATH}/lib -l:libtorch.so 
-l:libtorch_python.so")
 
   if(NOT USE_CUDA STREQUAL "OFF")
     add_definitions(-DPT_TVMDSOOP_ENABLE_GPU)
   endif()
 
-
   string(REGEX REPLACE "\n" " " PT_FLAGS "${PT_COMPILE_FLAGS} 
${PT_LINK_FLAGS}")
-  separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND ${PT_COMPILE_FLAGS_STR})
+  separate_arguments(PT_COMPILE_FLAGS UNIX_COMMAND)
   separate_arguments(PT_LINK_FLAGS UNIX_COMMAND ${PT_LINK_FLAGS_STR})
 
+  # This old version is depereated and will be removed after tvm 0.11
+  set(LIBRARY_OLD_NAME pt_tvmdsoop)
 
-  set(LIBRARY_NAME pt_tvmdsoop)
-  tvm_file_glob(GLOB_RECURSE PTTVM_SRCS 
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/**/*.cc)
-  add_library(${LIBRARY_NAME} SHARED ${PTTVM_SRCS})
+  # This new library is set for pytorch integration, which solves the c++ abi 
imcompability issue
+  set(LIBRARY_NEW_NAME pt_tvmdsoop_new)
+  tvm_file_glob(GLOB_RECURSE PTTVM_TORCH 
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/tvm_module_wrapper/*.cc)
+
+  tvm_file_glob(GLOB_RECURSE PTTVM_SRCS 
${CMAKE_CURRENT_SOURCE_DIR}/src/contrib/torch/pt_call_tvm/*.cc)
+
+  add_library(${LIBRARY_OLD_NAME} SHARED ${PTTVM_SRCS})
+  add_library(${LIBRARY_NEW_NAME} SHARED ${PTTVM_TORCH})
   set(PTTVM_LINK_FLAGS -ltvm -L${CMAKE_CURRENT_BINARY_DIR})
 
-  if (NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
-    add_dependencies(${LIBRARY_NAME} tvm) 
+  if(NOT BUILD_PT_TVMDSOOP_ONLY STREQUAL "ON")
+    add_dependencies(${LIBRARY_OLD_NAME} tvm)
+    add_dependencies(${LIBRARY_NEW_NAME} tvm)
   endif()
 
-  target_compile_options(${LIBRARY_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} 
${PT_COMPILE_FLAGS})
-  target_link_libraries(${LIBRARY_NAME} PUBLIC ${PTTVM_LINK_FLAGS} 
${PT_LINK_FLAGS})
-  target_compile_definitions(${LIBRARY_NAME} PUBLIC 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
+  target_compile_options(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} 
${PT_COMPILE_FLAGS})
+  target_link_libraries(${LIBRARY_OLD_NAME} PUBLIC ${PTTVM_LINK_FLAGS} 
${PT_LINK_FLAGS})
+  target_compile_definitions(${LIBRARY_OLD_NAME} PUBLIC 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
 
+  target_compile_options(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_COMPILE_FLAGS} 
${PT_COMPILE_FLAGS})
+  target_link_libraries(${LIBRARY_NEW_NAME} PUBLIC ${PTTVM_LINK_FLAGS} 
${PT_LINK_FLAGS})
+  target_compile_definitions(${LIBRARY_NEW_NAME} PUBLIC 
DMLC_USE_LOGGING_LIBRARY=<tvm/runtime/logging.h>)
 endif()
-
diff --git a/python/tvm/contrib/torch/__init__.py 
b/python/tvm/contrib/torch/__init__.py
index 340f9cef9e..c3dd34d470 100644
--- a/python/tvm/contrib/torch/__init__.py
+++ b/python/tvm/contrib/torch/__init__.py
@@ -18,11 +18,12 @@
 """Module container of Pytorch custom class"""
 import os
 import platform
+import warnings
 import torch
 from tvm._ffi import libinfo
 
 
-def _load_platform_specific_library(lib_name="libpt_tvmdsoop"):
+def _load_platform_specific_library(lib_name):
     system = platform.system()
     if system == "Darwin":
         lib_file_name = lib_name + ".dylib"
@@ -33,11 +34,27 @@ def 
_load_platform_specific_library(lib_name="libpt_tvmdsoop"):
     lib_path = libinfo.find_lib_path()[0]
     lib_dir = os.path.dirname(lib_path)
     lib_file_path = os.path.join(lib_dir, lib_file_name)
-    torch.classes.load_library(lib_file_path)
+    try:
+        torch.classes.load_library(lib_file_path)
+    except OSError as err:
+        errmsg = str(err)
+        if errmsg.find("undefined symbol") != -1:
+            reason = " ".join(
+                (
+                    "Got undefined symbol error,",
+                    "which might be due to the CXXABI incompatibility.",
+                )
+            )
+        else:
+            reason = errmsg
+        warnings.warn(
+            f"The library {lib_name} is not built successfully. {reason}",
+            RuntimeWarning,
+        )
 
 
-_load_platform_specific_library()
-
+_load_platform_specific_library("libpt_tvmdsoop")
+_load_platform_specific_library("libpt_tvmdsoop_new")
 
 from . import module
 
diff --git a/python/tvm/contrib/torch/module.py 
b/python/tvm/contrib/torch/module.py
index 3da9c6f591..cfa3ad264c 100644
--- a/python/tvm/contrib/torch/module.py
+++ b/python/tvm/contrib/torch/module.py
@@ -16,7 +16,9 @@
 # under the License.
 # pylint: disable=invalid-name
 """Module container of PyTorch custom class"""
+import warnings
 from typing import List
+
 import torch
 
 
@@ -29,6 +31,11 @@ class GraphModule(torch.nn.Module):
         return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
 
     def __init__(self, num_inputs, num_outputs, device=None):
+        warnings.warn(
+            "This module will be removed at TVM version 0.11",
+            DeprecationWarning,
+            stacklevel=2,
+        )
         super().__init__()
         self.dummy_param = torch.nn.Parameter(torch.empty(0))
         self.engine = None
@@ -67,6 +74,11 @@ class VMModule(torch.nn.Module):
         return torch.ops.tvm_dsoop.tvm_shape_repr(input_shapes)
 
     def __init__(self, num_inputs, num_outputs, device=None):
+        warnings.warn(
+            "This module will be removed at TVM version 0.11",
+            DeprecationWarning,
+            stacklevel=2,
+        )
         super().__init__()
         self.dummy_param = torch.nn.Parameter(torch.empty(0))
         self.engine = None
@@ -113,6 +125,11 @@ class TraceTvmModule(torch.nn.Module):
     """
 
     def __init__(self, tvm_module):
+        warnings.warn(
+            "This module will be removed at TVM version 0.11",
+            DeprecationWarning,
+            stacklevel=2,
+        )
         super().__init__()
         self.tvm_module = tvm_module
 
diff --git a/python/tvm/contrib/torch/pytorch_tvm.py 
b/python/tvm/contrib/torch/pytorch_tvm.py
index 1e50c98ab8..ffab4fa0d2 100644
--- a/python/tvm/contrib/torch/pytorch_tvm.py
+++ b/python/tvm/contrib/torch/pytorch_tvm.py
@@ -19,6 +19,7 @@
 # pylint: disable=redefined-builtin
 """`compile` api that convert torch module to torch tvm module"""
 import os
+import warnings
 import tvm
 import tvm.testing
 from tvm import relay, autotvm
@@ -183,6 +184,16 @@ class PyTorchTVMModule:
 
     def build_pytorch_module(self, num_inputs, num_outputs, input_infos=None):
         """Build pytorch module containing TVM Graph Module"""
+        warnings.warn(
+            " ".join(
+                (
+                    "This function will be removed at TVM version 0.11,",
+                    "we suggest users to use `optimized_torch` for tuning 
Torch modules instead.",
+                )
+            ),
+            DeprecationWarning,
+            stacklevel=2,
+        )
         assert self.export_dir, "you must build_tvm or load_tvm before"
         input_infos = input_infos or self.input_infos
         assert input_infos
@@ -224,6 +235,16 @@ def compile(script_module, option):
     pytorch_tvm_module = compile(script_module, option)
     pytorch_tvm_module("model_tvm.pt")
     """
+    warnings.warn(
+        " ".join(
+            (
+                "This function will be removed at TVM version 0.11,",
+                "we suggest users to use `optimized_torch` for tuning Torch 
modules instead.",
+            )
+        ),
+        DeprecationWarning,
+        stacklevel=2,
+    )
     input_infos = option["input_infos"]
     default_dtype = option.get("default_dtype", "float32")
     export_dir = option.get("export_dir", "pytorch_compiled")
diff --git a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc 
b/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
deleted file mode 100644
index 12c1017bea..0000000000
--- a/src/contrib/torch/pt_call_tvm/RuntimeModuleWrapper.cc
+++ /dev/null
@@ -1,259 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-#include <ATen/DLConvertor.h>
-#include <dlpack/dlpack.h>
-#include <dmlc/memory_io.h>
-#include <torch/custom_class.h>
-#include <torch/script.h>
-#include <tvm/runtime/module.h>
-#include <tvm/runtime/registry.h>
-#include <tvm/target/codegen.h>
-#include <tvm/target/target.h>
-
-#include <cstdio>
-#include <map>
-#include <string>
-#include <vector>
-
-#include "../../../runtime/graph_executor/graph_executor_factory.h"
-#include "../base64.h"
-
-namespace tvm {
-namespace contrib {
-
-/**
- * We pass the TVM module by TVM's FFI because Torch's FFI cannot recognize 
such TVM objects
- */
-struct ThreadLocalStore {
-  tvm::runtime::Module mod;
-  static ThreadLocalStore* ThreadLocal() {
-    thread_local ThreadLocalStore tls;
-    return &tls;
-  }
-};
-
-using SerializationType = std::string;  // base64 stream
-
-SerializationType serialize(tvm::runtime::Module module) {
-  static const runtime::PackedFunc* f_to_str =
-      runtime::Registry::Get("script_torch.save_to_base64");
-  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
-                      "`script_torch.save_to_base64` in the global registry";
-  return (*f_to_str)(module);
-}
-
-struct Deleter {  // deleter
-  explicit Deleter(std::string file_name) { this->file_name = file_name; }
-  void operator()(FILE* p) const {
-    fclose(p);
-    ICHECK(remove(file_name.c_str()) == 0)
-        << "Failed to  remove temporary file (" << file_name << ")";
-  }
-  std::string file_name;
-};
-
-tvm::runtime::Module deserialize(SerializationType state) {
-  auto length = tvm::support::b64strlen(state);
-
-  std::vector<u_char> bytes(length);
-  tvm::support::b64decode(state, bytes.data());
-
-  const std::string name = tmpnam(NULL);
-  auto file_name = name + ".so";
-  std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), 
Deleter(file_name));
-  fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
-  fflush(pFile.get());
-
-  std::string load_f_name = "runtime.module.loadfile_so";
-  const PackedFunc* f = runtime::Registry::Get(load_f_name);
-  ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
-                       << " resolved to (" << load_f_name << ") in the global 
registry."
-                       << "Ensure that you have loaded the correct runtime 
code, and"
-                       << "that you are on the correct hardware architecture.";
-
-  tvm::runtime::Module ret = (*f)(file_name, "");
-
-  return ret;
-}
-
-/**
- * @brief A Torch's module which wraps TVM's OperatorModule Class.
- * The basic forward function calling TVM's runtime is provided.
- * The TVM module can be serialized/deserialized as a Torch module.
- */
-class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
- public:
-  OperatorModuleWrapper() { runtime_module = 
ThreadLocalStore::ThreadLocal()->mod; }
-
-  void forward(const c10::List<at::Tensor>& inputs) {
-    int input_length = inputs.size();
-
-    std::vector<DLManagedTensor*> tensors;
-
-    for (int i = 0; i < input_length; ++i) 
tensors.push_back(toDLPack(inputs[i]));
-
-    tvm::runtime::PackedFunc run = runtime_module.GetFunction("__tvm_main__");
-
-    std::vector<TVMValue> tvm_values(input_length);
-    std::vector<int> tvm_type_codes(input_length);
-    tvm::runtime::TVMArgsSetter setter(tvm_values.data(), 
tvm_type_codes.data());
-    for (int k = 0; k < input_length; ++k) {
-      setter(k, &tensors[k]->dl_tensor);
-    }
-
-    run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), 
tvm_type_codes.data(), input_length),
-                   nullptr);
-
-    for (int k = 0; k < input_length; ++k) {
-      tensors[k]->deleter(tensors[k]);
-    }
-  }
-
-  SerializationType Serialize() { return serialize(runtime_module); }
-
-  explicit OperatorModuleWrapper(SerializationType state) { runtime_module = 
deserialize(state); }
-
- private:
-  tvm::runtime::Module runtime_module;
-};
-
-tvm::Device getDevice(const at::Tensor& tensor) {
-  tvm::Device dev;
-  dev.device_id = tensor.get_device();
-  switch (tensor.device().type()) {
-    case at::DeviceType::CPU:
-      dev.device_type = DLDeviceType::kDLCPU;
-      if (dev.device_id == -1) {
-        /*
-         * In PyTorch the device ID for cpu is -1, sometimes causing error 
during tuning
-         * Thus we manually set the device ID as 0 for avoiding potentially 
error of index out of
-         * bounds
-         */
-        dev.device_id = 0;
-      }
-      break;
-    case at::DeviceType::CUDA:
-      dev.device_type = DLDeviceType::kDLCUDA;
-      break;
-    default:
-      TORCH_CHECK(false, "PyTorch TVM integration doesn't support device " + 
tensor.device().str());
-  }
-  return dev;
-}
-
-/**
- * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
- * The basic forward function calling TVM's runtime is provided.
- * The TVM module can be serialized/deserialized as a Torch module.
- */
-class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
- public:
-  explicit GraphExecutorFactoryWrapper(tvm::runtime::Module executor_factory)
-      : executor_factory_(executor_factory) {
-    CHECK(executor_factory_->IsInstance<runtime::GraphExecutorFactory>())
-        << "module is not an instance of GraphExecutorFactory";
-  }
-
-  GraphExecutorFactoryWrapper()
-      : GraphExecutorFactoryWrapper(ThreadLocalStore::ThreadLocal()->mod) {}
-
-  c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
-    int input_length = inputs.size();
-
-    if (!executor_.defined()) {
-      TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
-      DLDevice input_device = getDevice(inputs.get(0));
-
-      auto tmp = executor_factory_.GetFunction("default");
-
-      executor_ = tmp(input_device);
-    }
-
-    std::vector<DLManagedTensor*> tensors;
-
-    for (int i = 0; i < input_length; ++i) 
tensors.push_back(toDLPack(inputs[i]));
-
-    tvm::runtime::PackedFunc run = executor_.GetFunction("run");
-    tvm::runtime::PackedFunc set_input = executor_.GetFunction("set_input");
-    tvm::runtime::PackedFunc get_output = executor_.GetFunction("get_output");
-    tvm::runtime::PackedFunc get_num_outputs = 
executor_.GetFunction("get_num_outputs");
-
-    for (int k = 0; k < input_length; ++k) {
-      set_input(k, &tensors[k]->dl_tensor);
-    }
-
-    run();
-
-    int64_t output_length = get_num_outputs();
-
-    c10::List<at::Tensor> outputs;
-    outputs.reserve(output_length);
-
-    for (int k = 0; k < output_length; ++k) {
-      tvm::runtime::NDArray results = get_output(k);
-      at::Tensor atTensor = at::fromDLPack(results.ToDLPack());
-      outputs.emplace_back(atTensor);
-    }
-
-    for (int k = 0; k < input_length; ++k) {
-      tensors[k]->deleter(tensors[k]);
-    }
-    return outputs;
-  }
-
-  SerializationType Serialize() { return serialize(executor_factory_); }
-
-  explicit GraphExecutorFactoryWrapper(SerializationType state) {
-    executor_factory_ = deserialize(state);
-  }
-
- private:
-  tvm::runtime::Module executor_factory_;
-  tvm::runtime::Module executor_;
-};
-
-TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module
 mod) {
-  ThreadLocalStore::ThreadLocal()->mod = mod;
-});
-
-TORCH_LIBRARY(tvm_torch, m) {
-  m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
-      .def(torch::init<>())
-      .def("forward", &OperatorModuleWrapper::forward)
-      .def_pickle(
-          [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> 
SerializationType {
-            return self->Serialize();
-          },
-          [](SerializationType state) {
-            return c10::make_intrusive<OperatorModuleWrapper>(state);
-          });
-  m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
-      .def(torch::init<>())
-      .def("forward", &GraphExecutorFactoryWrapper::forward)
-      .def_pickle(
-          [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> 
SerializationType {
-            return self->Serialize();
-          },
-          [](SerializationType state) {
-            return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
-          });
-}
-
-}  // namespace contrib
-}  // namespace tvm
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc 
b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
new file mode 100644
index 0000000000..fb570c163f
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTVM.cc
@@ -0,0 +1,266 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <dlpack/dlpack.h>
+#include <dmlc/memory_io.h>
+#include <tvm/runtime/module.h>
+#include <tvm/runtime/registry.h>
+#include <tvm/target/codegen.h>
+#include <tvm/target/target.h>
+
+#include <cstdio>
+#include <map>
+#include <string>
+#include <vector>
+
+#include "../../../runtime/graph_executor/graph_executor_factory.h"
+#include "../base64.h"
+#include "runtime_bridge.h"
+
+namespace tvm {
+namespace contrib {
+
+/*
+ * TVM's FFI for passing module from python to C++
+ */
+struct ThreadLocalStore {
+  tvm::runtime::Module mod;
+  static ThreadLocalStore* ThreadLocal() {
+    thread_local ThreadLocalStore tls;
+    return &tls;
+  }
+};
+
+/*
+ * Encode TVM runtime module to base64 stream
+ */
+std::string serialize(tvm::runtime::Module module) {
+  static const runtime::PackedFunc* f_to_str =
+      runtime::Registry::Get("script_torch.save_to_base64");
+  ICHECK(f_to_str) << "IndexError: Cannot find the packed function "
+                      "`script_torch.save_to_base64` in the global registry";
+  return (*f_to_str)(module);
+}
+
+struct Deleter {  // deleter
+  explicit Deleter(std::string file_name) { this->file_name = file_name; }
+  void operator()(FILE* p) const {
+    fclose(p);
+    ICHECK(remove(file_name.c_str()) == 0)
+        << "remove temporary file (" << file_name << ") unsuccessfully";
+  }
+  std::string file_name;
+};
+
+/*
+ * Decode TVM runtime module from base64 stream
+ */
+tvm::runtime::Module deserialize(std::string state) {
+  auto length = tvm::support::b64strlen(state);
+
+  std::vector<u_char> bytes(length);  // bytes stream
+  tvm::support::b64decode(state, bytes.data());
+
+  const std::string name = tmpnam(NULL);
+  auto file_name = name + ".so";
+  std::unique_ptr<FILE, Deleter> pFile(fopen(file_name.c_str(), "wb"), 
Deleter(file_name));
+  fwrite(bytes.data(), sizeof(u_char), length, pFile.get());
+  fflush(pFile.get());
+
+  std::string load_f_name = "runtime.module.loadfile_so";
+  const PackedFunc* f = runtime::Registry::Get(load_f_name);
+  ICHECK(f != nullptr) << "Loader for `.so` files is not registered,"
+                       << " resolved to (" << load_f_name << ") in the global 
registry."
+                       << "Ensure that you have loaded the correct runtime 
code, and"
+                       << "that you are on the correct hardware architecture.";
+
+  tvm::runtime::Module ret = (*f)(file_name, "");
+
+  return ret;
+}
+
+TVM_REGISTER_GLOBAL("tvmtorch.save_runtime_mod").set_body_typed([](tvm::runtime::Module
 mod) {
+  ThreadLocalStore::ThreadLocal()->mod = mod;
+});
+
+/*
+ * Convert NDArray to DLPack extend tensor. It should be zero-cost.
+ * @param src Pointer to NDArray
+ * @return DLPack extended tensor
+ */
+DLPackTensorExt CreateDLpackTensorExt(tvm::runtime::NDArray* src) {
+  auto is_bool = src->DataType().is_bool();
+  DLManagedTensor* tensor;
+  if (is_bool) {
+    // If we change DLDataType{kDLInt, 8, 1} to DataType::Bool()
+    // we will get `RuntimeError: Unsupported kUInt bits 1`
+    auto tmp = src->CreateView(src->Shape(), DLDataType{kDLInt, 8, 1});
+    tensor = tmp.ToDLPack();
+  } else {
+    tensor = src->ToDLPack();
+  }
+  DLPackTensorExt ret{tensor, is_bool};
+  return ret;
+}
+
+/*
+ * Create an NDArray with boolean type. (One memory copy)
+ * @param src DLpack extended tensor
+ * @return a new NDArray
+ */
+tvm::runtime::NDArray CreateBoolNDarray(DLPackTensorExt* src) {
+  auto& tensor = src->dl_managed_tensor->dl_tensor;
+  std::vector<int64_t> shape;
+  for (int64_t i = 0; i < tensor.ndim; i++) {
+    shape.push_back(tensor.shape[i]);
+  }
+  auto ret = tvm::runtime::NDArray::Empty(shape, DataType::Bool(), 
tensor.device);
+  ret.CopyFrom(&src->dl_managed_tensor->dl_tensor);
+  return std::move(ret);
+}
+
+bool IsZeroCopy(DLPackTensorExt* src) {
+  auto& dl_tensor = src->dl_managed_tensor->dl_tensor;
+  return tvm::runtime::NDArray::AbilityOfZeroCopyForDLTensor(&dl_tensor, 
dl_tensor.device);
+}
+
+/*
+ * Create an NDArray from DLpack extended tensor.
+ * @param src DLpack extended tensor
+ * @return a new NDArray
+ */
+tvm::runtime::NDArray NDarrayFromDLpack(DLPackTensorExt* src) {
+  using tvm::runtime::NDArray;
+
+  NDArray array;
+  auto& dl_tensor = src->dl_managed_tensor->dl_tensor;
+  if (src->is_bool) {
+    // one memory copy
+    // the code is similar to NewFromDLTensor except for the type
+    array = CreateBoolNDarray(src);
+  } else if (IsZeroCopy(src)) {
+    array = NDArray::FromExternalDLTensor(src->dl_managed_tensor->dl_tensor);
+  } else {
+    // one memory copy
+    array = NDArray::NewFromDLTensor(&dl_tensor, dl_tensor.device);
+  }
+  return array;
+}
+
+}  // namespace contrib
+}  // namespace tvm
+
+extern "C" {
+
+struct TVMContribTorchRuntimeModule {
+  tvm::runtime::Module mod;
+
+  explicit TVMContribTorchRuntimeModule(tvm::runtime::Module& mod) : mod(mod) 
{}
+};
+
+bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt* src) {
+  return (!src->is_bool) && (tvm::contrib::IsZeroCopy(src));
+}
+
+TVMContribTorchRuntimeModule* 
tvm_contrib_torch_get_last_saved_runtime_module() {
+  return new 
TVMContribTorchRuntimeModule(tvm::contrib::ThreadLocalStore::ThreadLocal()->mod);
+}
+
+void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* 
runtime_module,
+                                               DLPackTensorExt* inputs, size_t 
input_size) {
+  tvm::runtime::PackedFunc run = 
runtime_module->mod.GetFunction("__tvm_main__");
+
+  std::vector<TVMValue> tvm_values(input_size);
+  std::vector<int> tvm_type_codes(input_size);
+  tvm::runtime::TVMArgsSetter setter(tvm_values.data(), tvm_type_codes.data());
+
+  std::vector<tvm::runtime::NDArray> input_cache(input_size);
+
+  for (size_t k = 0; k < input_size; ++k) {
+    auto datum = tvm::contrib::NDarrayFromDLpack(&inputs[k]);  // could have 
one memory copy
+    input_cache[k] = datum;  // we keep the datum in a vector for future use, 
otherwise the datum
+                             // will be freed after the loop
+    setter(k, datum);
+  }
+
+  run.CallPacked(tvm::runtime::TVMArgs(tvm_values.data(), 
tvm_type_codes.data(), input_size),
+                 nullptr);
+
+  for (size_t k = 0; k < input_size; ++k) {
+    if (!tvm_contrib_torch_tensor_ability_of_zero_copy(&inputs[k]))
+      input_cache[k].CopyTo(&inputs[k].dl_managed_tensor->dl_tensor);
+  }
+}
+
+TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module(
+    TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* 
input_example) {
+  tvm::runtime::PackedFunc built_module = 
graph_executor_factory->mod.GetFunction("default");
+  tvm::Device device_info = input_example->dl_tensor.device;
+  tvm::runtime::Module runtime_module = built_module(device_info);
+  return new TVMContribTorchRuntimeModule(runtime_module);
+}
+
+size_t 
tvm_contrib_torch_graph_executor_module_forward(TVMContribTorchRuntimeModule* 
runtime_module,
+                                                       DLPackTensorExt* 
inputs, size_t input_size,
+                                                       DLPackTensorExt** 
outputs) {
+  tvm::runtime::PackedFunc run = runtime_module->mod.GetFunction("run");
+  tvm::runtime::PackedFunc set_input = 
runtime_module->mod.GetFunction("set_input");
+  tvm::runtime::PackedFunc get_output = 
runtime_module->mod.GetFunction("get_output");
+  tvm::runtime::PackedFunc get_num_outputs = 
runtime_module->mod.GetFunction("get_num_outputs");
+
+  for (size_t k = 0; k < input_size; ++k) {
+    set_input(k, &inputs[k].dl_managed_tensor->dl_tensor);
+  }
+
+  run();
+
+  int64_t output_length = get_num_outputs();
+
+  DLPackTensorExt* outputs_ptr = new DLPackTensorExt[output_length];
+  *outputs = outputs_ptr;
+
+  for (int64_t k = 0; k < output_length; ++k) {
+    tvm::runtime::NDArray results = get_output(k);
+    outputs_ptr[k] = tvm::contrib::CreateDLpackTensorExt(&results);
+  }
+
+  return output_length;
+}
+
+char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module) {
+  std::string std = tvm::contrib::serialize(runtime_module->mod);
+  char* ret = new char[std.length() + 1];
+  snprintf(ret, std.length() + 1, "%s", std.c_str());
+  return ret;
+}
+
+TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state) {
+  tvm::runtime::Module ret = tvm::contrib::deserialize(state);
+  return new TVMContribTorchRuntimeModule(ret);
+}
+
+void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* 
module_ptr) {
+  delete module_ptr;
+}
+
+void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt* 
dlpack_ptr) {
+  delete[] dlpack_ptr;
+}
+
+void tvm_contrib_torch_free_encoding(char* encoding) { delete[] encoding; }
+}
diff --git a/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc 
b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
new file mode 100644
index 0000000000..3159438d72
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/RuntimeModuleWrapperTorch.cc
@@ -0,0 +1,215 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <ATen/DLConvertor.h>
+#include <torch/custom_class.h>
+#include <torch/script.h>
+
+#include <iostream>
+
+#include "runtime_bridge.h"
+
+namespace tvm {
+namespace contrib {
+
+/*
+ * Convert Torch tensor to DLPack extended tensor.
+ * The boolean Torch tensor will convert to DLtensor with `is_bool=True` flag.
+ * @param src Torch tensor
+ * @return DLPack extended tensor
+ */
+DLPackTensorExt ToDLPackExt(const at::Tensor& src) {
+  if (!src.is_contiguous()) {
+    return ToDLPackExt(src.contiguous());
+  }
+  DLPackTensorExt ret;
+  if (src.dtype().isScalarType(torch::kBool)) {
+    auto temp = src.toType(torch::kUInt8);
+    ret.dl_managed_tensor = at::toDLPack(temp);
+    ret.is_bool = true;
+  } else {
+    ret.dl_managed_tensor = at::toDLPack(src);
+    ret.is_bool = false;
+  }
+
+  return ret;
+}
+
+/*
+ * Convert DLPack extended tensor to Torch tensor.
+ * @param src DLPack extended tensor
+ * @return Torch tensor
+ */
+at::Tensor FromDLPackExt(const DLPackTensorExt& src) {
+  if (src.is_bool) {
+    return at::fromDLPack(src.dl_managed_tensor).toType(torch::kBool);
+  } else {
+    return at::fromDLPack(src.dl_managed_tensor);
+  }
+}
+
+/**
+ * @brief A Torch's module which wraps TVM's OperatorModule Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class OperatorModuleWrapper : public torch::jit::CustomClassHolder {
+ public:
+  OperatorModuleWrapper() { runtime_module_ = 
tvm_contrib_torch_get_last_saved_runtime_module(); }
+  ~OperatorModuleWrapper() { 
tvm_contrib_torch_free_runtime_module(runtime_module_); }
+
+  void forward(const c10::List<at::Tensor>& inputs) {
+    int input_length = inputs.size();
+
+    std::vector<DLPackTensorExt> tensors;
+
+    // Torch tensor supports boolean type while DLpack does not,
+    // we convert Torch tensor to an extension of DLPack tensor
+    for (int i = 0; i < input_length; ++i) 
tensors.push_back(ToDLPackExt(inputs[i]));
+    tvm_contrib_torch_operator_module_forward(this->runtime_module_, 
tensors.data(),
+                                              tensors.size());
+
+    for (int k = 0; k < input_length; ++k) {
+      if (tvm_contrib_torch_tensor_ability_of_zero_copy(&tensors[k])) {
+        // We need to free memory manually
+        tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor);
+      } else {
+        // Ownership transferred
+        inputs[k].copy_(FromDLPackExt(tensors[k]));
+      }
+    }
+  }
+
+  std::string Serialize() {
+    auto encoding = tvm_contrib_torch_encode(runtime_module_);
+    auto ret = std::string(encoding);
+    tvm_contrib_torch_free_encoding(encoding);
+    return ret;
+  }
+
+  explicit OperatorModuleWrapper(std::string state) {
+    runtime_module_ = tvm_contrib_torch_decode(state.c_str());
+  }
+
+ private:
+  /*
+   * TVM runtime module wrapper
+   */
+  TVMContribTorchRuntimeModule* runtime_module_;
+};
+
+/**
+ * @brief A Torch's module which wraps TVM's GraphExecutorFactory Class.
+ * The basic forward function calling TVM's runtime is provided.
+ * The TVM module can be serialized/deserialized as a Torch module.
+ */
+class GraphExecutorFactoryWrapper : public torch::jit::CustomClassHolder {
+ public:
+  explicit GraphExecutorFactoryWrapper(TVMContribTorchRuntimeModule* 
executor_factory)
+      : executor_factory_(executor_factory), 
executor_factory_runtime_(nullptr) {}
+
+  ~GraphExecutorFactoryWrapper() {
+    tvm_contrib_torch_free_runtime_module(executor_factory_);
+    tvm_contrib_torch_free_runtime_module(executor_factory_runtime_);
+  }
+
+  GraphExecutorFactoryWrapper()
+      : 
GraphExecutorFactoryWrapper(tvm_contrib_torch_get_last_saved_runtime_module()) 
{}
+
+  std::string Serialize() {
+    auto encoding = tvm_contrib_torch_encode(executor_factory_);
+    auto ret = std::string(encoding);
+    tvm_contrib_torch_free_encoding(encoding);
+    return ret;
+  }
+
+  explicit GraphExecutorFactoryWrapper(std::string state) {
+    executor_factory_ = tvm_contrib_torch_decode(state.c_str());
+    executor_factory_runtime_ = nullptr;
+  }
+
+  c10::List<at::Tensor> forward(const c10::List<at::Tensor>& inputs) {
+    int input_length = inputs.size();
+
+    TORCH_CHECK(input_length > 0, "Receive empty list of input tensors");
+
+    std::vector<DLPackTensorExt> tensors;
+
+    // Torch tensor supports boolean type while DLpack does not,
+    // we convert Torch tensor to an extension of DLPack tensor
+    for (int i = 0; i < input_length; ++i) 
tensors.push_back(ToDLPackExt(inputs[i]));
+
+    DLPackTensorExt* outputs;
+    if (executor_factory_runtime_ == nullptr) {
+      executor_factory_runtime_ = 
tvm_contrib_torch_create_graph_runtime_module(
+          this->executor_factory_, tensors[0].dl_managed_tensor);
+    }
+    auto num_outputs = tvm_contrib_torch_graph_executor_module_forward(
+        executor_factory_runtime_, tensors.data(), tensors.size(), &outputs);
+
+    c10::List<at::Tensor> ret;
+    ret.reserve(num_outputs);
+
+    for (size_t k = 0; k < num_outputs; ++k) {
+      at::Tensor atTensor = FromDLPackExt(outputs[k]);
+      ret.emplace_back(atTensor);
+    }
+
+    for (int k = 0; k < input_length; ++k) {
+      tensors[k].dl_managed_tensor->deleter(tensors[k].dl_managed_tensor);
+    }
+    tvm_contrib_torch_free_dlpack_tensor_ext_array(outputs);
+
+    return ret;
+  }
+
+ private:
+  /*
+   * TVM Graph Executor Factory module wrapper
+   */
+  TVMContribTorchRuntimeModule* executor_factory_;
+
+  /*
+   * TVM runtime module wrapper
+   */
+  TVMContribTorchRuntimeModule* executor_factory_runtime_;
+};
+
+TORCH_LIBRARY(tvm_torch, m) {
+  m.class_<OperatorModuleWrapper>("OperatorModuleWrapper")
+      .def(torch::init<>())
+      .def("forward", &OperatorModuleWrapper::forward)
+      .def_pickle(
+          [](const c10::intrusive_ptr<OperatorModuleWrapper>& self) -> 
std::string {
+            return self->Serialize();
+          },
+          [](std::string state) { return 
c10::make_intrusive<OperatorModuleWrapper>(state); });
+  m.class_<GraphExecutorFactoryWrapper>("GraphExecutorFactoryWrapper")
+      .def(torch::init<>())
+      .def("forward", &GraphExecutorFactoryWrapper::forward)
+      .def_pickle(
+          [](const c10::intrusive_ptr<GraphExecutorFactoryWrapper>& self) -> 
std::string {
+            return self->Serialize();
+          },
+          [](std::string state) {
+            return c10::make_intrusive<GraphExecutorFactoryWrapper>(state);
+          });
+}
+
+}  // namespace contrib
+}  // namespace tvm
diff --git a/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h 
b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h
new file mode 100644
index 0000000000..58cd53a284
--- /dev/null
+++ b/src/contrib/torch/tvm_module_wrapper/runtime_bridge.h
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file runtime_bridge.h
+ * \brief Util functions for pytorch tvm interaction.
+ */
+#ifndef TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_
+#define TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_
+
+extern "C" {
+
+/*
+ * DLPack data structure extend with `is_bool` flag.
+ * DLPack haven't support boolean tensor
+ * 
(https://github.com/pytorch/pytorch/blob/4618371da56c887195e2e1d16dad2b9686302800/aten/src/ATen/DLConvertor.cpp#L42),
+ * thus a boolean tensor will be regarded as a UInt8 tensor
+ * 
(https://github.com/apache/tvm/blob/de124862714e747764aa8b7f41a90bcb25f3c6a8/python/tvm/_ffi/runtime_ctypes.py#L91).
+ */
+struct DLPackTensorExt {
+  DLManagedTensor* dl_managed_tensor;
+  bool is_bool;
+};
+
+/*
+ * A wrapper pointing to TVM runtime module.
+ */
+struct TVMContribTorchRuntimeModule;
+
+/*
+ * Obtain a saved runtime module passed by TVM FFI.
+ * @return A TVM runtime module wrapper.
+ */
+TVMContribTorchRuntimeModule* 
tvm_contrib_torch_get_last_saved_runtime_module();
+
+/*
+ * Delete TVMContribTorchRuntimeModule pointer.
+ */
+void tvm_contrib_torch_free_runtime_module(TVMContribTorchRuntimeModule* 
module_ptr);
+
+/*
+ * Obtain ExecutorFactory runtime module from ExecutorFactory class.
+ * @param graph_executor_factory ExecutorFactory class
+ * @param input_example For obtaining device information
+ * @return ExecutorFactory TVM runtime module wrapper
+ */
+TVMContribTorchRuntimeModule* tvm_contrib_torch_create_graph_runtime_module(
+    TVMContribTorchRuntimeModule* graph_executor_factory, DLManagedTensor* 
input_example);
+
+/*
+ * Forward method for OperatorModuleWrapper.
+ * @param runtime_module TVM runtime module wrapper
+ * @param inputs Array pointer of the input tensors
+ * @param input_size The number of input tensors
+ */
+void tvm_contrib_torch_operator_module_forward(TVMContribTorchRuntimeModule* 
runtime_module,
+                                               DLPackTensorExt* inputs, size_t 
input_size);
+
+/*
+ * Forward method for GraphExecutorFactoryWrapper.
+ * @param graph_executor_factory TVM runtime module wrapper
+ * @param inputs Array pointer of the input tensors
+ * @param input_size The number of input tensors
+ * @param outputs The resulting output tensors pointer
+ * @return The number of output tensors
+ */
+size_t tvm_contrib_torch_graph_executor_module_forward(
+    TVMContribTorchRuntimeModule* graph_executor_factory, DLPackTensorExt* 
inputs,
+    size_t input_size, DLPackTensorExt** outputs);
+
+/*
+ * Encode TVM runtime module.
+ * @param runtime_module TVM runtime module wrapper
+ * @return The encoding stream (char array)
+ */
+char* tvm_contrib_torch_encode(TVMContribTorchRuntimeModule* runtime_module);
+
+/*
+ * Decode TVM runtime module.
+ * @param state The encoding stream (char array) of TVM runtime module
+ * @return TVM runtime module wrapper
+ */
+TVMContribTorchRuntimeModule* tvm_contrib_torch_decode(const char* state);
+
+/*
+ * Delete DLPackTensorExt pointer.
+ */
+void tvm_contrib_torch_free_dlpack_tensor_ext_array(DLPackTensorExt*);
+
+/*
+ * Delete char array pointer.
+ */
+void tvm_contrib_torch_free_encoding(char* encoding);
+
+/*
+ * Checking if a DLPackTensorExt is boolean or cannot be copied in zero cost.
+ */
+bool tvm_contrib_torch_tensor_ability_of_zero_copy(DLPackTensorExt*);
+}
+
+#endif  // TVM_CONTRIB_TORCH_TVM_MODULE_WRAPPER_RUNTIME_BRIDGE_H_

Reply via email to