This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 0dcd4d2  [CPP] Make Tensor/TensorView API to method-based (#121)
0dcd4d2 is described below

commit 0dcd4d2b9a3ca2c7b3614cdbe7792798c00e7513
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Oct 14 20:08:56 2025 -0400

    [CPP] Make Tensor/TensorView API to method-based (#121)
    
    This PR makes Tensor/View API mostly align to method based. instead of
    relying on operator->() to get DLTensor* which can be error prone.
---
 docs/get_started/quick_start.md                    |  20 +--
 docs/guides/python_guide.md                        |  14 +-
 examples/inline_module/main.py                     |  32 ++---
 examples/packaging/python/my_ffi_extension/base.py |  11 +-
 examples/packaging/src/extension.cc                |  14 +-
 examples/quick_start/src/add_one_cpu.cc            |  14 +-
 examples/quick_start/src/add_one_cuda.cu           |  18 +--
 examples/quick_start/src/run_example.cc            |   4 +-
 examples/quick_start/src/run_example_cuda.cc       |   4 +-
 include/tvm/ffi/container/tensor.h                 | 157 ++++++++++++++++-----
 python/tvm_ffi/cpp/load_inline.py                  |  28 ++--
 rust/tvm-ffi/scripts/generate_example_lib.py       |  14 +-
 src/ffi/extra/structural_equal.cc                  |  17 +--
 src/ffi/extra/structural_hash.cc                   |  14 +-
 tests/cpp/test_tensor.cc                           |  22 +--
 tests/python/test_build_inline.py                  |  14 +-
 tests/python/test_load_inline.py                   | 132 ++++++++---------
 17 files changed, 313 insertions(+), 216 deletions(-)

diff --git a/docs/get_started/quick_start.md b/docs/get_started/quick_start.md
index 2ad8a41..26a7b86 100644
--- a/docs/get_started/quick_start.md
+++ b/docs/get_started/quick_start.md
@@ -105,16 +105,16 @@ namespace ffi = tvm::ffi;
 
 void AddOne(ffi::TensorView x, ffi::TensorView y) {
   // Validate inputs
-  TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
   DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same 
shape";
+  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
 
   // Perform the computation
-  for (int i = 0; i < x->shape[0]; ++i) {
-    static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+  for (int i = 0; i < x.size(0); ++i) {
+    static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
   }
 }
 
@@ -135,17 +135,17 @@ void AddOneCUDA(ffi::TensorView x, ffi::TensorView y) {
   // Validation (same as CPU version)
   // ...
 
-  int64_t n = x->shape[0];
+  int64_t n = x.size(0);
   int64_t nthread_per_block = 256;
   int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
 
   // Get current CUDA stream from environment
   cudaStream_t stream = static_cast<cudaStream_t>(
-      TVMFFIEnvGetStream(x->device.device_type, x->device.device_id));
+      TVMFFIEnvGetStream(x.device().device_type, x.device().device_id));
 
   // Launch kernel
   AddOneKernel<<<nblock, nthread_per_block, 0, stream>>>(
-      static_cast<float*>(x->data), static_cast<float*>(y->data), n);
+      static_cast<float*>(x.data_ptr()), static_cast<float*>(y.data_ptr()), n);
 }
 
 TVM_FFI_DLL_EXPORT_TYPED_FUNC(add_one_cuda, tvm_ffi_example::AddOneCUDA);
diff --git a/docs/guides/python_guide.md b/docs/guides/python_guide.md
index 7b977e7..dc34cfe 100644
--- a/docs/guides/python_guide.md
+++ b/docs/guides/python_guide.md
@@ -153,14 +153,14 @@ import tvm_ffi.cpp
 cpp_source = '''
      void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
        // implementation of a library function
-       TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+       TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
        DLDataType f32_dtype{kDLFloat, 32, 1};
-       TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
-       TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-       TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
-       TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the 
same shape";
-       for (int i = 0; i < x->shape[0]; ++i) {
-         static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+       TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+       TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+       TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+       TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same 
shape";
+       for (int i = 0; i < x.size(0); ++i) {
+         static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
        }
      }
 '''
diff --git a/examples/inline_module/main.py b/examples/inline_module/main.py
index d0ba6fe..83f36ad 100644
--- a/examples/inline_module/main.py
+++ b/examples/inline_module/main.py
@@ -28,14 +28,14 @@ def main() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
 
@@ -51,24 +51,24 @@ def main() -> None:
 
             void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
 
-              int64_t n = x->shape[0];
+              int64_t n = x.size(0);
               int64_t nthread_per_block = 256;
               int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
               // Obtain the current stream from the environment
               // it will be set to torch.cuda.current_stream() when calling 
the function
               // with torch.Tensors
               cudaStream_t stream = static_cast<cudaStream_t>(
-                  TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id));
+                  TVMFFIEnvGetStream(x.device().device_type, 
x.device().device_id));
               // launch the kernel
-              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
-                                                                     
static_cast<float*>(y->data), n);
+              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x.data_ptr()),
+                                                                     
static_cast<float*>(y.data_ptr()), n);
             }
         """,
         functions=["add_one_cpu", "add_one_cuda"],
diff --git a/examples/packaging/python/my_ffi_extension/base.py 
b/examples/packaging/python/my_ffi_extension/base.py
index fb6f6c2..4165023 100644
--- a/examples/packaging/python/my_ffi_extension/base.py
+++ b/examples/packaging/python/my_ffi_extension/base.py
@@ -26,15 +26,22 @@ def _load_lib() -> tvm_ffi.Module:
     # first look at the directory of the current file
     file_dir = Path(__file__).resolve().parent
 
+    path_candidates = [
+        file_dir,
+        file_dir / ".." / ".." / "build",
+    ]
+
     if sys.platform.startswith("win32"):
         lib_dll_name = "my_ffi_extension.dll"
     elif sys.platform.startswith("darwin"):
         lib_dll_name = "my_ffi_extension.dylib"
     else:
         lib_dll_name = "my_ffi_extension.so"
+    for candidate in path_candidates:
+        for path in Path(candidate).glob(lib_dll_name):
+            return tvm_ffi.load_module(str(path))
 
-    lib_path = file_dir / lib_dll_name
-    return tvm_ffi.load_module(str(lib_path))
+    raise RuntimeError(f"Cannot find {lib_dll_name} in {path_candidates}")
 
 
 _LIB = _load_lib()
diff --git a/examples/packaging/src/extension.cc 
b/examples/packaging/src/extension.cc
index c99bd07..8d2c504 100644
--- a/examples/packaging/src/extension.cc
+++ b/examples/packaging/src/extension.cc
@@ -46,14 +46,14 @@ void RaiseError(ffi::String msg) { 
TVM_FFI_THROW(RuntimeError) << msg; }
 
 void AddOne(ffi::TensorView x, ffi::TensorView y) {
   // implementation of a library function
-  TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
   DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same 
shape";
-  for (int i = 0; i < x->shape[0]; ++i) {
-    static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
+  for (int i = 0; i < x.size(0); ++i) {
+    static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
   }
 }
 
diff --git a/examples/quick_start/src/add_one_cpu.cc 
b/examples/quick_start/src/add_one_cpu.cc
index 886af13..abc188e 100644
--- a/examples/quick_start/src/add_one_cpu.cc
+++ b/examples/quick_start/src/add_one_cpu.cc
@@ -25,14 +25,14 @@ namespace tvm_ffi_example {
 
 void AddOne(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
   // implementation of a library function
-  TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
   DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same 
shape";
-  for (int i = 0; i < x->shape[0]; ++i) {
-    static_cast<float*>(y->data)[i] = static_cast<float*>(x->data)[i] + 1;
+  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
+  for (int i = 0; i < x.size(0); ++i) {
+    static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
   }
 }
 
diff --git a/examples/quick_start/src/add_one_cuda.cu 
b/examples/quick_start/src/add_one_cuda.cu
index b15f807..07acfdb 100644
--- a/examples/quick_start/src/add_one_cuda.cu
+++ b/examples/quick_start/src/add_one_cuda.cu
@@ -33,24 +33,24 @@ __global__ void AddOneKernel(float* x, float* y, int n) {
 
 void AddOneCUDA(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
   // implementation of a library function
-  TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+  TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
   DLDataType f32_dtype{kDLFloat, 32, 1};
-  TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float tensor";
-  TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-  TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float tensor";
-  TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have the same 
shape";
+  TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float tensor";
+  TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+  TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float tensor";
+  TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the same shape";
 
-  int64_t n = x->shape[0];
+  int64_t n = x.size(0);
   int64_t nthread_per_block = 256;
   int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
   // Obtain the current stream from the environment
   // it will be set to torch.cuda.current_stream() when calling the function
   // with torch.Tensors
   cudaStream_t stream =
-      static_cast<cudaStream_t>(TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id));
+      static_cast<cudaStream_t>(TVMFFIEnvGetStream(x.device().device_type, 
x.device().device_id));
   // launch the kernel
-  AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
-                                                         
static_cast<float*>(y->data), n);
+  AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x.data_ptr()),
+                                                         
static_cast<float*>(y.data_ptr()), n);
 }
 
 // Expose global symbol `add_one_cpu` that follows tvm-ffi abi
diff --git a/examples/quick_start/src/run_example.cc 
b/examples/quick_start/src/run_example.cc
index 90e61d1..4b38343 100644
--- a/examples/quick_start/src/run_example.cc
+++ b/examples/quick_start/src/run_example.cc
@@ -38,7 +38,7 @@ int main() {
   // create an Tensor, alternatively, one can directly pass in a DLTensor*
   ffi::Tensor x = Empty({5}, DLDataType({kDLFloat, 32, 1}), DLDevice({kDLCPU, 
0}));
   for (int i = 0; i < 5; ++i) {
-    reinterpret_cast<float*>(x->data)[i] = static_cast<float>(i);
+    reinterpret_cast<float*>(x.data_ptr())[i] = static_cast<float>(i);
   }
 
   ffi::Function add_one_cpu = mod->GetFunction("add_one_cpu").value();
@@ -46,7 +46,7 @@ int main() {
 
   std::cout << "x after add_one_cpu(x, x)" << std::endl;
   for (int i = 0; i < 5; ++i) {
-    std::cout << reinterpret_cast<float*>(x->data)[i] << " ";
+    std::cout << reinterpret_cast<float*>(x.data_ptr())[i] << " ";
   }
   std::cout << std::endl;
   return 0;
diff --git a/examples/quick_start/src/run_example_cuda.cc 
b/examples/quick_start/src/run_example_cuda.cc
index 1fdd27c..21e7f49 100644
--- a/examples/quick_start/src/run_example_cuda.cc
+++ b/examples/quick_start/src/run_example_cuda.cc
@@ -70,7 +70,7 @@ int main() {
   }
 
   size_t nbytes = host_x.size() * sizeof(float);
-  cudaError_t err = cudaMemcpy(x->data, host_x.data(), nbytes, 
cudaMemcpyHostToDevice);
+  cudaError_t err = cudaMemcpy(x.data_ptr(), host_x.data(), nbytes, 
cudaMemcpyHostToDevice);
   TVM_FFI_ICHECK_EQ(err, cudaSuccess)
       << "cudaMemcpy host to device failed: " << cudaGetErrorString(err);
 
@@ -80,7 +80,7 @@ int main() {
   add_one_cuda(x, y);
 
   std::vector<float> host_y(host_x.size());
-  err = cudaMemcpy(host_y.data(), y->data, nbytes, cudaMemcpyDeviceToHost);
+  err = cudaMemcpy(host_y.data(), y.data_ptr(), nbytes, 
cudaMemcpyDeviceToHost);
   TVM_FFI_ICHECK_EQ(err, cudaSuccess)
       << "cudaMemcpy device to host failed: " << cudaGetErrorString(err);
 
diff --git a/include/tvm/ffi/container/tensor.h 
b/include/tvm/ffi/container/tensor.h
index eb0d900..d99a79a 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -243,6 +243,47 @@ class TensorObjFromDLPack : public TensorObj {
  */
 class Tensor : public ObjectRef {
  public:
+  /*!
+   * \brief Default constructor.
+   */
+  Tensor() = default;
+  /*!
+   * \brief Constructor from a ObjectPtr<TensorObj>.
+   * \param n The ObjectPtr<TensorObj>.
+   */
+  explicit Tensor(::tvm::ffi::ObjectPtr<TensorObj> n) : 
ObjectRef(std::move(n)) {}
+  /*!
+   * \brief Constructor from a UnsafeInit tag.
+   * \param tag The UnsafeInit tag.
+   */
+  explicit Tensor(::tvm::ffi::UnsafeInit tag) : ObjectRef(tag) {}
+  /// \cond Doxygen_Suppress
+  TVM_FFI_DEFINE_DEFAULT_COPY_MOVE_AND_ASSIGN(Tensor)
+  /// \endcond
+  /*!
+   * \brief Get the data pointer of the Tensor.
+   * \return The data pointer of the Tensor.
+   */
+  void* data_ptr() const { return get()->data; }
+
+  /*!
+   * \brief Get the device of the Tensor.
+   * \return The device of the Tensor.
+   */
+  DLDevice device() const { return get()->device; }
+
+  /*!
+   * \brief Get the number of dimensions in the Tensor.
+   * \return The number of dimensions in the Tensor.
+   */
+  int32_t ndim() const { return get()->ndim; }
+
+  /*!
+   * \brief Get the data type of the Tensor.
+   * \return The data type of the Tensor.
+   */
+  DLDataType dtype() const { return get()->dtype; }
+
   /*!
    * \brief Get the shape of the Tensor.
    * \return The shape of the Tensor.
@@ -251,6 +292,7 @@ class Tensor : public ObjectRef {
     const TensorObj* obj = get();
     return tvm::ffi::ShapeView(obj->shape, obj->ndim);
   }
+
   /*!
    * \brief Get the strides of the Tensor.
    * \return The strides of the Tensor.
@@ -262,28 +304,29 @@ class Tensor : public ObjectRef {
   }
 
   /*!
-   * \brief Get the data pointer of the Tensor.
-   * \return The data pointer of the Tensor.
+   * \brief Get the size of the idx-th dimension.
+   * \param idx The index of the size.
+   * \return The size of the idx-th dimension.
    */
-  void* data_ptr() const { return (*this)->data; }
+  int64_t size(size_t idx) const { return get()->shape[idx]; }
 
   /*!
-   * \brief Get the number of dimensions in the Tensor.
-   * \return The number of dimensions in the Tensor.
+   * \brief Get the stride of the idx-th dimension.
+   * \param idx The index of the stride.
+   * \return The stride of the idx-th dimension.
    */
-  int32_t ndim() const { return (*this)->ndim; }
+  int64_t stride(size_t idx) const { return get()->strides[idx]; }
 
   /*!
    * \brief Get the number of elements in the Tensor.
    * \return The number of elements in the Tensor.
    */
   int64_t numel() const { return this->shape().Product(); }
-
   /*!
-   * \brief Get the data type of the Tensor.
-   * \return The data type of the Tensor.
+   * \brief Get the byte offset of the Tensor.
+   * \return The byte offset of the Tensor.
    */
-  DLDataType dtype() const { return (*this)->dtype; }
+  uint64_t byte_offset() const { return get()->byte_offset; }
   /*!
    * \brief Check if the Tensor is contiguous.
    * \return True if the Tensor is contiguous, false otherwise.
@@ -445,12 +488,22 @@ class Tensor : public ObjectRef {
    * \return The converted DLPack managed tensor.
    */
   DLManagedTensorVersioned* ToDLPackVersioned() const { return 
get_mutable()->ToDLPackVersioned(); }
-
+  /*!
+   * \brief Get the underlying DLTensor pointer.
+   * \return The underlying DLTensor pointer.
+   */
+  const DLTensor* GetDLTensorPtr() const { return get(); }
   /// \cond Doxygen_Suppress
-  TVM_FFI_DEFINE_OBJECT_REF_METHODS_NULLABLE(Tensor, ObjectRef, TensorObj);
+  [[maybe_unused]] static constexpr bool _type_is_nullable = true;
+  using ContainerType = TensorObj;
   /// \endcond
 
  protected:
+  /*!
+   * \brief Get const internal container pointer.
+   * \return a const container pointer.
+   */
+  const TensorObj* get() const { return static_cast<const 
TensorObj*>(ObjectRef::get()); }
   /*!
    * \brief Get mutable internal container pointer.
    * \return a mutable container pointer.
@@ -481,7 +534,7 @@ class TensorView {
    */
   TensorView(const Tensor& tensor) {  // NOLINT(*)
     TVM_FFI_ICHECK(tensor.defined());
-    tensor_ = *tensor.operator->();
+    tensor_ = *tensor.GetDLTensorPtr();
   }  // NOLINT(*)
   /*!
    * \brief Create a TensorView from a DLTensor.
@@ -520,7 +573,7 @@ class TensorView {
    */
   TensorView& operator=(const Tensor& tensor) {
     TVM_FFI_ICHECK(tensor.defined());
-    tensor_ = *tensor.operator->();
+    tensor_ = *tensor.GetDLTensorPtr();
     return *this;
   }
 
@@ -529,17 +582,37 @@ class TensorView {
   // delete move assignment operator from owned tensor
   TensorView& operator=(Tensor&& tensor) = delete;
   /*!
-   * \brief Get the underlying DLTensor pointer.
-   * \return The underlying DLTensor pointer.
+   * \brief Get the data pointer of the Tensor.
+   * \return The data pointer of the Tensor.
    */
-  const DLTensor* operator->() const { return &tensor_; }
-
+  void* data_ptr() const { return tensor_.data; }
+  /*!
+   * \brief Get the device of the Tensor.
+   * \return The device of the Tensor.
+   */
+  DLDevice device() const { return tensor_.device; }
+  /*!
+   * \brief Get the number of dimensions in the Tensor.
+   * \return The number of dimensions in the Tensor.
+   */
+  int32_t ndim() const { return tensor_.ndim; }
+  /*!
+   * \brief Get the data type of the Tensor.
+   * \return The data type of the Tensor.
+   */
+  DLDataType dtype() const { return tensor_.dtype; }
   /*!
    * \brief Get the shape of the Tensor.
    * \return The shape of the Tensor.
    */
   ShapeView shape() const { return ShapeView(tensor_.shape, tensor_.ndim); }
 
+  /*!
+   * \brief Get the number of elements in the Tensor.
+   * \return The number of elements in the Tensor.
+   */
+  int64_t numel() const { return this->shape().Product(); }
+
   /*!
    * \brief Get the strides of the Tensor.
    * \return The strides of the Tensor.
@@ -550,28 +623,24 @@ class TensorView {
   }
 
   /*!
-   * \brief Get the data pointer of the Tensor.
-   * \return The data pointer of the Tensor.
-   */
-  void* data_ptr() const { return tensor_.data; }
-
-  /*!
-   * \brief Get the number of dimensions in the Tensor.
-   * \return The number of dimensions in the Tensor.
+   * \brief Get the size of the idx-th dimension.
+   * \param idx The index of the size.
+   * \return The size of the idx-th dimension.
    */
-  int32_t ndim() const { return tensor_.ndim; }
+  int64_t size(size_t idx) const { return tensor_.shape[idx]; }
 
   /*!
-   * \brief Get the number of elements in the Tensor.
-   * \return The number of elements in the Tensor.
+   * \brief Get the stride of the idx-th dimension.
+   * \param idx The index of the stride.
+   * \return The stride of the idx-th dimension.
    */
-  int64_t numel() const { return this->shape().Product(); }
+  int64_t stride(size_t idx) const { return tensor_.strides[idx]; }
 
   /*!
-   * \brief Get the data type of the Tensor.
-   * \return The data type of the Tensor.
+   * \brief Get the byte offset of the Tensor.
+   * \return The byte offset of the Tensor.
    */
-  DLDataType dtype() const { return tensor_.dtype; }
+  uint64_t byte_offset() const { return tensor_.byte_offset; }
 
   /*!
    * \brief Check if the Tensor is contiguous.
@@ -581,8 +650,28 @@ class TensorView {
 
  private:
   DLTensor tensor_;
+  template <typename, typename>
+  friend struct TypeTraits;
 };
 
+/*!
+ * \brief Get the data size of the Tensor.
+ * \param tensor The input Tensor.
+ * \return The data size of the Tensor.
+ */
+inline size_t GetDataSize(const Tensor& tensor) {
+  return GetDataSize(tensor.numel(), tensor.dtype());
+}
+
+/*!
+ * \brief Get the data size of the TensorView.
+ * \param tensor The input TensorView.
+ * \return The data size of the TensorView.
+ */
+inline size_t GetDataSize(const TensorView& tensor) {
+  return GetDataSize(tensor.numel(), tensor.dtype());
+}
+
 // TensorView type, allow implicit casting from DLTensor*
 // NOTE: we deliberately do not support MoveToAny and MoveFromAny since it 
does not retain ownership
 template <>
@@ -594,7 +683,7 @@ struct TypeTraits<TensorView> : public TypeTraitsBase {
     result->type_index = TypeIndex::kTVMFFIDLTensorPtr;
     result->zero_padding = 0;
     TVM_FFI_CLEAR_PTR_PADDING_IN_FFI_ANY(result);
-    result->v_ptr = const_cast<DLTensor*>(src.operator->());
+    result->v_ptr = const_cast<DLTensor*>(&(src.tensor_));
   }
 
   TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
diff --git a/python/tvm_ffi/cpp/load_inline.py 
b/python/tvm_ffi/cpp/load_inline.py
index 260ccc2..4d7087c 100644
--- a/python/tvm_ffi/cpp/load_inline.py
+++ b/python/tvm_ffi/cpp/load_inline.py
@@ -455,14 +455,14 @@ def build_inline(  # noqa: PLR0915, PLR0912
         cpp_source = '''
              void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
                // implementation of a library function
-               TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+               TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
                DLDataType f32_dtype{kDLFloat, 32, 1};
-               TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-               TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-               TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-               TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must 
have the same shape";
-               for (int i = 0; i < x->shape[0]; ++i) {
-                 static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+               TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+               TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+               TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+               TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have 
the same shape";
+               for (int i = 0; i < x.size(0); ++i) {
+                 static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
                }
              }
         '''
@@ -663,14 +663,14 @@ def load_inline(
         cpp_source = '''
              void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
                // implementation of a library function
-               TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+               TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
                DLDataType f32_dtype{kDLFloat, 32, 1};
-               TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-               TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-               TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-               TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must 
have the same shape";
-               for (int i = 0; i < x->shape[0]; ++i) {
-                 static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+               TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+               TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+               TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+               TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have 
the same shape";
+               for (int i = 0; i < x.size(0); ++i) {
+                 static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
                }
              }
         '''
diff --git a/rust/tvm-ffi/scripts/generate_example_lib.py 
b/rust/tvm-ffi/scripts/generate_example_lib.py
index 43e822b..8f07eea 100644
--- a/rust/tvm-ffi/scripts/generate_example_lib.py
+++ b/rust/tvm-ffi/scripts/generate_example_lib.py
@@ -34,14 +34,14 @@ def main() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
diff --git a/src/ffi/extra/structural_equal.cc 
b/src/ffi/extra/structural_equal.cc
index 531a233..b236828 100644
--- a/src/ffi/extra/structural_equal.cc
+++ b/src/ffi/extra/structural_equal.cc
@@ -349,18 +349,19 @@ class StructEqualHandler {
   // NOLINTNEXTLINE(performance-unnecessary-value-param)
   bool CompareTensor(Tensor lhs, Tensor rhs) {
     if (lhs.same_as(rhs)) return true;
-    if (lhs->ndim != rhs->ndim) return false;
-    for (int i = 0; i < lhs->ndim; ++i) {
-      if (lhs->shape[i] != rhs->shape[i]) return false;
+    if (lhs.ndim() != rhs.ndim()) return false;
+    for (int i = 0; i < lhs.ndim(); ++i) {
+      if (lhs.size(i) != rhs.size(i)) return false;
     }
-    if (lhs->dtype != rhs->dtype) return false;
+
+    if (lhs.dtype() != rhs.dtype()) return false;
     if (!skip_tensor_content_) {
-      TVM_FFI_ICHECK_EQ(lhs->device.device_type, kDLCPU) << "can only compare 
CPU tensor";
-      TVM_FFI_ICHECK_EQ(rhs->device.device_type, kDLCPU) << "can only compare 
CPU tensor";
+      TVM_FFI_ICHECK_EQ(lhs.device().device_type, kDLCPU) << "can only compare 
CPU tensor";
+      TVM_FFI_ICHECK_EQ(rhs.device().device_type, kDLCPU) << "can only compare 
CPU tensor";
       TVM_FFI_ICHECK(lhs.IsContiguous()) << "Can only compare contiguous 
tensor";
       TVM_FFI_ICHECK(rhs.IsContiguous()) << "Can only compare contiguous 
tensor";
-      size_t data_size = GetDataSize(*(lhs.operator->()));
-      return std::memcmp(lhs->data, rhs->data, data_size) == 0;
+      size_t data_size = GetDataSize(lhs);
+      return std::memcmp(lhs.data_ptr(), rhs.data_ptr(), data_size) == 0;
     } else {
       return true;
     }
diff --git a/src/ffi/extra/structural_hash.cc b/src/ffi/extra/structural_hash.cc
index 154ef3c..271a0db 100644
--- a/src/ffi/extra/structural_hash.cc
+++ b/src/ffi/extra/structural_hash.cc
@@ -272,21 +272,21 @@ class StructuralHashHandler {
 
   // NOLINTNEXTLINE(performance-unnecessary-value-param)
   uint64_t HashTensor(Tensor tensor) {
-    uint64_t hash_value = details::StableHashCombine(tensor->GetTypeKeyHash(), 
tensor->ndim);
-    for (int i = 0; i < tensor->ndim; ++i) {
-      hash_value = details::StableHashCombine(hash_value, tensor->shape[i]);
+    uint64_t hash_value = details::StableHashCombine(tensor->GetTypeKeyHash(), 
tensor.ndim());
+    for (int i = 0; i < tensor.ndim(); ++i) {
+      hash_value = details::StableHashCombine(hash_value, tensor.size(i));
     }
     TVMFFIAny temp;
     temp.v_uint64 = 0;
-    temp.v_dtype = tensor->dtype;
+    temp.v_dtype = tensor.dtype();
     hash_value = details::StableHashCombine(hash_value, temp.v_int64);
 
     if (!skip_tensor_content_) {
-      TVM_FFI_ICHECK_EQ(tensor->device.device_type, kDLCPU) << "can only hash 
CPU tensor";
+      TVM_FFI_ICHECK_EQ(tensor.device().device_type, kDLCPU) << "can only hash 
CPU tensor";
       TVM_FFI_ICHECK(tensor.IsContiguous()) << "Can only hash contiguous 
tensor";
-      size_t data_size = GetDataSize(*(tensor.operator->()));
+      size_t data_size = GetDataSize(tensor.numel(), tensor.dtype());
       uint64_t data_hash =
-          details::StableHashBytes(static_cast<const char*>(tensor->data), 
data_size);
+          details::StableHashBytes(static_cast<const 
char*>(tensor.data_ptr()), data_size);
       hash_value = details::StableHashCombine(hash_value, data_hash);
     }
     return hash_value;
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index 65c6eeb..1c45e8a 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -63,18 +63,18 @@ TEST(Tensor, Basic) {
   EXPECT_EQ(strides[2], 1);
   EXPECT_EQ(nd.dtype(), DLDataType({kDLFloat, 32, 1}));
   for (int64_t i = 0; i < shape.Product(); ++i) {
-    reinterpret_cast<float*>(nd->data)[i] = static_cast<float>(i);
+    reinterpret_cast<float*>(nd.data_ptr())[i] = static_cast<float>(i);
   }
 
   EXPECT_EQ(nd.numel(), 6);
   EXPECT_EQ(nd.ndim(), 3);
-  EXPECT_EQ(nd.data_ptr(), nd->data);
+  EXPECT_EQ(nd.data_ptr(), nd.GetDLTensorPtr()->data);
 
   Any any0 = nd;
   Tensor nd2 = any0.as<Tensor>().value();  // 
NOLINT(bugprone-unchecked-optional-access)
   EXPECT_EQ(nd2.dtype(), DLDataType({kDLFloat, 32, 1}));
   for (int64_t i = 0; i < shape.Product(); ++i) {
-    EXPECT_EQ(reinterpret_cast<float*>(nd2->data)[i], i);
+    EXPECT_EQ(reinterpret_cast<float*>(nd2.data_ptr())[i], i);
   }
 
   EXPECT_EQ(nd.IsContiguous(), true);
@@ -101,7 +101,7 @@ TEST(Tensor, DLPack) {
   {
     Tensor tensor2 = Tensor::FromDLPack(dlpack);
     EXPECT_EQ(tensor2.use_count(), 1);
-    EXPECT_EQ(tensor2->data, tensor->data);
+    EXPECT_EQ(tensor2.data_ptr(), tensor.data_ptr());
     EXPECT_EQ(tensor.use_count(), 2);
     EXPECT_EQ(tensor2.use_count(), 1);
   }
@@ -129,7 +129,7 @@ TEST(Tensor, DLPackVersioned) {
   {
     Tensor tensor2 = Tensor::FromDLPackVersioned(dlpack);
     EXPECT_EQ(tensor2.use_count(), 1);
-    EXPECT_EQ(tensor2->data, tensor->data);
+    EXPECT_EQ(tensor2.data_ptr(), tensor.data_ptr());
     EXPECT_EQ(tensor.use_count(), 2);
     EXPECT_EQ(tensor2.use_count(), 1);
   }
@@ -142,15 +142,15 @@ TEST(Tensor, DLPackAlloc) {
                                           DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
   EXPECT_EQ(tensor.use_count(), 1);
   EXPECT_EQ(tensor.shape().size(), 3);
-  EXPECT_EQ(tensor.shape()[0], 1);
-  EXPECT_EQ(tensor.shape()[1], 2);
-  EXPECT_EQ(tensor.shape()[2], 3);
+  EXPECT_EQ(tensor.size(0), 1);
+  EXPECT_EQ(tensor.size(1), 2);
+  EXPECT_EQ(tensor.size(2), 3);
   EXPECT_EQ(tensor.dtype().code, kDLFloat);
   EXPECT_EQ(tensor.dtype().bits, 32);
   EXPECT_EQ(tensor.dtype().lanes, 1);
-  EXPECT_EQ(tensor->device.device_type, kDLCPU);
-  EXPECT_EQ(tensor->device.device_id, 0);
-  EXPECT_NE(tensor->data, nullptr);
+  EXPECT_EQ(tensor.device().device_type, kDLCPU);
+  EXPECT_EQ(tensor.device().device_id, 0);
+  EXPECT_NE(tensor.data_ptr(), nullptr);
 }
 
 TEST(Tensor, DLPackAllocError) {
diff --git a/tests/python/test_build_inline.py 
b/tests/python/test_build_inline.py
index 5e3b191..fa0e6c9 100644
--- a/tests/python/test_build_inline.py
+++ b/tests/python/test_build_inline.py
@@ -26,14 +26,14 @@ def test_build_inline_cpp() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
diff --git a/tests/python/test_load_inline.py b/tests/python/test_load_inline.py
index 795e691..299e4c8 100644
--- a/tests/python/test_load_inline.py
+++ b/tests/python/test_load_inline.py
@@ -38,14 +38,14 @@ def test_load_inline_cpp() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
@@ -64,14 +64,14 @@ def test_load_inline_cpp_with_docstrings() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
@@ -91,28 +91,28 @@ def test_load_inline_cpp_multiple_sources() -> None:
             r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
             r"""
             void add_two_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 2;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 2;
               }
             }
         """,
@@ -132,14 +132,14 @@ def test_load_inline_cpp_build_dir() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
         """,
@@ -169,26 +169,26 @@ def test_load_inline_cuda() -> None:
 
             void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y, 
int64_t raw_stream) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
 
-              int64_t n = x->shape[0];
+              int64_t n = x.size(0);
               int64_t nthread_per_block = 256;
               int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
               // Obtain the current stream from the environment
               // it will be set to torch.cuda.current_stream() when calling 
the function
               // with torch.Tensors
               cudaStream_t stream = static_cast<cudaStream_t>(
-                  TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id));
+                  TVMFFIEnvGetStream(x.device().device_type, 
x.device().device_id));
               TVM_FFI_ICHECK_EQ(reinterpret_cast<int64_t>(stream), raw_stream)
                 << "stream must be the same as raw_stream";
               // launch the kernel
-              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
-                                                                     
static_cast<float*>(y->data), n);
+              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x.data_ptr()),
+                                                                     
static_cast<float*>(y.data_ptr()), n);
             }
         """,
         functions=["add_one_cuda"],
@@ -227,16 +227,16 @@ def test_load_inline_with_env_tensor_allocator() -> None:
             ffi::Tensor return_add_one(ffi::Map<ffi::String, 
ffi::Tuple<ffi::Tensor>> kwargs) {
               ffi::Tensor x = kwargs["x"].get<0>();
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
               // allocate a new tensor with the env tensor allocator
               // it will be redirected to torch.empty when calling the function
               ffi::Tensor y = ffi::Tensor::FromDLPackAlloc(
-                TVMFFIEnvGetTensorAllocator(), ffi::Shape({x->shape[0]}), 
f32_dtype, x->device);
-              int64_t n = x->shape[0];
+                TVMFFIEnvGetTensorAllocator(), ffi::Shape({x.size(0)}), 
f32_dtype, x.device());
+              int64_t n = x.size(0);
               for (int i = 0; i < n; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
               return y;
             }
@@ -263,14 +263,14 @@ def test_load_inline_both() -> None:
         cpp_sources=r"""
             void add_one_cpu(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
-              for (int i = 0; i < x->shape[0]; ++i) {
-                static_cast<float*>(y->data)[i] = 
static_cast<float*>(x->data)[i] + 1;
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
+              for (int i = 0; i < x.size(0); ++i) {
+                static_cast<float*>(y.data_ptr())[i] = 
static_cast<float*>(x.data_ptr())[i] + 1;
               }
             }
 
@@ -286,24 +286,24 @@ def test_load_inline_both() -> None:
 
             void add_one_cuda(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
               // implementation of a library function
-              TVM_FFI_ICHECK(x->ndim == 1) << "x must be a 1D tensor";
+              TVM_FFI_ICHECK(x.ndim() == 1) << "x must be a 1D tensor";
               DLDataType f32_dtype{kDLFloat, 32, 1};
-              TVM_FFI_ICHECK(x->dtype == f32_dtype) << "x must be a float 
tensor";
-              TVM_FFI_ICHECK(y->ndim == 1) << "y must be a 1D tensor";
-              TVM_FFI_ICHECK(y->dtype == f32_dtype) << "y must be a float 
tensor";
-              TVM_FFI_ICHECK(x->shape[0] == y->shape[0]) << "x and y must have 
the same shape";
+              TVM_FFI_ICHECK(x.dtype() == f32_dtype) << "x must be a float 
tensor";
+              TVM_FFI_ICHECK(y.ndim() == 1) << "y must be a 1D tensor";
+              TVM_FFI_ICHECK(y.dtype() == f32_dtype) << "y must be a float 
tensor";
+              TVM_FFI_ICHECK(x.size(0) == y.size(0)) << "x and y must have the 
same shape";
 
-              int64_t n = x->shape[0];
+              int64_t n = x.size(0);
               int64_t nthread_per_block = 256;
               int64_t nblock = (n + nthread_per_block - 1) / nthread_per_block;
               // Obtain the current stream from the environment
               // it will be set to torch.cuda.current_stream() when calling 
the function
               // with torch.Tensors
               cudaStream_t stream = static_cast<cudaStream_t>(
-                  TVMFFIEnvGetStream(x->device.device_type, 
x->device.device_id));
+                  TVMFFIEnvGetStream(x.device().device_type, 
x.device().device_id));
               // launch the kernel
-              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x->data),
-                                                                     
static_cast<float*>(y->data), n);
+              AddOneKernel<<<nblock, nthread_per_block, 0, 
stream>>>(static_cast<float*>(x.data_ptr()),
+                                                                     
static_cast<float*>(y.data_ptr()), n);
             }
         """,
         functions=["add_one_cpu", "add_one_cuda"],
@@ -335,7 +335,7 @@ def test_cuda_memory_alloc_noleak() -> None:
 
             ffi::Tensor return_tensor(tvm::ffi::TensorView x) {
                 ffi::Tensor y = ffi::Tensor::FromDLPackAlloc(
-                    TVMFFIEnvGetTensorAllocator(), x.shape(), x.dtype(), 
x->device);
+                    TVMFFIEnvGetTensorAllocator(), x.shape(), x.dtype(), 
x.device());
                 return y;
             }
         """,

Reply via email to