This is an automated email from the ASF dual-hosted git repository.
yongwww pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 6a2459f185 [Unity][Doc] Document passes that depend on
`DataflowBlock`s and encourage using `ConvertToDataflow` (#16514)
6a2459f185 is described below
commit 6a2459f185d09e33abfbe67ad4b2512610f2868b
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Tue Feb 6 11:00:10 2024 -0500
[Unity][Doc] Document passes that depend on `DataflowBlock`s and encourage
using `ConvertToDataflow` (#16514)
* Indicate in doc comments which passes need dataflow blocks
* Also encourage users to use ConvertToDataflow
* Whitespace
---
include/tvm/relax/transform.h | 21 +++++++++++++++++----
python/tvm/relax/transform/transform.py | 25 ++++++++++++++++++++-----
src/relax/transform/dead_code_elimination.cc | 8 +++++---
3 files changed, 42 insertions(+), 12 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index efe30e5cbb..027fd6f824 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -214,7 +214,9 @@ TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr>
binding_map,
Optional<String> func_name = NullOpt);
/*!
- * \brief Fold constant expressions.
+ * \brief Fold constant expressions within dataflow blocks.
+ *
+ * \note ConvertToDataflow may need to be called first to provide dataflow
blocks.
*
* \return The Pass.
*/
@@ -458,6 +460,8 @@ class PatternCheckContext : public ObjectRef {
* of the return value as the target. If it is not specified, the first return
value will be the
* target.
* \return The Pass.
+ *
+ * \note ConvertToDataflow may need to be called first to provide dataflow
blocks.
*/
TVM_DLL Pass Gradient(String func_name, Optional<Array<Var>> require_grads =
NullOpt,
int target_index = 0);
@@ -477,6 +481,8 @@ TVM_DLL Pass Gradient(String func_name,
Optional<Array<Var>> require_grads = Nul
* This must be True if the created composite functions are intended to be
offloaded to
* an external backend without using the MergeCompositeFunctions pass.
* \return The Pass.
+ *
+ * \note Only operates within dataflow blocks. ConvertToDataflow may need to
be called first.
*/
TVM_DLL Pass FuseOpsByPattern(const tvm::Array<FusionPattern>& patterns, bool
bind_constants = true,
bool annotate_codegen = false);
@@ -548,6 +554,7 @@ TVM_DLL Pass AlterOpImpl(const Map<String, tir::PrimFunc>&
op_impl_map,
* \brief Layout conversion pass.
* \param desired_layouts The desired layouts for some operators.
* \return The Pass.
+ * \note Operates only on dataflow blocks. ConvertToDataflow may need to be
called first.
*/
TVM_DLL Pass ConvertLayout(Map<String, Array<String>> desired_layouts);
@@ -564,10 +571,13 @@ TVM_DLL Pass ConvertToDataflow(int min_size = 2);
* \brief Dead code elimination.
* \sa RemoveAllUnused
* Currently it removes:
- * 1. Unused local VarBindings in a DataflowBlock.
- * 2. Unused DataflowBlocks in a function.
- * 3. Unused Relax functions in the module.
+ * 1. Unused local VarBindings
+ * (those where the bound var is unused and no impure operation is used).
+ * 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all
unused functions.
+ *
+ * Any binding blocks that are left empty will be removed by the normalizer.
+ *
* \return The Pass.
*/
TVM_DLL Pass DeadCodeElimination(Array<runtime::String> entry_functions);
@@ -578,6 +588,7 @@ TVM_DLL Pass DeadCodeElimination(Array<runtime::String>
entry_functions);
* Supported operators will be replaced by calls to `call_tir_inplace` that
invoke in-place
* PrimFunc implementations of those operators (which are based on the
legalizations of those
* operators).
+ * \note ConvertToDataflow may need to be called first to provide dataflow
blocks.
* \return The pass.
*/
TVM_DLL Pass DataflowUseInplaceCalls();
@@ -589,6 +600,8 @@ TVM_DLL Pass DataflowUseInplaceCalls();
* \param fp16_input_names The names of function parameters whose dtype should
become fp16. The
* function signature would change accordingly.
* \return The Pass.
+ *
+ * \note Mainly operates within dataflow blocks. ConvertToDataflow may need to
be called first.
*/
TVM_DLL Pass ToMixedPrecision(const DataType& out_dtype,
Optional<Array<String>> fp16_input_names =
NullOpt);
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index e360c09392..b2aaa3e331 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -52,7 +52,7 @@ def Gradient(
"""Reverse-mode automatic differentiation.
This pass will differentiate one function in the IRModule. Now the input
function must have only
- one dataflow block.
+ one dataflow block (ConvertToDataflow may need to be called first).
For a given function specified by `func_name`, it generates a new function
with the name
`func_name + "_adjoint"`. The new function computes the gradient of the
**differentiation
@@ -260,6 +260,8 @@ def DataflowUseInplaceCalls() -> tvm.ir.transform.Pass:
in-place PrimFunc implementations of those operators (which are based on
the legalizations of
those operators).
+ Note: ConvertToDataflow may need to be called first to provide dataflow
blocks.
+
Returns
-------
ret: tvm.ir.transform.Pass
@@ -282,6 +284,8 @@ def ConvertToDataflow(min_size: int = 2) ->
tvm.ir.transform.Pass:
"""A pass that converts consecutive dataflow operations
inside binding blocks into dataflow blocks.
+ Note: ConvertToDataflow may need to be called first.
+
Params
------
min_size: int
@@ -395,6 +399,8 @@ def RewriteDataflowReshape() -> tvm.ir.transform.Pass:
operation at runtime, instead of doing real data copy.
Here "reshape-like" includes reshape, expand_dims, flatten, etc.
+ Note: Operates only in dataflow blocks. ConvertToDataflow may need to be
called first.
+
Returns
-------
ret : tvm.ir.transform.Pass
@@ -584,7 +590,9 @@ def RunCodegen(
def FoldConstant() -> tvm.ir.transform.Pass:
- """Fold constant expressions.
+ """Fold constant expressions within dataflow blocks.
+
+ Note: ConvertToDataflow may need to be called first to provide dataflow
blocks.
Returns
-------
@@ -651,6 +659,8 @@ def FuseOps(fuse_opt_level=-1) -> tvm.ir.transform.Pass:
A follow-up pass named "FuseTIR" will generate a TIR PrimFunc for each
grouped function.
+ Note: ConvertToDataflow may need to be called first to provide dataflow
blocks.
+
Parameters
----------
fuse_opt_level : int
@@ -764,6 +774,8 @@ def FuseOpsByPattern(
The end result is similar to FuseOps, but fusion is driven completely by
the provided patterns.
+ Note: Only operates within dataflow blocks. ConvertToDataflow may need to
be called first.
+
Parameters
----------
patterns : List[Union[FusionPattern, Tuple]]
@@ -1172,11 +1184,12 @@ def DeadCodeElimination(entry_functions:
Optional[List[str]] = None) -> tvm.ir.t
"""Remove dead code in the IRModule.
Currently it removes:
- 1. Unused local VarBindings in a DataflowBlock.
- 2. Unused DataflowBlocks in a function.
- 3. Unused Relax functions in the module.
+ 1. Unused local VarBindings
+ (those where the bound var is unused and no impure operation is
used).
+ 2. Unused Relax functions in the module.
We detect the call chain from the entry function, and remove all
unused functions.
+ Any binding blocks that are left empty will be removed by the normalizer.
Notes
-----
@@ -1203,6 +1216,8 @@ def ToMixedPrecision(
"""Automatic mixed precision pass. Currently the pass assumes the input
module to be fp32
only, and will automatically cast fp32 to fp16 for certain ops.
+ Note: Mainly operates within dataflow blocks. ConvertToDataflow may need
to be called first.
+
Parameters
----------
out_dtype : str
diff --git a/src/relax/transform/dead_code_elimination.cc
b/src/relax/transform/dead_code_elimination.cc
index 248e4c1c00..73f66d2ef3 100644
--- a/src/relax/transform/dead_code_elimination.cc
+++ b/src/relax/transform/dead_code_elimination.cc
@@ -24,10 +24,12 @@
* \sa tvm/relax/ir/binding_rewrite.cc
*
* Currently it removes:
- * 1. Unused local VarBindings in a DataflowBlock.
- * 2. Unused DataflowBlocks in a function.
- * 3. Unused Relax functions in the module.
+ * 1. Unused local VarBindings
+ * (those where the bound var is unused and no impure operation is used).
+ * 2. Unused Relax functions in the module.
* We detect the call chain from the entry function, and remove all
unused functions.
+ *
+ * Any binding blocks that are left empty will be removed by the normalizer.
*/
#include <tvm/relax/analysis.h>