Lunderberg commented on code in PR #15694:
URL: https://github.com/apache/tvm/pull/15694#discussion_r1321628770


##########
src/relax/ir/dataflow_matcher.cc:
##########
@@ -443,6 +444,92 @@ bool DFPatternMatcher::VisitDFPattern_(const 
ShapePatternNode* op, const Expr& e
   return false;
 }
 
+Optional<Bool> SameShapeConstraintNode::IsConstraintSatisfied(
+    std::function<Optional<Var>(const DFPatternNode*)> match_state,
+    arith::Analyzer* analyzer) const {
+  Optional<Array<PrimExpr>> expected_shape;
+  bool all_shapes_defined = true;
+
+  // The expression that must be true in order
+  PrimExpr all_dimensions_equal = Bool(true);
+
+  for (const auto& arg : args) {
+    if (auto opt_var = match_state(arg.get())) {
+      auto var = opt_var.value();
+      auto opt_var_shape = [&]() -> Optional<Array<PrimExpr>> {
+        auto sinfo = GetStructInfo(var);
+        if (auto tensor = sinfo.as<TensorStructInfoNode>()) {
+          return tensor->GetShape();
+        } else if (auto shape_expr = sinfo.as<ShapeStructInfoNode>()) {
+          return shape_expr->values;
+        } else {
+          return NullOpt;
+        }
+      }();
+
+      if (!opt_var_shape.defined()) {
+        // The pattern has matched to something without a shape.
+        // Therefore, it cannot have the same shape as something else.
+        return Bool(false);
+      }
+      auto var_shape = opt_var_shape.value();
+
+      if (expected_shape.defined()) {
+        auto prev_shape = expected_shape.value();
+        if (prev_shape.size() == var_shape.size()) {
+          // The dimensionalities match, so build up the expression
+          // that must be true for the shapes to be equivalent.
+          for (size_t i = 0; i < prev_shape.size(); i++) {
+            all_dimensions_equal = all_dimensions_equal && (var_shape[i] == 
prev_shape[i]);
+          }
+
+        } else {
+          // The shapes have different dimensionality.  No need to
+          // perform potentially-expensive simplifications, because
+          // the dimensions do not match.
+          return Bool(false);
+        }
+
+      } else {
+        // This is the first pattern with a known match.  Store the
+        // shape so it can be compared against later shapes.
+        expected_shape = var_shape;
+      }
+
+    } else {
+      // Missing an argument, so the constraint will either return
+      // NullOpt or false at this point.  However, delay the return of
+      // NullOpt until the end of the function, because we'd rather
+      // return "false" if it possible to do so.
+      all_shapes_defined = false;
+    }
+  }
+
+  // We check for a false result first, because that can be applied
+  // even if some shapes are still unknown
+  // (e.g. SameShapeConstraint(A,B,C) can return false if A and B have
+  // different shapes, even if C is unknown).  The passing constraint
+  // is only valid if all of the shapes are known.
+  if (all_shapes_defined) {
+    // If all shapes are known and have the same dimentionality, then
+    // we just need to prove that the sizes of each dimension match.
+    // If we cannot prove it at this point, we won't get more
+    // information later.
+    return Bool(analyzer->CanProve(all_dimensions_equal));

Review Comment:
   The pattern matcher should be conservative in the case of ambiguity, as this 
may be used to enable/disable rewrite rules whose validity depends on the 
pattern being matched.  Previously, I had been thinking of the return value as 
"yes", "no", and "ask me again later".  The "ask me again later" option is no 
longer valid at this point, as no further shape information will be received.  
As a result, the shape cannot be proven equal, and the pattern match should 
fail.
   
   This would be another argument for returning `(PrimExpr condition, bool 
condition_is_sufficient)` from this function, to avoid the potential confusion 
here.  If this choice of backtracking/attempting to combine with other 
expressions/etc is made at the calling scope, this function wouldn't need to 
distinguish between these uses of `Bool(false)` and `NullOpt`, and could 
instead return `{condition, true}`.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to