This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 6587faaec0 [Unity][Relax] Generalize CSE to work outside 
DataflowBlocks (#15047)
6587faaec0 is described below

commit 6587faaec0c0c7ba20eca4c7146b4572bbcda6df
Author: Steven S. Lyubomirsky <[email protected]>
AuthorDate: Thu Jun 22 01:50:03 2023 -0400

    [Unity][Relax] Generalize CSE to work outside DataflowBlocks (#15047)
    
    * No sense deduplicating lone PrimValues
    
    * Don't deduplicate StringImms either
    
    * Generalize CSE beyond dataflow blocks
    
    * Update doc comments for CSE pass as well
    
    * Whitespace fix
    
    * Return FunctionPass, not DataflowBlockPass!
---
 include/tvm/relax/transform.h                   |   5 +-
 python/tvm/relax/transform/transform.py         |   7 +-
 src/relax/transform/eliminate_common_subexpr.cc | 118 ++++++++++++++----------
 tests/python/relax/test_transform_cse.py        |  63 ++++++++++++-
 4 files changed, 131 insertions(+), 62 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 6f9841ba7a..b618672292 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -164,11 +164,10 @@ TVM_DLL Pass Normalize();
 TVM_DLL Pass CanonicalizeBindings();
 
 /*!
- * Eliminate common subexpressions within dataflow blocks.
+ * Eliminate common subexpressions within functions.
  * \return The pass that eliminates common subexpressions.
  *
- * \note For functions local to dataflow blocks, this pass performs
- * CSE *within* those functions.
+ * \note For nested functions, this pass performs CSE *within* those functions.
  * \param call_only If true, enable eliminating only call nodes.
  */
 TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 7d390ed1f9..d8d11a50d8 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -290,11 +290,10 @@ def CanonicalizeBindings() -> tvm.ir.transform.Pass:
     return _ffi_api.CanonicalizeBindings()  # type: ignore
 
 
-def EliminateCommonSubexpr(call_only=False) -> DataflowBlockPass:
-    """Eliminate common subexpressions within dataflow blocks.
+def EliminateCommonSubexpr(call_only=False) -> FunctionPass:
+    """Eliminate common subexpressions within functions.
 
-    Note: For functions local to dataflow blocks, this pass performs
-    CSE *within* those functions
+    Note: For nested functions, this pass performs CSE *within* those functions
 
     Parameters
     ----------
diff --git a/src/relax/transform/eliminate_common_subexpr.cc 
b/src/relax/transform/eliminate_common_subexpr.cc
index 3087c409ac..3452b6352b 100644
--- a/src/relax/transform/eliminate_common_subexpr.cc
+++ b/src/relax/transform/eliminate_common_subexpr.cc
@@ -22,14 +22,56 @@
  * \file tvm/relax/transform/eliminate_common_subexpr.cc
  * \brief Eliminrate common subexpression pass.
  *
- * Currently it removes common subexpressions within a DataflowBlock.
+ * Currently it removes common subexpressions within a Function.
  */
 #include <tvm/relax/expr_functor.h>
 #include <tvm/relax/transform.h>
+#include <tvm/relax/utils.h>
 
 namespace tvm {
 namespace relax {
 
+// Checks if a given expression contains an impure subexpression
+// Caches the results of checks to avoid revisiting subexpressions
+class ImpurityDetector : public ExprVisitor {
+ public:
+  bool Detect(const Expr& expr) {
+    impure_found_ = false;
+    VisitExpr(expr);
+    return impure_found_;
+  }
+
+  void VisitExpr(const Expr& expr) {
+    // already checked: do not revisit
+    if (purity_map_.count(expr)) {
+      impure_found_ = impure_found_ || !purity_map_.at(expr);
+      return;
+    }
+
+    // in principle, we could stop checking once we find an impurity,
+    // but not doing so lets us fully populate the cache
+
+    // store the previous state so we could assess the purity of this 
subexpression alone
+    bool prev_state = impure_found_;
+    impure_found_ = false;
+    ExprVisitor::VisitExpr(expr);
+    // if impure_found_ remains false, then the expression is pure
+    purity_map_[expr] = !impure_found_;
+    impure_found_ = prev_state || impure_found_;
+  }
+
+  void VisitExpr_(const CallNode* call) {
+    // the only possible impurities can come from call nodes
+    bool is_impure = IsImpureCall(GetRef<Call>(call));
+    impure_found_ = impure_found_ || is_impure;
+    ExprVisitor::VisitExpr_(call);
+  }
+
+ private:
+  bool impure_found_ = false;
+  std::unordered_map<Expr, bool, StructuralHash, StructuralEqual> purity_map_;
+};
+
 class SubexprCounter : public ExprVisitor {
  public:
   // overriding VisitExpr ensures we do this for every subexpression
@@ -37,15 +79,21 @@ class SubexprCounter : public ExprVisitor {
     // Cases we ignore because we will not substitute them:
     // 1. Vars of all kinds
     // 2. Op nodes (nothing we can do)
-    // 3. Scalar constants (not much benefit from binding to a var)
+    // 3. PrimValue nodes (not much benefit from binding to a var)
+    // 4. StringImm nodes (not much benefit from binding to a var)
+    // 5. Scalar constants (not much benefit from binding to a var)
     if (!(e->IsInstance<VarNode>() || e->IsInstance<DataflowVarNode>() ||
           e->IsInstance<GlobalVarNode>() || e->IsInstance<tvm::OpNode>() ||
+          e->IsInstance<PrimValueNode>() || e->IsInstance<StringImmNode>() ||
           (e.as<ConstantNode>() && (e.as<ConstantNode>()->is_scalar())))) {
-      int count = 0;
-      if (count_map_.count(e)) {
-        count = count_map_.at(e);
+      // also if e has an impure subexpression, we will not deduplicate it
+      if (!impurity_detector_.Detect(e)) {
+        int count = 0;
+        if (count_map_.count(e)) {
+          count = count_map_.at(e);
+        }
+        count_map_[e] = count + 1;
       }
-      count_map_[e] = count + 1;
     }
     ExprVisitor::VisitExpr(e);
   }
@@ -56,20 +104,18 @@ class SubexprCounter : public ExprVisitor {
   // we are not going to do replacements inside struct info to avoid binding 
lots of reused shapes
   void VisitExprDepStructInfoField(const StructInfo& struct_info) override {}
 
-  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(
-      const DataflowBlock& df_block) {
-    for (auto binding : df_block->bindings) {
-      VisitBinding(binding);
-    }
+  std::unordered_map<Expr, int, StructuralHash, StructuralEqual> Count(const 
Function& func) {
+    VisitExpr(func->body);
     return count_map_;
   }
 
  private:
   std::unordered_map<Expr, int, StructuralHash, StructuralEqual> count_map_;
+  ImpurityDetector impurity_detector_;
 };
 
 // forward declaration
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock&, bool call_only);
+Function EliminateCommonSubexpr(const Function&, bool call_only);
 
 class CommonSubexprEliminator : public ExprMutator {
  public:
@@ -104,37 +150,8 @@ class CommonSubexprEliminator : public ExprMutator {
   }
 
   Expr VisitExpr_(const FunctionNode* func) override {
-    // for an inner function, we will do CSE on its body
-    Expr new_body = ExprMutator::VisitExpr(func->body);
-    if (new_body.same_as(func->body)) {
-      return GetRef<Expr>(func);
-    }
-    return Function(func->params, new_body, func->ret_struct_info, 
func->is_pure, func->attrs,
-                    func->span);
-  }
-
-  // this should happen only for the inner function case
-  Expr VisitExpr_(const SeqExprNode* seq) override {
-    bool all_unchanged = true;
-    Array<BindingBlock> new_blocks;
-    // apply CSE within dataflow blocks only
-    for (auto block : seq->blocks) {
-      if (const DataflowBlockNode* df_block = block.as<DataflowBlockNode>()) {
-        auto new_df_block = 
EliminateCommonSubexpr(GetRef<DataflowBlock>(df_block), call_only_);
-        if (!new_df_block.same_as(block)) {
-          new_blocks.push_back(new_df_block);
-          all_unchanged = false;
-          continue;
-        }
-      }
-      new_blocks.push_back(block);
-    }
-
-    if (all_unchanged) {
-      return GetRef<Expr>(seq);
-    }
-    // do not visit the body
-    return SeqExpr(new_blocks, seq->body, seq->span);
+    // do full CSE within the function
+    return EliminateCommonSubexpr(GetRef<Function>(func), call_only_);
   }
 
   void VisitBinding_(const VarBindingNode* binding) override {
@@ -189,21 +206,22 @@ class CommonSubexprEliminator : public ExprMutator {
   bool call_only_{false};
 };
 
-DataflowBlock EliminateCommonSubexpr(const DataflowBlock& df_block, bool 
call_only) {
+Function EliminateCommonSubexpr(const Function& func, bool call_only) {
   SubexprCounter counter;
-  auto count_map = counter.Count(df_block);
+  auto count_map = counter.Count(func);
   CommonSubexprEliminator eliminator(count_map, call_only);
-  return Downcast<DataflowBlock>(eliminator.VisitBindingBlock(df_block));
+  return Function(func->params, eliminator.VisitExpr(func->body), 
func->ret_struct_info,
+                  func->is_pure, func->attrs, func->span);
 }
 
 namespace transform {
 
 Pass EliminateCommonSubexpr(bool call_only) {
-  runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule, 
PassContext)> pass_func =
-      [=](DataflowBlock df_block, IRModule m, PassContext pc) {
-        return Downcast<DataflowBlock>(EliminateCommonSubexpr(df_block, 
call_only));
+  runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> 
pass_func =
+      [=](Function func, IRModule m, PassContext pc) {
+        return Downcast<Function>(EliminateCommonSubexpr(func, call_only));
       };
-  return CreateDataflowBlockPass(pass_func, 1, "EliminateCommonSubexpr", {});
+  return CreateFunctionPass(pass_func, 1, "EliminateCommonSubexpr", {});
 }
 
 TVM_REGISTER_GLOBAL("relax.transform.EliminateCommonSubexpr")
diff --git a/tests/python/relax/test_transform_cse.py 
b/tests/python/relax/test_transform_cse.py
index 94897c1eae..89421ba07c 100644
--- a/tests/python/relax/test_transform_cse.py
+++ b/tests/python/relax/test_transform_cse.py
@@ -124,8 +124,6 @@ def test_inner_function():
                 # we are going to do CSE inside the local function
                 @R.function
                 def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), 
dtype="int32"):
-                    # not in dataflow: should not be touched
-                    z = R.add(R.add(y, y), R.add(y, y))
                     with R.dataflow():
                         # writing this out in ANF to illustrate why CSE 
behaves as it does
                         # result of ANF transforming R.add(R.add(y, y), 
R.add(y, y))
@@ -134,7 +132,7 @@ def test_inner_function():
                         lv2 = R.add(lv0, lv1)
                         gv = lv2
                         R.output(gv)
-                    return R.add(z, gv)
+                    return R.add(gv, gv)
 
                 # also making the ANF explicit to better illustrate the result 
of CSE
                 # result of ANF transforming R.add(R.add(bar(x), bar(x)), 
R.add(bar(x), bar(x)))
@@ -157,14 +155,13 @@ def test_inner_function():
 
                 @R.function
                 def bar(y: R.Tensor((), dtype="int32")) -> R.Tensor((), 
dtype="int32"):
-                    z = R.add(R.add(y, y), R.add(y, y))
                     with R.dataflow():
                         lv0 = R.add(y, y)
                         lv1 = lv0
                         lv2 = R.add(lv0, lv1)
                         gv = lv2
                         R.output(gv)
-                    return R.add(z, gv)
+                    return R.add(gv, gv)
 
                 # can further clean this up
                 # using canonicalize bindings, eliminate unused bindings, and 
CSE again
@@ -210,5 +207,61 @@ def test_call_only():
     verify(Before, Expected, call_only=True)
 
 
+def test_cse_outside_dataflow():
+    # same example as previously but it will work without a dataflow wrapper
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")):
+            lv0 = R.add(x, y)
+            lv1 = R.add(x, y)
+            gv = R.multiply(lv0, lv1)
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")):
+            lv0 = R.add(x, y)
+            lv1 = lv0
+            gv = R.multiply(lv0, lv1)
+            return gv
+
+    verify(Before, Expected)
+
+
+def test_do_not_eliminate_impure():
+    @I.ir_module
+    class Before:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")):
+            R.is_impure()
+            # it's a repeated subexpression but it would be wrong to 
deduplicate it
+            p1 = R.print(format="Message")
+            p2 = R.print(format="Message")
+            a1 = R.assert_op(R.const(False), format="Always fails")
+            lv0 = R.add(x, y)
+            lv1 = R.add(x, y)
+            gv = R.multiply(lv0, lv1)
+            a2 = R.assert_op(R.const(False), format="Always fails")
+            return gv
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def foo(x: R.Tensor((2, 3), dtype="float32"), y: R.Tensor((2, 3), 
dtype="float32")):
+            R.is_impure()
+            p1 = R.print(format="Message")
+            p2 = R.print(format="Message")
+            a1 = R.assert_op(R.const(False), format="Always fails")
+            lv0 = R.add(x, y)
+            lv1 = lv0
+            gv = R.multiply(lv0, lv1)
+            a2 = R.assert_op(R.const(False), format="Always fails")
+            return gv
+
+    verify(Before, Expected)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to