This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 6587faaec0 [Unity][Relax] Generalize CSE to work outside
DataflowBlocks (#15047)
6587faaec0 is described below
commit 6587faaec0c0c7ba20eca4c7146b4572bbcda6df
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Thu Jun 22 01:50:03 2023 -0400
[Unity][Relax] Generalize CSE to work outside DataflowBlocks (#15047)
* No sense deduplicating lone PrimValues
* Don't deduplicate StringImms either
* Generalize CSE beyond dataflow blocks
* Update doc comments for CSE pass as well
* Whitespace fix
* Return FunctionPass, not DataflowBlockPass!
---
include/tvm/relax/transform.h | 5 +-
python/tvm/relax/transform/transform.py | 7 +-
src/relax/transform/eliminate_common_subexpr.cc | 118 ++++++++++++++----------
tests/python/relax/test_transform_cse.py | 63 ++++++++++++-
4 files changed, 131 insertions(+), 62 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 6f9841ba7a..b618672292 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -164,11 +164,10 @@ TVM_DLL Pass Normalize();
TVM_DLL Pass CanonicalizeBindings();
/*!
- * Eliminate common subexpressions within dataflow blocks.
+ * Eliminate common subexpressions within functions.
* \return The pass that eliminates common subexpressions.
*
- * \note For functions local to dataflow blocks, this pass performs
- * CSE *within* those functions.
+ * \note For nested functions, this pass performs CSE *within* those functions.
* \param call_only If true, enable eliminating only call nodes.
*/
TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 7d390ed1f9..d8d11a50d8 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -290,11 +290,10 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
return _ffi_api.CanonicalizeBindings() # type: ignore
-def EliminateCommonSubexpr(call_only=False) -> DataflowBlockPass:
- """Eliminate common subexpressions within dataflow blocks.
+def EliminateCommonSubexpr(call_only=False) -> FunctionPass:
+ """Eliminate common subexpressions within functions.
- Note: For functions local to dataflow blocks, this pass performs
- CSE *within* those functions
+ Note: For nested functions, this pass performs CSE *within* those functions
Parameters
----------
diff --git a/src/relax/transform/eliminate_common_subexpr.cc
b/src/relax/transform/eliminate_common_subexpr.cc
index 3087c409ac..3452b6352b 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -22,14 +22,56 @@
* \file tvm/relax/transform/eliminate_common_subexpr.cc
* \brief Eliminrate common subexpression pass.
*
- * Currently it removes common subexpressions within a DataflowBlock.
+ * Currently it removes common subexpressions within a Function.
*/
#include <tvm/relax/expr_functor.h>
#include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
namespace tvm {
namespace relax {
+// Checks if a given expression contains an impure subexpression
+// Caches the results of checks to avoid revisiting subexpressions
+class ImpurityDetector : public ExprVisitor {
+ public:
+ bool Detect(const Expr& expr) {
+ impure_found_ = false;
+ VisitExpr(expr);
+ return impure_found_;
+ }
+
+ void VisitExpr(const Expr& expr) {
+ // already checked: do not revisit
+ if (purity_map_.count(expr)) {
+ impure_found_ = impure_found_ || !purity_map_.at(expr);
+ return;
+ }
+
+ // in principle, we could stop checking once we find an impurity,
+ // but not doing so lets us fully populate the cache
+
+ // store the previous state so we could assess the purity of this
subexpression alone
+ bool prev_state = impure_found_;
+ impure_found_ = false;
+ ExprVisitor::VisitExpr(expr);
+ // if impure_found_ remains false, then the expression is pure
+ purity_map_[expr] = !impure_found_;
+ impure_found_ = prev_state || impure_found_;
+ }
+
+ void VisitExpr_(const CallNode* call) {
+ // the only possible impurities can come from call nodes
+ bool is_impure = IsImpureCall(GetRef<Call>(call));
+ impure_found_ = impure_found_ || is_impure;
+ ExprVisitor::VisitExpr_(call);
+ }
+
+ private:
+ bool impure_found_ = false;
+ std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
+};
+
class SubexprCounter : public ExprVisitor {
public:
// overriding VisitExpr ensures we do this for every subexpression
@@ -37,15 +79,21 @@ class SubexprCounter : public ExprVisitor {
// Cases we ignore because we will not substitute them:
// 1. Vars of all kinds
// 2. Op nodes (nothing we can do)
- // 3. Scalar constants (not much benefit from binding to a var)
+ // 3. PrimValue nodes (not much benefit from binding to a var)
+ // 4. StringImm nodes (not much benefit from binding to a var)
+ // 5. Scalar constants (not much benefit from binding to a var)
if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+ e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
(e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
- int count = 0;
- if (count_map_.count(e)) {
- count = count_map_.at(e);
+ // also if e has an impure subexpression, we will not deduplicate it
+ if (!impurity_detector_.Detect(e)) {
+ int count = 0;
+ if (count_map_.count(e)) {
+ count = count_map_.at(e);
+ }
+ count_map_[e] = count + 1;
}
- count_map_[e] = count + 1;
}
ExprVisitor::VisitExpr(e);
}
@@ -56,20 +104,18 @@ class SubexprCounter : public ExprVisitor {
// we are not going to do replacements inside struct info to avoid binding
lots of reused shapes
void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
- std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
- const DataflowBlock& df_block) {
- for (auto binding : df_block->bindings) {
- VisitBinding(binding);
- }
+ std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const
Function& func) {
+ VisitExpr(func->body);
return count_map_;
}
private:
std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+ ImpurityDetector impurity_detector_;
};
// forward declaration
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock&, bool call_only);
+Function EliminateCommonSubexpr(const Function&, bool call_only);
class CommonSubexprEliminator : public ExprMutator {
public:
@@ -104,37 +150,8 @@ class CommonSubexprEliminator : public ExprMutator {
}
Expr VisitExpr_(const FunctionNode* func) override {
- // for an inner function, we will do CSE on its body
- Expr new_body = ExprMutator::VisitExpr(func->body);
- if (new_body.same_as(func->body)) {
- return GetRef<Expr>(func);
- }
- return Function(func->params, new_body, func->ret_struct_info,
func->is_pure, func->attrs,
- func->span);
- }
-
- // this should happen only for the inner function case
- Expr VisitExpr_(const SeqExprNode* seq) override {
- bool all_unchanged = true;
- Array<BindingBlock> new_blocks;
- // apply CSE within dataflow blocks only
- for (auto block : seq->blocks) {
- if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
- auto new_df_block =
EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block), call_only_);
- if (!new_df_block.same_as(block)) {
- new_blocks.push_back(new_df_block);
- all_unchanged = false;
- continue;
- }
- }
- new_blocks.push_back(block);
- }
-
- if (all_unchanged) {
- return GetRef<Expr>(seq);
- }
- // do not visit the body
- return SeqExpr(new_blocks, seq->body, seq->span);
+ // do full CSE within the function
+ return EliminateCommonSubexpr(GetRef<Function>(func), call_only_);
}
void VisitBinding_(const VarBindingNode* binding) override {
@@ -189,21 +206,22 @@ class CommonSubexprEliminator : public ExprMutator {
bool call_only_{false};
};
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, bool
call_only) {
+Function EliminateCommonSubexpr(const Function& func, bool call_only) {
SubexprCounter counter;
- auto count_map = counter.Count(df_block);
+ auto count_map = counter.Count(func);
CommonSubexprEliminator eliminator(count_map, call_only);
- return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
+ return Function(func->params, eliminator.VisitExpr(func->body),
func->ret_struct_info,
+ func->is_pure, func->attrs, func->span);
}
namespace transform {
Pass EliminateCommonSubexpr(bool call_only) {
- runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func =
- [=](DataflowBlock df_block, IRModule m, PassContext pc) {
- return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block,
call_only));
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function func, IRModule m, PassContext pc) {
+ return Downcast<Function>(EliminateCommonSubexpr(func, call_only));
};
- return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
+ return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {});
}
TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr")
diff --git a/tests/python/relax/test_transform_cse.py
b/tests/python/relax/test_transform_cse.py
index 94897c1eae..89421ba07c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -124,8 +124,6 @@ def test_inner_function():
# we are going to do CSE inside the local function
@R.function
def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((),
dtype="int32"):
- # not in dataflow: should not be touched
- z = R.add(R.add(y, y), R.add(y, y))
with R.dataflow():
# writing this out in ANF to illustrate why CSE
behaves as it does
# result of ANF transforming R.add(R.add(y, y),
R.add(y, y))
@@ -134,7 +132,7 @@ def test_inner_function():
lv2 = R.add(lv0, lv1)
gv = lv2
R.output(gv)
- return R.add(z, gv)
+ return R.add(gv, gv)
# also making the ANF explicit to better illustrate the result
of CSE
# result of ANF transforming R.add(R.add(bar(x), bar(x)),
R.add(bar(x), bar(x)))
@@ -157,14 +155,13 @@ def test_inner_function():
@R.function
def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((),
dtype="int32"):
- z = R.add(R.add(y, y), R.add(y, y))
with R.dataflow():
lv0 = R.add(y, y)
lv1 = lv0
lv2 = R.add(lv0, lv1)
gv = lv2
R.output(gv)
- return R.add(z, gv)
+ return R.add(gv, gv)
# can further clean this up
# using canonicalize bindings, eliminate unused bindings, and
CSE again
@@ -210,5 +207,61 @@ def test_call_only():
verify(Before, Expected, call_only=True)
+def test_cse_outside_dataflow():
+ # same example as previously but it will work without a dataflow wrapper
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
+ lv0 = R.add(x, y)
+ lv1 = R.add(x, y)
+ gv = R.multiply(lv0, lv1)
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
+ lv0 = R.add(x, y)
+ lv1 = lv0
+ gv = R.multiply(lv0, lv1)
+ return gv
+
+ verify(Before, Expected)
+
+
+def test_do_not_eliminate_impure():
+ @I.ir_module
+ class Before:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
+ R.is_impure()
+ # it's a repeated subexpression but it would be wrong to
deduplicate it
+ p1 = R.print(format="Message")
+ p2 = R.print(format="Message")
+ a1 = R.assert_op(R.const(False), format="Always fails")
+ lv0 = R.add(x, y)
+ lv1 = R.add(x, y)
+ gv = R.multiply(lv0, lv1)
+ a2 = R.assert_op(R.const(False), format="Always fails")
+ return gv
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3),
dtype="float32")):
+ R.is_impure()
+ p1 = R.print(format="Message")
+ p2 = R.print(format="Message")
+ a1 = R.assert_op(R.const(False), format="Always fails")
+ lv0 = R.add(x, y)
+ lv1 = lv0
+ gv = R.multiply(lv0, lv1)
+ a2 = R.assert_op(R.const(False), format="Always fails")
+ return gv
+
+ verify(Before, Expected)
+
+
if __name__ == "__main__":
tvm.testing.main()