This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 45c654c573 [NDArray] Allow creating a view from a strided array 
(#15132)
45c654c573 is described below

commit 45c654c5733d5ae4201e53fabc6aa29ebe63572f
Author: masahi <[email protected]>
AuthorDate: Thu Jun 22 15:20:31 2023 +0900

    [NDArray] Allow creating a view from a strided array (#15132)
    
    * Allow creating a view from a stride array
    
    * use IsContiguous
---
 src/runtime/ndarray.cc                       | 25 ++++++++++++++++++++++++-
 tests/python/unittest/test_runtime_dlpack.py | 13 +++++++++++++
 2 files changed, 37 insertions(+), 1 deletion(-)

diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index c7bfefa9a8..b7153ab50f 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -181,7 +181,30 @@ struct NDArray::Internal {
 
 NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) {
   ICHECK(data_ != nullptr);
-  ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view 
for compact tensor";
+
+  const DLTensor& orig = get_mutable()->dl_tensor;
+  ICHECK(IsContiguous()) << "Can only create view for compact tensor, but 
found strides " <<
+      [&orig]() {
+        std::stringstream ss;
+        ss << "[";
+        for (int i = 0; i < orig.ndim; i++) {
+          if (i) ss << ", ";
+          ss << orig.strides[i];
+        }
+        ss << "]";
+        return ss.str();
+      }() << ", for shape "
+                         << [&]() {
+                              std::stringstream ss;
+                              ss << "[";
+                              for (int i = 0; i < orig.ndim; i++) {
+                                if (i) ss << ", ";
+                                ss << orig.shape[i];
+                              }
+                              ss << "]";
+                              return ss.str();
+                            }();
+
   NDArray ret = Internal::Create(shape, dtype, 
get_mutable()->dl_tensor.device);
   ret.get_mutable()->dl_tensor.byte_offset = 
this->get_mutable()->dl_tensor.byte_offset;
   size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
diff --git a/tests/python/unittest/test_runtime_dlpack.py 
b/tests/python/unittest/test_runtime_dlpack.py
index 3f13e2e5fe..cf12c89cdd 100644
--- a/tests/python/unittest/test_runtime_dlpack.py
+++ b/tests/python/unittest/test_runtime_dlpack.py
@@ -48,5 +48,18 @@ def test_from_dlpack_shape_one():
     tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())
 
 
[email protected]_package("torch")
+def test_from_dlpack_strided():
+    import torch
+    from torch.utils.dlpack import to_dlpack
+
+    rows = 1
+    inp = torch.randn(rows, 16)
+    a = tvm.runtime.ndarray.from_dlpack(to_dlpack(inp))
+    view = a._create_view((2, 8))
+
+    np.testing.assert_equal(inp.numpy().reshape(2, 8), view.numpy())
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to