This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 45c654c573 [NDArray] Allow creating a view from a strided array
(#15132)
45c654c573 is described below
commit 45c654c5733d5ae4201e53fabc6aa29ebe63572f
Author: masahi <[email protected]>
AuthorDate: Thu Jun 22 15:20:31 2023 +0900
[NDArray] Allow creating a view from a strided array (#15132)
* Allow creating a view from a stride array
* use IsContiguous
---
src/runtime/ndarray.cc | 25 ++++++++++++++++++++++++-
tests/python/unittest/test_runtime_dlpack.py | 13 +++++++++++++
2 files changed, 37 insertions(+), 1 deletion(-)
diff --git a/src/runtime/ndarray.cc b/src/runtime/ndarray.cc
index c7bfefa9a8..b7153ab50f 100644
--- a/src/runtime/ndarray.cc
+++ b/src/runtime/ndarray.cc
@@ -181,7 +181,30 @@ struct NDArray::Internal {
NDArray NDArray::CreateView(ShapeTuple shape, DLDataType dtype) {
ICHECK(data_ != nullptr);
- ICHECK(get_mutable()->dl_tensor.strides == nullptr) << "Can only create view
for compact tensor";
+
+ const DLTensor& orig = get_mutable()->dl_tensor;
+ ICHECK(IsContiguous()) << "Can only create view for compact tensor, but
found strides " <<
+ [&orig]() {
+ std::stringstream ss;
+ ss << "[";
+ for (int i = 0; i < orig.ndim; i++) {
+ if (i) ss << ", ";
+ ss << orig.strides[i];
+ }
+ ss << "]";
+ return ss.str();
+ }() << ", for shape "
+ << [&]() {
+ std::stringstream ss;
+ ss << "[";
+ for (int i = 0; i < orig.ndim; i++) {
+ if (i) ss << ", ";
+ ss << orig.shape[i];
+ }
+ ss << "]";
+ return ss.str();
+ }();
+
NDArray ret = Internal::Create(shape, dtype,
get_mutable()->dl_tensor.device);
ret.get_mutable()->dl_tensor.byte_offset =
this->get_mutable()->dl_tensor.byte_offset;
size_t curr_size = GetDataSize(this->get_mutable()->dl_tensor);
diff --git a/tests/python/unittest/test_runtime_dlpack.py
b/tests/python/unittest/test_runtime_dlpack.py
index 3f13e2e5fe..cf12c89cdd 100644
--- a/tests/python/unittest/test_runtime_dlpack.py
+++ b/tests/python/unittest/test_runtime_dlpack.py
@@ -48,5 +48,18 @@ def test_from_dlpack_shape_one():
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())
[email protected]_package("torch")
+def test_from_dlpack_strided():
+ import torch
+ from torch.utils.dlpack import to_dlpack
+
+ rows = 1
+ inp = torch.randn(rows, 16)
+ a = tvm.runtime.ndarray.from_dlpack(to_dlpack(inp))
+ view = a._create_view((2, 8))
+
+ np.testing.assert_equal(inp.numpy().reshape(2, 8), view.numpy())
+
+
if __name__ == "__main__":
tvm.testing.main()