Lunderberg commented on code in PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#discussion_r891630186


##########
rfcs/0077-layout-transform-padding.md:
##########
@@ -0,0 +1,2522 @@
+- Feature Name: Layout Transformation Padding Roadmap
+- Authors: [Eric Lunderberg](https://github.com/Lunderberg/),
+           [Chris Sullivan](https://github.com/csullivan),
+           [Wuwei Lin](https://github.com/vinx13/),
+           [Junru Shao](https://github.com/junrushao1994)
+- Start Date: 2022-06-06
+- RFC PR: [apache/tvm-rfcs#0077](https://github.com/apache/tvm-rfcs/pull/0077)
+- GitHub Issue: TBD
+
+# Table of contents
+- [Table of contents](#table-of-contents)
+- [Summary](#summary)
+- [Motivation](#motivation)
+- [Guide-level explanation](#guide-level-explanation)
+  - [Padded Transformations](#padded-transformations)
+  - [Defining Padded Values](#defining-padded-values)
+  - [Overcompute vs Branching](#overcompute-vs-branching)
+- [Reference-level explanation](#reference-level-explanation)
+  - [TIR Changes](#tir-changes)
+    - [Buffer Annotation of Padding Predicate/Constraint 
Pairs](#buffer-annotation-of-padding-predicateconstraint-pairs)
+    - [New TIR Op, `tir::builtin::arbitrary`](#new-tir-op-tirbuiltinarbitrary)
+    - [Buffer Annotation of Layout 
Transforms](#buffer-annotation-of-layout-transforms)
+  - [Transformations/Metaschedule 
Primitives](#transformationsmetaschedule-primitives)
+    - [Enhancement - transform_layout](#enhancement---transform_layout)
+    - [New Primitive - Add buffer 
constraint](#new-primitive---add-buffer-constraint)
+    - [New Primitive - Reorder Loops According to 
Buffer](#new-primitive---reorder-loops-according-to-buffer)
+    - [Enhancement - Predicate for 
DomainTouched](#enhancement---predicate-for-domaintouched)
+    - [Enhancement - Remove No Op](#enhancement---remove-no-op)
+    - [Enhancement - Simplify](#enhancement---simplify)
+    - [New Transform - Hoist Expression](#new-transform---hoist-expression)
+    - [New Transform - Reduce Loop 
Extents](#new-transform---reduce-loop-extents)
+    - [Utility - Merge Adjacent Loops](#utility---merge-adjacent-loops)
+    - [New Primitive - Remove Branching Through 
Overcompute](#new-primitive---remove-branching-through-overcompute)
+    - [New Primitive - Remove Overcompute Through 
Branching](#new-primitive---remove-overcompute-through-branching)
+    - [New Lowering Transform - Remove 
T.Arbitrary](#new-lowering-transform---remove-tarbitrary)
+  - [Implementation options](#implementation-options)
+    - [Never write to transformation 
padding](#never-write-to-transformation-padding)
+    - [Never read from transformation 
padding](#never-read-from-transformation-padding)
+    - [Allocate internal buffer containing transformation 
padding](#allocate-internal-buffer-containing-transformation-padding)
+    - [Explicitly write next operator's desired default at end of 
function](#explicitly-write-next-operators-desired-default-at-end-of-function)
+    - [Implicitly write default value of next 
operator](#implicitly-write-default-value-of-next-operator)
+    - [Apply operator element-wise over the transformation 
padding](#apply-operator-element-wise-over-the-transformation-padding)
+    - [Multiple Buffer Semantics](#multiple-buffer-semantics)
+  - [Points of Communication](#points-of-communication)
+- [Drawbacks](#drawbacks)
+- [Rationale and alternatives](#rationale-and-alternatives)
+- [Prior art](#prior-art)
+- [Unresolved questions](#unresolved-questions)
+- [Future possibilities](#future-possibilities)
+
+# Summary
+[summary]: #summary
+
+Buffer layout transformations can require padding in the transformed
+buffer.  The efficiency of an operator depends on the semantics used
+for loads and stores to values in the required padding.  The choice of
+buffer semantics can reduce branch divergence and avoid repeated
+setting of default values, but also imposes constraints between the
+producer and consumer of a buffer.
+
+This RFC discusses a general plan for specifying buffer semantics to
+be used, and the constraints imposed.  Subsequent RFCs will follow
+describing the design for support of each of the semantics proposed in
+this roadmap.
+
+# Motivation
+[motivation]: #motivation
+
+Suppose a buffer of shape `[14]` is transformed such that each index
+`i` is mapped to `[i//4, i%4]`.  The first index can range from 0
+(`0//4`) to 3 (`14//4`), and the second index can range from 0 (`0%4`)
+to 3 (`3%4`).  Therefore, the transformed shape is `[4,4]`.  However,
+this has 16 elements, because the transformed coordinates `(3,2)` and `(3,3)` 
do
+not have a corresponding index on the workload range `0 <= i < 14`.  The final
+result in these locations is not determined by the compute definition,
+so we have flexibility in what to store in the padding that is
+introduced by the transformation, and what assumptions can be made
+when reading from those locations.
+
+For example, an element-wise function may be most efficiently written
+using vectorized instructions over all values, regardless of whether
+they exist in the compute definition.  Or a maxpool may be most
+efficiently written if input tensors have `-INF` stored in the
+transformation padding.  Satisfying both of these at the same time may
+not be possible.  While the compute definition doesn't impose
+constraints on the values in the transformation padding, there are
+still constraints imposed by the usage of those values by different
+operators.
+
+
+```
+ ┌─Logical-index-space───────────────────┐
+ │                                       │
+┌▼─┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬──┬─▼┌──┬──┐
+│00│01│02│03│04│05│06│07│08│09│10│11│12│13│14│15│
+└▲─┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴──┴─▲┘
+ │                                             │
+ └─Physical-index-space────────────────────────┘
+
+ ┌─Transformed-index-space─┐
+ │                         │
+ │      ┌────┬────┬────┬───▼┐
+ │      │ 00 │ 01 │ 02 │ 03 │
+ │      ├────┼────┼────┼────┤
+ │      │ 04 │ 05 │ 06 │ 07 │
+ │      ├────┼────┼────┼────┤
+ │      │ 08 │ 09 │ 10 │ 11 │
+ │      ├────┼────┼────┼────┤
+ └──────► 12 │ 13 │ 14 │ 15 │
+        └────┴────┴────┴────┘
+```
+
+# Guide-level explanation
+[guide-level-explanation]: #guide-level-explanation
+
+## Padded Transformations
+
+In general, a transformation will introduce the minimum amount of
+padding such that all values in the original buffer can be stored in
+the layout specified.  As a result, whether a transformation
+introduces padding depends on the transformation being applied and the
+buffer shape on which it is being applied.  For example, consider a
+schedule that contains tensor `A` with shape `[16]` and tensor `B` with shape
+`[14]`.
+
+```python
+# This transformation does not introduce padding.  The original shape
+# of [16] produces the transformed shape [2,8], which contains the
+# original 16 values no additional padding.
+sched[A].transform_layout(lambda i: [i//8, i%8])
+
+# This transform introduces padding.  The original shape of [14] also
+# produces the transformed shape [2,8], which contains the original 14
+# values and an additional 2 values of padding.  These are located at
+# transformed indices [1,6] and [1,7].
+sched[B].transform_layout(lambda i: [i//8, i%8])
+```
+
+The above example introduces padding at the end of a buffer.  By
+including an offset in the layout transformation, we can instead place
+the padding at the beginning of a buffer.
+
+```python
+# This transform introduces padding.  For 0 <= i < 14, the transformed
+# index (i+2)//8 can have values of 0 or 1, so the transformed shape
+# is [2,8].  There are no valid values of i that would produce [0,0]
+# or [0,1], so these transformed indices contain padding.
+sched[B].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+In addition to moving the location of the padded indices, use of an
+offset in a layout transformation can introduce additional padding.
+
+```python
+# This transformation introduces padding.  For 0 <= i < 16, the
+# transformed index (i+2)//8 can have values of 0, 1, or 2, so the
+# transformed shape is [3,8].  Padding is introduced from [0,0] to
+# [0,1], and from [2,2] to [2,7].
+sched[A].transform_layout(lambda i: [(i+2)//8, (i+2)%8])
+```
+
+
+## Defining Padded Values
+
+When a buffer is transformed, the majority of values in the
+transformed buffer are constrained to have the corresponding value in
+the original buffer.  However, when a buffer is padded to meet some
+alignment criteria, these additional padded values have no such
+constraint.
+
+To specify the values stored in the padding, the `transform_layout`
+function takes an optional argument `pad_value` that
+specifies the value that should be present in the padding.  This
+should be a function that maps from transformed indices to an
+`Optional[PrimExpr]`.
+
+```python
+# B.shape is [14]
+transform = lambda i: [i//4, i%4]
+
+# Three equivalent calls to perform the same layout transformation.
+# Padding is introduced, but access of the padding is forbidden.
+sched[B].transform_layout(transform)
+sched[B].transform_layout(transform, pad_value=None)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: None)
+
+# Padding is introduced, and contains zeros.
+sched[B].transform_layout(transform, pad_value=0.0)
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 0.0)
+
+# Padding is introduced, and contains arbitrary values.
+sched[B].transform_layout(transform, pad_value=tir.arbitrary(dtype="float32"))
+sched[B].transform_layout(transform, pad_value=lambda io,ii: 
tir.arbitrary(dtype="float32"))
+
+# Padding is introduced, and wraps to the beginning of the array.
+sched[B].transform_layout(transform, pad_value=lambda io,ii: B[0, (io-14)%4])
+```
+
+The `Buffer` object stores a predicate to identify which indices
+contain padding, along with the expression given in `pad_value`.  This
+expression may only contain constants and the transformed buffer
+itself, and may not introduce dependencies on another buffer.
+
+For a producer of the transformed buffer, if `pad_value` is defined,
+the padding value must be written to the padding prior to the
+completion of the operator.  Effectively, the producer must have a
+postlude as follows:
+
+```python
+for transformed_indices in T.grid(*transformed_shape):
+    if padding_predicate(*transformed_indices):
+        B[transformed_indices] = pad_value(*transformed_indices)
+```
+
+For a consumer of the transformed buffer, these padding values are
+initially unused, but may be used in later simplifications.
+
+## Overcompute vs Branching
+
+Depending on the computation being performed and the value stored in
+the padding, there can be trade-offs between branching and
+overcompute.  For example, consider the following `PrimFunc`, which
+computes the sum over each row of the input data.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 14), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j in T.serial(14):
+            B[i] = B[i] + A[i, j]
+```
+
+We'd like to transform the layout of buffer `A` from `[i, j]` to `[i,
+j//4, j%4]`, along with the loop iteration.  By default, after using
+the `transform_layout` and `split` metaschedule primitives, we have
+the following function.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            if 4*j_outer + j_inner < 14:
+                B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+If the conditional can be removed, this function would be much more
+amenable for later vectorization, or to reduce branch divergence when
+bound to a thread index.  If the padding in `A` is pre-filled with
+zero, then `B[i] = B[i] + 0.0` is a no-op, and can be performed
+without changing the final computation.
+
+```python
+@T.prim_func
+def row_summation(a: T.handle, b: T.handle):
+    A = T.match_buffer(shape=(16, 4, 4), dtype="float32")
+    B = T.match_buffer(shape=(16,), dtype="float32")
+    for i in T.serial(16):
+        B[i] = 0.0
+        for j_outer, j_inner in T.grid(4, 4):
+            B[i] = B[i] + A[i, j_outer, j_inner]
+```
+
+By annotating the layout transformation with the value stored in the
+padding, this condition can be proven, allowing this conditional to
+automatically be removed.  Since the tradeoff between branching and
+overcompute may or may not be beneficial dependent on the schedule,
+these options are exposed as two additional transformations,
+`tir.transform.RemoveBranchingThroughOvercompute` and
+`tir.transform.RemoveOvercomputeThroughBranching`.
+
+
+# Reference-level explanation
+[reference-level-explanation]: #reference-level-explanation
+
+## TIR Changes
+
+### Buffer Annotation of Padding Predicate/Constraint Pairs

Review Comment:
   @tqchen @Hzfengsy Thank you, and I definitely agree on minimizing the number 
of IR changes being made.  (Also, phew, this ended up being a longer reply than 
I had expected, which probably means that whatever the result of this thread, 
the "Rationale and Alternatives" section should be updated.)
   
   @areusch The example transformations are largely present in the 
"Implementation Options" section.  The goal of that section was to describe 
different example transformations that we'd like to be able to make, and to 
ensure that they could be made using the functionality introduced earlier in 
the RFC.  It wasn't until this morning that I realized that there should also 
be links in the other direction, pointing from the proposed IR changes to the 
motivating use case.
   
   Below is the general rationale, with high-level implementations.
   
   Starting by listing out the desired properties of an implementation.
   
   1. No changes to existing TIR data structures
   2. No additional meaning to existing TIR data structures
   3. Simplifications can use constraints from multiple buffers
   4. No ordering requiring between `transform_layout` and fuse/split/reorder.  
(Conditional statements that are removed using the buffer constraints are 
typically introduced by loop rewrites.)
   5. Can be used to describe out-of-bounds access (e.g. texture memory 
clamping on a GPU) that returns a default value.
   6. Only allocate memory when required or requested
   
   
   Implementations considered
   
   - A. All buffer transformations introduce new stage
   
     - Pro: No coordination required between different operators.
     - Con: Any producer/consumer interactions must be recognized by operator 
fusion/device planning.
     - Con: Cannot apply to primfunc input/outputs.  (e.g. To de-duplicate 
operators that differ only by underlying layout, such as `topi.nn.conv2d_hwcn`, 
`topi.nn.conv2d_nchw`, `topi.nn.conv2d_NCHWc`, etc.)
     - Con: May introduce unnecessary data copies, if the constraint required 
by the consumer is already met.
   
   - B. Perform one `transform_layout` at a time.  For each one, simplify using 
provided constraints, do not store constraints afterward.
     
     - Pro: Main downside is that it could only use the constraints of a single 
buffer at a time.  This wouldn't be able to express simplifications that rely 
on the padding in multiple buffers. (e.g. [elementwise 
operator](https://github.com/Lunderberg/tvm-rfcs/blob/buffer_layout_padding/rfcs/0077-layout-transform-padding.md#apply-operator-element-wise-over-the-transformation-padding)
   
     - Con: Requires loop rewriting to be done either inside `layout_transform` 
or prior to calling `layout_transform`.
     - Con: Can't be applied to use cases outside of layout transformations 
(e.g. texture memory clamping on a GPU), where simplifications could benefit 
from assumed constraints.
     
     
   - C. Perform all `transform_layout` in a single function call, passing all 
layout transforms and padding constraints.
   
     - Pro: Simplifications may use constraints of all buffers being 
transformed.
     - Con: Requires changing the calling convention for layout transformations.
     - Con: Requires loop rewriting to be done either inside `layout_transform` 
or prior to calling `layout_transform`.
     - Con: Can't be applied to use cases outside of layout transformations 
(e.g. texture memory clamping on a GPU), where simplifications could benefit 
from assumed constraints.
       
   - D. Express buffer constraints using existing `AssertStmt`
   
     In pseudocode, each consumer would have roughly the loopnest below. 
However, this would still need to have some way of indicating that the 
constraint should be removed when lowering, and should not produce any runtime 
assertions.
     
     ```python
     for indices in T.grid(*transform_shape):
         if padding_predicate(indices):
             T.Assert(buf[indices] == pad_value(indices))
     ```
   
     - Pro: No change to TIR data structures
     - Pro: No change required for calling convention for layout 
transformations.
     - Pro: Simplifications may use constraints of all buffers being 
transformed.
     - Pro: Can be applied to use cases outside of layout transformations (e.g. 
texture memory clamping on a GPU), where simplifications could benefit from 
assumed constraints.
     - Pro: No ordering between loop/layout transforms, because the constraints 
can be determined from the TIR.
     - Con: Additional meaning attached to existing TIR data structures.
     - Con: Can only describe a fixed number of assertions, wouldn't be able to 
express a default value for all out-of-bounds reads.
   
   - E. Express buffer constraints as a field in `PrimFuncNode::attrs`
   
     - Con: Passes that replace `Buffer` objects must be aware of this 
attribute, in order to update the `Buffer` object stored in it.
       
   - F. Express buffer constraints as a new member variable in `Buffer`
   
     - Con: Changes change to TIR data structures
     - Pro: No change required for calling convention for layout 
transformations.
     - Pro: Simplifications may use constraints of all buffers being 
transformed.
     - Pro: Can be applied to use cases outside of layout transformations (e.g. 
texture memory clamping on a GPU), where simplifications could benefit from 
assumed constraints.
     - Pro: Can rewrite loop structure later, use existing constraints.
     
     
   1. No changes to existing TIR data structures
   2. No additional meaning to existing TIR data structures
   3. Simplifications can use constraints from multiple buffers
   4. No ordering requiring between `transform_layout` and fuse/split/reorder.
   5. Can be used to describe out-of-bounds access (e.g. texture memory 
clamping on a GPU) that returns a default value.
   6. Only allocate memory when required or requested
     
   - A. All buffer transformations introduce new stage
   - B. Perform one `transform_layout` at a time.  For each one, simplify using 
provided constraints, do not store constraints afterward.
   - C. Perform all `transform_layout` in a single function call, passing all 
layout transforms and padding constraints.
   - D. Express buffer constraints using existing `AssertStmt`
   - E. Express buffer constraints as a field in `PrimFuncNode::attrs`
   - F. Express buffer constraints as a new member variable in `Buffer`
     
     
   |        | Goal 1 | Goal 2 | Goal 3 | Goal 4 | Goal 5 | Goal 6 |
   |--------|--------|--------|--------|--------|--------|--------|
   | Impl A | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :x:    | 
:x:    | :x:    |
   | Impl B | :heavy_check_mark:    | :heavy_check_mark:    | :x:    | :x:    | 
:x:    | :heavy_check_mark:    |
   | Impl C | :heavy_check_mark:    | :heavy_check_mark:    | 
:heavy_check_mark:    | :x:    | :x:    | :heavy_check_mark:    |
   | Impl D | :heavy_check_mark:    | :x:    | :heavy_check_mark:    | 
:heavy_check_mark:    | :x:    | :heavy_check_mark:    |
   | Impl E | :heavy_check_mark:    | :x:    | :heavy_check_mark:    | 
:heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    |
   | Impl F | :x:    | :heavy_check_mark:    | :heavy_check_mark:    | 
:heavy_check_mark:    | :heavy_check_mark:    | :heavy_check_mark:    |
   
   The implementations that would satisfy the largest number of the desired 
goals would be adding the member variable `BufferNode::constraints`, or adding 
a field to `PrimFuncNode::attrs` that holds the constraints.  Between the two, 
I lean toward having it as an explicit member variable, so that incorrect usage 
appears as a compilation error when compiling TVM, but would find either 
implementation acceptable.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to