This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 8888eb4  [TENSOR] Add API to create strided view (#335)
8888eb4 is described below

commit 8888eb4b254486fb1fb5baad7e9f24bc1cfac63a
Author: Tianqi Chen <[email protected]>
AuthorDate: Fri Dec 12 12:32:19 2025 -0500

    [TENSOR] Add API to create strided view (#335)
    
    This PR adds API to create strided view which is needed in certain
    settings.
---
 docs/guides/kernel_library_guide.rst |   8 ++
 include/tvm/ffi/c_api.h              |  14 +++
 include/tvm/ffi/container/tensor.h   | 126 +++++++++++++++++++++++++-
 src/ffi/tensor.cc                    |  28 ++++++
 tests/cpp/test_tensor.cc             | 169 +++++++++++++++++++++++++++++++++++
 5 files changed, 342 insertions(+), 3 deletions(-)

diff --git a/docs/guides/kernel_library_guide.rst 
b/docs/guides/kernel_library_guide.rst
index 82b1bff..01faf20 100644
--- a/docs/guides/kernel_library_guide.rst
+++ b/docs/guides/kernel_library_guide.rst
@@ -94,6 +94,14 @@ However, the tensors allocated by 
:cpp:func:`tvm::ffi::Tensor::FromNDAlloc` only
 
 But in the scenarios of linked runtime libraries and c++ applications, the 
libraries alive globally throughout the entire lifetime of the process. So 
:cpp:func:`tvm::ffi::Tensor::FromNDAlloc` works well in these scenarios without 
the use-after-delete issue above. Otherwise, in general, 
:cpp:func:`tvm::ffi::Tensor::FromEnvAlloc` is free of this issue, which is more 
**recommended** in practice.
 
+
+FromNDAllocStrided
+^^^^^^^^^^^^^^^^^^
+
+:cpp:func:`tvm::ffi::Tensor::FromNDAllocStrided` can be used to create a 
tensor with a custom memory allocator and strided layout (e.g. column major 
layout).
+Note that for tensor memory that will be returned from the kernel library to 
the caller, we instead recommend using 
:cpp:func:`tvm::ffi::Tensor::FromEnvAlloc`
+followed by :cpp:func:`tvm::ffi::Tensor::as_strided` to create a strided view 
of the tensor.
+
 FromDLPack
 ^^^^^^^^^^
 
diff --git a/include/tvm/ffi/c_api.h b/include/tvm/ffi/c_api.h
index 6dde4fc..7548899 100644
--- a/include/tvm/ffi/c_api.h
+++ b/include/tvm/ffi/c_api.h
@@ -674,6 +674,20 @@ TVM_FFI_DLL int 
TVMFFITensorFromDLPackVersioned(DLManagedTensorVersioned* from,
  */
 TVM_FFI_DLL int TVMFFITensorToDLPackVersioned(TVMFFIObjectHandle from,
                                               DLManagedTensorVersioned** out);
+
+/*!
+ * \brief Create a Tensor view from source using metadata in the prototype 
while retaining the
+ * source tensor.
+ * \param source The source tensor whose data memory will be shared by the 
view.
+ * \param prototype The prototype DLTensor that contains the metadata for the 
view.
+ * \param out The output Tensor handle.
+ * \return 0 on success, nonzero on failure.
+ * \note This function is unsafe and the caller must ensure the prototype is 
valid and that
+ *       the prototype's data pointer points to memory owned by the source 
tensor. The callee
+ *       allocates shape and strides arrays in the output tensor and copies 
them from the prototype.
+ */
+TVM_FFI_DLL int TVMFFITensorCreateUnsafeView(TVMFFIObjectHandle source, const 
DLTensor* prototype,
+                                             TVMFFIObjectHandle* out);
 //---------------------------------------------------------------
 // Section: string/bytes support APIs.
 // These APIs are used to simplify the string/bytes construction
diff --git a/include/tvm/ffi/container/tensor.h 
b/include/tvm/ffi/container/tensor.h
index 2155924..6d7e637 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -32,6 +32,7 @@
 
 #include <atomic>
 #include <memory>
+#include <optional>
 #include <string>
 #include <utility>
 
@@ -202,6 +203,19 @@ class TensorObjFromNDAlloc : public TensorObj {
     alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
   }
 
+  template <typename... ExtraArgs>
+  TensorObjFromNDAlloc(TNDAlloc alloc, const DLTensor& prototype, 
ExtraArgs&&... extra_args)
+      : alloc_(alloc) {
+    *static_cast<DLTensor*>(this) = prototype;
+    this->shape = reinterpret_cast<int64_t*>(reinterpret_cast<char*>(this) + 
sizeof(Self));
+    this->strides = this->shape + prototype.ndim;
+    TVM_FFI_ICHECK_NOTNULL(prototype.strides);
+    std::copy(prototype.shape, prototype.shape + prototype.ndim, this->shape);
+    std::copy(prototype.strides, prototype.strides + prototype.ndim, 
this->strides);
+    // call allocator to alloc data
+    alloc_.AllocData(static_cast<DLTensor*>(this), 
std::forward<ExtraArgs>(extra_args)...);
+  }
+
   ~TensorObjFromNDAlloc() { alloc_.FreeData(static_cast<DLTensor*>(this)); }
 
  private:
@@ -348,6 +362,40 @@ class Tensor : public ObjectRef {
    * \return True if the Tensor data is aligned to the given alignment, false 
otherwise.
    */
   bool IsAligned(size_t alignment) const { return tvm::ffi::IsAligned(*get(), 
alignment); }
+
+  /*!
+   * \brief Create a new Tensor as a strided view of the current Tensor.
+   * \param shape The shape of the new Tensor.
+   * \param strides The strides of the new Tensor.
+   * \param element_offset The element offset of the new Tensor in the unit of 
dtype elements.
+   * \return The new Tensor.
+   * \note element_offset is in the unit of dtype elements not bytes.
+   */
+  Tensor as_strided(ShapeView shape, ShapeView strides,
+                    std::optional<int64_t> element_offset = std::nullopt) 
const {
+    DLTensor prototype;
+    prototype = *static_cast<const DLTensor*>(get());
+    prototype.shape = const_cast<int64_t*>(shape.data());
+    prototype.ndim = static_cast<int>(shape.size());
+    prototype.strides = const_cast<int64_t*>(strides.data());
+    int64_t elem_offset_as_i64 = element_offset.value_or(0);
+
+    TVM_FFI_ICHECK_GE(elem_offset_as_i64, 0);
+    prototype.byte_offset += 
GetDataSize(static_cast<size_t>(elem_offset_as_i64), prototype.dtype);
+
+    if (prototype.byte_offset != 0 && IsDirectAddressDevice(prototype.device)) 
{
+      // If the device supports direct address, we can just add the byte 
offset to the data pointer.
+      prototype.data =
+          reinterpret_cast<void*>(reinterpret_cast<char*>(prototype.data) + 
prototype.byte_offset);
+      prototype.byte_offset = 0;
+    }
+
+    TVMFFIObjectHandle out;
+    Object* obj_handle = const_cast<TensorObj*>(get());
+    TVM_FFI_CHECK_SAFE_CALL(TVMFFITensorCreateUnsafeView(obj_handle, 
&prototype, &out));
+    return Tensor(
+        
details::ObjectUnsafe::ObjectPtrFromOwned<TensorObj>(static_cast<TVMFFIObject*>(out)));
+  }
   /*!
    * \brief Create a Tensor from a NDAllocator.
    *
@@ -417,6 +465,41 @@ class Tensor : public ObjectRef {
         num_extra_i64_at_tail, alloc, shape, dtype, device,
         std::forward<ExtraArgs>(extra_args)...));
   }
+
+  /*!
+   * \brief A variant of FromNDAlloc that allows explicit passing a strides.
+   *
+   * \note This function needs to ensure that strides are well-defined
+   *       with respect to the allocated compact shape.
+   *
+   * \param alloc The NDAllocator.
+   * \param shape The shape of the Tensor.
+   * \param strides The strides of the Tensor.
+   * \param dtype The data type of the Tensor.
+   * \param device The device of the Tensor.
+   * \param extra_args Extra arguments to be forwarded to TNDAlloc.
+   * \return The created Tensor.
+   * \tparam TNDAlloc The type of the NDAllocator, impelments Alloc and Free.
+   * \tparam ExtraArgs Extra arguments to be passed to Alloc.
+   */
+  template <typename TNDAlloc, typename... ExtraArgs>
+  static Tensor FromNDAllocStrided(TNDAlloc alloc, ffi::ShapeView shape, 
ffi::ShapeView strides,
+                                   DLDataType dtype, DLDevice device, 
ExtraArgs&&... extra_args) {
+    TVM_FFI_CHECK(shape.size() == strides.size(), ValueError)
+        << "shape and strides must have the same size.";
+    // inplace alloc shape and strides after data structure (as a result why 
multiply 2)
+    size_t num_extra_i64_at_tail = shape.size() * 2;
+    DLTensor prototype;
+    prototype.data = nullptr;
+    prototype.device = device;
+    prototype.dtype = dtype;
+    prototype.shape = const_cast<int64_t*>(shape.data());
+    prototype.ndim = static_cast<int>(shape.size());
+    prototype.strides = const_cast<int64_t*>(strides.data());
+    prototype.byte_offset = 0;
+    return 
Tensor(make_inplace_array_object<details::TensorObjFromNDAlloc<TNDAlloc>, 
int64_t>(
+        num_extra_i64_at_tail, alloc, prototype, 
std::forward<ExtraArgs>(extra_args)...));
+  }
   /*!
    * \brief Create a Tensor from the TVMFFIEnvTensorAlloc API
    *
@@ -704,17 +787,54 @@ class TensorView {
    * \brief This functions redirects to ndim().
    * \return The number of dimensions in the Tensor.
    */
-  inline int32_t dim() { return ndim(); }
+  int32_t dim() const { return ndim(); }
   /*!
    * \brief This functions redirects to shape().
    * \return The shape of the Tensor.
    */
-  inline ShapeView sizes() const { return shape(); }
+  ShapeView sizes() const { return shape(); }
   /*!
    * \brief This functions redirects to IsContiguous().
    * \return True if the Tensor is contiguous, false otherwise.
    */
-  inline bool is_contiguous() const { return IsContiguous(); }
+  bool is_contiguous() const { return IsContiguous(); }
+  /*!
+   * \brief Create a new TensorView as a strided view of the current 
TensorView.
+   *
+   * Use this function with extreme caution. The user must ensure that the 
shape and strides
+   * arrays, as well as the data pointer, remain valid for the lifetime of the 
returned TensorView.
+   *
+   * One common anti-pattern is to create temporary shape/strides arrays and 
pass them in, but
+   * then deallocate the temporary arrays immediately after the call to 
as_strided, which
+   * causes the returned TensorView to point to invalid memory.
+   *
+   * \param shape The shape of the new TensorView.
+   * \param strides The strides of the new TensorView.
+   * \param element_offset The element offset of the new TensorView in units 
of dtype elements, not
+   * bytes.
+   * \return The new TensorView.
+   *
+   * \note The caller must ensure that the shape and strides arrays remain 
valid for the lifetime
+   *       of the returned TensorView.
+   */
+  TensorView as_strided(ShapeView shape, ShapeView strides,
+                        std::optional<int64_t> element_offset = std::nullopt) 
const {
+    DLTensor prototype = tensor_;
+    prototype.shape = const_cast<int64_t*>(shape.data());
+    prototype.ndim = static_cast<int>(shape.size());
+    prototype.strides = const_cast<int64_t*>(strides.data());
+    TVM_FFI_ICHECK_EQ(shape.size(), strides.size());
+    int64_t elem_offset_as_i64 = element_offset.value_or(0);
+    TVM_FFI_ICHECK_GE(elem_offset_as_i64, 0);
+    prototype.byte_offset += 
GetDataSize(static_cast<size_t>(elem_offset_as_i64), prototype.dtype);
+    if (prototype.byte_offset != 0 && IsDirectAddressDevice(prototype.device)) 
{
+      // If the device supports direct address, we can just add the byte 
offset to the data pointer.
+      prototype.data =
+          reinterpret_cast<void*>(reinterpret_cast<char*>(prototype.data) + 
prototype.byte_offset);
+      prototype.byte_offset = 0;
+    }
+    return TensorView(&prototype);
+  }
 
  private:
   DLTensor tensor_;
diff --git a/src/ffi/tensor.cc b/src/ffi/tensor.cc
index d408280..2033690 100644
--- a/src/ffi/tensor.cc
+++ b/src/ffi/tensor.cc
@@ -47,6 +47,34 @@ TVM_FFI_STATIC_INIT_BLOCK() {
 }  // namespace ffi
 }  // namespace tvm
 
+int TVMFFITensorCreateUnsafeView(TVMFFIObjectHandle source, const DLTensor* 
prototype,
+                                 TVMFFIObjectHandle* out) {
+  TVM_FFI_SAFE_CALL_BEGIN();
+  tvm::ffi::ObjectPtr<tvm::ffi::TensorObj> source_tensor =
+      
tvm::ffi::details::ObjectUnsafe::ObjectPtrFromUnowned<tvm::ffi::TensorObj>(
+          static_cast<tvm::ffi::Object*>(source));
+
+  class ViewNDAlloc {
+   public:
+    ViewNDAlloc(tvm::ffi::ObjectPtr<tvm::ffi::TensorObj> tensor) : 
tensor_(tensor) {}
+    void AllocData(DLTensor* tensor, void* data_ptr) { tensor->data = 
data_ptr; }
+    void FreeData(DLTensor* tensor) {}
+
+   private:
+    tvm::ffi::ObjectPtr<tvm::ffi::TensorObj> tensor_;
+  };
+
+  void* source_data_ptr = prototype->data;
+  size_t num_extra_i64_at_tail = prototype->ndim * 2;
+  ViewNDAlloc alloc(source_tensor);
+  tvm::ffi::Tensor ret_tensor(
+      
tvm::ffi::make_inplace_array_object<tvm::ffi::details::TensorObjFromNDAlloc<ViewNDAlloc>,
+                                          int64_t>(num_extra_i64_at_tail, 
alloc, *prototype,
+                                                   source_data_ptr));
+  *out = 
tvm::ffi::details::ObjectUnsafe::MoveObjectRefToTVMFFIObjectPtr(std::move(ret_tensor));
+  TVM_FFI_SAFE_CALL_END();
+}
+
 int TVMFFITensorFromDLPack(DLManagedTensor* from, int32_t min_alignment, 
int32_t require_contiguous,
                            TVMFFIObjectHandle* out) {
   TVM_FFI_SAFE_CALL_BEGIN();
diff --git a/tests/cpp/test_tensor.cc b/tests/cpp/test_tensor.cc
index 052a8ec..ac3dbdb 100644
--- a/tests/cpp/test_tensor.cc
+++ b/tests/cpp/test_tensor.cc
@@ -32,6 +32,11 @@ inline Tensor Empty(const Shape& shape, DLDataType dtype, 
DLDevice device) {
   return Tensor::FromNDAlloc(CPUNDAlloc(), shape, dtype, device);
 }
 
+inline Tensor EmptyStrided(const Shape& shape, const Shape& strides, 
DLDataType dtype,
+                           DLDevice device) {
+  return Tensor::FromNDAllocStrided(CPUNDAlloc(), shape, strides, dtype, 
device);
+}
+
 int TestEnvTensorAllocator(DLTensor* prototype, TVMFFIObjectHandle* out) {
   Shape shape(prototype->shape, prototype->shape + prototype->ndim);
   Tensor nd = Empty(shape, prototype->dtype, prototype->device);
@@ -74,6 +79,16 @@ TEST(Tensor, Basic) {
 
   EXPECT_EQ(nd.IsContiguous(), true);
   EXPECT_EQ(nd2.use_count(), 3);
+
+  Tensor nd3 = EmptyStrided({2, 3}, {1, 2}, DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
+  Shape shape3 = nd3.shape();
+  Shape strides3 = nd3.strides();
+  EXPECT_EQ(shape3.size(), 2);
+  EXPECT_EQ(shape3[0], 2);
+  EXPECT_EQ(shape3[1], 3);
+  EXPECT_EQ(strides3.size(), 2);
+  EXPECT_EQ(strides3[0], 1);
+  EXPECT_EQ(strides3[1], 2);
 }
 
 TEST(Tensor, DLPack) {
@@ -192,4 +207,158 @@ TEST(Tensor, TensorView) {
   EXPECT_EQ(tensor_view2.dtype().lanes, 1);
 }
 
+TEST(Tensor, TensorViewAsStrided) {
+  // Create a base tensor with shape [2, 3] = 6 elements
+  Tensor tensor = Empty({2, 3}, DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
+
+  // Fill with sequential values: [0, 1, 2, 3, 4, 5]
+  float* data = reinterpret_cast<float*>(tensor.data_ptr());
+  for (int64_t i = 0; i < tensor.numel(); ++i) {
+    data[i] = static_cast<float>(i);
+  }
+
+  TensorView tensor_view = tensor;
+  void* original_data_ptr = tensor_view.data_ptr();
+  EXPECT_EQ(tensor_view.byte_offset(), 0);
+
+  // Create a strided view with shape [3, 2] and custom strides
+  // Use local variables to ensure they stay in scope for the TensorView
+  Shape new_shape = {3, 2};
+  Shape new_strides = {1, 3};
+  TensorView strided_view = tensor_view.as_strided(new_shape, new_strides);
+
+  // Verify the view has correct shape and strides
+  EXPECT_EQ(strided_view.shape().size(), 2);
+  EXPECT_EQ(strided_view.shape()[0], 3);
+  EXPECT_EQ(strided_view.shape()[1], 2);
+  EXPECT_EQ(strided_view.strides().size(), 2);
+  EXPECT_EQ(strided_view.strides()[0], 1);
+  EXPECT_EQ(strided_view.strides()[1], 3);
+
+  // Verify the view shares the same underlying data pointer (no offset)
+  EXPECT_EQ(strided_view.data_ptr(), original_data_ptr);
+  EXPECT_EQ(strided_view.byte_offset(), 0);
+  EXPECT_EQ(strided_view.dtype(), tensor_view.dtype());
+
+  // Test with element_offset - for float32, 1 element = 4 bytes
+  Shape offset_shape = {2, 2};
+  Shape offset_strides = {3, 1};
+  int64_t element_offset = 1;
+  TensorView offset_view = tensor_view.as_strided(offset_shape, 
offset_strides, element_offset);
+
+  EXPECT_EQ(offset_view.shape().size(), 2);
+  EXPECT_EQ(offset_view.shape()[0], 2);
+  EXPECT_EQ(offset_view.shape()[1], 2);
+  EXPECT_EQ(offset_view.strides().size(), 2);
+  EXPECT_EQ(offset_view.strides()[0], 3);
+  EXPECT_EQ(offset_view.strides()[1], 1);
+
+  // For CPU (direct address device), byte_offset should be added to data 
pointer
+  // and byte_offset field should be 0
+  // element_offset=1 for float32 = 4 bytes
+  size_t expected_byte_offset =
+      GetDataSize(static_cast<size_t>(element_offset), DLDataType({kDLFloat, 
32, 1}));
+  EXPECT_EQ(expected_byte_offset, 4);  // 1 element * 32 bits / 8 = 4 bytes
+
+  // The data pointer should be advanced by 4 bytes (1 float element)
+  void* expected_offset_ptr = reinterpret_cast<char*>(original_data_ptr) + 
expected_byte_offset;
+  EXPECT_EQ(offset_view.data_ptr(), expected_offset_ptr);
+  EXPECT_EQ(offset_view.byte_offset(), 0);  // Should be 0 for direct address 
devices
+
+  // Verify data access through the offset view
+  float* offset_data = reinterpret_cast<float*>(offset_view.data_ptr());
+  EXPECT_EQ(offset_data[0 * 3 + 0 * 1], 1.0f);  // Points to data[1]
+  EXPECT_EQ(offset_data[1 * 3 + 0 * 1], 4.0f);  // Points to data[4]
+
+  // Test with larger element_offset
+  int64_t element_offset2 = 2;
+  Shape offset_shape2 = {1, 2};
+  Shape offset_strides2 = {3, 1};
+  TensorView offset_view2 = tensor_view.as_strided(offset_shape2, 
offset_strides2, element_offset2);
+  size_t expected_byte_offset2 =
+      GetDataSize(static_cast<size_t>(element_offset2), DLDataType({kDLFloat, 
32, 1}));
+  EXPECT_EQ(expected_byte_offset2, 8);  // 2 elements * 32 bits / 8 = 8 bytes
+  void* expected_offset_ptr2 = reinterpret_cast<char*>(original_data_ptr) + 
expected_byte_offset2;
+  EXPECT_EQ(offset_view2.data_ptr(), expected_offset_ptr2);
+  EXPECT_EQ(offset_view2.byte_offset(), 0);
+
+  float* offset_data2 = reinterpret_cast<float*>(offset_view2.data_ptr());
+  EXPECT_EQ(offset_data2[0 * 3 + 0 * 1], 2.0f);  // Points to data[2]
+}
+
+TEST(Tensor, AsStrided) {
+  // Create a base tensor with shape [2, 3] = 6 elements
+  Tensor tensor = Empty({2, 3}, DLDataType({kDLFloat, 32, 1}), 
DLDevice({kDLCPU, 0}));
+
+  // Fill with sequential values: [0, 1, 2, 3, 4, 5]
+  float* data = reinterpret_cast<float*>(tensor.data_ptr());
+  for (int64_t i = 0; i < tensor.numel(); ++i) {
+    data[i] = static_cast<float>(i);
+  }
+
+  void* original_data_ptr = tensor.data_ptr();
+  EXPECT_EQ(tensor.byte_offset(), 0);
+
+  // Create a strided view with shape [3, 2] and custom strides
+  Shape new_shape = {3, 2};
+  Shape new_strides = {1, 3};
+  Tensor strided_view = tensor.as_strided(new_shape, new_strides);
+
+  // Verify the view has correct shape and strides
+  EXPECT_EQ(strided_view.shape().size(), 2);
+  EXPECT_EQ(strided_view.shape()[0], 3);
+  EXPECT_EQ(strided_view.shape()[1], 2);
+  EXPECT_EQ(strided_view.strides().size(), 2);
+  EXPECT_EQ(strided_view.strides()[0], 1);
+  EXPECT_EQ(strided_view.strides()[1], 3);
+
+  // Verify the view shares the same underlying data pointer (no offset)
+  EXPECT_EQ(strided_view.data_ptr(), original_data_ptr);
+  EXPECT_EQ(strided_view.byte_offset(), 0);
+  EXPECT_EQ(strided_view.dtype(), tensor.dtype());
+
+  // Test with element_offset - for float32, 1 element = 4 bytes
+  Shape offset_shape = {2, 2};
+  Shape offset_strides = {3, 1};
+  int64_t element_offset = 1;
+  Tensor offset_view = tensor.as_strided(offset_shape, offset_strides, 
element_offset);
+
+  EXPECT_EQ(offset_view.shape().size(), 2);
+  EXPECT_EQ(offset_view.shape()[0], 2);
+  EXPECT_EQ(offset_view.shape()[1], 2);
+  EXPECT_EQ(offset_view.strides().size(), 2);
+  EXPECT_EQ(offset_view.strides()[0], 3);
+  EXPECT_EQ(offset_view.strides()[1], 1);
+
+  // For CPU (direct address device), byte_offset should be added to data 
pointer
+  // and byte_offset field should be 0
+  // element_offset=1 for float32 = 4 bytes
+  size_t expected_byte_offset =
+      GetDataSize(static_cast<size_t>(element_offset), DLDataType({kDLFloat, 
32, 1}));
+  EXPECT_EQ(expected_byte_offset, 4);  // 1 element * 32 bits / 8 = 4 bytes
+
+  // The data pointer should be advanced by 4 bytes (1 float element)
+  void* expected_offset_ptr = reinterpret_cast<char*>(original_data_ptr) + 
expected_byte_offset;
+  EXPECT_EQ(offset_view.data_ptr(), expected_offset_ptr);
+  EXPECT_EQ(offset_view.byte_offset(), 0);  // Should be 0 for direct address 
devices
+
+  // Verify data access through the offset view
+  float* offset_data = reinterpret_cast<float*>(offset_view.data_ptr());
+  EXPECT_EQ(offset_data[0 * 3 + 0 * 1], 1.0f);  // Points to data[1]
+  EXPECT_EQ(offset_data[1 * 3 + 0 * 1], 4.0f);  // Points to data[4]
+
+  // Test with larger element_offset
+  int64_t element_offset2 = 2;
+  Tensor offset_view2 = tensor.as_strided({1, 2}, {3, 1}, element_offset2);
+  size_t expected_byte_offset2 =
+      GetDataSize(static_cast<size_t>(element_offset2), DLDataType({kDLFloat, 
32, 1}));
+  EXPECT_EQ(expected_byte_offset2, 8);  // 2 elements * 32 bits / 8 = 8 bytes
+  void* expected_offset_ptr2 = reinterpret_cast<char*>(original_data_ptr) + 
expected_byte_offset2;
+  EXPECT_EQ(offset_view2.data_ptr(), expected_offset_ptr2);
+  EXPECT_EQ(offset_view2.byte_offset(), 0);
+
+  float* offset_data2 = reinterpret_cast<float*>(offset_view2.data_ptr());
+  EXPECT_EQ(offset_data2[0 * 3 + 0 * 1], 2.0f);  // Points to data[2]
+}
+
 }  // namespace

Reply via email to