Hzfengsy commented on code in PR #89: URL: https://github.com/apache/tvm-rfcs/pull/89#discussion_r950165733
########## rfcs/0089-relax-upstreaming.md: ########## @@ -0,0 +1,701 @@ +- Feature Name: Relax Upstreaming +- Start Date: 2022-08-17 +- RFC PR: [apache/tvm-rfcs#0089](https://github.com/apache/tvm-rfcs/pull/0089) +- GitHub Issue: [apache/tvm#0000](https://github.com/apache/tvm/issues/0000) +- Co-Authors: [@denise-k](https://github.com/denise-k), [@jwfromm](https://github.com/jwfromm) + +# 1. **Summary** + +This RFC proposes to upstream the core foundation of Relax (Relay Next). Relax is a new graph-level IR that enables new capabilities to address the [critical needs](https://discuss.tvm.apache.org/t/establish-tvm-unity-connection-a-technical-strategy/13344) identified by the TVM community over the years of using and developing deep learning compilers. + +# 2. **Motivation and goals** + +Relax is an effort within [TVM Unity](https://tvm.apache.org/2021/12/15/tvm-unity) that aims to evolve the graph-level IR to maximize **expressibility, performance, and portability** across today and tomorrow’s workloads. Relax has three key goals motivated by the TVM community’s needs, and lessons the community has learned in ML acceleration through years of using and developing TVM: + +- Build a unified interface to transcends the boundaries of TVM’s abstractions between graph-level IR, tensor programs (TensorIR), and runtime libraries (PackedFunc); +- Enable and optimize dynamic shape workloads; +- Support “computational graph” style optimizations with advanced dataflow semantics. + +For more details on the design goals of Relax, please check out the [discuss forum post](https://discuss.tvm.apache.org/t/relax-co-designing-high-level-abstraction-towards-tvm-unity/12496). + +The main focus of this upstreaming RFC is to upstream the **core foundation** of Relax as an **optional** compilation flow in TVM with two principles: + +- **Minimize disruption:** This upstreaming should provide a **non-default** path to enable new capabilities for users/developers who are interested in what Relax brings, so it will not break the current default Relay flow. +- **Minimize complexity:** This upstreaming should reuse existing TVM/Relay infrastructure as much as possible (for example IRModule, runtime Module, TOPI library, etc.) to avoid duplicated effort and code. + +This initial upstreaming will open the path for TVM Unity, and incrementally bring Relax into the community. + +# 3. **Guide-level explanation** + +This section introduces the three major design points of Relax, which map directly to the three key goals of Relax in the last section. At the beginning of this section, we first introduce what user-facing interfaces will look like after this RFC lands. + +(Most of the code examples in this RFC are written in [TVMScript](https://github.com/apache/tvm-rfcs/pull/74/files#diff-6965a40ad8df7618ae68e11c88f924542a506c74a931cc3011ae9f99989b5f51R21-R27), which enables users to write and print TVM programs containing both Relax and TIR functions with Python syntax.) + +## User-facing interface + +After this upstreaming lands, users are able to write a Relax program in TVMScript or translate a model directly from Relay. Relax provides a simple API to compile the IRModule to VM executable, and run it on Relax VM. + +```python +import tvm.script +from tvm.script import relax as R, tir as T + +# Relax IRModule written in TVMScript +@tvm.script.ir_module +class MyIRModule: + # This is a TIR PrimFunc which calls the TIR intrinsic T.exp + @T.prim_func + def tir_exp_func(x: T.handle, y: T.handle): ## <= D2 + X = T.match_buffer(x, (n,), "float32") + Y = T.match_buffer(y, (n,), "float32") + with T.grid(n) as i: + Y[i] = T.exp(X[i]) + + # This is a Relax function which contains a dataflow block + # representing a computational graph, as well as a call to an + # opaque packed function which performs an in-place update to the + # data in variable gv0. + # We mark the corresponding design points (D0, D1, D2) that map to + # the following sections throughout the relax function bellow. + @R.function + def relax_func(x: R.Tensor[(n, k), "float32"], w: R.Tensor[_, "float32"]): + # n, k above are implicitly defined within the function signature + # so we will be able to refer to n, k within all of relax_func + with R.dataflow(): ## <= D2 + lv0 = R.match_shape(w, (k, m)) ## <= D1 + lv1: R.Tensor[(n, m), "float32"] = R.dot(x, lv0) + lv2: R.Tensor[(n * m,), "float32"] = R.flatten(lv1) ## <= D1 + lv3: R.Shape = (n * m,) ## <= D1 + gv0 = R.call_tir(tir_exp_func, [lv2], lv3, dtype="float32") ## <= D0 + R.outputs(gv0) + + R.call_packed("custom_inplace_update", gv0) ## <= D0, D2 + return gv0 + +# Print IRModule with syntax highlighting +MyIRModule.show() + +# Build the Relax IRModule +target = tvm.target.Target("llvm") +exec = relax.vm.build(MyIRModule, target) + +# Dump the VM executable instructions as text +print(ex.as_text()) + +# Run the function on Relax VM runtime +vm = relax.VirtualMachine(exec, tvm.cpu()) +shape = (2, 3) +data = tvm.nd.array(np.random.rand(*shape).astype(np.float32)) +res = vm["relax_func"](data) +``` + +## D0: ****Unified abstractions and optimizations across layers**** + +The first key design point is to allow the high-level graph IR to be able to directly interact and call into lower-level TensorIR and PackedFunc (TVM FFI). + +The TensorIR PrimFunc and many external libraries adopt a **destination-passing-style** (DPS) calling convention that both input and output are passed to the function as arguments, and the outputs are mutated directly inside the function: + +```python +def low_level_func(input0, input1, ..., output): + # implementations +``` + +The main idea of DPS is that input and output are explicitly allocated outside and passed to the low-level primitive function. This style is commonly used in low-level library designs (for example TensorRT), so that higher-level frameworks (for example, the compiler) can handle memory allocation. + +### ****call_tir**** + +In Relax, we introduce `call_tir` to bridge graph-level IR and TIR. `call_tir` is an intrinsic that calls a TIR PrimFunc (that follows DPS) and returns the output. The semantics of `call_tir` can be demonstrated by the code below. Review Comment: `call_tir` is not only designed for dynamic shape support. It enables optimization/transformations of an IRModule for both GraphIR and TensorIR. We do support having Relay and TIR in the same IRModule, but we can not optimize them together. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org