tqchen commented on a change in pull request #9016: URL: https://github.com/apache/tvm/pull/9016#discussion_r711350425
########## File path: src/arith/pattern_match.h ########## @@ -210,6 +210,73 @@ class PVar : public Pattern<PVar<T>> { mutable bool filled_{false}; }; +/*! + * \brief Wrapper for pattern variable container with extra match logic. + * + * \tparam Derived the type of derived class. + * \tparam T the type of the hole. + */ +template <typename Derived, typename T> +class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> { + public: + // Store by reference in the expression. + using Nested = const PVarWithCheck<Derived, T>&; + + void InitMatch_() const { pvar_.InitMatch_(); } + + bool Match_(const T& value) const { + if (!static_cast<const Derived*>(this)->Match_(value)) return false; + return pvar_.Match_(value); + } + + template <typename NodeRefType, + typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type> + bool Match_(const NodeRefType& value) const { + if (const auto* ptr = value.template as<typename T::ContainerType>()) { + return Match_(GetRef<T>(ptr)); + } else { + return false; + } + } + + T Eval() const { return pvar_.Eval(); } + + protected: + arith::PVar<T> pvar_; +}; + +/*! + * \brief Pattern variable container with expr type check. + * + * \tparam T the type of the hole. + * \tparam DType the Pattern type of dtype. + */ +template <typename T, typename DType, + typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>> +class PVarWithType : public PVarWithCheck<PVarWithType<T, DType>, T> { + public: + explicit PVarWithType(const DType& dtype) : dtype_(dtype) {} + + bool Match_(const T& value) const { return dtype_.Match_(value->dtype); } + + protected: + typename DType::Nested dtype_; +}; + +/*! + * \brief Pattern variable container for data type with lanes. + */ +class PVecType : public PVarWithCheck<PVecType, DataType> { Review comment: PVecDataType ########## File path: src/arith/pattern_match.h ########## @@ -210,6 +210,73 @@ class PVar : public Pattern<PVar<T>> { mutable bool filled_{false}; }; +/*! + * \brief Wrapper for pattern variable container with extra match logic. + * + * \tparam Derived the type of derived class. + * \tparam T the type of the hole. + */ +template <typename Derived, typename T> +class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> { + public: + // Store by reference in the expression. + using Nested = const PVarWithCheck<Derived, T>&; + + void InitMatch_() const { pvar_.InitMatch_(); } + + bool Match_(const T& value) const { + if (!static_cast<const Derived*>(this)->Match_(value)) return false; + return pvar_.Match_(value); + } + + template <typename NodeRefType, + typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type> + bool Match_(const NodeRefType& value) const { + if (const auto* ptr = value.template as<typename T::ContainerType>()) { + return Match_(GetRef<T>(ptr)); + } else { + return false; + } + } + + T Eval() const { return pvar_.Eval(); } + + protected: + arith::PVar<T> pvar_; +}; + +/*! + * \brief Pattern variable container with expr type check. + * + * \tparam T the type of the hole. + * \tparam DType the Pattern type of dtype. + */ +template <typename T, typename DType, + typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>> +class PVarWithType : public PVarWithCheck<PVarWithType<T, DType>, T> { Review comment: Would be great to rename to PVarWithDataType -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org