tkonolige commented on code in PR #89:
URL: https://github.com/apache/tvm-rfcs/pull/89#discussion_r949627983


##########
rfcs/0089-relax-upstreaming.md:
##########
@@ -0,0 +1,701 @@
+- Feature Name: Relax Upstreaming
+- Start Date: 2022-08-17
+- RFC PR: [apache/tvm-rfcs#0089](https://github.com/apache/tvm-rfcs/pull/0089)
+- GitHub Issue: [apache/tvm#0000](https://github.com/apache/tvm/issues/0000)
+- Co-Authors: [@denise-k](https://github.com/denise-k), 
[@jwfromm](https://github.com/jwfromm)
+
+# 1. **Summary**
+
+This RFC proposes to upstream the core foundation of Relax (Relay Next). Relax 
is a new graph-level IR that enables new capabilities to address the [critical 
needs](https://discuss.tvm.apache.org/t/establish-tvm-unity-connection-a-technical-strategy/13344)
 identified by the TVM community over the years of using and developing deep 
learning compilers.
+
+# 2. **Motivation and goals**
+
+Relax is an effort within [TVM 
Unity](https://tvm.apache.org/2021/12/15/tvm-unity) that aims to evolve the 
graph-level IR to maximize **expressibility, performance, and portability** 
across today and tomorrow’s workloads. Relax has three key goals motivated by 
the TVM community’s needs, and lessons the community has learned in ML 
acceleration through years of using and developing TVM:
+
+- Build a unified interface to transcends the boundaries of TVM’s abstractions 
between graph-level IR, tensor programs (TensorIR), and runtime libraries 
(PackedFunc);
+- Enable and optimize dynamic shape workloads;
+- Support “computational graph” style optimizations with advanced dataflow 
semantics.
+
+For more details on the design goals of Relax, please check out the [discuss 
forum 
post](https://discuss.tvm.apache.org/t/relax-co-designing-high-level-abstraction-towards-tvm-unity/12496).
+
+The main focus of this upstreaming RFC is to upstream the **core foundation** 
of Relax as an **optional** compilation flow in TVM with two principles:
+
+- **Minimize disruption:** This upstreaming should provide a **non-default** 
path to enable new capabilities for users/developers who are interested in what 
Relax brings, so it will not break the current default Relay flow.
+- **Minimize complexity:** This upstreaming should reuse existing TVM/Relay 
infrastructure as much as possible (for example IRModule, runtime Module, TOPI 
library, etc.) to avoid duplicated effort and code.
+
+This initial upstreaming will open the path for TVM Unity, and incrementally 
bring Relax into the community.
+
+# 3. **Guide-level explanation**
+
+This section introduces the three major design points of Relax, which map 
directly to the three key goals of Relax in the last section. At the beginning 
of this section, we first introduce what user-facing interfaces will look like 
after this RFC lands.
+
+(Most of the code examples in this RFC are written in 
[TVMScript](https://github.com/apache/tvm-rfcs/pull/74/files#diff-6965a40ad8df7618ae68e11c88f924542a506c74a931cc3011ae9f99989b5f51R21-R27),
 which enables users to write and print TVM programs containing both Relax and 
TIR functions with Python syntax.)
+
+## User-facing interface
+
+After this upstreaming lands, users are able to write a Relax program in 
TVMScript or translate a model directly from Relay. Relax provides a simple API 
to compile the IRModule to VM executable, and run it on Relax VM.
+
+```python
+import tvm.script
+from tvm.script import relax as R, tir as T
+
+# Relax IRModule written in TVMScript
+@tvm.script.ir_module
+class MyIRModule:
+    # This is a TIR PrimFunc which calls the TIR intrinsic T.exp
+    @T.prim_func
+    def tir_exp_func(x: T.handle, y: T.handle): ## <= D2
+        X = T.match_buffer(x, (n,), "float32")
+        Y = T.match_buffer(y, (n,), "float32")
+        with T.grid(n) as i:
+            Y[i] = T.exp(X[i])
+
+    # This is a Relax function which contains a dataflow block
+    # representing a computational graph, as well as a call to an
+    # opaque packed function which performs an in-place update to the
+    # data in variable gv0.
+    # We mark the corresponding design points (D0, D1, D2) that map to
+    # the following sections throughout the relax function bellow.
+    @R.function
+    def relax_func(x: R.Tensor[(n, k), "float32"], w: R.Tensor[_, "float32"]):
+    # n, k above are implicitly defined within the function signature
+    # so we will be able to refer to n, k within all of relax_func
+        with R.dataflow(): ## <= D2
+            lv0 = R.match_shape(w, (k, m)) ## <= D1
+            lv1: R.Tensor[(n, m), "float32"] = R.dot(x, lv0)
+            lv2: R.Tensor[(n * m,), "float32"] = R.flatten(lv1) ## <= D1
+            lv3: R.Shape = (n * m,)  ## <= D1
+            gv0 = R.call_tir(tir_exp_func, [lv2], lv3, dtype="float32") ## <= 
D0
+            R.outputs(gv0)
+
+        R.call_packed("custom_inplace_update", gv0) ## <= D0, D2
+        return gv0
+
+# Print IRModule with syntax highlighting
+MyIRModule.show()
+
+# Build the Relax IRModule
+target = tvm.target.Target("llvm")
+exec = relax.vm.build(MyIRModule, target)
+
+# Dump the VM executable instructions as text
+print(ex.as_text())
+
+# Run the function on Relax VM runtime
+vm = relax.VirtualMachine(exec, tvm.cpu())
+shape = (2, 3)
+data = tvm.nd.array(np.random.rand(*shape).astype(np.float32))
+res = vm["relax_func"](data)
+```
+
+## D0: ****Unified abstractions and optimizations across layers****
+
+The first key design point is to allow the high-level graph IR to be able to 
directly interact and call into lower-level TensorIR and PackedFunc (TVM FFI).
+
+The TensorIR PrimFunc and many external libraries adopt a 
**destination-passing-style** (DPS) calling convention that both input and 
output are passed to the function as arguments, and the outputs are mutated 
directly inside the function:
+
+```python
+def low_level_func(input0, input1, ..., output):
+    # implementations
+```
+
+The main idea of DPS is that input and output are explicitly allocated outside 
and passed to the low-level primitive function. This style is commonly used in 
low-level library designs (for example TensorRT), so that higher-level 
frameworks (for example, the compiler) can handle memory allocation.
+
+### ****call_tir****
+
+In Relax, we introduce `call_tir` to bridge graph-level IR and TIR. `call_tir` 
is an intrinsic that calls a TIR PrimFunc (that follows DPS) and returns the 
output. The semantics of `call_tir` can be demonstrated by the code below.
+
+```python
+def call_tir(tir_primfunc: GlobalVar, 
+             inputs: Tuple[Expr], 
+             output_shape: Shape, 
+             output_dtype: DataType) -> Expr:
+    """Example code to demonstrate the semantics of call_tir"""
+    out_tensor = alloc_tensor(output_shape, output_dtype)
+    low_level_func(*inputs, out_tensor)
+    return out_tensor
+```
+
+`call_tir` takes in tir_primfunc (a GlobalVar that maps to a TIR PrimFunc in 
the IRModule), a tuple of inputs, output tensor shape and datatype.  Notably, 
when the compiler lowers `call_tir`, it is not required to individually 
allocate each output tensor. The compiler can choose to create a memory plan of 
the intermediate tensors and tie things together for effective reuse.
+
+`call_tir` is implemented as a special relax operator to minimize the impact 
on the IR changes (instead of a standalone IR node). From the AST point of 
view, it becomes:
+
+```python
+Call(
+    op=Op::Get("relax.call_tir"),   
+    tir_primfunc,
+    inputs,
+    output_shape,
+    output_dtype
+)
+```
+
+### ****call_packed****
+
+In Relax, we introduce `call_packed` to bridge graph-level IR and PackedFunc. 
It indicates a call to a **non-DPS packed function** that is registered in the 
environment via TVM FFI. 
+
+From the AST’s point of view, we do not need to introduce an additional call 
node, instead, we introduce an `ExternFunc` construct that represents a 
PackedFunc that we can call into (the PackedFunc may or may not return a value):
+
+```python
+Call(op=ExternFunc("my_packed_func"), *args)
+```
+
+`R.call_packed("my_packed_func", gv0)` in TVMScript (as shown in the 
User-facing interface section) only served as a syntax sugar to represent the 
above AST node. 
+
+### ****call_dps_packed****
+
+To be able to call into a DPS packed function (many low-level library (e.g. 
TensorRT) functions are designed in this way), and hence the compiler is able 
to directly handle the output memory, we introduce a `call_dps_packed` 
intrinsic, which corresponds to the following AST:
+
+```python
+Call(
+    op=Op::Get("relax.call_dps_packed"),   
+    ExternFunc("my_packed_func"),
+    inputs,
+    output_shape,
+    output_dtype
+)
+```
+
+Suppose `custom_packed_func` is a user-defined packed function in DPS:
+
+```python
+R.call_dps_packed("custom_packed_func", (input0, input1), output_shape=(3, 4), 
output_dtype="float32")
+```
+
+corresponds to the following AST:
+
+```python
+Call(
+    op=Op::Get("relax.call_dps_packed"),
+    ExternFunc("custom_packed_func"),
+    (input0, input1),
+    output_shape=(3, 4), 
+    output_dtype="float32"
+)
+```
+
+The following program in TVMScript shows that with `call_tir`, `call_packed`, 
and `call_dps_packed`, we can directly embed and call the TIR and PackedFunc 
functions in the high-level Relax IR program.
+
+```python
+from tvm.script import relax as R
+
+# User-defined packed functions
+# Non-DPS PackedFunc with return
+@tvm.register_func("custom_add")
+def add_packed(a, b):
+    ret = a.numpy() + b.numpy()
+    return tvm.nd.array(ret)
+
+# Non-DPS PackedFunc without return
+@tvm.register_func("custom_print")
+def print_packed(a):
+    print(a)
+
+# DPS PackedFunc
+@tvm.register_func("custom_tile")
+def tile_packed(a, b):
+    b[:] = tvm.nd.array(np.tile(a.numpy(), (1, 2)))
+
+@tvm.script.ir_module
+class MyIRModule:
+    # define a PrimFunc to do matrix multiply
+    # note TIR PrimFunc is in DPS, here z is the output
+    @T.prim_func
+    def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+        m = T.var("int32")
+        n = T.var("int32")
+        k = T.var("int32")
+        A = T.match_buffer(x, (m, n))
+        B = T.match_buffer(y, (n, k))
+        C = T.match_buffer(z, (m, k))
+
+        for (i0, j0, k0) in T.grid(m, n, k):
+            with T.block():
+                i, j, k = T.axis.remap("SSR", [i0, j0, k0])
+                with T.init():
+                    C[i, j] = 0.0
+                C[i, j] += A[i, k] * B[j, k]
+
+    @R.function
+    def relax_func(x: R.Tensor[(m, n), "float32"], y: R.Tensor[(n, k), 
"float32"]):
+        with R.dataflow():
+            # call_tir calls into a PrimFunc, and returns the matrix 
multiplication result
+            gv0 = R.call_tir(tir_matmul, (x, y), (m, k), dtype="float32")
+            R.outputs(gv0)
+
+        # call into a PackedFunc to print the value of gv0
+        R.call_packed("custom_print", gv0)
+
+        # call the registered "custom_add" non-DPS PackedFunc and return the 
result
+        gv1 = R.call_packed("custom_add", gv0, gv0)
+
+        # call the registered "custom_tile" DPS PackedFunc and return the 
result
+        gv2 = R.call_dps_packed("custom_tile", (gv1), (m, k * 2), 
dtype="float32")
+        return gv2
+```
+
+This cross-level interaction unlocks many interesting things that were not 
possible before, including, but not limited to:
+
+- Incrementally lower different parts of a program using different strategies, 
instead of lowering the entire program to TIR directly from Relay as today.
+- Allow for more customized optimizations, such as whole program 
optimizations, cascading, and other post-schedule optimizations.
+- Enable automation (MetaSchedule) to analyze call_tir nodes and the callee 
TIR programs, perform optimizations and rewrites to one or more call_tir nodes, 
thus feeding decisions such as layout rewrite directly to the high-level IR.
+- By turning subgraphs into calls to PackedFunc (via call_dps_packed), BYOC 
becomes an IRModule ⇒ IRModule transformation as a natural part of compilation.
+- Provide a flexible way to incorporate TensorIR and existing libraries such 
as cuDNN.
+
+Through this unified interface, ML researchers, system engineers, and hardware 
vendors can collaborate better, since we can incrementally optimize and 
translate specific parts of the whole program in Relax.
+
+## D1: ****Shape deduction as first-class computation****
+
+Shape deduction is essential to compiling dynamic workloads. Under a dynamic 
shape setting, the destination-passing call style adopted by call_tir and 
call_dps_packed requires that the shapes of the output tensors are computed. We 
can solve this challenge by invoking a function to compute the shape before 
calling the operator function. However, there are also cases where the shape 
itself is data-dependent (e.g. `unique` operation used to select the unique 
elements of a tensor). Finally, since most dynamic shape workloads still 
contain a lot of (partially) static shapes, ideally we want to take benefit of 
this static shape information for optimization.
+
+In Relax, a shape constraint of a tensor is represented by two fields of the 
`relax.Expr`(`RelayExpr`).
+
+- `checked_type_: Type`, stores the generic rank and dtype constraints.
+- `shape_: Expr`, stores ways to compute shape of the expression at runtime. 
It’s `nullptr` when the expression’s `checked_type_` is not 
`DynTensorType`(meaning the expression is not a Tensor). Otherwise, this 
`shape_` field takes one of the 3 possible types outlined below.
+
+**checked_type_**
+
+`Expr→checked_type_` stores the compile time deduced type of an expression. We 
introduce a new type `DynTensorType` to represent the type of a Relax tensor 
Expr, which contains the following two fields:
+
+```python
+class DynTensorType(Type): 
+    ndim: int # ndim=-1 means unknown rank
+    dtype: DataType # dtype=DataType::Void() means unknown dtype
+```
+
+**shape_**
+
+`DynTensorType` does not contain shape information. Instead, the shape of a 
Tensor is stored in an **optional** `shape_` field in an Expr.
+
+For an `Expr x`, `x.shape_` can contain the following values:
+
+- V0: `ShapeExpr` (see Section 4.1 for its definition), which contains an 
`Array<PrimExpr>`. Static shapes are always represented in this form by 
encoding each dimension as `IntImm`. Symbolic shapes can also be represented 
(see section 4.1 for more).
+- V1: Generic `Expr`, which is expected to, at runtime, result in something of 
type `Shape`. The `Expr` can call into opaque (shape) functions, or shape 
deduction intrinsics.
+- V2: `RuntimeDepShape` (see Section 4.1 for its definition), a special `Expr` 
to indicate that shape is unknown at compile time and cannot be determined at 
runtime without producing the attached Tensor (see Safety Net section for its 
handling).
+
+The following program covers typical scenarios in shape deduction (marked in 
comments). Importantly, shape is now part of the computation along with Tensor 
values. This reflects the fact that the computation of shapes can happen at 
runtime.
+
+```python
+from tvm.script import relax as R
+
+@R.function
+def shape_example(x: R.Tensor[(n, 2, 2), "float32"]):
+    with R.dataflow():
+        # V0: symbolic and static shape deduction
+        lv0: R.Tensor[(n, 4), "float32"] = R.reshape(x, (n, 4))
+        lv1: R.Tensor[(n * 4,), "float32"] = R.flatten(lv0)
+        lv2: R.Shape = (n * 4,)
+
+        # V1: external opaque shape function
+        lv3: R.Shape = R.call_packed("myshape_func", lv2)
+        lv4 = R.call_tir("custom_func", (lv1,), lv3, dtype="float32")
+
+        # V2: runtime dependent case: _ is used to represent RuntimeDepShape
+        lv5: R.Tensor[_, "float32"] = R.unique(lv4)
+
+        # re-match shape
+        lv6: R.Tensor[(m,), "float32"] = R.match_shape(lv5, (m,))
+        lv7: R.Shape = R.match_shape(lv3, (m,))
+
+        gv0: R.Tensor[(m,), "float32"] = R.exp(lv6)
+        R.outputs(gv0)
+
+    return gv0
+```
+
+While the text format type annotation `lv0: R.Tensor[(n, 4), "float32"]` shows 
the shape of each value, this is only syntactic sugar. From the IR’s point of 
view, the `shape_` field `(n, 4)` is not included in the type signature of 
`lv0`. The type signature of `lv0` is `DynTensor(rank=2, dtype="float32")`, and 
the shape is a special value field that is attached to each `Expr`. We made 
this explicit choice to simplify the type inference so that we do not need to 
get into the [dependent typing](https://en.wikipedia.org/wiki/Dependent_type) 
land where type depends on value (shape in our case) which requires heavier 
machinery to handle. 
+
+**match_shape**
+
+After a data-dependent computation (like `unique`) or external calls, we may 
need to be able to recover/refine the shape information to enable more 
optimizations. The `match_shape` construct is used to perform such refinements.
+
+`var: Var = match_shape(value: Expr, pattern: List[PrimExpr])`
+
+The match_shape construct takes a **value** and a **pattern** (a list of 
`PrimExpr`, for example `(m, n)`), and returns a **var**. It has two overloaded 
semantics:
+
+- When value is a Tensor, it matches `value.shape` to the pattern, populates 
the corresponding symbolic integer variable if it occurs in the pattern for the 
first time in the scope, and then returns a new Tensor that is the same as 
value but the shape field is updated to the pattern. In the V2 case in the 
above code snippet, `R.match_shape(lv5, (m,))` defines a symbolic TIR variable 
`m`, and matches tensor lv5’s shape with the pattern `(m,)`.
+- When value is a Shape (for example `lv7: R.Shape = R.match_shape(lv3, (m,))` 
in the above code snippet), it directly matches the pattern, and returns a 
Shape. This is useful when we want to isolate out shape functions that do not 
correspond to any Tensor value.
+
+**Safety Net (handle `RuntimeDepShape`)**
+
+While fixed rank, dynamic symbolic shape relation covers most of the use 
cases. Inevitably we also need to be able to cover general cases that may not 
fall into the category:
+
+- C0: Dynamic shape relations where output shape is data dependent on the 
input (e.g. `unique` operator).
+- C1: Rank of a tensor is not known (can happen in rare cases of loops).
+- C2: dtype of a tensor is not known.
+- C3: Other cases, opaque runtime objects for low-level libraries(e.g. PRNG 
handle, cuDNN context).
+
+As a result, it is important to have a "safety net" solution so that we cover 
the general cases.
+
+Suppose we have a `unique` operation which we cannot deduce the return 
tensor’s shape at compile time:
+
+`y: R.Tensor[_, _] = R.unique(x)`
+
+During lowering, this call won't get translated into destination passing 
style, because it is impossible to obtain the shape information and 
pre-allocate the memory. Instead, they are directly translated to calls that 
allocate and return the result tensor.
+
+- `R.unique` can be mapped to a runtime PackedFunc calls that takes in an 
NDArray x and perform an unique operation.
+    - We can even dispatch to common runtime libraries such as `torch.unique`, 
for exmaple the above `R.unique(x)` can be lowered to 
`call_packed(”torch.unique”, x)`.
+
+These features are supported by Relax VM as PackedFunc calls that return TVM 
Object. We can bring the tensors from no shape computation land to the 
shape-aware land using match_shape. The no shape computation is by no means the 
most effective way to handle things. It is necessary for cases like 
data-dependent calculation and interfaces with external libs that have weaker 
shape information.
+
+## D2: ****Dataflow block as a first-class construct****
+
+Most machine learning models can be represented with a 
**pure**/**side-effect-free** computational graph. An operation is pure or 
side-effect free ****if: it only reads from its inputs and returns the result 
via its output, it will not change other parts of the program (such as 
incrementing a global counter).
+
+A **dataflow graph** means every operation inside is **side-effect free** and 
there are no **control flows** (such as if-then-else). A **dataflow block** is 
a way for us to mark the dataflow graph regions of the program in Relax. 
Specifically, all the operations under the dataflow block are side-effect-free 
and do not contain control flows (control flow is an advanced semantic that 
most pass writers do not consider). Outside a dataflow block, operations can 
contain side effects (for example doing in-place weight update during model 
training) and control flow. The program below is an example program that 
contains two dataflow blocks.
+

Review Comment:
   I think relay is pure except for `Ref`s. And in practice they are not used 
because they are poorly supported by the compiler.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to