This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit de321945959f09e4db20106e99e71c1ea448e55a
Author: Tianqi Chen <tqc...@users.noreply.github.com>
AuthorDate: Tue Feb 14 16:58:14 2023 -0500

    [Unity] NestedMsg Support utility (#13995)
    
    This PR introduce NestedMsg to robustly handle nested-tuple analysis.
    
    Relax support nested tuple structures in the IR.
    Nested tuple structure is important to support advanced groupings in
    cases such as gradient calculation and other scenarios.
    
    The possible presence of nested tuple does mean that we need to to
    robustly handle analysis that contains nested tuple structures in a 
dataflow graph.
    
    This PR introduces a NestedMsg<T> class that corresponds to a possibly
    nested message tuple for a given leaf message class T.
    We also introduces various helper functions to compose and decompose 
messages.
    
    Co-authored-by: Bohan Hou 
<32121147+spectrometer...@users.noreply.github.com>
    Co-authored-by: Yixin Dong <ubosp...@gmail.com>
    Co-authored-by: Ruihang Lai <ruiha...@cs.cmu.edu>
---
 include/tvm/relax/nested_msg.h | 536 +++++++++++++++++++++++++++++++++++++++++
 tests/cpp/nested_msg_test.cc   | 318 ++++++++++++++++++++++++
 2 files changed, 854 insertions(+)

diff --git a/include/tvm/relax/nested_msg.h b/include/tvm/relax/nested_msg.h
new file mode 100644
index 0000000000..93fc9a36c5
--- /dev/null
+++ b/include/tvm/relax/nested_msg.h
@@ -0,0 +1,536 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/nested_msg.h
+ * \brief Helper container to store nested message for robust tuple-aware 
analysis.
+ *
+ * Please see NestedMsg for description of usage.
+ *
+ * \sa NestedMsg
+ */
+#ifndef TVM_RELAX_NESTED_MSG_H_
+#define TVM_RELAX_NESTED_MSG_H_
+
+#include <tvm/relax/expr.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/container/array.h>
+#include <tvm/runtime/container/optional.h>
+
+#include <utility>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Container that stores possibly nested message with leaf message type 
T.
+ *
+ * NestedMsg is a helper structure to store intermediate
+ * message state in pass analysis so we can robustly handle message
+ * passing with the presence of nested tuple types.
+ *
+ * Under the hood, NestedMsg[T] = Union[T, NullOpt, Array[NestedMsg[T]]].
+ * Each nested message corresponds to the same nesting structure as
+ * the nested tuple types when we encounter them in analysis.
+ *
+ * Relax support nested tuple structures in the IR. Nested tuple structure
+ * is important to support advanced groupings in cases such as gradient 
calculation
+ * and other scenarios.
+ *
+ * The possible presence of nested tuple does mean that we need to
+ * to robustly handle analysis that contains nested tuple structures
+ * in a dataflow graph.
+ *
+ * \code
+ *
+ * v1 = relu(v0)
+ * v2 = exp(v0)
+ * t = ((v0, v1), (v2,), v0)
+ * t1 = t[0]
+ * v3 = concat(t1)
+ * v4 = t[2]
+ * v5 = add(v4, v3)
+ *
+ * \endcode
+ *
+ * Consider the above code sequence that contains a mixture of tuple
+ * nesting and normal operations. A common message-passing-based analysis
+ * will track messages attached to each intermediate variable.
+ *
+ * Because the intermediate value can contain nested-tuples, we need to have
+ * abilities to nest messages according to tuple structure and propagate them
+ * along the way. In python, this simply corresponds to using a tuple to hold
+ * nested messages. This class provides a helper wrapper in C++ to present such
+ * possibly nested message for a given leaf message.
+ *
+ * This design pattern is necessary to handle tuple values regardless of
+ * the normal form design of the IR to enable different messages for each
+ * tuple component without enforcing all tuple elements to have the same 
message.
+ *
+ * Please consider the following patterns in our pass:
+ *
+ * On a forward propagation message passing analysis:
+ * - Create map [leafnode=>NestedMsg<T>], scan forward
+ * - input_msg = [MapToNestedMsg<T>(x, lookup_map) for x in call->args]
+ * - output_msg = ForwardProp[call->op](input_msg, call)
+ * - map[binding->var] = output_msg
+ * - Use MapToNestedMsg to remap the remaining body.
+ *
+ * On a backward propagation message passing analysis:
+ * - Create map [leafnode=>NestedMsg<T>], scan backward
+ * - output_msg = lookup map(binding->var)
+ * - handle case when output_msg is null
+ * - input_msg = BackProp[call->op](out_msg, call)
+ * - for arg, msg in zip(call->args, input_msg),
+ *     DecomposeNestedMessage(arg, msg, lambda node, m: update_map(node, m))
+ * - update_map(node, m) => CombineNestedMessage(map[node], m)
+ *
+ * Here leafnode is a node that you would like to propagate messages to
+ * such as constant, var and should not include tuple.
+ *
+ * We also recommend writing unit-test cases that involve nested tuple 
composition
+ * and decomposition.
+ *
+ * \sa MapToNestedMsg, DecomposeNestedMsg, CombineNestedMsg, ForEachLeaf, Equal
+ *
+ * \note If you want to write robust message passing-based analysis for
+ *       programs that can contain nested tuples, you likely need to
+ *       use this class or logic of a similar kind.
+ */
+template <typename T>
+class NestedMsg : public ObjectRef {
+ public:
+  // default constructors.
+  NestedMsg() = default;
+  NestedMsg(const NestedMsg<T>&) = default;
+  NestedMsg(NestedMsg<T>&&) = default;
+  NestedMsg<T>& operator=(const NestedMsg<T>&) = default;
+  NestedMsg<T>& operator=(NestedMsg<T>&&) = default;
+  /*!
+   * \brief Construct from an ObjectPtr
+   *        whose type already satisfies the constraint
+   * \param ptr
+   */
+  explicit NestedMsg(ObjectPtr<Object> ptr) : ObjectRef(ptr) {}
+  /*! \brief Nullopt handling */
+  NestedMsg(runtime::NullOptType) {}  // NOLINT(*)
+  // nullptr handling.
+  // disallow implicit conversion as 0 can be implicitly converted to nullptr_t
+  explicit NestedMsg(std::nullptr_t) {}
+  NestedMsg<T>& operator=(std::nullptr_t) {
+    data_ = nullptr;
+    return *this;
+  }
+  // normal value handling.
+  NestedMsg(T other)  // NOLINT(*)
+      : ObjectRef(std::move(other)) {}
+  NestedMsg<T>& operator=(T other) {
+    ObjectRef::operator=(std::move(other));
+    return *this;
+  }
+  // Array<NestedMsg<T>> handling
+  NestedMsg(Array<NestedMsg<T>, void> other)  // NOLINT(*)
+      : ObjectRef(std::move(other)) {}
+  NestedMsg<T>& operator=(Array<NestedMsg<T>, void> other) {
+    ObjectRef::operator=(std::move(other));
+    return *this;
+  }
+
+  // initializer list handling
+  NestedMsg(std::initializer_list<NestedMsg<T>> other)  // NOLINT(*)
+      : NestedMsg(Array<NestedMsg<T>, void>(other)) {}
+  NestedMsg<T>& operator=(std::initializer_list<NestedMsg<T>> other) {
+    return operator=(Array<NestedMsg<T>, void>(other));
+  }
+
+  // delete the int constructor
+  // since NestedMsg<Integer>(0) is ambiguous
+  // 0 can be implicitly casted to nullptr_t
+  explicit NestedMsg(int val) = delete;
+  NestedMsg<T>& operator=(int val) = delete;
+  // operator overloadings
+  bool operator==(std::nullptr_t) const { return data_ == nullptr; }
+  bool operator!=(std::nullptr_t) const { return data_ != nullptr; }
+
+  /*! \return Whether the nested message is not-null leaf value */
+  bool IsLeaf() const { return data_ != nullptr && 
data_->IsInstance<LeafContainerType>(); }
+
+  /*! \return Whether the nested message is null */
+  bool IsNull() const { return data_ == nullptr; }
+
+  /*! \return Whether the nested message is nested */
+  bool IsNested() const { return data_ != nullptr && 
data_->IsInstance<ArrayNode>(); }
+
+  /*!
+   * \return The underlying leaf value.
+   * \note This function checks if the msg is leaf.
+   */
+  T LeafValue() const {
+    ICHECK(IsLeaf());
+    return T(data_);
+  }
+
+  /*!
+   * \return a corresponding nested array.
+   * \note This checks if the underlying data type is array.
+   */
+  Array<NestedMsg<T>, void> NestedArray() const {
+    ICHECK(IsNested());
+    return Array<NestedMsg<T>, void>(data_);
+  }
+
+  using ContainerType = Object;
+  using LeafContainerType = typename T::ContainerType;
+
+  static_assert(std::is_base_of<ObjectRef, T>::value, "NestedMsg is only 
defined for ObjectRef.");
+
+  static constexpr bool _type_is_nullable = true;
+};
+
+/*!
+ * \brief Apply fvisit for each leaf elements in the nested message.
+ * \param fvisit The visit callback.
+ * \param msg The input nested message.
+ * \tparam T the content type of nested msg
+ * \tparam FType the visitor type with signature void fvisit(T)
+ */
+template <typename T, typename FType>
+void ForEachLeaf(const NestedMsg<T>& msg, FType fvisit) {
+  if (msg == nullptr) return;
+  if (msg.IsLeaf()) {
+    fvisit(msg.LeafValue());
+  } else {
+    for (NestedMsg<T> x : msg.NestedArray()) {
+      ForEachLeaf(x, fvisit);
+    }
+  }
+}
+
+/*!
+ * \brief Recursively compare two nested messages.
+ *
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \param fequal The equal functor with signature bool fequal(T, T)
+ * \tparam T the content type of nested msg
+ * \tparam FType the equal comparator type
+ */
+template <typename T, typename FType>
+bool Equal(const NestedMsg<T>& lhs, const NestedMsg<T>& rhs, FType fequal) {
+  if (lhs.IsNull()) return rhs.IsNull();
+  if (rhs.IsNull()) return lhs.IsNull();
+  if (lhs.IsLeaf()) {
+    return rhs.IsLeaf() && fequal(lhs.LeafValue(), rhs.LeafValue());
+  } else {
+    if (!rhs.IsNested()) return false;
+    Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
+    Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
+    if (arr_lhs.size() != arr_rhs.size()) return false;
+    for (size_t i = 0; i < arr_lhs.size(); ++i) {
+      if (!Equal(arr_lhs[i], arr_rhs[i], fequal)) return false;
+    }
+    return true;
+  }
+}
+
+/*!
+ * \brief Map expr with possible nested-tuple to nested message.
+ *
+ * This function will unpack recursive tuples and run fmapleaf for each leaf,
+ * then recursively combines the results together into a NestedMsg.
+ *
+ * The nesting structure will corresponds to the tuple structure.
+ *
+ * \param expr The input expression.
+ * \param fmapleaf The mapping function for each leaf with signature 
`NestedMsg<T> fmap(Expr)`
+ * \tparam T the content type of nested msg
+ * \tparam FType The mapping function type
+ */
+template <typename T, typename FType>
+NestedMsg<T> MapToNestedMsg(Expr expr, FType fmapleaf) {
+  if (auto* tuple = expr.as<TupleNode>()) {
+    Array<NestedMsg<T>> res;
+    res.reserve(tuple->fields.size());
+    for (Expr x : tuple->fields) {
+      res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
+    }
+    return res;
+  } else {
+    return fmapleaf(expr);
+  }
+}
+
+/*!
+ * \brief Map structinfo with possible nested-sinfo to nested message.
+ *
+ * This function will unpack recursive sinfo and run fmapleaf for each leaf,
+ * then recursively combines the results together into a NestedMsg.
+ *
+ * The nesting structure will corresponds to the tuple structure.
+ *
+ * \param sinfo The input struct info.
+ * \param fmapleaf The mapping function for each leaf with signature 
`NestedMsg<T> fmap(StructInfo)`
+ * \tparam T the content type of nested msg
+ * \tparam FType The mapping function type
+ */
+template <typename T, typename FType>
+NestedMsg<T> MapToNestedMsg(StructInfo sinfo, FType fmapleaf) {
+  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+    Array<NestedMsg<T>> res;
+    res.reserve(tuple->fields.size());
+    for (StructInfo x : tuple->fields) {
+      res.push_back(MapToNestedMsg<T, FType>(x, fmapleaf));
+    }
+    return res;
+  } else {
+    return fmapleaf(sinfo);
+  }
+}
+
+/*!
+ * \brief Map expr with possible nested-tuple to nested message.
+ *
+ * This function will unpack recursive expr by its struct info and
+ * run fmapleaf for each leaf, then recursively combines the results
+ * together into a NestedMsg.
+ *
+ * The nesting structure will corresponds to the struct info of expr.
+ *
+ * \param expr The input expression which should have struct info.
+ * \param fmapleaf The mapping function for each leaf with signature 
`NestedMsg<T> fmapleaf(Expr)`
+ * \tparam T the content type of nested msg
+ * \tparam FType The mapping function type
+ */
+template <typename T, typename FType>
+NestedMsg<T> MapToNestedMsgBySInfo(Expr expr, FType fmapleaf) {
+  auto sinfo = GetStructInfo(expr);
+  if (auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+    Array<NestedMsg<T>> res;
+    res.reserve(tuple->fields.size());
+    for (size_t i = 0; i < tuple->fields.size(); ++i) {
+      Expr field;
+      if (const auto* expr_tuple = expr.as<TupleNode>()) {
+        field = expr_tuple->fields[i];
+      } else {
+        field = TupleGetItem(expr, i);
+        UpdateStructInfo(field, tuple->fields[i]);
+      }
+      res.push_back(MapToNestedMsgBySInfo<T, FType>(field, fmapleaf));
+    }
+    return res;
+  } else {
+    return fmapleaf(expr);
+  }
+}
+
+/*!
+ * \brief Map nested message back to the expr.
+ *
+ * This function will decompose the nested message and
+ * run fmapleaf for each leaf message and get the leaf expr,
+ * then recursively combines the results as tuple expr.
+ *
+ * \param msg The input nested message.
+ * \param fmapleaf The mapping function for each leaf with signature `Expr 
fmapleaf(Optional<T>)`.
+ * \tparam T the content type of nested msg.
+ * \tparam FType The mapping function type.
+ */
+template <typename T, typename FType>
+Expr NestedMsgToExpr(NestedMsg<T> msg, FType fmapleaf) {
+  if (msg.IsNull()) {
+    return fmapleaf(NullOpt);
+  } else if (msg.IsLeaf()) {
+    return fmapleaf(msg.LeafValue());
+  } else {
+    ICHECK(msg.IsNested());
+    Array<NestedMsg<T>> arr = msg.NestedArray();
+    Array<Expr> subexpr;
+    subexpr.reserve(arr.size());
+    for (size_t i = 0; i < arr.size(); ++i) {
+      subexpr.push_back(NestedMsgToExpr<T, FType>(arr[i], fmapleaf));
+    }
+    Optional<Expr> simplified_tuple;
+    bool simplified_flag = false;
+    if (subexpr.size() >= 1) {
+      simplified_flag = true;
+      for (size_t i = 0; i < subexpr.size() && simplified_flag; ++i) {
+        auto* node = subexpr[i].as<TupleGetItemNode>();
+        if (node == nullptr || node->index != static_cast<int>(i)) {
+          simplified_flag = false;
+        } else {
+          if (simplified_tuple.defined()) {
+            simplified_flag &= (simplified_tuple == node->tuple);
+          } else {
+            simplified_tuple = node->tuple;
+            ICHECK(simplified_tuple.defined());
+          }
+        }
+      }
+    }
+    return simplified_flag ? simplified_tuple.value() : Tuple(subexpr);
+  }
+}
+
+/*!
+ * \brief Recursively combine two nested message into one.
+ *
+ * This function requires the two messages to be compatible with each other.
+ * The combination rule is as follows:
+ * - combine(null, msg) => msg
+ * - combine(leaf1, leaf2) => fcombine(leaf1, leaf2)
+ * - combine(array1, array2) => [combine(x, y) for x, y in zip(array1, array2)]
+ * - This function will throw an error if array have different size
+ *
+ * \param lhs The left operand.
+ * \param rhs The right operand.
+ * \param fcombine with signature T fcombine(T lhs, T rhs)
+ * \tparam T the content type of nested msg
+ * \tparam FType combine function type.
+ */
+template <typename T, typename FType>
+NestedMsg<T> CombineNestedMsg(NestedMsg<T> lhs, NestedMsg<T> rhs, FType 
fcombine) {
+  if (lhs.IsNull()) return rhs;
+  if (rhs.IsNull()) return lhs;
+
+  if (lhs.IsLeaf()) {
+    ICHECK(rhs.IsLeaf()) << "Cannot combine leaf with nested";
+    return NestedMsg<T>(fcombine(lhs.LeafValue(), rhs.LeafValue()));
+  } else {
+    ICHECK(lhs.IsNested());
+    ICHECK(rhs.IsNested()) << "Cannot combine leaf with nested";
+    Array<NestedMsg<T>> arr_lhs = lhs.NestedArray();
+    Array<NestedMsg<T>> arr_rhs = rhs.NestedArray();
+    ICHECK_EQ(arr_lhs.size(), arr_rhs.size())
+        << "Cannot combine two nested array with different sizes";
+    Array<NestedMsg<T>> res;
+    res.reserve(arr_lhs.size());
+    for (size_t i = 0; i < arr_lhs.size(); ++i) {
+      res.push_back(CombineNestedMsg<T, FType>(arr_lhs[i], arr_rhs[i], 
fcombine));
+    }
+    return NestedMsg<T>(res);
+  }
+}
+
+/*!
+ * \brief Recursively map a nested message to another one, with leaf mapped by 
the input fmapleaf.
+ * \param msg The nested message to be mapped.
+ * \param fmapleaf The leaf map function, with signature NestedMsg<T> 
fmapleaf(T msg)
+ * \tparam T The content type of nested message.
+ * \tparam FType The leaf map function type.
+ * \return The new nested message.
+ */
+template <typename T, typename FType>
+NestedMsg<T> MapNestedMsg(NestedMsg<T> msg, FType fmapleaf) {
+  if (msg.IsNull()) {
+    return msg;
+  } else if (msg.IsLeaf()) {
+    return fmapleaf(msg.LeafValue());
+  } else {
+    ICHECK(msg.IsNested());
+    Array<NestedMsg<T>> arr = msg.NestedArray();
+    Array<NestedMsg<T>> res;
+    res.reserve(arr.size());
+    for (int i = 0; i < static_cast<int>(arr.size()); ++i) {
+      res.push_back(MapNestedMsg(arr[i], fmapleaf));
+    }
+    return NestedMsg<T>(res);
+  }
+}
+
+/*!
+ * \brief Recursively decompose the tuple structure in expr and msg along with 
it.
+ *
+ * This function will call fvisitleaf for each leaf expression in expr.
+ * This function will throw an error if the nesting structure in msg does not
+ * match the tuple nesting structure in expr.
+ *
+ * \param expr The input expression to be decomposed.
+ * \param msg The input nested message.
+ * \param fvisitleaf with signature fvisitleaf(Expr expr, NestedMsg<T> msg)
+ * \tparam T the content type of nested msg
+ * \tparam FType The visit function type.
+ */
+template <typename T, typename FType>
+void DecomposeNestedMsg(Expr expr, NestedMsg<T> msg, FType fvisitleaf) {
+  if (auto* tuple = expr.as<TupleNode>()) {
+    ICHECK(msg.IsNested()) << "Expected nested to match tuple";
+    Array<NestedMsg<T>> arr = msg.NestedArray();
+    ICHECK_EQ(arr.size(), tuple->fields.size()) << "Expected nested array size 
to match tuple size";
+    for (size_t i = 0; i < arr.size(); ++i) {
+      DecomposeNestedMsg(tuple->fields[i], arr[i], fvisitleaf);
+    }
+  } else {
+    fvisitleaf(expr, msg);
+  }
+}
+
+/*!
+ * \brief Recursively transform the tuple structure in expr and msgs along 
with it.
+ *
+ * This function will call ftransleaf for each leaf expression in expr.
+ * This function will throw an error if the nesting structure in msg does not
+ * match the tuple nesting structure in expr.
+ *
+ * \param expr The input expression to be transform. 
+ * \param msgs The input messages to guide the transformation.
+ * \param ftransleaf with signature ftransleaf(Expr, Array<NestedMsg<T>>)->Expr
+ * \tparam T the content type of nested msg
+ * \tparam N the number of messages
+ * \tparam FType The visit function type.
+ */
+template <typename T, std::size_t N, typename FType>
+Expr TransformTupleLeaf(Expr expr, std::array<NestedMsg<T>, N> msgs, FType 
ftransleaf) {
+  StructInfo sinfo = GetStructInfo(expr);
+  if (const auto* tuple = sinfo.as<TupleStructInfoNode>()) {
+    std::array<Array<NestedMsg<T>>, N> msg_arrays;
+    for (size_t i = 0; i < N; ++i) {
+      ICHECK(msgs[i].IsNested()) << "Expected nested to match tuple";
+      msg_arrays[i] = msgs[i].NestedArray();
+    }
+    bool same = true;
+    Array<Expr> fields;
+    fields.reserve(tuple->fields.size());
+    for (size_t i = 0; i < tuple->fields.size(); ++i) {
+      Expr field;
+      if (const auto* expr_tuple = expr.as<TupleNode>()) {
+        field = expr_tuple->fields[i];
+      } else {
+        field = TupleGetItem(expr, i);
+        UpdateStructInfo(field, tuple->fields[i]);
+      }
+      std::array<NestedMsg<T>, N> sub_msgs;
+      for (size_t j = 0; j < N; ++j) {
+        sub_msgs[j] = msg_arrays[j][i];
+      }
+      fields.push_back(TransformTupleLeaf(field, std::move(sub_msgs), 
ftransleaf));
+      same &= (fields.back().same_as(field));
+    }
+    return same ? expr : Tuple(fields);
+  } else {
+    for (const auto& msg : msgs) {
+      ICHECK(msg.IsLeaf()) << "Expected leaf to match non-tuple";
+    }
+    return ftransleaf(expr, msgs);
+  }
+}
+
+}  // namespace relax
+}  // namespace tvm
+#endif  // TVM_RELAX_NESTED_MSG_H_
diff --git a/tests/cpp/nested_msg_test.cc b/tests/cpp/nested_msg_test.cc
new file mode 100644
index 0000000000..48af552007
--- /dev/null
+++ b/tests/cpp/nested_msg_test.cc
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <dmlc/logging.h>
+#include <gtest/gtest.h>
+#include <tvm/relax/nested_msg.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/tir/expr.h>
+
+#include <algorithm>
+#include <cstring>
+#include <functional>
+#include <iterator>
+#include <new>
+#include <string>
+#include <unordered_map>
+#include <utility>
+#include <vector>
+
+using namespace tvm;
+using namespace tvm::runtime;
+using namespace tvm::relax;
+
+TEST(NestedMsg, Basic) {
+  // start with no annotation
+  relax::Var x("x", NullOpt), y("y", NullOpt);
+
+  // constructor from array, T and nullopt.
+  NestedMsg<relax::Expr> msg({x, NullOpt, x});
+
+  EXPECT_TRUE(msg.IsNested());
+  EXPECT_FALSE(msg.IsLeaf());
+  EXPECT_TRUE(msg != nullptr);
+
+  EXPECT_ANY_THROW(msg.LeafValue());
+
+  auto arr = msg.NestedArray();
+  EXPECT_TRUE(arr[0].same_as(x));
+  EXPECT_TRUE(arr[1] == nullptr);
+  EXPECT_TRUE(arr[1].IsNull());
+
+  EXPECT_TRUE(arr[2].LeafValue().same_as(x));
+
+  auto a0 = arr[0];
+  EXPECT_TRUE(a0.IsLeaf());
+
+  // assignment
+  // assign null
+  a0 = NullOpt;
+  EXPECT_TRUE(a0 == nullptr);
+
+  // assign array
+  a0 = {x, {x, NullOpt, y}};
+  EXPECT_TRUE(a0.IsNested());
+  auto t0 = a0.NestedArray()[1];
+  EXPECT_TRUE(t0.IsNested());
+  EXPECT_TRUE(t0.NestedArray()[2].same_as(y));
+
+  // assign leaf
+  a0 = x;
+
+  EXPECT_TRUE(a0.IsLeaf());
+  EXPECT_TRUE(a0.same_as(x));
+}
+
+TEST(NestedMsg, ForEachLeaf) {
+  relax::Var x("x", NullOpt), y("y", NullOpt);
+  NestedMsg<Expr> msg = {x, {x, y}, NullOpt, {x, {x, y}}};
+
+  int x_count = 0, y_count = 0;
+
+  ForEachLeaf(msg, [&](const Expr& v) {
+    if (v.same_as(x)) ++x_count;
+    if (v.same_as(y)) ++y_count;
+  });
+  EXPECT_EQ(x_count, 4);
+  EXPECT_EQ(y_count, 2);
+}
+
+TEST(NestedMsg, Equal) {
+  relax::Var x("x", NullOpt), y("y", NullOpt);
+  relax::Var z("z", NullOpt);
+
+  auto fequal = [](Expr lhs, Expr rhs) { return lhs.same_as(rhs); };
+
+  using M = NestedMsg<relax::Expr>;
+
+  EXPECT_TRUE(Equal(M(NullOpt), M(NullOpt), fequal));
+
+  EXPECT_TRUE(Equal(M(x), M(x), fequal));
+
+  EXPECT_TRUE(Equal(M({x, y}), M({x, y}), fequal));
+
+  EXPECT_TRUE(Equal(M({x, NullOpt}), M({x, NullOpt}), fequal));
+
+  EXPECT_TRUE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}}), fequal));
+
+  EXPECT_TRUE(Equal(M({x, {NullOpt, y}, {x, z}}), M({x, {NullOpt, y}, {x, 
z}}), fequal));
+
+  // type mismatch
+  EXPECT_FALSE(Equal(M({x, {NullOpt, y}, x}), M({x, {NullOpt, y}, {x, z}}), 
fequal));
+
+  EXPECT_FALSE(Equal(M({x, {NullOpt, y}, {x, NullOpt}}), M({x, {NullOpt, y}, 
{x, z}}), fequal));
+
+  EXPECT_FALSE(Equal(M({x, {NullOpt, y}}), M({x, {NullOpt, y}, {x, z}}), 
fequal));
+
+  EXPECT_FALSE(Equal(M(x), M(NullOpt), fequal));
+
+  EXPECT_FALSE(Equal(M(NullOpt), M(x), fequal));
+
+  EXPECT_FALSE(Equal(M(x), M(Array<M>({x})), fequal));
+
+  EXPECT_FALSE(Equal(M(Array<M>({x})), M(x), fequal));
+}
+
+TEST(NestedMsg, MapAndDecompose) {
+  relax::Var x("x", PrimStructInfo(runtime::DataType::Int(16)));
+  relax::Var y("y", PrimStructInfo(runtime::DataType::Int(32)));
+  relax::Var z("z", PrimStructInfo(runtime::DataType::Int(64)));
+
+  BlockBuilder bb = BlockBuilder::Create(NullOpt);
+  relax::Expr t0 = bb->Normalize(Tuple({x, y}));
+  relax::Expr t1 = bb->Normalize(Tuple({t0, x, z, t0}));
+
+  auto c0 = Integer(0);
+  auto c1 = Integer(1);
+  auto c2 = Integer(2);
+
+  auto output = MapToNestedMsg<Integer>(t1, [&](Expr value) {
+    if (value.same_as(x)) return c0;
+    if (value.same_as(y)) return c1;
+    return c2;
+  });
+
+  NestedMsg<Integer> expected = {{c0, c1}, c0, c2, {c0, c1}};
+
+  EXPECT_TRUE(Equal(output, expected,
+                    [](Integer lhs, Integer rhs) -> bool { return lhs->value 
== rhs->value; }));
+
+  auto output2 =
+      MapToNestedMsg<Integer>(GetStructInfo(t1), [&](StructInfo sinfo) -> 
NestedMsg<Integer> {
+        const auto* prim_sinfo = sinfo.as<PrimStructInfoNode>();
+        if (prim_sinfo == nullptr) return NullOpt;
+        int bits = prim_sinfo->dtype.bits();
+        if (bits == 16) return c0;
+        if (bits == 32) return c1;
+        if (bits == 64) return c2;
+        return NullOpt;
+      });
+
+  EXPECT_TRUE(Equal(output2, expected,
+                    [](Integer lhs, Integer rhs) -> bool { return lhs->value 
== rhs->value; }));
+
+  int x_count = 0, y_count = 0, z_count = 0;
+
+  DecomposeNestedMsg(t1, expected, [&](Expr value, NestedMsg<Integer> msg) {
+    if (value.same_as(x)) {
+      EXPECT_TRUE(msg.same_as(c0));
+      ++x_count;
+    } else if (value.same_as(y)) {
+      EXPECT_TRUE(msg.same_as(c1));
+      ++y_count;
+    } else {
+      EXPECT_TRUE(msg.same_as(c2));
+      ++z_count;
+    }
+  });
+  EXPECT_EQ(x_count, 3);
+  EXPECT_EQ(y_count, 2);
+  EXPECT_EQ(z_count, 1);
+}
+
+TEST(NestedMsg, MapToNestedMsgBySInfo) {
+  auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0);
+  auto sf1 = TupleStructInfo({sf0, sf0});
+  auto sf2 = TupleStructInfo({sf0, sf0});
+  auto x = relax::Var("x", TupleStructInfo({sf1, sf2, sf0}));
+
+  auto msg = MapToNestedMsgBySInfo<Expr>(x, [](Expr value) { return value; });
+
+  EXPECT_TRUE(msg.IsNested());
+  auto arr = msg.NestedArray();
+
+  EXPECT_TRUE(arr[1].IsNested());
+  auto arr1 = arr[1].NestedArray();
+
+  EXPECT_TRUE(arr1[0].IsLeaf());
+  EXPECT_TRUE(StructuralEqual()(arr1[0].LeafValue(), 
TupleGetItem(TupleGetItem(x, 1), 0)));
+
+  EXPECT_TRUE(arr[2].IsLeaf());
+  EXPECT_TRUE(StructuralEqual()(arr[2].LeafValue(), TupleGetItem(x, 2)));
+}
+
+TEST(NestedMsg, NestedMsgToExpr) {
+  auto sf0 = TensorStructInfo(DataType::Float(32), /*ndim=*/0);
+  auto sf1 = TupleStructInfo({sf0, sf0});
+
+  auto c0 = Integer(0);
+  auto c1 = Integer(1);
+  auto c2 = Integer(2);
+
+  relax::Var x("x", sf0), y("y", sf0), z("z", sf0);
+
+  NestedMsg<Integer> msg = {c0, {c0, c1}, {c0, {c1, c2}}};
+  auto expr = NestedMsgToExpr<Integer>(msg, [&](Optional<Integer> leaf) {
+    ICHECK(leaf.defined());
+    int value = leaf.value().IntValue();
+    switch (value) {
+      case 0:
+        return x;
+      case 1:
+        return y;
+      default:
+        return z;
+    }
+  });
+
+  Expr expected = Tuple({x, Tuple({x, y}), Tuple({x, Tuple({y, z})})});
+  EXPECT_TRUE(StructuralEqual()(expr, expected));
+
+  // test simplified
+  relax::Var t("t", sf1);
+  NestedMsg<Expr> msg1 = {TupleGetItem(t, 0), TupleGetItem(t, 1)};
+  auto expr1 = NestedMsgToExpr<Expr>(msg1, [](Optional<Expr> leaf) { return 
leaf.value(); });
+  EXPECT_TRUE(StructuralEqual()(expr1, t));
+}
+
+TEST(NestedMsg, CombineNestedMsg) {
+  auto c0 = Integer(0);
+  auto c1 = Integer(1);
+  auto c2 = Integer(2);
+
+  NestedMsg<Integer> lhs = {c0, {c0, c1}, NullOpt, {c0, {c1, c2}}};
+  NestedMsg<Integer> rhs = {c1, {c2, NullOpt}, NullOpt, {c1, {c2, c2}}};
+  NestedMsg<Integer> expected = {c1, {c2, c1}, NullOpt, {c1, {c2, c2}}};
+
+  auto output = CombineNestedMsg(lhs, rhs, [](Integer x, Integer y) {
+    if (x->value > y->value) return x;
+    return y;
+  });
+
+  EXPECT_TRUE(Equal(output, expected,
+                    [](Integer lhs, Integer rhs) -> bool { return lhs->value 
== rhs->value; }));
+}
+
+TEST(NestedMsg, MapNestedMsg) {
+  auto c0 = Integer(0);
+  auto c1 = Integer(1);
+  auto c2 = Integer(2);
+  auto c3 = Integer(3);
+
+  NestedMsg<Integer> msg = {c0, {c0, c1}, NullOpt, {c0, {c2, c1}}};
+  NestedMsg<Integer> expected = {c3, {c3, NullOpt}, NullOpt, {c3, {c2, 
NullOpt}}};
+
+  auto output = MapNestedMsg(msg, [](Integer x) {
+    if (x->value == 0) {
+      return NestedMsg<Integer>(Integer(3));
+    } else if (x->value == 1) {
+      return NestedMsg<Integer>();
+    } else {
+      return NestedMsg<Integer>(x);
+    }
+  });
+
+  EXPECT_TRUE(Equal(output, expected,
+                    [](Integer lhs, Integer rhs) -> bool { return lhs->value 
== rhs->value; }));
+}
+
+TEST(NestedMsg, TransformTupleLeaf) {
+  auto c0 = Integer(0);
+  auto c1 = Integer(1);
+  auto c2 = Integer(2);
+  using NInt = NestedMsg<Integer>;
+
+  NInt msg1 = {c0, {c0, c1}, c2, {c0, {c1, c2}}};
+  NInt msg2 = {c1, {c2, c0}, c2, {c1, {c2, c0}}};
+
+  PrimStructInfo s = PrimStructInfo(runtime::DataType::Int(32));
+  relax::Var x("x", s), y("y", s), z("z", s);
+  BlockBuilder bb = BlockBuilder::Create(NullOpt);
+  Expr expr = bb->Normalize(Tuple({x, Tuple({x, x}), x, Tuple({x, Tuple({x, 
x})})}));
+
+  auto ftransleaf = [&](Expr value, std::array<NInt, 2> msgs) -> Expr {
+    int lhs = Downcast<Integer>(msgs[0].LeafValue())->value;
+    int rhs = Downcast<Integer>(msgs[1].LeafValue())->value;
+    if (lhs > rhs)
+      return z;
+    else if (lhs == rhs)
+      return value;
+    else
+      return y;
+  };
+
+  Expr expected = Tuple({y, Tuple({y, z}), x, Tuple({y, Tuple({y, z})})});
+
+  EXPECT_TRUE(StructuralEqual()(
+      TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg2}), ftransleaf), 
expected));
+
+  EXPECT_TRUE(
+      expr.same_as(TransformTupleLeaf(expr, std::array<NInt, 2>({msg1, msg1}), 
ftransleaf)));
+}

Reply via email to