This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2226a1f558 [Unity] Multi-device support for Relax (#15447)
2226a1f558 is described below

commit 2226a1f5581cb77e599ce16d61227ed1088ff058
Author: Yong Wu <yongc...@gmail.com>
AuthorDate: Fri Aug 11 12:25:29 2023 -0700

    [Unity] Multi-device support for Relax (#15447)
---
 include/tvm/ir/global_info.h                       | 50 ++++++++++-
 include/tvm/relax/attrs/op.h                       |  8 ++
 include/tvm/relax/struct_info.h                    | 17 +++-
 include/tvm/relay/transform.h                      | 12 +++
 include/tvm/script/printer/ir_docsifier.h          |  7 ++
 include/tvm/target/target.h                        | 10 ---
 include/tvm/target/target_kind.h                   | 27 +-----
 include/tvm/tir/usmp/utils.h                       |  1 +
 python/tvm/ir/__init__.py                          |  2 +-
 python/tvm/ir/global_info.py                       | 12 +++
 python/tvm/ir/json_compact.py                      | 17 ++++
 python/tvm/relax/op/base.py                        | 21 +++++
 python/tvm/relax/struct_info.py                    | 14 ++--
 python/tvm/script/ir_builder/ir/__init__.py        |  2 +
 python/tvm/script/ir_builder/ir/ir.py              | 38 ++++++++-
 python/tvm/script/ir_builder/relax/ir.py           |  2 +
 python/tvm/script/parser/ir/__init__.py            |  2 +
 python/tvm/script/parser/relax/dist.py             |  8 +-
 python/tvm/script/parser/relax/entry.py            | 21 ++++-
 src/driver/driver_api.cc                           | 17 ++++
 src/ir/global_info.cc                              | 13 +++
 src/relax/analysis/struct_info_analysis.cc         | 25 ++++--
 src/relax/backend/vm/vm_builtin_lower.cc           | 21 +++++
 src/relax/ir/expr.cc                               |  2 +-
 src/relax/ir/struct_info.cc                        | 12 +--
 src/relax/ir/struct_info_functor.cc                |  7 +-
 src/relax/op/op.cc                                 | 28 +++++++
 src/relax/transform/convert_layout.cc              |  6 +-
 src/relax/transform/to_mixed_precision.cc          |  6 +-
 src/relay/backend/contrib/cmsisnn/target.cc        |  4 +-
 src/relay/backend/contrib/codegen_c/target.cc      |  2 +-
 src/relay/backend/contrib/cutlass/target.cc        |  2 +-
 src/relay/backend/contrib/ethosu/codegen.cc        |  4 +-
 .../backend/contrib/example_target_hooks/target.cc |  6 +-
 src/relay/backend/contrib/tensorrt/target.cc       |  2 +-
 src/relay/backend/contrib/uma/targets.cc           | 28 +++----
 src/runtime/relax_vm/builtin.cc                    |  6 ++
 src/script/ir_builder/ir/ir.cc                     | 30 +++++++
 src/script/printer/ir/ir.cc                        | 10 +++
 src/script/printer/ir_docsifier.cc                 |  6 ++
 src/script/printer/relax/struct_info.cc            |  7 ++
 src/script/printer/relax/utils.h                   | 16 ++++
 src/target/codegen.cc                              |  3 +
 src/target/target.cc                               | 15 +---
 tests/cpp/target_test.cc                           |  2 +-
 .../python/relax/test_json_compact.py              | 50 ++++++-----
 tests/python/relax/test_relax_operators.py         | 30 +++++++
 tests/python/relax/test_tvmscript_parser.py        | 97 +++++++++++++++++++++-
 .../relax/test_tvmscript_parser_op_manipulate.py   | 15 ++++
 tests/python/relax/test_vm_build.py                | 35 ++++++++
 tests/python/relax/test_vm_codegen_only.py         | 27 ++++++
 51 files changed, 680 insertions(+), 125 deletions(-)

diff --git a/include/tvm/ir/global_info.h b/include/tvm/ir/global_info.h
index bd6006864a..57bc57fd09 100644
--- a/include/tvm/ir/global_info.h
+++ b/include/tvm/ir/global_info.h
@@ -25,10 +25,16 @@
 #ifndef TVM_IR_GLOBAL_INFO_H_
 #define TVM_IR_GLOBAL_INFO_H_
 
-#include "tvm/ir/expr.h"
+#include <tvm/ir/expr.h>
+#include <tvm/target/target.h>
 
 namespace tvm {
 
+/*!
+ * \brief Abstract label for an area of memory.
+ */
+using MemoryScope = String;
+
 /*!
  * \brief GlobalInfo are globally static object that are referred by the IR 
itself.
  *        Base node for all global info that can appear in the IR
@@ -50,6 +56,48 @@ class GlobalInfo : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_METHODS(GlobalInfo, ObjectRef, GlobalInfoNode);
 };
 
+/*!
+ * \brief A global info subclass for virtual devices.
+ */
+class VDeviceNode : public GlobalInfoNode {
+ public:
+  /*! \brief The \p Target describing how to compile for the virtual device. */
+  Target target;
+  /*! \brief The device identifier for the virtual device. This enables us to
+   * differentiate between distinct devices with same Target, such as multiple 
GPUs.
+   */
+  int vdevice_id;
+  MemoryScope memory_scope;
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("target", &target);
+    v->Visit("vdevice_id", &vdevice_id);
+    v->Visit("memory_scope", &memory_scope);
+  }
+
+  TVM_DLL bool SEqualReduce(const VDeviceNode* other, SEqualReducer equal) 
const {
+    return equal(target, other->target) && equal(vdevice_id, 
other->vdevice_id) &&
+           equal(memory_scope, other->memory_scope);
+  }
+
+  TVM_DLL void SHashReduce(SHashReducer hash_reduce) const {
+    hash_reduce(target);
+    hash_reduce(vdevice_id);
+    hash_reduce(memory_scope);
+  }
+  static constexpr const char* _type_key = "VDevice";
+  TVM_DECLARE_FINAL_OBJECT_INFO(VDeviceNode, GlobalInfoNode);
+};
+
+/*!
+ * \brief Managed reference to VDeviceNode.
+ * \sa VDeviceNode
+ */
+class VDevice : public GlobalInfo {
+ public:
+  TVM_DLL explicit VDevice(Target tgt, int dev_id, MemoryScope mem_scope);
+  TVM_DEFINE_OBJECT_REF_METHODS(VDevice, GlobalInfo, VDeviceNode);
+};
+
 /*!
  * \brief A dummy global info sub-class for testing purpose.
  */
diff --git a/include/tvm/relax/attrs/op.h b/include/tvm/relax/attrs/op.h
index 8c0ee6abc6..f33f3a9701 100644
--- a/include/tvm/relax/attrs/op.h
+++ b/include/tvm/relax/attrs/op.h
@@ -57,6 +57,14 @@ struct CallTIRInplaceAttrs : public 
tvm::AttrsNode<CallTIRInplaceAttrs> {
   }
 };  // struct CallTIRInplaceAttrs
 
+/*! \brief Attributes used in to_vdevice */
+struct ToVDeviceAttrs : public tvm::AttrsNode<ToVDeviceAttrs> {
+  VDevice dst_vdevice;
+  TVM_DECLARE_ATTRS(ToVDeviceAttrs, "relax.attrs.ToVDeviceAttrs") {
+    TVM_ATTR_FIELD(dst_vdevice).describe("The destination device where the 
data is copied to.");
+  }
+};  // struct ToVDeviceAttrs
+
 }  // namespace relax
 }  // namespace tvm
 
diff --git a/include/tvm/relax/struct_info.h b/include/tvm/relax/struct_info.h
index 385c320db1..d2bf525225 100644
--- a/include/tvm/relax/struct_info.h
+++ b/include/tvm/relax/struct_info.h
@@ -156,6 +156,10 @@ class TensorStructInfoNode : public StructInfoNode {
    * \note shape must be normalized: it can only be NullOpt or ShapeExpr or 
Var.
    */
   Optional<Expr> shape;
+  /*! \brief The virtual device, indicates where the tensor
+   *  is expected to be executed.
+   */
+  Optional<VDevice> vdevice;
   /*! \brief The content data type, use void to denote the dtype is unknown. */
   DataType dtype;
   /*!
@@ -180,17 +184,20 @@ class TensorStructInfoNode : public StructInfoNode {
   void VisitAttrs(AttrVisitor* v) {
     v->Visit("shape", &shape);
     v->Visit("dtype", &dtype);
+    v->Visit("vdevice", &vdevice);
     v->Visit("ndim", &ndim);
     v->Visit("span", &span);
   }
 
   bool SEqualReduce(const TensorStructInfoNode* other, SEqualReducer equal) 
const {
-    return equal(shape, other->shape) && equal(ndim, other->ndim) && 
equal(dtype, other->dtype);
+    return equal(shape, other->shape) && equal(ndim, other->ndim) &&
+           equal(vdevice, other->vdevice) && equal(dtype, other->dtype);
   }
 
   void SHashReduce(SHashReducer hash_reduce) const {
     hash_reduce(shape);
     hash_reduce(dtype);
+    hash_reduce(vdevice);
     hash_reduce(ndim);
   }
 
@@ -208,19 +215,23 @@ class TensorStructInfo : public StructInfo {
    * \brief Construction with a known shape expression.
    * \param shape The shape of the tensor.
    * \param dtype The data type of tensor's elements.
+   * \param vdevice The virtual device.
    * \param span The span of the AST.
    *
    * \note shape must already be normalized.
    */
-  TVM_DLL TensorStructInfo(Expr shape, DataType dtype, Span span = Span());
+  TVM_DLL TensorStructInfo(Expr shape, DataType dtype, VDevice vdevice = 
VDevice(),
+                           Span span = Span());
 
   /*!
    * \brief Construction with an unknown shape expression.
    * \param dtype The data type of tensor's elements.
    * \param ndim The number of dimensions
+   * \param vdevice The virtual device.
    * \param span The span of the AST.
    */
-  TVM_DLL TensorStructInfo(DataType dtype, int ndim, Span span = Span());
+  TVM_DLL TensorStructInfo(DataType dtype, int ndim, VDevice vdevice = 
VDevice(),
+                           Span span = Span());
 
   TVM_DEFINE_OBJECT_REF_METHODS(TensorStructInfo, StructInfo, 
TensorStructInfoNode);
 };
diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h
index 7a0e003038..f4286512e5 100644
--- a/include/tvm/relay/transform.h
+++ b/include/tvm/relay/transform.h
@@ -47,6 +47,18 @@ using PassInfoNode = tvm::transform::PassInfoNode;
 using PassContext = tvm::transform::PassContext;
 using PassContextNode = tvm::transform::PassContextNode;
 using Sequential = tvm::transform::Sequential;
+using FTVMRelayToTIR = tvm::transform::Pass;
+/*!
+ * \brief TIRToRuntime conversion specific to a TargetKind
+ *
+ * This function is responsible for scanning an IRModule for appropriate 
Target-specific functions
+ and generating a Runtime module representing the compiled output
+ *
+ * \param ir_module Unified IRModule
+ * \param target Target to filter on or retrieve arguments from
+ * \return Runtime Module containing compiled functions
+ */
+using FTVMTIRToRuntime = 
tvm::runtime::TypedPackedFunc<runtime::Module(IRModule, Target)>;
 
 /*
  * \brief Create a function pass.
diff --git a/include/tvm/script/printer/ir_docsifier.h 
b/include/tvm/script/printer/ir_docsifier.h
index 156daebf00..1163464738 100644
--- a/include/tvm/script/printer/ir_docsifier.h
+++ b/include/tvm/script/printer/ir_docsifier.h
@@ -145,6 +145,8 @@ class IRDocsifierNode : public Object {
   std::unordered_map<ObjectRef, VariableInfo, ObjectPtrHash, ObjectPtrEqual> 
obj2info;
   /*! \brief Metadata printing */
   std::unordered_map<String, Array<ObjectRef>> metadata;
+  /*! \brief GlobalInfo printing */
+  std::unordered_map<String, Array<GlobalInfo>> global_infos;
   /*! \brief The variable names used already */
   std::unordered_set<String> defined_names;
   /*! \brief Common prefixes of variable usages */
@@ -206,6 +208,11 @@ class IRDocsifierNode : public Object {
   Optional<ExprDoc> GetVarDoc(const ObjectRef& obj) const;
   /*! \brief Add a TVM object to the metadata section*/
   ExprDoc AddMetadata(const ObjectRef& obj);
+  /*! \brief Add a GlobalInfo to the global_infos map.
+   * \param name The name of key of global_infos.
+   * \param ginfo The GlobalInfo to be added.
+   */
+  void AddGlobalInfo(const String& name, const GlobalInfo& ginfo);
   /*!
    * \brief Check if a variable exists in the table.
    * \param obj The variable object.
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 56d6a596b9..d47ac94e06 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -25,7 +25,6 @@
 #define TVM_TARGET_TARGET_H_
 
 #include <tvm/ir/expr.h>
-#include <tvm/ir/module.h>
 #include <tvm/node/node.h>
 #include <tvm/support/with.h>
 #include <tvm/target/target_kind.h>
@@ -284,14 +283,5 @@ class Target : public ObjectRef {
  */
 void CheckAndUpdateHostConsistency(Target* target, Target* host);
 
-/*!
- * \brief Check and update host field of the given legacy heterogeneous 
targets and
- *  target host.Note that this function is for legacy target api compatibility 
issue only,
- *  not recommended for other use.
- * \param ir_modules The pointer to a Map objects with keys being Target 
objects
- * \param host The Target typed object for target host to be updated
- */
-void CheckAndUpdateHostConsistency(Map<Target, IRModule>* ir_modules, Target* 
host);
-
 }  // namespace tvm
 #endif  // TVM_TARGET_TARGET_H_
diff --git a/include/tvm/target/target_kind.h b/include/tvm/target/target_kind.h
index 19bcce3116..10808fd12d 100644
--- a/include/tvm/target/target_kind.h
+++ b/include/tvm/target/target_kind.h
@@ -24,7 +24,6 @@
 #ifndef TVM_TARGET_TARGET_KIND_H_
 #define TVM_TARGET_TARGET_KIND_H_
 
-#include <tvm/ir/transform.h>
 #include <tvm/node/attr_registry_map.h>
 #include <tvm/node/node.h>
 
@@ -50,31 +49,7 @@ using TargetFeatures = Map<String, ObjectRef>;
  * \return The transformed Target JSON object.
  */
 using TargetJSON = Map<String, ObjectRef>;
-using FTVMTargetParser = TypedPackedFunc<TargetJSON(TargetJSON)>;
-
-/*!
- * \brief RelayToTIR tvm::transform::Pass specific to a TargetKind
- *
- * Called before the default lowering passes.
- *
- * \param mod The module that an optimization pass runs on.
- * \param pass_ctx The pass context that can provide information for the 
optimization.
- *
- * \return The transformed module.
- */
-using FTVMRelayToTIR = transform::Pass;
-
-/*!
- * \brief TIRToRuntime conversion specific to a TargetKind
- *
- * This function is responsible for scanning an IRModule for appropriate 
Target-specific functions
- and generating a Runtime module representing the compiled output
- *
- * \param ir_module Unified IRModule
- * \param target Target to filter on or retrieve arguments from
- * \return Runtime Module containing compiled functions
- */
-using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, 
Target)>;
+using FTVMTargetParser = tvm::runtime::TypedPackedFunc<TargetJSON(TargetJSON)>;
 
 namespace detail {
 template <typename, typename, typename>
diff --git a/include/tvm/tir/usmp/utils.h b/include/tvm/tir/usmp/utils.h
index f49e9ceef7..a67350a2bb 100644
--- a/include/tvm/tir/usmp/utils.h
+++ b/include/tvm/tir/usmp/utils.h
@@ -27,6 +27,7 @@
 
 #include <tvm/ir/expr.h>
 #include <tvm/ir/memory_pools.h>
+#include <tvm/ir/module.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/target/target.h>
 #include <tvm/tir/stmt.h>
diff --git a/python/tvm/ir/__init__.py b/python/tvm/ir/__init__.py
index 49c2cf6348..939a5f6383 100644
--- a/python/tvm/ir/__init__.py
+++ b/python/tvm/ir/__init__.py
@@ -35,7 +35,7 @@ from .base import (
 from .container import Array, Map
 from .expr import BaseExpr, GlobalVar, PrimExpr, Range, RelayExpr
 from .function import BaseFunc, CallingConv
-from .global_info import GlobalInfo, DummyGlobalInfo
+from .global_info import GlobalInfo, DummyGlobalInfo, VDevice
 from .memory_pools import (
     ConstantMemoryPools,
     ConstantPoolInfo,
diff --git a/python/tvm/ir/global_info.py b/python/tvm/ir/global_info.py
index 17011e76a6..458a16717b 100644
--- a/python/tvm/ir/global_info.py
+++ b/python/tvm/ir/global_info.py
@@ -40,3 +40,15 @@ class DummyGlobalInfo(GlobalInfo):
         self.__init_handle_by_constructor__(
             _ffi_api.DummyGlobalInfo,
         )
+
+
+class VDevice(GlobalInfo):
+    def __init__(
+        self,
+        target=None,
+        vdevice_id: int = 0,
+        memory_scope: str = "global",
+    ) -> None:
+        if isinstance(target, (dict, str)):
+            target = tvm.target.Target(tvm.runtime.convert(target))
+        self.__init_handle_by_constructor__(_ffi_api.VDevice, target, 
vdevice_id, memory_scope)
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 6ce2a8b9e2..224932b00c 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -57,6 +57,21 @@ def create_updater(node_map, from_ver, to_ver):
     return _updater
 
 
+def create_updater_13_to_14():
+    """Create an update to upgrade json from v0.13 to v0.14 for TVM Unity"""
+
+    def _update_vdevice(item, _):
+        if "vdevice" not in item["attrs"]:
+            item["attrs"]["vdevice"] = "0"
+        return item
+
+    node_map = {
+        "relax.TensorStructInfo": _update_vdevice,
+    }
+
+    return create_updater(node_map, "0.13", "0.14")
+
+
 def create_updater_08_to_09():
     """
     Create an update to upgrade json from v0.8 to v0.9
@@ -259,6 +274,8 @@ def upgrade_json(json_str):
         data = create_updater_08_to_09()(create_updater_07_to_08()(data))
     elif from_version.startswith("0.8"):
         data = create_updater_08_to_09()(data)
+    elif from_version.startswith("0.13"):
+        data = create_updater_13_to_14()(data)
     else:
         raise ValueError(f"Cannot update from version {from_version}")
     return json.dumps(data, indent=2)
diff --git a/python/tvm/relax/op/base.py b/python/tvm/relax/op/base.py
index 25c70e0493..1d49c00ea8 100644
--- a/python/tvm/relax/op/base.py
+++ b/python/tvm/relax/op/base.py
@@ -690,3 +690,24 @@ def invoke_pure_closure(
         sinfo_args = [sinfo_args]
 
     return _ffi_api.invoke_pure_closure(closure, args, sinfo_args)  # type: 
ignore
+
+
+def to_vdevice(data, dst_vdevice) -> Expr:
+    """Copy data to the destination device. This
+    operator helps data transferring between difference devices for
+    heterogeneous execution.
+
+    Parameters
+    ----------
+    data : Expr
+        The tensor to be copied.
+
+    dst_device : Union[:py:class:`Device`, str]
+        The destination device where the data is copied to.
+
+    Returns
+    -------
+    result : Expr
+        The copied result.
+    """
+    return _ffi_api.to_vdevice(data, dst_vdevice)  # type: ignore
diff --git a/python/tvm/relax/struct_info.py b/python/tvm/relax/struct_info.py
index 3dcc3dc9a0..e78e1cf69a 100644
--- a/python/tvm/relax/struct_info.py
+++ b/python/tvm/relax/struct_info.py
@@ -16,14 +16,14 @@
 # under the License.
 # pylint: disable=invalid-name, unused-import
 """The struct info nodes of the Relax language."""
-from typing import List, Optional, Tuple, Union
+from typing import List, Optional, Union
 
 import tvm._ffi
 import tvm
 
-from tvm.ir import Span, Node, EnvFunc, Array, Type
+from tvm.ir import Span, EnvFunc, Array, VDevice
 from tvm.tir import PrimExpr
-from .expr import StructInfo, Var, Expr, ShapeExpr
+from .expr import StructInfo, Expr, ShapeExpr
 
 from . import _ffi_api, ty, expr
 
@@ -93,6 +93,9 @@ class TensorStructInfo(StructInfo):
     dtype : Optional[str]
         The content data type.
 
+    vdevice : Optional[Vdevice]
+        The virtual device.
+
     ndim : Optional[int]
        The number of dimensions of the tensor.
 
@@ -103,6 +106,7 @@ class TensorStructInfo(StructInfo):
 
     shape: Optional[Expr]
     dtype: str
+    vdevice: Optional[VDevice]
     ndim: int
     span: Span
 
@@ -110,14 +114,14 @@ class TensorStructInfo(StructInfo):
         self,
         shape: Union[Optional[Expr], List[PrimExpr]] = None,
         dtype: str = "float32",
+        vdevice: Union[Optional[VDevice], str] = None,
         ndim: int = -1,
         span: Span = None,
     ) -> None:
         if isinstance(shape, (list, tuple, Array)):
             shape = ShapeExpr(shape)
-
         self.__init_handle_by_constructor__(
-            _ffi_api.TensorStructInfo, shape, dtype, ndim, span  # type: ignore
+            _ffi_api.TensorStructInfo, shape, dtype, ndim, vdevice, span  # 
type: ignore
         )
 
 
diff --git a/python/tvm/script/ir_builder/ir/__init__.py 
b/python/tvm/script/ir_builder/ir/__init__.py
index 68eda2cfee..fdf44b2b79 100644
--- a/python/tvm/script/ir_builder/ir/__init__.py
+++ b/python/tvm/script/ir_builder/ir/__init__.py
@@ -22,5 +22,7 @@ from .ir import (
     ir_module,
     module_attrs,
     module_global_infos,
+    lookup_vdevice,
+    vdevice,
     dummy_global_info,
 )
diff --git a/python/tvm/script/ir_builder/ir/ir.py 
b/python/tvm/script/ir_builder/ir/ir.py
index 53c48b4cc5..0d3523ec7d 100644
--- a/python/tvm/script/ir_builder/ir/ir.py
+++ b/python/tvm/script/ir_builder/ir/ir.py
@@ -18,7 +18,7 @@
 
 from typing import Dict, List
 
-from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, DummyGlobalInfo
+from tvm.ir import BaseFunc, GlobalVar, GlobalInfo, VDevice, DummyGlobalInfo
 from tvm.runtime import Object as tvm_Object
 
 
@@ -104,3 +104,39 @@ def dummy_global_info() -> DummyGlobalInfo:
         The result dummy global info.
     """
     return DummyGlobalInfo()  # type: ignore[attr-defined] # pylint: 
disable=no-member
+
+
+def vdevice(target=None, vdevice_id: int = 0, memory_scope: str = "global") -> 
VDevice:
+    """Create a virtual device global info.
+    Parameters
+    ----------
+    target
+        The target.
+    vdevice_id: int
+        The virtual device index.
+    memory_scope: str
+        The memory scope, default is "global"
+
+    Returns
+    -------
+    res : VDevice
+        The result virtual device.
+    """
+    return VDevice(target, vdevice_id, memory_scope)  # type: 
ignore[attr-defined] # pylint: disable=no-member
+
+
+def lookup_vdevice(target_kind: str = None, device_index: int = -1) -> VDevice:
+    """Retrieve a virtual device from the globalinfo vdevice list.
+    Parameters
+    ----------
+    target_kind: str
+        The target device kind, for example 'llvm' or 'cuda'.
+    device_index: int
+        The virtual device index.
+
+    Returns
+    -------
+    res : VDevice
+        The result virtual device.
+    """
+    return _ffi_api.LookupVDevice(target_kind, device_index)  # type: 
ignore[attr-defined] # pylint: disable=no-member
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 5bb0374d35..8a538c1868 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -140,6 +140,7 @@ from tvm.relax.op import (
     tanh,
     erf,
     tile,
+    to_vdevice,
     tril,
     triu,
     unique,
@@ -693,6 +694,7 @@ __all__ = [
     "tan",
     "tanh",
     "tile",
+    "to_vdevice",
     "tril",
     "triu",
     "tuple",
diff --git a/python/tvm/script/parser/ir/__init__.py 
b/python/tvm/script/parser/ir/__init__.py
index ec518f8573..3a8196288d 100644
--- a/python/tvm/script/parser/ir/__init__.py
+++ b/python/tvm/script/parser/ir/__init__.py
@@ -27,4 +27,6 @@ __all__ = [
     "module_global_infos",
     "dummy_global_info",
     "Range",
+    "lookup_vdevice",
+    "vdevice",
 ]
diff --git a/python/tvm/script/parser/relax/dist.py 
b/python/tvm/script/parser/relax/dist.py
index 120d57ca56..f9c78f980f 100644
--- a/python/tvm/script/parser/relax/dist.py
+++ b/python/tvm/script/parser/relax/dist.py
@@ -78,7 +78,11 @@ def DTensor(
         raise ValueError(f"shape must be a list or tuple, but got: {shape}")
     if isinstance(device_mesh, str):
         if not IRBuilder.is_in_scope():
-            return (DTensorProxy(TensorProxy(shape, dtype, ndim), 
DeviceMesh([], Range(0, 1)), ""),)
+            return (
+                DTensorProxy(
+                    TensorProxy(shape, dtype, None, ndim), DeviceMesh([], 
Range(0, 1)), ""
+                ),
+            )
         name, index = device_mesh.split("[")
         index = int(index[:-1])
         frames = IRBuilder.current().frames
@@ -89,7 +93,7 @@ def DTensor(
         assert isinstance(device_mesh, DeviceMesh)
     if isinstance(placement, str):
         placement = Placement.from_text(placement)
-    return DTensorProxy(TensorProxy(shape, dtype, ndim), device_mesh, 
placement)
+    return DTensorProxy(TensorProxy(shape, dtype, None, ndim), device_mesh, 
placement)
 
 
 __all__ = ["DTensor", "device_mesh"]
diff --git a/python/tvm/script/parser/relax/entry.py 
b/python/tvm/script/parser/relax/entry.py
index ff237a5600..1c18d75be4 100644
--- a/python/tvm/script/parser/relax/entry.py
+++ b/python/tvm/script/parser/relax/entry.py
@@ -37,6 +37,7 @@ from tvm.runtime import ObjectGeneric
 from tvm.tir import PrimExpr
 
 from .._core import parse, utils
+from ..ir import lookup_vdevice
 
 FType = TypeVar("FType", bound=_Callable)
 
@@ -103,12 +104,14 @@ def _eval_shape(expr: Union[str, PrimExpr], dict_globals: 
Optional[Dict[str, Any
 class TensorProxy(StructInfoProxy):
     shape: Optional[List[Union[str, PrimExpr]]]
     dtype: str
+    vdevice: Optional[str]
     ndim: int
 
     def __init__(
         self,
         shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
         dtype: Optional[str] = None,
+        vdevice: Optional[str] = None,
         ndim: int = -1,
     ) -> None:
         if isinstance(shape, Expr):
@@ -124,6 +127,7 @@ class TensorProxy(StructInfoProxy):
                 )
         self.shape = shape
         self.dtype = dtype
+        self.vdevice = vdevice
         self.ndim = ndim
 
     def get_symbolic_vars(self) -> Set[str]:
@@ -133,10 +137,18 @@ class TensorProxy(StructInfoProxy):
             return {s for s in self.shape if isinstance(s, str) and 
s.isidentifier()}
 
     def as_struct_info(self, dict_globals: Optional[Dict[str, Any]] = None) -> 
TensorStructInfo:
+        vdev = self.vdevice
+        if isinstance(self.vdevice, str):
+            if ":" in self.vdevice:
+                split_vdev = self.vdevice.split(":")
+                vdev = lookup_vdevice(split_vdev[0], int(split_vdev[1]))
+            else:
+                vdev = lookup_vdevice(self.vdevice, 0)
+
         if self.shape is None:
-            return TensorStructInfo(None, self.dtype, self.ndim)
+            return TensorStructInfo(None, self.dtype, vdev, self.ndim)
         elif isinstance(self.shape, (ShapeExpr, Var)):
-            return TensorStructInfo(self.shape, self.dtype, self.ndim)
+            return TensorStructInfo(self.shape, self.dtype, vdev, self.ndim)
         else:
             if dict_globals is None and any([isinstance(s, str) for s in 
self.shape]):
                 raise ValueError(
@@ -144,12 +156,13 @@ class TensorProxy(StructInfoProxy):
                     "and return annotations for TVMScript."
                 )
             shape = [_eval_shape(s, dict_globals) for s in self.shape]
-            return TensorStructInfo(shape, self.dtype, self.ndim)
+            return TensorStructInfo(shape, self.dtype, vdev, self.ndim)
 
 
 def Tensor(
     shape: Optional[Union[List[Union[PrimExpr, str]], Expr]] = None,
     dtype: Optional[str] = None,
+    vdevice: Optional[str] = None,
     ndim: int = -1,
 ) -> TensorProxy:
     # scalar tensor case
@@ -161,7 +174,7 @@ def Tensor(
 
     if shape is not None and not isinstance(shape, (tuple, list)) and not 
isinstance(shape, Expr):
         raise ValueError(f"shape must be a list/tuple or an Expr, but got: 
{shape}")
-    return TensorProxy(shape, dtype, ndim)
+    return TensorProxy(shape, dtype, vdevice, ndim)
 
 
 ############################## R.Callable ##############################
diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc
index d46fab7168..b7ba0ffe44 100644
--- a/src/driver/driver_api.cc
+++ b/src/driver/driver_api.cc
@@ -431,6 +431,23 @@ std::pair<IRModule, IRModule> SplitMixedModule(IRModule 
mod_mixed, const Target&
   return {host_mod, device_mod};
 }
 
+/*!
+ * \brief Check and update host field of the given legacy heterogeneous 
targets and
+ *  target host.Note that this function is for legacy target api compatibility 
issue only,
+ *  not recommended for other use.
+ * \param ir_modules The pointer to a Map objects with keys being Target 
objects
+ * \param host The Target typed object for target host to be updated
+ */
+void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* 
host) {
+  Map<Target, IRModule> new_targets;
+  for (auto& it : *targets) {
+    auto target = it.first;
+    CheckAndUpdateHostConsistency(&target, host);
+    new_targets.Set(target, it.second);
+  }
+  *targets = new_targets;
+}
+
 runtime::Module TIRToRuntime(const Map<Target, IRModule>& inputs_arg,
                              const Target& target_host_arg) {
   std::vector<runtime::Module> device_modules;
diff --git a/src/ir/global_info.cc b/src/ir/global_info.cc
index 48f56d60d6..f1ecc8cd04 100644
--- a/src/ir/global_info.cc
+++ b/src/ir/global_info.cc
@@ -29,4 +29,17 @@ 
TVM_REGISTER_GLOBAL("ir.DummyGlobalInfo").set_body_typed([]() {
   auto n = DummyGlobalInfo(make_object<DummyGlobalInfoNode>());
   return n;
 });
+
+VDevice::VDevice(Target tgt = {}, int dev_id = -1, MemoryScope mem_scope = {}) 
{
+  ObjectPtr<VDeviceNode> n = make_object<VDeviceNode>();
+  n->target = std::move(tgt);
+  n->vdevice_id = std::move(dev_id);
+  n->memory_scope = std::move(mem_scope);
+  data_ = std::move(n);
+}
+
+TVM_REGISTER_NODE_TYPE(VDeviceNode);
+TVM_REGISTER_GLOBAL("ir.VDevice").set_body_typed([](Target tgt, int dev_id, 
MemoryScope mem_scope) {
+  return VDevice(tgt, dev_id, mem_scope);
+});
 }  // namespace tvm
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index 7006f71198..82ccdf33ea 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -149,19 +149,24 @@ class WellDefinedEraser : public StructInfoMutator,
       std::swap(has_undefined_, has_undefined);
     }
 
+    VDevice vdev = VDevice();
+    if (op->vdevice.defined()) {
+      vdev = op->vdevice.value();
+    }
+
     // erase symbolic shape if we have undefined.
     if (!has_undefined) {
       if (shape.same_as(op->shape)) {
         return GetRef<StructInfo>(op);
       } else {
         if (shape.defined()) {
-          return TensorStructInfo(shape.value(), op->dtype, op->span);
+          return TensorStructInfo(shape.value(), op->dtype, vdev, op->span);
         } else {
-          return TensorStructInfo(op->dtype, op->ndim, op->span);
+          return TensorStructInfo(op->dtype, op->ndim, vdev, op->span);
         }
       }
     } else {
-      return TensorStructInfo(op->dtype, op->ndim, op->span);
+      return TensorStructInfo(op->dtype, op->ndim, vdev, op->span);
     }
   }
 
@@ -767,6 +772,16 @@ class StructInfoLCAFinder
     // find the target dtype and ndim.
     DataType dtype = lhs->dtype == rhs->dtype ? lhs->dtype : DataType::Void();
     int ndim = lhs->ndim == rhs->ndim ? lhs->ndim : kUnknownNDim;
+    VDevice vdev = VDevice();
+    if (lhs->vdevice.defined() && rhs->vdevice.defined()) {
+      if (lhs->vdevice.value().same_as(lhs->vdevice.value())) {
+        vdev = lhs->vdevice.value();
+      }
+    } else if (lhs->vdevice.defined()) {
+      vdev = lhs->vdevice.value();
+    } else if (rhs->vdevice.defined()) {
+      vdev = rhs->vdevice.value();
+    }
     // if ndim mismatch or one side of shape is missing
     // then we cannot keep in symbolic shape
     if (lhs->ndim != rhs->ndim || !lhs->shape.defined() || 
!rhs->shape.defined() ||
@@ -775,12 +790,12 @@ class StructInfoLCAFinder
       if (!lhs->shape.defined() && lhs->dtype == dtype && lhs->ndim == ndim) {
         return GetRef<StructInfo>(lhs);
       } else {
-        return TensorStructInfo(dtype, ndim, lhs->span);
+        return TensorStructInfo(dtype, ndim, vdev, lhs->span);
       }
     }
     // symbolic shape match but dtype mismatch
     if (lhs->dtype != dtype) {
-      return TensorStructInfo(lhs->shape.value(), dtype, lhs->span);
+      return TensorStructInfo(lhs->shape.value(), dtype, vdev, lhs->span);
     } else {
       return GetRef<StructInfo>(lhs);
     }
diff --git a/src/relax/backend/vm/vm_builtin_lower.cc 
b/src/relax/backend/vm/vm_builtin_lower.cc
index 6087c2bb25..784b3c9fd5 100644
--- a/src/relax/backend/vm/vm_builtin_lower.cc
+++ b/src/relax/backend/vm/vm_builtin_lower.cc
@@ -21,6 +21,7 @@
  * \brief Lowers most builtin functions and packed calls.
  */
 #include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/op.h>
 #include <tvm/relax/backend.h>
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/type.h>
@@ -46,6 +47,8 @@ class VMBuiltinLowerMutator : public ExprMutator {
       return Reshape(call);
     } else if (call->op == shape_of_op_) {
       return ShapeOf(call);
+    } else if (call->op == to_vdevice_op_) {
+      return ToDevice(call);
     } else if (call->op == make_closure_op_) {
       return MakeClosure(call);
     } else if (call->op == invoke_closure_op_) {
@@ -156,6 +159,22 @@ class VMBuiltinLowerMutator : public ExprMutator {
     return Call(builtin_shape_of_, call_node->args, Attrs(), 
{GetStructInfo(call_node)});
   }
 
+  Expr ToDevice(const Call& call_node) {
+    // TODO(yongwww): replace ToVDeviceAttrs with related Expr
+    ICHECK(call_node->args.size() == 1);
+    ICHECK(call_node->struct_info_.defined());
+    auto attrs = call_node->attrs.as<ToVDeviceAttrs>();
+    Array<Expr> args;
+    args.push_back(call_node->args[0]);
+    // Get the DLDeviceType and device_id from VDevice
+    VDevice vdev = attrs->dst_vdevice;
+    int dev_type = vdev->target->GetTargetDeviceType();
+    int dev_id = vdev->vdevice_id;
+    args.push_back(PrimValue::Int64(dev_type));
+    args.push_back(PrimValue::Int64(dev_id));
+    return Call(builtin_to_device_, args, call_node->attrs, 
{GetStructInfo(call_node)});
+  }
+
   Expr MakeClosure(const Call& call_node) {
     ICHECK(call_node->args.size() == 2);
     ICHECK(call_node->args[0]->IsInstance<GlobalVarNode>());
@@ -198,6 +217,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
   const Op& call_tir_dyn_op_ = Op::Get("relax.vm.call_tir_dyn");
   const Op& reshape_op_ = Op::Get("relax.reshape");
   const Op& shape_of_op_ = Op::Get("relax.shape_of");
+  const Op& to_vdevice_op_ = Op::Get("relax.to_vdevice");
   const Op& make_closure_op_ = Op::Get("relax.make_closure");
   const Op& invoke_closure_op_ = Op::Get("relax.invoke_closure");
   const Op& alloc_tensor_op_ = Op::Get("relax.builtin.alloc_tensor");
@@ -214,6 +234,7 @@ class VMBuiltinLowerMutator : public ExprMutator {
   const ExternFunc builtin_call_tir_dyn_{"vm.builtin.call_tir_dyn"};
   const ExternFunc builtin_reshape_{"vm.builtin.reshape"};
   const ExternFunc builtin_shape_of_{"vm.builtin.shape_of"};
+  const ExternFunc builtin_to_device_{"vm.builtin.to_device"};
   const ExternFunc builtin_make_closure_{"vm.builtin.make_closure"};
   const ExternFunc builtin_invoke_closure_{"vm.builtin.invoke_closure"};
 };
diff --git a/src/relax/ir/expr.cc b/src/relax/ir/expr.cc
index ccff18cd40..ac04096aaf 100644
--- a/src/relax/ir/expr.cc
+++ b/src/relax/ir/expr.cc
@@ -292,7 +292,7 @@ Constant::Constant(runtime::NDArray data, 
Optional<StructInfo> struct_info_annot
     n->struct_info_ = struct_info_annotation.value();
     n->checked_type_ = GetStaticType(struct_info_annotation.value());
   } else {
-    TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), span);
+    TensorStructInfo tinfo(ShapeExpr(values), n->data.DataType(), VDevice(), 
span);
     n->struct_info_ = tinfo;
     n->checked_type_ = DynTensorType(tinfo->ndim, tinfo->dtype);
   }
diff --git a/src/relax/ir/struct_info.cc b/src/relax/ir/struct_info.cc
index c290711dcd..31784af000 100644
--- a/src/relax/ir/struct_info.cc
+++ b/src/relax/ir/struct_info.cc
@@ -92,7 +92,7 @@ TVM_REGISTER_GLOBAL("relax.ShapeStructInfo")
     });
 
 // Tensor
-TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, Span span) {
+TensorStructInfo::TensorStructInfo(Expr shape, DataType dtype, VDevice 
vdevice, Span span) {
   ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
   // assign ndim before move
   Optional<ShapeStructInfo> sinfo = MatchStructInfo<ShapeStructInfo>(shape);
@@ -104,15 +104,17 @@ TensorStructInfo::TensorStructInfo(Expr shape, DataType 
dtype, Span span) {
   // assign rest of the fields.
   n->shape = std::move(shape);
   n->dtype = dtype;
+  n->vdevice = vdevice;
   n->span = span;
   data_ = std::move(n);
 }
 
-TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, Span span) {
+TensorStructInfo::TensorStructInfo(DataType dtype, int ndim, VDevice vdevice, 
Span span) {
   ObjectPtr<TensorStructInfoNode> n = make_object<TensorStructInfoNode>();
   CHECK_GE(ndim, -1) << "ndim of TensorStructInfo must be >= -1, but got " << 
ndim;
   n->ndim = ndim;
   n->dtype = dtype;
+  n->vdevice = vdevice;
   n->span = span;
   data_ = std::move(n);
 }
@@ -120,12 +122,12 @@ TensorStructInfo::TensorStructInfo(DataType dtype, int 
ndim, Span span) {
 TVM_REGISTER_NODE_TYPE(TensorStructInfoNode);
 
 TVM_REGISTER_GLOBAL("relax.TensorStructInfo")
-    .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, Span 
span) {
+    .set_body_typed([](Optional<Expr> shape, DataType dtype, int ndim, VDevice 
vdevice, Span span) {
       if (shape.defined()) {
         CHECK_EQ(ndim, kUnknownNDim) << "ValueError: Cannot both specify shape 
and ndim";
-        return TensorStructInfo(shape.value(), dtype, span);
+        return TensorStructInfo(shape.value(), dtype, vdevice, span);
       } else {
-        return TensorStructInfo(dtype, ndim, span);
+        return TensorStructInfo(dtype, ndim, vdevice, span);
       }
     });
 
diff --git a/src/relax/ir/struct_info_functor.cc 
b/src/relax/ir/struct_info_functor.cc
index 72ea623e07..c998d8c0b2 100644
--- a/src/relax/ir/struct_info_functor.cc
+++ b/src/relax/ir/struct_info_functor.cc
@@ -94,10 +94,15 @@ StructInfo StructInfoMutator::VisitStructInfo_(const 
TensorStructInfoNode* op) {
     shape = this->VisitStructInfoExprField(op->shape.value());
   }
 
+  VDevice vdev = VDevice();
+  if (op->vdevice.defined()) {
+    vdev = op->vdevice.value();
+  }
+
   if (shape.same_as(op->shape)) {
     return GetRef<StructInfo>(op);
   } else {
-    return TensorStructInfo(shape.value(), op->dtype, op->span);
+    return TensorStructInfo(shape.value(), op->dtype, vdev, op->span);
   }
 }
 
diff --git a/src/relax/op/op.cc b/src/relax/op/op.cc
index af93b43dcf..1f4f7d5c34 100644
--- a/src/relax/op/op.cc
+++ b/src/relax/op/op.cc
@@ -866,5 +866,33 @@ Expr MakeStopLiftParams(Expr x) {
 
 
TVM_REGISTER_GLOBAL("relax.op.builtin.stop_lift_params").set_body_typed(MakeStopLiftParams);
 
+// to_vdevice
+TVM_REGISTER_NODE_TYPE(ToVDeviceAttrs);
+
+StructInfo InferToVDeviceStructInfo(const Call& call, const BlockBuilder& ctx) 
{
+  ICHECK(call->args.size() == 1);
+  ICHECK(call->args[0]->struct_info_.defined());
+  TensorStructInfo data_sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  return data_sinfo;
+}
+
+RELAY_REGISTER_OP("relax.to_vdevice")
+    .set_num_inputs(1)
+    .set_attrs_type<ToVDeviceAttrs>()
+    .add_argument("data", "Expr", "The input expression to be copied")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferToVDeviceStructInfo)
+    .set_attr<Bool>("FPurity", Bool(true));
+
+Expr MakeToVDevice(Expr data, VDevice dst_vdevice) {
+  static const Op& op = Op::Get("relax.to_vdevice");
+  // TODO(@yongwww): replace Attr with TensorStructInfo
+  ObjectPtr<ToVDeviceAttrs> attrs = make_object<ToVDeviceAttrs>();
+  attrs->dst_vdevice = dst_vdevice;
+
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.to_vdevice").set_body_typed(MakeToVDevice);
+
 }  // namespace relax
 }  // namespace tvm
diff --git a/src/relax/transform/convert_layout.cc 
b/src/relax/transform/convert_layout.cc
index 91dcd5d8e8..dd09dd67b8 100644
--- a/src/relax/transform/convert_layout.cc
+++ b/src/relax/transform/convert_layout.cc
@@ -267,7 +267,11 @@ class LayoutConvertMutator : public ExprMutator {
         new_shape.push_back(
             
shape->values[from.LeafValue()->layout.IndexOf(to.LeafValue()->layout[i])]);
       }
-      return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, 
tsinfo->span);
+      VDevice vdev = VDevice();
+      if (tsinfo->vdevice.defined()) {
+        vdev = tsinfo->vdevice.value();
+      }
+      return TensorStructInfo(ShapeExpr(new_shape), tsinfo->dtype, vdev, 
tsinfo->span);
     };
     StructInfo new_struct_info = TransformTupleLeaf<LayoutDecision>(
         binding->struct_info, std::array<NLayout, 2>({from_layout, 
input_layout}), fvisitleaf);
diff --git a/src/relax/transform/to_mixed_precision.cc 
b/src/relax/transform/to_mixed_precision.cc
index 64763276d0..d12d1080b9 100644
--- a/src/relax/transform/to_mixed_precision.cc
+++ b/src/relax/transform/to_mixed_precision.cc
@@ -289,7 +289,11 @@ class ToMixedPrecisionRewriter : public ExprMutator {
       if (fp16_input_names_.count(var->name_hint())) {
         auto sinfo = GetStructInfo(var);
         if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
-          TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(), 
DataType::Float(16),
+          VDevice vdev = VDevice();
+          if (tensor_sinfo->vdevice.defined()) {
+            vdev = tensor_sinfo->vdevice.value();
+          }
+          TensorStructInfo fp16_sinfo(tensor_sinfo->shape.value(), 
DataType::Float(16), vdev,
                                       tensor_sinfo->span);
           Var fp16_var(var->vid, fp16_sinfo, var->span);
           var_remap_[var->vid] = fp16_var;
diff --git a/src/relay/backend/contrib/cmsisnn/target.cc 
b/src/relay/backend/contrib/cmsisnn/target.cc
index f14c106703..527fba98c0 100644
--- a/src/relay/backend/contrib/cmsisnn/target.cc
+++ b/src/relay/backend/contrib/cmsisnn/target.cc
@@ -37,8 +37,8 @@ TVM_REGISTER_TARGET_KIND("cmsis-nn", kDLCPU)
     .add_attr_option<Array<String>>("mattr")
     .add_attr_option<String>("mcpu")
     .add_attr_option<Bool>("debug_last_error")
-    .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
-    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
+    .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, 
RelayToTIR())
+    .set_attr<relay::transform::FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime)
     .set_target_parser(tvm::target::parsers::cpu::ParseTarget);
 
 }  // namespace cmsisnn
diff --git a/src/relay/backend/contrib/codegen_c/target.cc 
b/src/relay/backend/contrib/codegen_c/target.cc
index 623057ac17..cd1e0283df 100644
--- a/src/relay/backend/contrib/codegen_c/target.cc
+++ b/src/relay/backend/contrib/codegen_c/target.cc
@@ -34,7 +34,7 @@ namespace contrib {
  */
 TVM_REGISTER_TARGET_KIND("ccompiler", kDLCPU)
     .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
-    .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, CCompilerPass())
+    .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, 
CCompilerPass())
     // Value is prepended to every output CModule.
     .add_attr_option<String>("header", String(""));
 
diff --git a/src/relay/backend/contrib/cutlass/target.cc 
b/src/relay/backend/contrib/cutlass/target.cc
index 7b377f340a..50c8b84a90 100644
--- a/src/relay/backend/contrib/cutlass/target.cc
+++ b/src/relay/backend/contrib/cutlass/target.cc
@@ -40,7 +40,7 @@ namespace cutlass {
  */
 TVM_REGISTER_TARGET_KIND("cutlass", kDLCUDA)
     .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
-    .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForCutlass())
+    .set_attr<tvm::transform::Pass>("RelayToTIR", CompileForCutlass())
     // An integer specifying the compute capability. For example, 75 for 
Turing and
     // 80 or 86 for Ampere.
     .add_attr_option<Integer>("sm", Integer(80))
diff --git a/src/relay/backend/contrib/ethosu/codegen.cc 
b/src/relay/backend/contrib/ethosu/codegen.cc
index f35d4c6d48..2e635455e9 100644
--- a/src/relay/backend/contrib/ethosu/codegen.cc
+++ b/src/relay/backend/contrib/ethosu/codegen.cc
@@ -320,8 +320,8 @@ runtime::Module TIRToRuntime(IRModule mod, Target target) {
 
 TVM_REGISTER_TARGET_KIND("ethos-u", kDLCPU)
     .set_attr<Bool>("use_device_api", Bool(true))
-    .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, RelayToTIR())
-    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", TIRToRuntime);
+    .set_attr<relay::transform::FTVMRelayToTIR>(tvm::attr::kRelayToTIR, 
RelayToTIR())
+    .set_attr<relay::transform::FTVMTIRToRuntime>("TIRToRuntime", 
TIRToRuntime);
 
 }  // namespace ethosu
 }  // namespace contrib
diff --git a/src/relay/backend/contrib/example_target_hooks/target.cc 
b/src/relay/backend/contrib/example_target_hooks/target.cc
index b01c23ed80..275efaa933 100644
--- a/src/relay/backend/contrib/example_target_hooks/target.cc
+++ b/src/relay/backend/contrib/example_target_hooks/target.cc
@@ -33,8 +33,10 @@ runtime::Module TIRToRuntime(IRModule mod, Target target);
 
 TVM_REGISTER_TARGET_KIND("example_target_hook", kDLCPU)
     .set_attr<Bool>("use_device_api", Bool(true))
-    .set_attr<FTVMRelayToTIR>(attr::kRelayToTIR, 
relay::contrib::example_target_hooks::RelayToTIR())
-    .set_attr<FTVMTIRToRuntime>("TIRToRuntime", 
relay::contrib::example_target_hooks::TIRToRuntime)
+    .set_attr<relay::transform::FTVMRelayToTIR>(attr::kRelayToTIR,
+                                                
relay::contrib::example_target_hooks::RelayToTIR())
+    .set_attr<relay::transform::FTVMTIRToRuntime>(
+        "TIRToRuntime", relay::contrib::example_target_hooks::TIRToRuntime)
     .add_attr_option<Integer>("example_attribute", Integer(0));
 
 }  // namespace tvm
diff --git a/src/relay/backend/contrib/tensorrt/target.cc 
b/src/relay/backend/contrib/tensorrt/target.cc
index 2e4581d30a..0277787a8c 100644
--- a/src/relay/backend/contrib/tensorrt/target.cc
+++ b/src/relay/backend/contrib/tensorrt/target.cc
@@ -39,7 +39,7 @@ namespace tensorrt {
  */
 TVM_REGISTER_TARGET_KIND("tensorrt", kDLCUDA)
     .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true))
-    .set_attr<FTVMRelayToTIR>("RelayToTIR", CompileForTensorRT())
+    .set_attr<tvm::transform::Pass>("RelayToTIR", CompileForTensorRT())
     // A array of three integers given the major, minor, and patch numbers for 
the supported
     // TensorRT compiler version. If empty will be auto-detected from linked 
library. Default empty.
     .add_attr_option<Array<Integer>>("tensorrt_version", Array<Integer>())
diff --git a/src/relay/backend/contrib/uma/targets.cc 
b/src/relay/backend/contrib/uma/targets.cc
index e2fe644cb9..d01f5b4c73 100644
--- a/src/relay/backend/contrib/uma/targets.cc
+++ b/src/relay/backend/contrib/uma/targets.cc
@@ -46,20 +46,20 @@ 
TVM_REGISTER_GLOBAL("relay.backend.contrib.uma.RegisterTarget")
         }
       }
 
-      auto target_kind =
-          TargetKindRegEntry::RegisterOrGet(target_name)
-              .set_name()
-              .set_default_device_type(kDLCPU)
-              .add_attr_option<Array<String>>("keys")
-              .add_attr_option<String>("tag")
-              .add_attr_option<String>("device")
-              .add_attr_option<String>("model")
-              .add_attr_option<Array<String>>("libs")
-              .add_attr_option<Target>("host")
-              .add_attr_option<Integer>("from_device")
-              .set_attr<FTVMRelayToTIR>(attr::kRelayToTIR,
-                                        
relay::contrib::uma::RelayToTIR(target_name))
-              .set_attr<FTVMTIRToRuntime>("TIRToRuntime", 
relay::contrib::uma::TIRToRuntime);
+      auto target_kind = TargetKindRegEntry::RegisterOrGet(target_name)
+                             .set_name()
+                             .set_default_device_type(kDLCPU)
+                             .add_attr_option<Array<String>>("keys")
+                             .add_attr_option<String>("tag")
+                             .add_attr_option<String>("device")
+                             .add_attr_option<String>("model")
+                             .add_attr_option<Array<String>>("libs")
+                             .add_attr_option<Target>("host")
+                             .add_attr_option<Integer>("from_device")
+                             .set_attr<relay::transform::FTVMRelayToTIR>(
+                                 attr::kRelayToTIR, 
relay::contrib::uma::RelayToTIR(target_name))
+                             .set_attr<relay::transform::FTVMTIRToRuntime>(
+                                 "TIRToRuntime", 
relay::contrib::uma::TIRToRuntime);
 
       // target kind attrs inventory
       auto kind = TargetKind::Get(target_name).value();
diff --git a/src/runtime/relax_vm/builtin.cc b/src/runtime/relax_vm/builtin.cc
index 86f0152ce7..8b27bb2d9e 100644
--- a/src/runtime/relax_vm/builtin.cc
+++ b/src/runtime/relax_vm/builtin.cc
@@ -327,6 +327,12 @@ 
TVM_REGISTER_GLOBAL("vm.builtin.null_value").set_body([](TVMArgs args, TVMRetVal
   *rv = nullptr;
 });
 
+TVM_REGISTER_GLOBAL("vm.builtin.to_device")
+    .set_body_typed([](NDArray data, int dev_type, int dev_id) {
+      Device dst_device = {(DLDeviceType)dev_type, dev_id};
+      return data.CopyTo(dst_device);
+    });
+
 /*!
  * \brief Load the scalar value in cond and return the result value.
  * \param cond The condition
diff --git a/src/script/ir_builder/ir/ir.cc b/src/script/ir_builder/ir/ir.cc
index fb51886a7d..2f2785ca44 100644
--- a/src/script/ir_builder/ir/ir.cc
+++ b/src/script/ir_builder/ir/ir.cc
@@ -110,11 +110,41 @@ void ModuleGlobalInfos(Map<String, Array<GlobalInfo>> 
global_infos) {
   }
 }
 
+VDevice LookupVDevice(String target_kind, int device_index) {
+  if (IRBuilder::IsInScope()) {
+    IRModuleFrame frame = FindModuleFrame();
+    if (frame->global_infos.empty()) {
+      LOG(FATAL) << "ValueError: The GlobalInfos in the IRModule is not 
defined.";
+    }
+    Array<GlobalInfo> vdevices = frame->global_infos["vdevice"];
+    if (vdevices.empty() || device_index < 0 ||
+        static_cast<size_t>(device_index) >= vdevices.size()) {
+      LOG(FATAL) << "ValueError: The target VDevice in the GlobalInfos was not 
found.";
+    }
+    if (target_kind == "vdevice") {
+      return Downcast<VDevice>(vdevices[device_index]);
+    }
+    int count = 0;
+    for (auto vdevice : vdevices) {
+      auto vdev = Downcast<VDevice>(vdevice);
+      if (vdev->target->kind->name == target_kind) {
+        if (count == device_index) {
+          return vdev;
+        }
+        count++;
+      }
+    }
+  }
+  LOG(WARNING) << "The annotated device was not found, please check your 
vdevice list.";
+  return VDevice();
+}
+
 TVM_REGISTER_GLOBAL("script.ir_builder.ir.IRModule").set_body_typed(IRModule);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DeclFunction").set_body_typed(DeclFunction);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.DefFunction").set_body_typed(DefFunction);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleAttrs").set_body_typed(ModuleAttrs);
 
TVM_REGISTER_GLOBAL("script.ir_builder.ir.ModuleGlobalInfos").set_body_typed(ModuleGlobalInfos);
+TVM_REGISTER_GLOBAL("script.ir_builder.ir.LookupVDevice").set_body_typed(LookupVDevice);
 
 }  // namespace ir
 }  // namespace ir_builder
diff --git a/src/script/printer/ir/ir.cc b/src/script/printer/ir/ir.cc
index ecf92897a5..a239481d03 100644
--- a/src/script/printer/ir/ir.cc
+++ b/src/script/printer/ir/ir.cc
@@ -118,6 +118,16 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
       return IR(d, "dummy_global_info")->Call({});
     });
 
+TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
+    .set_dispatch<VDevice>("", [](VDevice vdev, ObjectPath p, IRDocsifier d) 
-> Doc {
+      d->AddGlobalInfo("vdevice", vdev);
+      Map<String, ObjectRef> config = vdev->target->Export();
+      return IR(d, "vdevice")
+          ->Call({d->AsDoc<ExprDoc>(config, p),
+                  LiteralDoc::Int(vdev->vdevice_id, p->Attr("vdevice_id")),
+                  LiteralDoc::Str(vdev->memory_scope, 
p->Attr("memory_scope"))});
+    });
+
 TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
     .set_dispatch<Op>("", [](Op op, ObjectPath p, IRDocsifier d) -> Doc {
       return IR(d, "Op")->Call({LiteralDoc::Str(op->name, p->Attr("name"))});
diff --git a/src/script/printer/ir_docsifier.cc 
b/src/script/printer/ir_docsifier.cc
index 521ab07359..a424863495 100644
--- a/src/script/printer/ir_docsifier.cc
+++ b/src/script/printer/ir_docsifier.cc
@@ -63,6 +63,12 @@ ExprDoc IRDocsifierNode::AddMetadata(const ObjectRef& obj) {
   return IdDoc("metadata")[{LiteralDoc::Str(key, 
NullOpt)}][{LiteralDoc::Int(index, NullOpt)}];
 }
 
+void IRDocsifierNode::AddGlobalInfo(const String& name, const GlobalInfo& 
ginfo) {
+  ICHECK(ginfo.defined()) << "TypeError: Cannot add nullptr to global_infos";
+  Array<GlobalInfo>& array = global_infos[name];
+  array.push_back(ginfo);
+}
+
 bool IRDocsifierNode::IsVarDefined(const ObjectRef& obj) const { return 
obj2info.count(obj); }
 
 void IRDocsifierNode::RemoveVar(const ObjectRef& obj) {
diff --git a/src/script/printer/relax/struct_info.cc 
b/src/script/printer/relax/struct_info.cc
index 49162bb824..7fab5b59a2 100644
--- a/src/script/printer/relax/struct_info.cc
+++ b/src/script/printer/relax/struct_info.cc
@@ -111,6 +111,13 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
             kwargs_keys.push_back("ndim");
             kwargs_values.push_back(LiteralDoc::Int(n->ndim, 
n_p->Attr("ndim")));
           }
+          if (n->vdevice.defined()) {
+            kwargs_keys.push_back("vdevice");
+            std::string dev_kind = n->vdevice.value()->target->kind->name;
+            int dev_index = FindVDeviceIndexByTargetKind(n->vdevice.value(), 
d);
+            kwargs_values.push_back(
+                LiteralDoc::Str(dev_kind + ":" + std::to_string(dev_index), 
n_p->Attr("vdevice")));
+          }
           if (args.empty() && kwargs_keys.empty()) {
             return Relax(d, "Tensor");
           }
diff --git a/src/script/printer/relax/utils.h b/src/script/printer/relax/utils.h
index 88fc7491c2..e0b5348d73 100644
--- a/src/script/printer/relax/utils.h
+++ b/src/script/printer/relax/utils.h
@@ -97,6 +97,22 @@ Array<StmtDoc> PrintSeqExpr(const relax::SeqExpr& n, const 
ObjectPath& n_p, cons
 
 ExprDoc PrintShapeVar(const PrimExpr& e, const ObjectPath& e_p, const 
IRDocsifier& d);
 
+inline int FindVDeviceIndexByTargetKind(const VDevice& vdevice, const 
IRDocsifier& d) {
+  Array<GlobalInfo> vdevices = d->global_infos["vdevice"];
+  int kind_index = 0;
+  for (size_t i = 0; i < vdevices.size(); ++i) {
+    auto vdev = Downcast<VDevice>(vdevices[i]);
+    if (vdev.same_as(vdevice)) {
+      return kind_index;
+    }
+    if (vdev->target->kind->name == vdevice->target->kind->name) {
+      kind_index++;
+    }
+  }
+  LOG(WARNING) << "The VDevice was not found in the global_infos map: " << 
vdevice;
+  return -1;
+}
+
 }  // namespace printer
 }  // namespace script
 }  // namespace tvm
diff --git a/src/target/codegen.cc b/src/target/codegen.cc
index bbb2c15a64..55af8889e1 100644
--- a/src/target/codegen.cc
+++ b/src/target/codegen.cc
@@ -38,6 +38,9 @@
 #include <vector>
 
 namespace tvm {
+
+using FTVMTIRToRuntime = runtime::TypedPackedFunc<runtime::Module(IRModule, 
Target)>;
+
 namespace codegen {
 
 runtime::Module Build(IRModule mod, Target target) {
diff --git a/src/target/target.cc b/src/target/target.cc
index 2f585188d0..cd2e3714e4 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -21,6 +21,7 @@
  * \file src/target/target.cc
  */
 #include <dmlc/thread_local.h>
+#include <tvm/ir/transform.h>
 #include <tvm/runtime/device_api.h>
 #include <tvm/runtime/logging.h>
 #include <tvm/runtime/registry.h>
@@ -91,16 +92,6 @@ void CheckAndUpdateHostConsistency(Target* target, Target* 
host) {
   *host = (*target)->GetHost().value_or(Target());
 }
 
-void CheckAndUpdateHostConsistency(Map<Target, IRModule>* targets, Target* 
host) {
-  Map<Target, IRModule> new_targets;
-  for (auto& it : *targets) {
-    auto target = it.first;
-    CheckAndUpdateHostConsistency(&target, host);
-    new_targets.Set(target, it.second);
-  }
-  *targets = new_targets;
-}
-
 static std::vector<String> DeduplicateKeys(const std::vector<String>& keys) {
   std::vector<String> new_keys;
   for (size_t i = 0; i < keys.size(); ++i) {
@@ -614,8 +605,8 @@ Target::Target(TargetKind kind, Optional<ObjectRef> host, 
String tag, Array<Stri
 bool Target::IsExternalCodegen() const {
   TargetKindAttrMap<Bool> is_external_codegen_map =
       TargetKind::GetAttrMap<Bool>(tvm::attr::kIsExternalCodegen);
-  TargetKindAttrMap<FTVMRelayToTIR> relay_to_tir_map =
-      TargetKind::GetAttrMap<FTVMRelayToTIR>(tvm::attr::kRelayToTIR);
+  TargetKindAttrMap<tvm::transform::Pass> relay_to_tir_map =
+      TargetKind::GetAttrMap<tvm::transform::Pass>(tvm::attr::kRelayToTIR);
   return is_external_codegen_map.get(get()->kind, Bool(false)) ||
          relay_to_tir_map.count(get()->kind);
 }
diff --git a/tests/cpp/target_test.cc b/tests/cpp/target_test.cc
index 37a8eeb448..9d05023e79 100644
--- a/tests/cpp/target_test.cc
+++ b/tests/cpp/target_test.cc
@@ -458,7 +458,7 @@ TVM_REGISTER_TARGET_KIND("test_external_codegen_2", 
kDLMetal)
     .set_attr<Bool>(tvm::attr::kIsExternalCodegen, Bool(true));
 
 TVM_REGISTER_TARGET_KIND("test_external_codegen_3", kDLCPU)
-    .set_attr<FTVMRelayToTIR>(tvm::attr::kRelayToTIR, 
tvm::relay::transform::InferType());
+    .set_attr<tvm::transform::Pass>(tvm::attr::kRelayToTIR, 
tvm::relay::transform::InferType());
 
 TEST(Target, ExternalCodegen) {
   Target regular("cuda");
diff --git a/python/tvm/ir/global_info.py 
b/tests/python/relax/test_json_compact.py
similarity index 53%
copy from python/tvm/ir/global_info.py
copy to tests/python/relax/test_json_compact.py
index 17011e76a6..1320ff1cd6 100644
--- a/python/tvm/ir/global_info.py
+++ b/tests/python/relax/test_json_compact.py
@@ -14,29 +14,37 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""Global Info."""
-import tvm
-from tvm.runtime.object import Object
-from . import _ffi_api
-
 
-class GlobalInfo(Object):
-    """Base node for all global info that can appear in the IR"""
-
-    def __eq__(self, other):
-        """Compare two struct info for structural equivalence."""
-        return tvm.ir.structural_equal(self, other)
+import tvm
+import tvm.testing
+from tvm import relax
+import json
 
-    def __ne__(self, other):
-        return not self.__eq__(other)
 
-    def same_as(self, other):
-        """Overload with structural equality."""
-        return super().__eq__(other)
+# 0.13 BACKWARDS COMPATIBILITY TESTS
+def test_vdevice():
+    nodes = [
+        {"type_key": ""},
+        {
+            "type_key": "relax.TensorStructInfo",
+            "attrs": {
+                "dtype": "float32",
+                "ndim": "-1",
+                "shape": "0",
+                "span": "0",
+            },
+        },
+    ]
+    data = {
+        "root": 1,
+        "nodes": nodes,
+        "attrs": {"tvm_version": "0.13.0"},
+        "b64ndarrays": [],
+    }
+    tsinfo = tvm.ir.load_json(json.dumps(data))
+    assert isinstance(tsinfo, relax.TensorStructInfo)
+    assert not tsinfo.vdevice
 
 
-class DummyGlobalInfo(GlobalInfo):
-    def __init__(self) -> None:
-        self.__init_handle_by_constructor__(
-            _ffi_api.DummyGlobalInfo,
-        )
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_relax_operators.py 
b/tests/python/relax/test_relax_operators.py
index 90608df4b6..e0904477d4 100644
--- a/tests/python/relax/test_relax_operators.py
+++ b/tests/python/relax/test_relax_operators.py
@@ -247,5 +247,35 @@ def test_op_call_pure_packed():
     assert (copy_found.numpy() == arr).all()
 
 
+def test_op_to_device():
+    @tvm.script.ir_module
+    class CallToDevice:
+        @R.function
+        def to_dev(x: R.Tensor((3, 4), "float32")):
+            z = R.call_pure_packed(
+                "vm.builtin.to_device",
+                x,
+                1,
+                0,
+                sinfo_args=(R.Tensor((3, 4), dtype="float32")),
+            )
+            return z
+
+    np.random.seed(0)  # to avoid flakiness
+    arr = np.random.rand(3, 4).astype("float32")
+    copy_found = run_cpu(CallToDevice, "to_dev", tvm.nd.array(arr))
+    assert (copy_found.numpy() == arr).all()
+
+
+def test_op_to_vdevice():
+    @tvm.script.ir_module
+    class ToVDevice:
+        @R.function
+        def to_vdev(x: R.Tensor((3, 4), "float32")):
+            dst_vdev = tvm.ir.VDevice("llvm", 0, "global")
+            ret = R.to_vdevice(x, dst_vdev)
+            return ret
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser.py 
b/tests/python/relax/test_tvmscript_parser.py
index bc324fe364..39a4d33ca6 100644
--- a/tests/python/relax/test_tvmscript_parser.py
+++ b/tests/python/relax/test_tvmscript_parser.py
@@ -22,7 +22,7 @@ import tvm
 import tvm.script
 import tvm.testing
 from tvm import IRModule, relax, tir, topi
-from tvm.ir import DummyGlobalInfo
+from tvm.ir import VDevice, DummyGlobalInfo
 from tvm.script.parser import ir as I
 from tvm.script.parser import relax as R
 from tvm.script.parser import tir as T
@@ -303,6 +303,56 @@ def test_module_with_attr_and_global_info():
     _check(TestModule, mod)
 
 
+def test_global_info_vdevice():
+    vdevices = [
+        VDevice("llvm"),
+        VDevice("cuda", 0),
+        VDevice("cuda -arch=sm_80", 0),
+        VDevice("metal", 0, "global"),
+    ]
+
+    @I.ir_module
+    class TestModule:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("cuda", 0),
+                    I.vdevice("cuda -arch=sm_80", 0),
+                    I.vdevice("metal", 0, "global"),
+                ]
+            }
+        )
+
+        @T.prim_func(private=True)
+        def tir_func(
+            x: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+            y: T.Buffer((T.int64(128), T.int64(128)), "float32"),
+        ):
+            T.func_attr({"tir.noalias": True})
+            for i, j in T.grid(T.int64(128), T.int64(128)):
+                with T.block():
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    y[vi, vj] = x[vi, vj] + 1.0
+
+        @R.function
+        def foo(x: R.Tensor((128, 128), "float32")) -> R.Tensor((128, 128), 
"float32"):
+            cls = TestModule
+            gv0 = R.call_tir(cls.tir_func, x, R.Tensor((128, 128), 
dtype="float32"))
+            return gv0
+
+    x = relax.Var("x", R.Tensor((128, 128), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (x,)):
+        out = bb.emit_te(lambda x: x + 1, x, primfunc_name_hint="tir_func")
+        bb.emit_func_output(out)
+    mod = bb.get()
+    mod.update_global_info("vdevice", vdevices)
+    mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+    _check(TestModule, mod)
+
+
 def test_relax_tensor_op():
     @R.function
     def foo(x: R.Tensor((4, 4), "float32")) -> R.Tensor((4, 4), "float32"):
@@ -714,6 +764,51 @@ def test_tensor_type_without_args():
     _check(foo, bb.get()["foo"])
 
 
+def test_tensor_with_vdevice():
+    vdevices = [
+        VDevice("llvm"),
+        VDevice("cuda", 0),
+        VDevice("metal", 0, "global"),
+        VDevice("cuda -arch=sm_80", 0),
+    ]
+
+    @I.ir_module
+    class TestModule:
+        I.module_attrs({"attr": 10})
+        I.module_global_infos(
+            {
+                "vdevice": [
+                    I.vdevice("llvm"),
+                    I.vdevice("cuda", 0),
+                    I.vdevice("metal", 0, "global"),
+                    I.vdevice("cuda -arch=sm_80", 0),
+                ]
+            }
+        )
+
+        @R.function
+        def foo(
+            a: R.Tensor((128, 128), "float32", "cuda:1"),  # noqa: F722
+            b: R.Tensor((128, 128), "float32", "llvm"),
+            c: R.Tensor((128, 128), "float32", "vdevice:3"),  # noqa: F722
+        ) -> R.Tensor((128, 128), "float32"):
+            s = R.add(a, c)
+            return s
+
+    a = relax.Var("a", R.Tensor((128, 128), "float32", vdevices[3]))
+    b = relax.Var("b", R.Tensor((128, 128), "float32", vdevices[0]))
+    c = relax.Var("c", R.Tensor((128, 128), "float32", vdevices[3]))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (a, b, c)):
+        out = bb.emit(relax.op.add(a, c))
+        bb.emit_func_output(out)
+    mod = bb.get()
+    mod = mod.with_attr("attr", tvm.tir.IntImm("int32", 10))
+    mod.update_global_info("vdevice", vdevices)
+
+    _check(TestModule, mod)
+
+
 def test_direct_return():
     @R.function
     def foo(x: R.Tensor((32, 32), "float32")) -> R.Tensor((32, 32), "float32"):
diff --git a/tests/python/relax/test_tvmscript_parser_op_manipulate.py 
b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
index a80e8aad37..25f4f08520 100644
--- a/tests/python/relax/test_tvmscript_parser_op_manipulate.py
+++ b/tests/python/relax/test_tvmscript_parser_op_manipulate.py
@@ -403,5 +403,20 @@ def test_flip():
     _check(foo, bb.get()["foo"])
 
 
+def test_to_vdevice():
+    @R.function
+    def foo(x: R.Tensor((), "int32")) -> R.Tensor((), "int32"):
+        tensor = R.to_vdevice(x, tvm.ir.VDevice("llvm", 0, "global"))
+        return tensor
+
+    x = relax.Var("x", R.Tensor((), "int32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", (x,)):
+        tensor = bb.emit(relax.op.to_vdevice(x, tvm.ir.VDevice("llvm", 0, 
"global")))
+        bb.emit_func_output(tensor)
+
+    _check(foo, bb.get()["foo"])
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git a/tests/python/relax/test_vm_build.py 
b/tests/python/relax/test_vm_build.py
index 0b34d24540..085b6137ac 100644
--- a/tests/python/relax/test_vm_build.py
+++ b/tests/python/relax/test_vm_build.py
@@ -725,6 +725,41 @@ def test_recursion(exec_mode):
     tvm.testing.assert_allclose(res.numpy(), np.power(2.0, recursion_runs), 
rtol=1e-7, atol=1e-7)
 
 
+@tvm.testing.requires_gpu
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_to_device(exec_mode):
+    @tvm.script.ir_module
+    class TestToVDevice:
+        @R.function
+        def foo1(
+            x: R.Tensor((2, 3), "float32"),
+        ) -> R.Tensor((2, 3), "float32"):
+            copied = R.to_vdevice(x, tvm.ir.VDevice("cuda", 0, "global"))
+            return copied
+
+        @R.function
+        def foo2(
+            x: R.Tensor((2, 3), "float32"),
+        ) -> R.Tensor((2, 3), "float32"):
+            copied = R.to_vdevice(x, tvm.ir.VDevice("llvm", 0, "global"))
+            return copied
+
+    mod = TestToVDevice
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = relax.build(mod, target, exec_mode=exec_mode)
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    x_inp = tvm.nd.array(np.random.rand(2, 3).astype("float32"))
+    res_1 = check_saved_func(vm, "foo1", x_inp)
+    res_2 = check_saved_func(vm, "foo2", x_inp)
+
+    # check the copied tensor's device
+    assert str(res_1.device) == "cuda(0)"
+    assert str(res_2.device) == "cpu(0)"
+
+    tvm.testing.assert_allclose(res_1.numpy(), x_inp.numpy())
+    tvm.testing.assert_allclose(res_2.numpy(), x_inp.numpy())
+
+
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_vm_closure(exec_mode):
     @tvm.script.ir_module
diff --git a/tests/python/relax/test_vm_codegen_only.py 
b/tests/python/relax/test_vm_codegen_only.py
index ffa9837d02..d9fb130f3c 100644
--- a/tests/python/relax/test_vm_codegen_only.py
+++ b/tests/python/relax/test_vm_codegen_only.py
@@ -57,6 +57,33 @@ def test_vm_copy(exec_mode):
     tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
 
 
+@pytest.mark.parametrize("exec_mode", EXEC_MODE)
+def test_vm_to_device(exec_mode):
+    @tvm.script.ir_module
+    class TestVMToDevice:
+        @R.function
+        def foo(x: R.Tensor((3, 4), "float32")):
+            R.func_attr({"global_symbol": "foo"})
+            # Copy x to the first cpu: device_type=1 and device_id=0.
+            # More device info. please take a look at 
python/tvm/_ffi/runtime_ctypes.py
+            z = R.call_packed(
+                "vm.builtin.to_device", x, 1, 0, sinfo_args=(R.Tensor((3, 4), 
dtype="float32"))
+            )
+            return z
+
+    mod = TestVMToDevice
+    target = tvm.target.Target("llvm", host="llvm")
+    ex = codegen(mod, target, exec_mode)
+    inp = tvm.nd.array(np.random.rand(3, 4).astype(np.float32))
+    vm = relax.VirtualMachine(ex, tvm.cpu())
+    res = check_saved_func(vm, "foo", inp)
+    tvm.testing.assert_allclose(res.numpy(), inp.numpy(), rtol=1e-7, atol=1e-7)
+    # check the resulting tensor is on cpu:0
+    assert str(res.device) == "cpu(0)"
+    assert res.device.device_type == 1
+    assert res.device.device_id == 0
+
+
 @pytest.mark.parametrize("exec_mode", EXEC_MODE)
 def test_if_cond_const(exec_mode):
     @tvm.script.ir_module

Reply via email to