mbs-octoml commented on code in PR #11981: URL: https://github.com/apache/tvm/pull/11981#discussion_r918212033
########## src/relay/collage/sub_graph.h: ########## @@ -0,0 +1,451 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file src/relay/collage/sub_graph.h + * \brief Represents a sub-graph of an overall Relay expression. + */ + +#ifndef TVM_RELAY_COLLAGE_SUB_GRAPH_H_ +#define TVM_RELAY_COLLAGE_SUB_GRAPH_H_ + +#include <tvm/ir/transform.h> +#include <tvm/relay/op_attr_types.h> + +#include <string> +#include <unordered_map> +#include <utility> +#include <vector> + +#include "../ir/dataflow_matcher_impl.h" +#include "../ir/indexed_graph.h" +#include "./dataflow_graph.h" +#include "./index_set.h" + +namespace tvm { +namespace relay { +namespace collage { + +/*! \brief Returns operator pattern kind as single-letter string. */ +std::string KindToString(OpPatternKind kind); + +/*! + * \brief Returns a kind and label for the single \p sub_expr, ignoring it's sub-sub expressions. + */ +std::pair<OpPatternKind, std::string> SubExprKindAndLabel(const Expr& sub_expr); + +/*! + * \brief Returns a kind and label for all the nodes in \p inside. + */ +std::pair<OpPatternKind, std::string> SubGraphKindAndLabel(const DataflowGraph& dataflow_graph, + const IndexSet& inside); + +/*! + * \brief Returns the index set representing all the sub-expression matched by \p matcher. + */ +IndexSet MatcherToIndexSet(const DFPatternMatcher& matcher); + +/*! + * \brief Configuration controlling which sub-graphs are considered valid. + */ +struct SubGraphConfig { + /*! \brief Maximum number of exit nodes in the sub-graph, or zero if no limit. */ + size_t max_exits = 0; + /*! + * \brief Whether a node inside the sub-graph may flow to nodes both inside and outside + * the sub-graph (which we call a 'tap'). Note that it is still possible to have multiple outputs + * even with this flag false. + */ + bool allow_taps = false; + /*! + * \brief Maximum allowed maximum depth, or zero if no-limit. + */ + size_t max_max_depth = 0; + + std::string ToString() const; +}; + +class SubGraph; +using FunctionAttrsMap = Map<String, ObjectRef>; + +/*! + * \brief A sub-sub graph is a sub-graph which is to be nested inside a function as part of some + * enclosing sub-graph. + * + * Extraction yields a function with input nodes replaced by parameters and exit nodes in the + * function result. Rewriting replaces the sub-graph with a call to that function, and all + * outputs with (projections from) the call result. + * + * (Note that it's tempting to move attrs_ into \p SubGraphNode and thus avoid this class. + * However we found the implementation was easier to understand in this form since it makes + * the result of \p Extract unambiguous.) + */ +class SubSubGraphNode : public Object { + public: + /*! \brief The nested sub-graph. */ + ObjectRef /* actually SubGraph */ sub_graph_obj_; + /*! \brief Attributes (possibly empty) to attach to the extracted function. */ + FunctionAttrsMap attrs_; + + void VisitAttrs(AttrVisitor* v); + + SubGraph sub_graph() const; + + bool operator==(const SubSubGraphNode& that) const; + bool operator!=(const SubSubGraphNode& that) const { return !(*this == that); } + bool operator<(const SubSubGraphNode& that) const; + size_t hash() const; + + std::string ToString() const; + + /*! + * \brief Returns the function representing this sub-sub-graph within the overall expression + * represented by \p dataflow_graph: + * - All sub-graph inputs become parameters. + * - All sub-graph outputs become function results (either directly or as a field in a tuple). + * - The function has attrs_ for attributes (which may be empty). + * - The function body accounts for any rewrites implied by the nested sub-graph. + */ + Function Extract(const DataflowGraph& dataflow_graph) const; + + /*! + * \brief Returns \p expr rewritten to encode the partitioning implied by this sub-sub-graph. + * + * It is valid for \p expr to not be the same as \p dataflow_graph.expr(), however all nodes + * inside this sub-sub-graph must correspond to nodes shared between \p dataflow_graph.expr() and + * \p expr. See \p SubGraph::ParallelRewrite below. + */ + Expr Rewrite(const DataflowGraph& dataflow_graph, const Expr& expr) const; + + static constexpr const char* _type_key = "relay.collage.SubSubGraph"; + TVM_DECLARE_FINAL_OBJECT_INFO(SubSubGraphNode, Object); +}; + +class SubSubGraph : public ObjectRef { Review Comment: I hadn't thought of that actually. At first I had SubGraph be directly recursive until I realized things were much clearer with the intermediate NestedSubGraph, and was so happy with that I didn't push further. You are right there's some signature sharing but no implementation sharing I can see, and I think making code polymorphic on SubGraph vs NestedSubGraph would only make things even more confusing. So let me leave it as is. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@tvm.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org