This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 4fefeb0 [TENSOR] Allow strides to be null for zerodim case (#82)
4fefeb0 is described below
commit 4fefeb0f5913fc41cf860f517b9320f1bf1d0e98
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed Oct 1 17:38:26 2025 -0400
[TENSOR] Allow strides to be null for zerodim case (#82)
This PR fixes an outlier case where the strides maybe null when the
tensor is zero ndim.
---
include/tvm/ffi/container/tensor.h | 10 +++++-----
1 file changed, 5 insertions(+), 5 deletions(-)
diff --git a/include/tvm/ffi/container/tensor.h
b/include/tvm/ffi/container/tensor.h
index 5197d4e..f451aaf 100644
--- a/include/tvm/ffi/container/tensor.h
+++ b/include/tvm/ffi/container/tensor.h
@@ -257,7 +257,7 @@ class Tensor : public ObjectRef {
*/
ShapeView strides() const {
const TensorObj* obj = get();
- TVM_FFI_ICHECK(obj->strides != nullptr);
+ TVM_FFI_ICHECK(obj->strides != nullptr || obj->ndim == 0);
return ShapeView(obj->strides, obj->ndim);
}
@@ -367,7 +367,7 @@ class Tensor : public ObjectRef {
throw ffi::Error(error_context.kind, error_context.message,
TVMFFIBacktrace(__FILE__, __LINE__, __func__, 0));
}
- if (tensor->dl_tensor.strides != nullptr) {
+ if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
@@ -394,7 +394,7 @@ class Tensor : public ObjectRef {
if (require_contiguous && !ffi::IsContiguous(tensor->dl_tensor)) {
TVM_FFI_THROW(RuntimeError) << "FromDLPack: Tensor is not contiguous.";
}
- if (tensor->dl_tensor.strides != nullptr) {
+ if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensor>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
@@ -423,7 +423,7 @@ class Tensor : public ObjectRef {
if (tensor->flags & DLPACK_FLAG_BITMASK_IS_SUBBYTE_TYPE_PADDED) {
TVM_FFI_THROW(RuntimeError) << "Subbyte type padded is not yet
supported";
}
- if (tensor->dl_tensor.strides != nullptr) {
+ if (tensor->dl_tensor.strides != nullptr || tensor->dl_tensor.ndim == 0) {
return
Tensor(make_object<details::TensorObjFromDLPack<DLManagedTensorVersioned>>(
tensor, /*extra_strides_at_tail=*/false));
} else {
@@ -545,7 +545,7 @@ class TensorView {
* \return The strides of the Tensor.
*/
ShapeView strides() const {
- TVM_FFI_ICHECK(tensor_.strides != nullptr);
+ TVM_FFI_ICHECK(tensor_.strides != nullptr || tensor_.ndim == 0);
return ShapeView(tensor_.strides, tensor_.ndim);
}