tkonolige commented on a change in pull request #8408:
URL: https://github.com/apache/tvm/pull/8408#discussion_r664877945
##########
File path: tests/python/unittest/test_tvmscript_ops.py
##########
@@ -103,5 +103,46 @@ def test_get_valid_counts_script_func():
_check_get_valid_counts_with_numpy(f, (1, 2500, 6), 0.0, 0, 1)
[email protected]
+def ops_with_buffer_slice_indices(data: ty.handle, index: ty.handle) -> None:
+ data_buf = tir.match_buffer(data, (1, 5, 6), "float16")
+ index_buf = tir.match_buffer(index, (1,), "int32")
+ copy_buf = tir.alloc_buffer((1, 5), "float16")
+
+ with tir.block([1, 5], "init") as [vi, vj]:
+ copy_buf[vi, vj] = data_buf[vi, vj, index_buf[0]]
Review comment:
Can you add the following tests:
- a test that nests buffer accesses three deep. i.e. `data_
buf[index_buf[index_buf[0]]]]`
- a test where the inner buffer dtype is not int32 (should fail).
- a test where the inner buffer is a slice. i.e. `data_buf[index_buf[0,1:2]]`
##########
File path: python/tvm/script/node.py
##########
@@ -133,18 +138,27 @@ def check_index(index: Union[int, PrimExpr]):
def __str__(self):
regions: List[str] = []
for s in self.slices:
- if s.stop is None:
- regions.append(str(s.start))
- else:
- regions.append(str(s.start) + ": " + str(s.stop))
+ if isinstance(s, Slice):
+ if s.stop is None:
+ regions.append(str(s.start))
+ else:
+ regions.append(str(s.start) + ": " + str(s.stop))
+ elif isinstance(s, BufferSlice):
+ regions.append(s.buffer.name + "[" + str(s.start) + "]")
return self.buffer.name + "[" + ", ".join(regions) + "]"
def asobject(self) -> BufferLoad:
"""Convert object."""
+ indices: List[PrimExpr] = []
for s in self.slices:
- if s.stop is not None:
- self.report_error("BufferLoad only accepts elementwise
access", self.span)
-
- indices = [s.start for s in self.slices]
+ if isinstance(s, Slice):
+ if s.stop is not None:
+ self.report_error("BufferLoad only accepts elementwise
access", self.span)
+ indices.append(s.start)
+ elif isinstance(s, BufferSlice):
+ args: List[PrimExpr] = []
+ for idx in s.slices:
+ args.append(idx.start)
Review comment:
Why is the conversion from from BufferSlice to BufferLoad necessary
here? Plus it seems like this should be handled by the recursive case.
##########
File path: python/tvm/script/node.py
##########
@@ -133,18 +138,27 @@ def check_index(index: Union[int, PrimExpr]):
def __str__(self):
regions: List[str] = []
for s in self.slices:
- if s.stop is None:
- regions.append(str(s.start))
- else:
- regions.append(str(s.start) + ": " + str(s.stop))
+ if isinstance(s, Slice):
+ if s.stop is None:
+ regions.append(str(s.start))
+ else:
+ regions.append(str(s.start) + ": " + str(s.stop))
+ elif isinstance(s, BufferSlice):
+ regions.append(s.buffer.name + "[" + str(s.start) + "]")
Review comment:
You want this to recurse, so just call str on the BufferSlice.
##########
File path: python/tvm/script/node.py
##########
@@ -114,6 +114,11 @@ def check_index(index: Union[int, PrimExpr]):
check_index(index.start)
check_index(index.stop)
slices.append(index)
+ elif isinstance(index, BufferSlice):
+ for s in index.slices:
+ if isinstance(s, Slice):
+ check_index(s.start)
+ slices.append(index)
Review comment:
I don't think this is correct. I believe you should should assume that
the BufferSlice already checked that its indices were correct. Instead you want
to check if the buffer dtype is correct (int32).
I am also a little confused as to why this case is not handled by the
PrimExpr case?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]