Lunderberg commented on code in PR #12652:
URL: https://github.com/apache/tvm/pull/12652#discussion_r959701621


##########
tests/python/unittest/test_tir_transform_flatten_buffer.py:
##########
@@ -20,211 +20,219 @@
 from tvm.script import tir as T
 
 
-def _check(original, transformed):
-    func = original
-    mod = tvm.IRModule.from_expr(func)
-    mod = tvm.tir.transform.FlattenBuffer()(mod)
-    mod = tvm.tir.transform.Simplify()(mod)
-    tvm.ir.assert_structural_equal(mod["main"], transformed, True)
-
-
[email protected]_func
-def elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i in T.serial(0, 16):
-        B_new = T.allocate([1, 16], "float32", "global")
-        for j in T.serial(0, 16):
-            B_new[0, j] = A[i, j] + 1.0
-        for j in T.serial(0, 16):
-            C[i, j] = B_new[0, j] * 2.0
-
-
[email protected]_func
-def flattened_elementwise_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
-    for i in T.serial(0, 16):
-        B_new = T.allocate([16], "float32", "global")
-        for j in T.serial(0, 16):
-            B_new[j] = A[((i * 16) + j)] + 1.0
-        for j in T.serial(0, 16):
-            C[((i * 16) + j)] = B_new[j] * 2.0
-
-
[email protected]_func
-def gpu_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-
-    i0 = T.env_thread("blockIdx.x")
-    i1 = T.env_thread("threadIdx.x")
-    i2 = T.env_thread("vthread")
-
-    T.launch_thread(i0, 4)
-    T.launch_thread(i1, 2)
-    T.launch_thread(i2, 2)
-    B = T.allocate([1, 16], "float32", "local")
-    for j in range(0, 16):
-        B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
-    for j in range(0, 16):
-        C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
-
-
[email protected]_func
-def flattened_gpu_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, 256, "float32")
-    C = T.match_buffer(c, 256, "float32")
-    T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
-    T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
-
-    i0 = T.env_thread("blockIdx.x")
-    i1 = T.env_thread("threadIdx.x")
-    i2 = T.env_thread("vthread")
-
-    T.launch_thread(i0, 4)
-    T.launch_thread(i1, 2)
-    T.launch_thread(i2, 2)
-    B = T.allocate([16], "float32", "local")
-    for j in range(0, 16):
-        B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
-    for j in range(0, 16):
-        C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
-
-
[email protected]_func
-def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
-    A = T.match_buffer(a, (n, m), "float32")
-    C = T.match_buffer(c, (n, m), "float32")
-
-    for i in range(0, n):
-        B = T.allocate([m], "float32", "global")
-        for j in range(0, m):
-            B[j] = A[i, j] + 1.0
-        for j in range(0, m):
-            C[i, j] = B[j] * 2.0
-
-
[email protected]_func
-def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) 
-> None:
-    A = T.match_buffer(a, n * m, "float32")
-    C = T.match_buffer(c, n * m, "float32")
-    T.preflattened_buffer(A, (n, m), "float32", data=A.data)
-    T.preflattened_buffer(C, (n, m), "float32", data=C.data)
-
-    for i in range(0, n):
-        B = T.allocate([m], "float32", "global")
-        for j in range(0, m):
-            B[j] = A[i * m + j] + 1.0
-        for j in range(0, m):
-            C[i * m + j] = B[j] * 2.0
-
-
[email protected]_func
-def multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, (4, 32), "float32")
-    D = T.match_buffer(d, (4, 32), "float32")
-
-    for i, j in T.grid(4, 32):
-        B = T.allocate((4, 32), "float32", scope="global")
-        C = T.allocate((4, 32), "float32", scope="global")
-        B[i, j] = A[i, j] + 1.0
-        C[i, j] = A[i, j] + B[i, j]
-        D[i, j] = C[i, j] * 2.0
-
-
[email protected]_func
-def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None:
-    A = T.match_buffer(a, 128, "float32")
-    D = T.match_buffer(d, 128, "float32")
-    T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
-    T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
-
-    for i, j in T.grid(4, 32):
-        B = T.allocate([128], "float32", "global")
-        C = T.allocate([128], "float32", "global")
-        B[i * 32 + j] = A[i * 32 + j] + 1.0
-        C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
-        D[i * 32 + j] = C[i * 32 + j] * 2.0
-
-
[email protected]_func
-def strided_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (16, 16), "float32")
-    C = T.match_buffer(c, (16, 16), "float32")
-    for i0 in T.serial(4):
-        B = T.allocate([4, 17], "float32", "global")
-        B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, 
strides=[17, 1])
-        for i1, j in T.grid(4, 16):
-            B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
-        for i1, j in T.grid(4, 16):
-            C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
-
-
[email protected]_func
-def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None:
-    A = T.match_buffer(a, (256,), "float32")
-    C = T.match_buffer(c, (256,), "float32")
-    T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
-    T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
-    for i0 in T.serial(0, 4):
-        B_new = T.allocate([68], "float32", "global")
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
-        for i1 in T.serial(0, 4):
-            for j in T.serial(0, 16):
-                C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
-
-
[email protected]_func
-def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) 
-> None:
-    for i0 in T.serial(10):
-        b[i0] = a[i0]
-
-
[email protected]_func
-def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) 
-> None:
-    T.preflattened_buffer(a, [10], dtype="bool", data=a.data)
-    T.preflattened_buffer(b, [10], dtype="bool", data=b.data)
-    # body
-    for i0 in T.serial(10):
-        b[i0] = T.cast(T.cast(a[i0], "bool"), "int8")
-
-
-def test_elementwise():
-    _check(elementwise_func, flattened_elementwise_func)
-
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.transform.Sequential(
+        [
+            tvm.tir.transform.FlattenBuffer(),
+            tvm.tir.transform.Simplify(),
+        ]
+    )
 
-def test_gpu_workload():
-    _check(gpu_func, flattened_gpu_func)
 
+class TestElementwise(BaseCompare):
+    """2-d buffers are flattened to 1-d"""
 
-def test_symbolic_shape():
-    _check(symbolic_func, flattened_symbolic_func)
-
-
-def test_multi_alloc():
-    _check(multi_alloc_func, flattened_multi_alloc_func)
-
-
-def test_strided_buffer():
-    _check(strided_buffer_func, flattened_strided_buffer_func)
-
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]):
+        for i in T.serial(0, 16):
+            B_new = T.alloc_buffer([1, 16], "float32")
+            for j in T.serial(0, 16):
+                B_new[0, j] = A[i, j] + 1.0
+            for j in T.serial(0, 16):
+                C[i, j] = B_new[0, j] * 2.0
 
-def test_lower_te():
-    x = te.placeholder((1,))
-    y = te.compute((1,), lambda i: x[i] + 2)
-    s = te.create_schedule(y.op)
-    orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
-    mod = tvm.tir.transform.FlattenBuffer()(orig_mod)
-    tvm.ir.assert_structural_equal(mod, orig_mod)  # FlattenBuffer should do 
nothing on TE
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+        for i in T.serial(0, 16):
+            B_new = T.alloc_buffer([16], "float32")
+            for j in T.serial(0, 16):
+                B_new[j] = A[((i * 16) + j)] + 1.0
+            for j in T.serial(0, 16):
+                C[((i * 16) + j)] = B_new[j] * 2.0
+
+
+class TestGPU(BaseCompare):
+    """Buffers allocated inside GPU-specific constructs are ignored.
+
+    These are assumed to be deliberate on the part of the
+    schedule-writer, and are left as-is.
+    """
+
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]):
+
+        i0 = T.env_thread("blockIdx.x")
+        i1 = T.env_thread("threadIdx.x")
+        i2 = T.env_thread("vthread")
+
+        T.launch_thread(i0, 4)
+        T.launch_thread(i1, 2)
+        T.launch_thread(i2, 2)
+        B = T.alloc_buffer([1, 16], "float32", scope="local")
+        for j in range(0, 16):
+            B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0
+        for j in range(0, 16):
+            C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
+        T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
+
+        i0 = T.env_thread("blockIdx.x")
+        i1 = T.env_thread("threadIdx.x")
+        i2 = T.env_thread("vthread")
+
+        T.launch_thread(i0, 4)
+        T.launch_thread(i1, 2)
+        T.launch_thread(i2, 2)
+        B = T.alloc_buffer([16], "float32", scope="local")
+        for j in range(0, 16):
+            B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
+        for j in range(0, 16):
+            C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0
+
+
+class TestSymbolic(BaseCompare):
+    """Dynamically-sized arrrays are flattened"""
+
+    def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+        A = T.match_buffer(a, (n, m), "float32")
+        C = T.match_buffer(c, (n, m), "float32")
+
+        for i in range(0, n):
+            B = T.alloc_buffer([m], "float32")
+            for j in range(0, m):
+                B[j] = A[i, j] + 1.0
+            for j in range(0, m):
+                C[i, j] = B[j] * 2.0
+
+    def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
+        A = T.match_buffer(a, n * m, "float32")
+        C = T.match_buffer(c, n * m, "float32")
+        T.preflattened_buffer(A, (n, m), "float32", data=A.data)
+        T.preflattened_buffer(C, (n, m), "float32", data=C.data)
+
+        for i in range(0, n):
+            B = T.alloc_buffer([m], "float32")
+            for j in range(0, m):
+                B[j] = A[i * m + j] + 1.0
+            for j in range(0, m):
+                C[i * m + j] = B[j] * 2.0
+
+
+class TestMultiAlloc(BaseCompare):
+    """If multiple allocations occur, all are flattened."""
+
+    def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), 
"float32"]):
+        for i, j in T.grid(4, 32):
+            B = T.alloc_buffer((4, 32), "float32", scope="global")
+            C = T.alloc_buffer((4, 32), "float32", scope="global")
+            B[i, j] = A[i, j] + 1.0
+            C[i, j] = A[i, j] + B[i, j]
+            D[i, j] = C[i, j] * 2.0
+
+    def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]):
+        T.preflattened_buffer(A, (4, 32), "float32", data=A.data)
+        T.preflattened_buffer(D, (4, 32), "float32", data=D.data)
+
+        for i, j in T.grid(4, 32):
+            B = T.alloc_buffer([128], "float32")
+            C = T.alloc_buffer([128], "float32")
+            B[i * 32 + j] = A[i * 32 + j] + 1.0
+            C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
+            D[i * 32 + j] = C[i * 32 + j] * 2.0
+
+
+class TestStrided(BaseCompare):
+    """Indices for flattened buffers use the specified striding."""
+
+    def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), 
"float32"]):
+        for i0 in T.serial(4):
+            B = T.alloc_buffer([4, 17], "float32")
+            B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, 
strides=[17, 1])
+            for i1, j in T.grid(4, 16):
+                B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0
+            for i1, j in T.grid(4, 16):
+                C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0
+
+    def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
+        T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
+        T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
+        for i0 in T.serial(0, 4):
+            B_new = T.alloc_buffer([68], "float32")
+            for i1 in T.serial(0, 4):
+                for j in T.serial(0, 16):
+                    B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
+            for i1 in T.serial(0, 4):
+                for j in T.serial(0, 16):
+                    C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0
+
+
+class TestBoolean(BaseCompare):
+    """Boolean buffers should be replaced by a backing int8 array"""
+
+    def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None:
+        for i0 in T.serial(10):
+            B[i0] = A[i0]
+
+    def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None:
+        T.preflattened_buffer(A, [10], dtype="bool", data=A.data)
+        T.preflattened_buffer(B, [10], dtype="bool", data=B.data)
+        for i0 in T.serial(10):
+            B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
+
+
+class TestLowerTE(BaseCompare):
+    """FlattenBuffer should do nothing on TE-based functions"""
+
+    def before(self):
+        x = te.placeholder((1,))
+        y = te.compute((1,), lambda i: x[i] + 2)
+        s = te.create_schedule(y.op)
+        mod = tvm.driver.build_module.schedule_to_module(s, [x, y])
+        return mod["main"]
+
+    expected = before
+
+
+class TestFlattenInsideBlock(BaseCompare):
+    """Flattening access inside a block flattens the accessed region."""
+
+    def before():
+        A = T.alloc_buffer([32, 32])
+        for i, j in T.grid(32, 32):
+            with T.block("block"):
+                T.evaluate(A[i, j])
+
+    def expected():
+        A = T.alloc_buffer([1024])
+        for i, j in T.grid(32, 32):
+            with T.block("block"):
+                T.evaluate(A[i * 32 + j])
+
+
+class TestNoChangeTo2DPhysicalBuffer(BaseCompare):
+    """Flattening preserves axis separators."""
+
+    def before():
+        A = T.alloc_buffer([32, 32], axis_separators=[1])
+        for i, j in T.grid(32, 32):
+            T.evaluate(A[i, j])
 
+    expected = before
+
+
+class TestFlattenWithAxisSeparators(BaseCompare):
+    """Flattening preserves axis separators"""
+
+    def before():
+        A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3])

Review Comment:
   The axis separators are intended to represent memory spaces that should be 
exposed with multiple indices at the codegen level.  Currently, the only use as 
part of the hexagon codegen to represent 2-d buffers (`T** buf; T value = 
buf[i][j]`), but I'd like to also use it to represent GPU texture memory.
   
   For comparison with strided transforms, applying the transform transforming 
`(i,j) => stride*i + j` would be exposed to the codegen as a single index, 
where transforming `(i,j) => [i, AXIS_SEPARATOR, j]` would be exposed to to the 
codegen as two separate values `(i, j)`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to