This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push: new 73dda6b [Relay] Convert Layout Pass. (#4335) 73dda6b is described below commit 73dda6be503922da4a3861f5a48718f6d7b2ef1a Author: Animesh Jain <anij...@umich.edu> AuthorDate: Thu Dec 26 11:15:46 2019 -0800 [Relay] Convert Layout Pass. (#4335) --- include/tvm/relay/op_attr_types.h | 17 + include/tvm/relay/transform.h | 20 ++ python/tvm/relay/op/nn/_nn.py | 41 +++ python/tvm/relay/op/op.py | 17 + python/tvm/relay/transform.py | 28 ++ src/relay/op/annotation/annotation.cc | 2 +- src/relay/op/device_copy.cc | 2 +- src/relay/op/memory/memory.cc | 2 +- src/relay/op/nn/bitserial.cc | 2 +- src/relay/op/nn/convolution.cc | 2 +- src/relay/op/nn/nn.cc | 2 +- src/relay/op/nn/pooling.cc | 2 +- src/relay/op/nn/sparse.cc | 2 +- src/relay/op/op_common.h | 2 +- src/relay/op/tensor/transform.cc | 2 +- src/relay/pass/alter_op_layout.cc | 361 +++----------------- src/relay/pass/convert_layout.cc | 146 +++++++++ .../{alter_op_layout.h => infer_layout_util.h} | 45 ++- src/relay/pass/transform_layout.h | 362 +++++++++++++++++++++ tests/python/relay/test_pass_convert_op_layout.py | 360 ++++++++++++++++++++ 20 files changed, 1094 insertions(+), 323 deletions(-) diff --git a/include/tvm/relay/op_attr_types.h b/include/tvm/relay/op_attr_types.h index 741e8b4..54ea707 100644 --- a/include/tvm/relay/op_attr_types.h +++ b/include/tvm/relay/op_attr_types.h @@ -29,6 +29,7 @@ #include <tvm/build_module.h> #include <tvm/relay/type.h> #include <tvm/relay/expr.h> +#include <string> namespace tvm { namespace relay { @@ -133,6 +134,22 @@ using FTVMAlterOpLayout = runtime::TypedPackedFunc< const Array<Tensor>& tinfos)>; /*! + * \brief Convert the layout of operators or replace the + * operator with other expressions. This function will be invoked + * in ConvertLayout pass. + * \param attrs The attribute of the original node. + * \param inputs The input symbols of the original node. + * \param tinfos An array of placeholders, use for getting the inferred shape + * and dtype of the inputs. + * \param desired_layout The desired layout. + * \return new_expr The modified expression. + */ +using FTVMConvertOpLayout = runtime::TypedPackedFunc< + Expr(const Attrs& attrs, + const Array<Expr>& args, + const Array<Tensor>& tinfos, + const std::string& desired_layout)>; +/*! * \brief Legalizes an expression with another expression. This function will be * invoked in Legalize pass. It is a target-dependent pass. * \param attrs The attribute of the original node. diff --git a/include/tvm/relay/transform.h b/include/tvm/relay/transform.h index ddadbe4..52be6a0 100644 --- a/include/tvm/relay/transform.h +++ b/include/tvm/relay/transform.h @@ -533,6 +533,26 @@ TVM_DLL Pass CanonicalizeOps(); TVM_DLL Pass AlterOpLayout(); /*! + * \brief Given a dest layout, this pass transforms the expr such that most of the ops input data + * layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, one + * at the start and one at the end. + * + * This pass is not a part of relay.build and is expected to be called between framework-relay + * parser and relay.build call. This is very helpful for hardware backends that support/prefer only + * type of data layout. + * + * RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009 + * + * This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define new + * layouts for conv2d ops for now. Most of the other operators try to adapt to their input layout + * using the InferCorrectLayout infrastructure. + * + * \param desired_layout The desired layout. + * \return The pass. + */ +TVM_DLL Pass ConvertLayout(const std::string& desired_layout); + +/*! * \brief Legalizes an expr with another expression. * \param legalize_map_attr_name The Op's attr name which corresponds to the legalize rule function. * One can collect and isolate similar type of legalize transformations using this param. For diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py index ce47736..761abc7 100644 --- a/python/tvm/relay/op/nn/_nn.py +++ b/python/tvm/relay/op/nn/_nn.py @@ -251,6 +251,47 @@ def legalize_conv2d(attrs, inputs, types): """ return topi.nn.conv2d_legalize(attrs, inputs, types) + +@reg.register_convert_op_layout("nn.conv2d") +def convert_conv2d(attrs, inputs, tinfos, desired_layout): + """Convert Layout pass registration for conv2d op. + + Parameters + ---------- + attrs : tvm.attrs.Attrs + Attributes of current convolution + inputs : list of tvm.relay.Expr + The args of the Relay expr to be legalized + tinfos : list of types + List of input and output types + desired_layout : str + The desired layout + + Returns + ------- + result : tvm.relay.Expr + The transformed expr + """ + + from tvm import relay + data_layout = attrs['data_layout'] + kernel_layout = attrs['kernel_layout'] + data, weight = inputs + assert desired_layout == 'NCHW', \ + "Currently only transformation to NCHW layout is supported." + if desired_layout == 'NCHW': + new_attrs = dict(attrs) + new_attrs['data_layout'] = desired_layout + new_attrs['kernel_layout'] = 'OIHW' + + if data_layout == 'NHWC' and kernel_layout == 'HWIO': + # Convert (NHWC, HWIO) to (NCHW, OIHW) + return relay.nn.conv2d(data, weight, **new_attrs) + if data_layout == 'NHWC' and kernel_layout == 'HWOI': + # Convert (NHWC, HWOI) to (NCHW, OIHW). Depthwise conv2d. + return relay.nn.conv2d(data, weight, **new_attrs) + return None + reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE) diff --git a/python/tvm/relay/op/op.py b/python/tvm/relay/op/op.py index 355496e..382f667 100644 --- a/python/tvm/relay/op/op.py +++ b/python/tvm/relay/op/op.py @@ -196,6 +196,23 @@ def register_alter_op_layout(op_name, alter_layout=None, level=10): return register(op_name, "FTVMAlterOpLayout", alter_layout, level) +def register_convert_op_layout(op_name, convert_layout=None, level=10): + """Register convert op layout function for an op + + Parameters + ---------- + op_name : str + The name of the operator + + convert_layout: function (attrs: Attrs, inputs: List[Expr]) -> new_expr: Expr + The function for changing the layout or replacing the operator + + level : int + The priority level + """ + return register(op_name, "FTVMConvertOpLayout", convert_layout, level) + + def register_legalize(op_name, legal_op=None, level=10): """Register legal transformation function for an op diff --git a/python/tvm/relay/transform.py b/python/tvm/relay/transform.py index 540c1f5..1f91272 100644 --- a/python/tvm/relay/transform.py +++ b/python/tvm/relay/transform.py @@ -460,6 +460,34 @@ def AlterOpLayout(): return _transform.AlterOpLayout() +def ConvertLayout(desired_layout): + """ Given a dest layout, this pass transforms the expr such that most of the ops input data + layout is changed to the dest layout. In ideal situation, there are only 2 layout transforms, + one at the start and one at the end. + + This pass is not a part of relay.build and is expected to be called between framework-relay + parser and relay.build call. This is very helpful for hardware backends that support/prefer only + type of data layout. + + RFC - https://discuss.tvm.ai/t/layout-conversion-pass/4009 + + This pass uses most of the AlterOpLayout and InferCorrectLayout infrastructure. We can define + new layouts for conv2d ops for now. Most of the other operators try to adapt to their input + layout using the InferCorrectLayout infrastructure. + + Parameters + ---------- + desired_layout : str + The desired layout for the transformed expr. + + Returns + ------- + pass: FunctionPass + The pass. + """ + return _transform.ConvertLayout(desired_layout) + + def Legalize(legalize_map_attr_name="FTVMLegalize"): """Legalizes an expression with another expression. This pass can be used to replace an expr with another expr for target diff --git a/src/relay/op/annotation/annotation.cc b/src/relay/op/annotation/annotation.cc index f5674fa..6835525 100644 --- a/src/relay/op/annotation/annotation.cc +++ b/src/relay/op/annotation/annotation.cc @@ -30,7 +30,7 @@ #include <tvm/relay/op_attr_types.h> #include <topi/elemwise.h> -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" #include "../type_relations.h" namespace tvm { diff --git a/src/relay/op/device_copy.cc b/src/relay/op/device_copy.cc index 51aff41..3b997a2 100644 --- a/src/relay/op/device_copy.cc +++ b/src/relay/op/device_copy.cc @@ -33,7 +33,7 @@ #include <tvm/relay/op_attr_types.h> #include "type_relations.h" -#include "../pass/alter_op_layout.h" +#include "../pass/infer_layout_util.h" namespace tvm { namespace relay { diff --git a/src/relay/op/memory/memory.cc b/src/relay/op/memory/memory.cc index c99cf0f..c535d76 100644 --- a/src/relay/op/memory/memory.cc +++ b/src/relay/op/memory/memory.cc @@ -29,7 +29,7 @@ #include <tvm/relay/attrs/memory.h> #include "../op_common.h" -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" #include "../type_relations.h" namespace tvm { diff --git a/src/relay/op/nn/bitserial.cc b/src/relay/op/nn/bitserial.cc index d70f1af..d651bae 100644 --- a/src/relay/op/nn/bitserial.cc +++ b/src/relay/op/nn/bitserial.cc @@ -26,7 +26,7 @@ #include <tvm/relay/attrs/bitserial.h> #include <tvm/relay/op.h> -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" namespace tvm { namespace relay { diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc index df890b1..5f1b194 100644 --- a/src/relay/op/nn/convolution.cc +++ b/src/relay/op/nn/convolution.cc @@ -27,7 +27,7 @@ #include <tvm/relay/attrs/nn.h> #include <vector> -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" #include "../op_common.h" #include "convolution.h" diff --git a/src/relay/op/nn/nn.cc b/src/relay/op/nn/nn.cc index ac38485..2cb0c28 100644 --- a/src/relay/op/nn/nn.cc +++ b/src/relay/op/nn/nn.cc @@ -33,7 +33,7 @@ #include <vector> #include <string> #include "../type_relations.h" -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" #include "../op_common.h" #include "nn.h" diff --git a/src/relay/op/nn/pooling.cc b/src/relay/op/nn/pooling.cc index aa37559..e7529a9 100644 --- a/src/relay/op/nn/pooling.cc +++ b/src/relay/op/nn/pooling.cc @@ -27,7 +27,7 @@ #include <tvm/relay/attrs/nn.h> #include <topi/nn/pooling.h> #include <vector> -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" namespace tvm { namespace relay { diff --git a/src/relay/op/nn/sparse.cc b/src/relay/op/nn/sparse.cc index a2d89c4..7cf8a27 100644 --- a/src/relay/op/nn/sparse.cc +++ b/src/relay/op/nn/sparse.cc @@ -27,7 +27,7 @@ #include <tvm/relay/attrs/nn.h> #include <vector> -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" namespace tvm { namespace relay { diff --git a/src/relay/op/op_common.h b/src/relay/op/op_common.h index b960c75..42c1fc4 100644 --- a/src/relay/op/op_common.h +++ b/src/relay/op/op_common.h @@ -32,7 +32,7 @@ #include <string> #include <unordered_map> #include "type_relations.h" -#include "../pass/alter_op_layout.h" +#include "../pass/infer_layout_util.h" namespace tvm { namespace relay { diff --git a/src/relay/op/tensor/transform.cc b/src/relay/op/tensor/transform.cc index ecd20e5..ff018e4 100644 --- a/src/relay/op/tensor/transform.cc +++ b/src/relay/op/tensor/transform.cc @@ -36,7 +36,7 @@ #include <vector> #include "../op_common.h" #include "../../../arithmetic/compute_expr.h" -#include "../../pass/alter_op_layout.h" +#include "../../pass/infer_layout_util.h" #include "../../pass/pattern_util.h" #include "transform.h" diff --git a/src/relay/pass/alter_op_layout.cc b/src/relay/pass/alter_op_layout.cc index d893d94..bd89c51 100644 --- a/src/relay/pass/alter_op_layout.cc +++ b/src/relay/pass/alter_op_layout.cc @@ -36,7 +36,7 @@ #include <utility> #include <unordered_map> -#include "alter_op_layout.h" +#include "transform_layout.h" #include "pattern_util.h" namespace tvm { @@ -44,328 +44,73 @@ namespace relay { namespace alter_op_layout { -// Make a transform CallNode -/* Performs 2 operations - * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim size. - * For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC. - * 2) Call layout transform with new src layout. +/*! + * \brief Container to instantiate a Node for alter op layouts. */ -Expr TransformLayout(Expr raw, Layout src_layout, Layout dst_layout) { - if (src_layout.Equals(dst_layout)) { - return raw; - } - - // 1) Check if the shape lengths are different. If yes, expand dims. - Expr input_expr = raw; - Layout new_src_layout = src_layout; - if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { - int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); - new_src_layout = src_layout.ExpandPrimal(dst_layout); - input_expr = MakeExpandDims(input_expr, 0, num_new_axis); - if (new_src_layout.Equals(dst_layout)) { - return input_expr; - } - } - - // 2) Insert layout transform on the transformed src. - CHECK(new_src_layout.defined() && dst_layout.defined()) - << "Cannot insert layout transform because there are undefined layouts"; - CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined()) - << "Cannot insert layout transform because there are inconvertible layouts: " - << new_src_layout << " v.s. " << dst_layout; - return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); -} - -// Memorize layout transform so we can reuse internal transformed nodes -class TransformMemorizerNode : public Node { +class AlterTransformMemorizerNode : public TransformMemorizerNode { public: - // map from (Expr, src_layout, dst_layout) to transformed Expr - using TransformKey = std::tuple<const Node*, std::string, std::string>; -struct key_hash : public std::function<std::size_t(TransformKey)> { - std::size_t operator()(const TransformKey& k) const { - return dmlc::HashCombine<std::string>(dmlc::HashCombine<std::string>( - std::hash<const Node*>()(std::get<0>(k)), std::get<1>(k)), (std::get<2>(k))); - } - }; - - std::unordered_map<TransformKey, Expr, key_hash> memo; - static constexpr const char *_type_key = "relay.alter_op_layout.TransformMemorizerNode"; - TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); -}; - -class TransformMemorizer : public NodeRef { - public: - TransformMemorizer() {} - explicit TransformMemorizer(ObjectPtr<Object> n) : NodeRef(n) {} - - TransformMemorizerNode* operator->() { - return static_cast<TransformMemorizerNode*>(get_mutable()); - } - - // Transform layout with memorizer - Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { - if (src_layout.Equals(dst_layout)) { return raw; } - - std::tuple<const Node*, std::string, std::string> key = - std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); - auto& memo = operator->()->memo; - - auto iter = memo.find(key); - if (iter != memo.end()) { - return iter->second; - } else { - Expr transform = TransformLayout(raw, src_layout, dst_layout); - memo[key] = transform; - return transform; - } - } - - using ContainerType = TransformMemorizerNode; + static constexpr const char* _type_key = "relay.alter_op_layout.AlterTransformMemorizerNode"; }; - -// TempExprNode during layout transform -// Instance of this expr will be Realized to normal expr ultimately -class LayoutAlternatedExprNode : public TempExprNode { +/*! + * \brief Container that provides the transformation function for alter layout.. + */ +class AlterTransformMemorizer : public TransformMemorizer { public: - Expr value; - Layout old_layout; - Layout new_layout; - TransformMemorizer memorizer; - - Expr Realize() const final { - // NOTE: use a copy to discard the "const" qualifier - TransformMemorizer tmp_memorizer = memorizer; - // fallback to old layout - return tmp_memorizer.Transform(value, new_layout, old_layout); - } - - void VisitAttrs(AttrVisitor *v) { - v->Visit("value", &value); - v->Visit("old_layout", &old_layout); - v->Visit("new_layout", &new_layout); - } - - static constexpr const char *_type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; - TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); -}; - -RELAY_DEFINE_NODE_REF(LayoutAlternatedExpr, LayoutAlternatedExprNode, TempExpr); - -// Call registered FInferCorrectLayout of an op. -// Parameters are the same as the parameters for FInferCorrectLayout -// Returns inferred_input_layout, inferred_output_layout, success -std::tuple<Array<Layout>, Array<Layout>, bool> CallInfer( - const Call& call, - const Array<Layout>& new_in_layouts, - const Array<Layout>& old_in_layouts, - const Array<Array<IndexExpr> > &old_in_shapes) { - static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout"); - if (!call->op.as<OpNode>()) { - return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); - } - - Op op = Downcast<Op>(call->op); - if (finfer_layout.count(op)) { - Array<Array<Layout> > inferred_layouts; - inferred_layouts = finfer_layout[op](call->attrs, new_in_layouts, - old_in_layouts, old_in_shapes); - CHECK_EQ(inferred_layouts.size(), 2) - << "FInferCorrectLayout should return an array with size of 2"; - for (auto x : inferred_layouts) { - for (auto y : x) { - if (!y.defined()) { // inference fails - return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); - } + AlterTransformMemorizer() {} + explicit AlterTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {} + + AlterTransformMemorizerNode* operator->() { + return static_cast<AlterTransformMemorizerNode*>(get_mutable()); + } + + /*! + * \brief Defines the call transformation for AlterOpLayout pass. The new layouts are defined by + * used for different targets using a packed func. + * \param ref_call The original call. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override { + static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout"); + Op op = Downcast<Op>(ref_call->op); + + Expr new_e; + bool modified = false; + if (falter_layout.count(op)) { + tvm::Array<tvm::Tensor> tinfos; + for (auto expr : ref_call->args) { + auto ttype = expr->type_as<TensorTypeNode>(); + tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); } - } - return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true); - } else { - return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); - } -} - -// Call registered FTVMAlterOpLayout of an op -// Returns the altered expression -Call CallAlter(const Call& ref_call, - const std::vector<Expr>& new_args) { - static auto falter_layout = Op::GetAttr<FTVMAlterOpLayout>("FTVMAlterOpLayout"); - Op op = Downcast<Op>(ref_call->op); - - Expr new_e; - bool modified = false; - if (falter_layout.count(op)) { - tvm::Array<tvm::Tensor> tinfos; - for (auto expr : ref_call->args) { - auto ttype = expr->type_as<TensorTypeNode>(); - tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); - } - Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); - if (altered_value.defined()) { - new_e = altered_value; - modified = true; - } - } - if (!modified) { - new_e = CallNode::make(ref_call->op, new_args, - ref_call->attrs); - } - - const CallNode *new_call = new_e.as<CallNode>(); - CHECK(new_call) << "Can only replace the original operator with another call node"; - return GetRef<Call>(new_call); -} - -Expr AlterOpLayoutRewrite(const Call &ref_call, - const Array<Expr> &new_args, - const NodeRef& ctx) { - std::vector<LayoutAlternatedExpr> inputs; - std::vector<Expr> normal_new_args; - Array<Array<IndexExpr> > input_shapes; - - // NOTE: discard the "const" qualifier - TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx); - - // fill incomplete state and flatten tuple - auto push_back_one_arg = [&inputs, memorizer](Expr arg) { - // We always expect LayoutAlternatedExpr. - // This is used to convert the normal Expr to LayoutAlternatedExpr. - if (const LayoutAlternatedExprNode *inp = arg.as<LayoutAlternatedExprNode>()) { - inputs.push_back(GetRef<LayoutAlternatedExpr>(inp)); - return inp->value; - } else { - auto inode = make_node<LayoutAlternatedExprNode>(); - inode->value = arg; - inode->memorizer = memorizer; - inputs.push_back(LayoutAlternatedExpr(inode)); - return arg; - } - }; - - for (auto new_arg : new_args) { - // NOTE: do not support nested tuple - if (new_arg->IsInstance<TupleNode>()) { - Tuple tuple_new_arg = Downcast<Tuple>(new_arg); - std::vector<Expr> fields; - for (auto x : tuple_new_arg->fields) { - Expr tmp = push_back_one_arg(x); - fields.push_back(tmp); + Expr altered_value = falter_layout[op](ref_call->attrs, new_args, tinfos); + if (altered_value.defined()) { + new_e = altered_value; + modified = true; } - normal_new_args.push_back(TupleNode::make(fields)); - } else { - Expr tmp = push_back_one_arg(new_arg); - normal_new_args.push_back(tmp); } - } - - // old_in, new_in = state[inputs] - Array<Layout> old_in, old_out, new_in, new_out, new_in2; - for (auto inp : inputs) { - old_in.push_back(inp->old_layout); - new_in.push_back(inp->new_layout); - } - - for (auto arg : ref_call->args) { - if (arg->IsInstance<TupleNode>()) { // flatten tuple - Tuple tuple_arg = Downcast<Tuple>(arg); - for (auto x : tuple_arg->fields) { - input_shapes.push_back(x->type_as<TensorTypeNode>()->shape); - } - } else { - input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape); + if (!modified) { + new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs); } - } - - // old_in, old_out = op.infer(old_in) - bool success = false; - std::tie(old_in, old_out, success) = CallInfer(ref_call, - Array<Layout>(nullptr), - old_in, input_shapes); - if (!success) { return Expr(nullptr); } - CHECK_EQ(old_in.size(), new_in.size()); - - // if new_in == 'undef': new_in = old_in - for (size_t i = 0; i < new_in.size(); ++i) { - if (!new_in[i].defined()) { - new_in.Set(i, old_in[i]); - } - } - - // new_op = alter(op) - Call new_call = CallAlter(ref_call, normal_new_args); - - // new_in2, new_out = op.infer(new_in) - if (new_call->op->IsInstance<OpNode>()) { - success = false; - std::tie(new_in2, new_out, success) = CallInfer(new_call, new_in, old_in, input_shapes); - if (!success) { return Expr(nullptr); } - } else { - return Expr(nullptr); - } - - CHECK_EQ(new_out.size(), old_out.size()) - << "The number of output nodes should keep the same during alter_op_layout"; - CHECK_EQ(new_in.size(), new_in2.size()) - << "The number of input nodes should keep the same during alter_op_layout"; - // if (new_in != new_in2): insert transform (new_in -> new_in2) - Array<Expr> transformed_args; - size_t pt = 0; - for (auto arg : new_call->args) { - if (arg->IsInstance<TupleNode>()) { // unflatten tuple - Tuple tuple_arg = Downcast<Tuple>(arg); - std::vector<Expr> transformed_tuple_arg; - for (auto arg_item : tuple_arg->fields) { - transformed_tuple_arg.push_back( - memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); - pt++; - } - transformed_args.push_back(TupleNode::make(transformed_tuple_arg)); - } else { - transformed_args.push_back( - memorizer.Transform(arg, new_in[pt], new_in2[pt])); - pt++; - } + const CallNode* new_call = new_e.as<CallNode>(); + CHECK(new_call) << "Can only replace the original operator with another call node"; + return GetRef<Call>(new_call); } - CHECK_EQ(pt, inputs.size()); - // state[node] = (old_out, new_out) - // (handle tuple output) - if (ref_call->checked_type()->IsInstance<TupleTypeNode>()) { - Expr tuple_output = CallNode::make(new_call->op, transformed_args, - new_call->attrs); - Array<Expr> fields; - for (size_t i = 0; i < new_out.size(); ++i) { - auto rnode = make_node<LayoutAlternatedExprNode>(); - rnode->value = TupleGetItemNode::make(tuple_output, i); - rnode->old_layout = old_out[i]; - rnode->new_layout = new_out[i]; - rnode->memorizer = memorizer; - fields.push_back(Expr(rnode)); - } - return TupleNode::make(fields); - } else { - auto rnode = make_node<LayoutAlternatedExprNode>(); - CHECK_EQ(new_out.size(), 1); - rnode->value = CallNode::make(new_call->op, transformed_args, - new_call->attrs); - rnode->old_layout = old_out[0]; - rnode->new_layout = new_out[0]; - rnode->memorizer = memorizer; - return Expr(rnode); - } -} + using ContainerType = AlterTransformMemorizerNode; +}; -// Limiations: -// 1. the altered op should have the same number of arguments as the previous one -// 2. do not support nested tuple arguments +/*! + * Limitations: + * 1. The altered op should have the same number of arguments as the previous one. + * 2. Do not support nested tuple arguments. + */ Expr AlterOpLayout(const Expr& expr) { - TransformMemorizer transformMemorizer(make_node<TransformMemorizerNode>()); - auto fcontext = [&](const Call& call) -> NodeRef{ - return transformMemorizer; - }; + AlterTransformMemorizer alterMemorizer(make_node<AlterTransformMemorizerNode>()); + auto fcontext = [&](const Call& call) -> NodeRef { return alterMemorizer; }; - return ForwardRewrite(expr, AlterOpLayoutRewrite, fcontext); + return ForwardRewrite(expr, LayoutRewriter<AlterTransformMemorizer>, fcontext); } } // namespace alter_op_layout diff --git a/src/relay/pass/convert_layout.cc b/src/relay/pass/convert_layout.cc new file mode 100644 index 0000000..1db4422 --- /dev/null +++ b/src/relay/pass/convert_layout.cc @@ -0,0 +1,146 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convert_op_layout.cc + * \brief Alternate the layouts of operators or replace primitive operators with + other expressions. This pass can be used for computing convolution in + custom layouts or other general weight pre-transformation. + */ +#include <tvm/relay/analysis.h> +#include <tvm/relay/transform.h> +#include <tvm/relay/op_attr_types.h> +#include <tvm/relay/attrs/transform.h> +#include <tvm/relay/transform.h> +#include <tvm/operation.h> +#include <tuple> +#include <vector> +#include <functional> +#include <string> +#include <utility> +#include <unordered_map> + +#include "transform_layout.h" +#include "pattern_util.h" + +namespace tvm { +namespace relay { + +namespace convert_op_layout { + +/*! + * \brief Container for the transformations for ConvertLayout. + */ +class ConvertTransformMemorizerNode : public TransformMemorizerNode { + public: + /*! + * \brief Initializes the desired_layout. + * \param desired_layout The desired layout. + */ + explicit ConvertTransformMemorizerNode(const std::string& desired_layout) + : desired_layout_(desired_layout) {} + + /*! \brief The desired layout for the Convert Layout pass */ + std::string desired_layout_; +}; + +/*! + * \brief Container that provides the transformation function for convert layout. + */ +class ConvertTransformMemorizer : public TransformMemorizer { + public: + ConvertTransformMemorizer() {} + explicit ConvertTransformMemorizer(ObjectPtr<Object> n) : TransformMemorizer(n) {} + + ConvertTransformMemorizerNode* operator->() { + return static_cast<ConvertTransformMemorizerNode*>(get_mutable()); + } + + /*! + * \brief Defines the call transformation for ConvertLayout pass. The new layouts should be the + * desired layout as specified by the user. + * \param ref_call The original call. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) override { + static auto fconvert_layout = Op::GetAttr<FTVMConvertOpLayout>("FTVMConvertOpLayout"); + Op op = Downcast<Op>(ref_call->op); + + Expr new_e; + bool modified = false; + if (fconvert_layout.count(op)) { + tvm::Array<tvm::Tensor> tinfos; + for (auto expr : ref_call->args) { + auto ttype = expr->type_as<TensorTypeNode>(); + tinfos.push_back(tvm::placeholder(ttype->shape, ttype->dtype)); + } + Expr altered_value = + fconvert_layout[op](ref_call->attrs, new_args, tinfos, operator->()->desired_layout_); + if (altered_value.defined()) { + new_e = altered_value; + modified = true; + } + } + if (!modified) { + new_e = CallNode::make(ref_call->op, new_args, ref_call->attrs); + } + + const CallNode* new_call = new_e.as<CallNode>(); + CHECK(new_call) << "Can only replace the original operator with another call node"; + return GetRef<Call>(new_call); + } + + using ContainerType = ConvertTransformMemorizerNode; +}; + +/*! + * Limitations: + * 1. The altered op should have the same number of arguments as the previous one. + * 2. Do not support nested tuple arguments. + */ +Expr ConvertLayout(const Expr& expr, const std::string& desired_layout) { + ConvertTransformMemorizer transformMemorizer( + make_node<ConvertTransformMemorizerNode>(desired_layout)); + auto fcontext = [&](const Call& call) -> NodeRef { return transformMemorizer; }; + + return ForwardRewrite(expr, LayoutRewriter<ConvertTransformMemorizer>, fcontext); +} + +} // namespace convert_op_layout + +namespace transform { + +Pass ConvertLayout(const std::string& desired_layout) { + runtime::TypedPackedFunc<Function(Function, Module, PassContext)> pass_func = + [=](Function f, Module m, PassContext pc) { + return Downcast<Function>(relay::convert_op_layout::ConvertLayout(f, desired_layout)); + }; + return CreateFunctionPass( + pass_func, 3, "ConvertLayout", + {ir::StringImm::make("InferType"), ir::StringImm::make("SimplifyInference"), + ir::StringImm::make("CanonicalizeOps")}); +} + +TVM_REGISTER_API("relay._transform.ConvertLayout").set_body_typed(ConvertLayout); + +} // namespace transform + +} // namespace relay +} // namespace tvm diff --git a/src/relay/pass/alter_op_layout.h b/src/relay/pass/infer_layout_util.h similarity index 82% rename from src/relay/pass/alter_op_layout.h rename to src/relay/pass/infer_layout_util.h index 49bf35c..94eeba1 100644 --- a/src/relay/pass/alter_op_layout.h +++ b/src/relay/pass/infer_layout_util.h @@ -18,18 +18,20 @@ */ /*! - * \file alter_op_layout.h - * \brief Alternate the layouts of operators or replace primitive operators with + * \file infer_layout_util.h + * \brief Utility functions to alter the layouts of operators or replace primitive operators with other expressions. This pass can be used for computing convolution in custom layouts or other general weight pre-transformation. */ -#ifndef TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ -#define TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ +#ifndef TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ +#define TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ #include <tvm/data_layout.h> #include <tvm/relay/expr.h> #include <string> +#include <tuple> +#include "pattern_util.h" namespace tvm { namespace relay { @@ -193,7 +195,40 @@ inline Array<Array<Layout> > BinaryBroadcastLayout(const Attrs& attrs, } } +/*! + * Call registered FInferCorrectLayout of an op. + * Parameters are the same as the parameters for FInferCorrectLayout + * Returns inferred_input_layout, inferred_output_layout, success + */ +static inline std::tuple<Array<Layout>, Array<Layout>, bool> InferCorrectLayouts( + const Call& call, const Array<Layout>& new_in_layouts, const Array<Layout>& old_in_layouts, + const Array<Array<IndexExpr>>& old_in_shapes) { + static auto finfer_layout = Op::GetAttr<FInferCorrectLayout>("FInferCorrectLayout"); + if (!call->op.as<OpNode>()) { + return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); + } + + Op op = Downcast<Op>(call->op); + if (finfer_layout.count(op)) { + Array<Array<Layout>> inferred_layouts; + inferred_layouts = + finfer_layout[op](call->attrs, new_in_layouts, old_in_layouts, old_in_shapes); + CHECK_EQ(inferred_layouts.size(), 2) + << "FInferCorrectLayout should return an array with size of 2"; + for (auto x : inferred_layouts) { + for (auto y : x) { + if (!y.defined()) { // inference fails + return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); + } + } + } + return std::make_tuple<>(inferred_layouts[0], inferred_layouts[1], true); + } else { + return std::make_tuple<>(Array<Layout>(nullptr), Array<Layout>(nullptr), false); + } +} + } // namespace relay } // namespace tvm -#endif // TVM_RELAY_PASS_ALTER_OP_LAYOUT_H_ +#endif // TVM_RELAY_PASS_INFER_LAYOUT_UTIL_H_ diff --git a/src/relay/pass/transform_layout.h b/src/relay/pass/transform_layout.h new file mode 100644 index 0000000..21a82a6 --- /dev/null +++ b/src/relay/pass/transform_layout.h @@ -0,0 +1,362 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * + * \file transform_layout.h + * \brief Common infrastructure for transforming the layouts. This is used for AlterOpLayout and + * ConvertLayout pass. */ + +#ifndef TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ +#define TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ + +#include <tvm/data_layout.h> +#include <tvm/relay/expr.h> +#include <string> +#include <unordered_map> +#include <tuple> +#include <vector> +#include "pattern_util.h" +#include "infer_layout_util.h" + +namespace tvm { +namespace relay { + +/*! + * \brief Memorizes layout transformations to reuse. + */ +class TransformMemorizerNode : public Node { + public: + /*! \brief The key for the memorizer map is (Expr, src_layout, dst_layout). */ + using TransformKey = std::tuple<const Node*, std::string, std::string>; + + struct key_hash : public std::function<std::size_t(TransformKey)> { + std::size_t operator()(const TransformKey& k) const { + return dmlc::HashCombine<std::string>( + dmlc::HashCombine<std::string>(std::hash<const Node*>()(std::get<0>(k)), std::get<1>(k)), + (std::get<2>(k))); + } + }; + + /*! \brief The memorizer map. */ + std::unordered_map<TransformKey, Expr, key_hash> memo; + + static constexpr const char* _type_key = "relay.alter_op_layout.TransformMemorizerNode"; + TVM_DECLARE_NODE_TYPE_INFO(TransformMemorizerNode, Node); +}; + +/*! + * \brief Container that transforms the layouts and memorizes them. + */ +class TransformMemorizer : public NodeRef { + public: + TransformMemorizer() {} + explicit TransformMemorizer(ObjectPtr<Object> n) : NodeRef(n) {} + + TransformMemorizerNode* operator->() { + return static_cast<TransformMemorizerNode*>(get_mutable()); + } + + /* + * \brief Memorizes and transforms the layout. + * \param expr The initial expr. + * \param src_layout The source layout. + * \param dst_layout The dest layout. + * \return The new expr with the dst layout. + */ + Expr Transform(Expr raw, const Layout& src_layout, const Layout& dst_layout) { + if (src_layout.Equals(dst_layout)) { + return raw; + } + + std::tuple<const Node*, std::string, std::string> key = + std::make_tuple<>(raw.get(), src_layout.name(), dst_layout.name()); + auto& memo = operator->()->memo; + + auto iter = memo.find(key); + if (iter != memo.end()) { + return iter->second; + } else { + Expr transform = TransformHelper(raw, src_layout, dst_layout); + memo[key] = transform; + return transform; + } + } + + /* + * \brief Helper to transform the layouts. + * \param expr The initial expr. + * \param src_layout The source layout. + * \param dst_layout The dest layout. + * \return The new expr with the dst layout. + * \note It performs following 2 operations + * 1) If src_layout ndim is smaller then dst_layout, expand_dim is inserted to match the dim + * size. For example, src_layout = C, dst_layout = NCHW16c. The src is expanded to NHWC. + * 2) Call layout transform with new src layout. + */ + Expr TransformHelper(Expr raw, Layout src_layout, Layout dst_layout) { + if (src_layout.Equals(dst_layout)) { + return raw; + } + + // 1) Check if the shape lengths are different. If yes, expand dims. + Expr input_expr = raw; + Layout new_src_layout = src_layout; + if (src_layout.ndim_primal() < dst_layout.ndim_primal()) { + int num_new_axis = dst_layout.ndim_primal() - src_layout.ndim_primal(); + new_src_layout = src_layout.ExpandPrimal(dst_layout); + input_expr = MakeExpandDims(input_expr, 0, num_new_axis); + if (new_src_layout.Equals(dst_layout)) { + return input_expr; + } + } + + // 2) Insert layout transform on the transformed src. + CHECK(new_src_layout.defined() && dst_layout.defined()) + << "Cannot insert layout transform because there are undefined layouts"; + CHECK(BijectiveLayoutNode::make(new_src_layout, dst_layout).defined()) + << "Cannot insert layout transform because there are inconvertible layouts: " + << new_src_layout << " v.s. " << dst_layout; + return MakeLayoutTransform(input_expr, new_src_layout.name(), dst_layout.name()); + } + + /*! + * \brief Defines the call transformation for derived passes. The new layouts are defined by + * used for different targets using a packed func. + * \param ref_call The original call. + * \param new_args The traversed/recursed args to the call. + * \return The new Call after calling the packed func. + */ + virtual Call CallWithNewLayouts(const Call& ref_call, const std::vector<Expr>& new_args) = 0; + using ContainerType = TransformMemorizerNode; +}; + +/* + * \brief TempExprNode during layout transform. Instance of this expr will be Realized to normal + * expr ultimately. + * \tparam TransformMemorizerT The derived TransformMemorizer type. + */ +template <class TransformMemorizerT> +class LayoutAlternatedExprNode : public TempExprNode { + public: + Expr value; + Layout old_layout; + Layout new_layout; + TransformMemorizerT memorizer; + + Expr Realize() const final { + // NOTE: use a copy to discard the "const" qualifier + TransformMemorizerT tmp_memorizer = memorizer; + // fallback to old layout + return tmp_memorizer.Transform(value, new_layout, old_layout); + } + + void VisitAttrs(AttrVisitor* v) { + v->Visit("value", &value); + v->Visit("old_layout", &old_layout); + v->Visit("new_layout", &new_layout); + } + + static constexpr const char* _type_key = "relay.alter_op_layout.LayoutAlternatedExprNode"; + TVM_DECLARE_NODE_TYPE_INFO(LayoutAlternatedExprNode, TempExprNode); +}; + +/*! + * \brief Container for the layout alternated expr. + * \tparam TransformMemorizerT The derived TransformMemorizer type. + */ +template <class TransformMemorizerT> +class LayoutAlternatedExpr : public NodeRef { + public: + LayoutAlternatedExpr() {} + explicit LayoutAlternatedExpr(ObjectPtr<Object> n) : NodeRef(n) {} + + LayoutAlternatedExprNode<TransformMemorizerT>* operator->() { + return static_cast<LayoutAlternatedExprNode<TransformMemorizerT>*>(get_mutable()); + } + + using ContainerType = LayoutAlternatedExprNode<TransformMemorizerT>; +}; + +/* + * \brief Used with ForwardRewrite to transform the expr. The input args are same as + * FForwardRewrite. + * \param ref_call The reference old call type to be rewritten. + * We can make use of the op and type information. + * \param new_args The new arguments (some of them could be TempExpr). + * \param ctx Optional context information about ref_call. + * \tparam TransformMemorizerT The derived TransformMemorizer type. + * \return The rewriten result call, can also return nullptr, + * which indicate the rewriter should use the default fallback + * rule that realizes all its input and compose the call. + * + * \note The ctx can be used to provide extra information during transformation. The ctx is + * templated to reuse across AlterOpLayout and ConvertLayout pass. The steps are + * - Extract the original layouts. + * - Use ctx transformation to get a Call with new layouts - CallWithNewLayouts. + * - Extract the new layouts from the returned Call. + * - Transform the original call to reuse the new layouts using TransformMemorizer. + */ +template <class TransformMemorizerT> +Expr LayoutRewriter(const Call& ref_call, const Array<Expr>& new_args, const NodeRef& ctx) { + std::vector<LayoutAlternatedExpr<TransformMemorizerT>> inputs; + std::vector<Expr> normal_new_args; + Array<Array<IndexExpr>> input_shapes; + + // NOTE: discard the "const" qualifier + // TransformMemorizer memorizer = Downcast<TransformMemorizer>(ctx); + // TransformMemorizerT* ctx_transformer = + // static_cast<TransformMemorizerT*>(memorizer.operator->()); + TransformMemorizerT memorizer = Downcast<TransformMemorizerT>(ctx); + + // fill incomplete state and flatten tuple + auto push_back_one_arg = [&inputs, memorizer](Expr arg) { + // We always expect LayoutAlternatedExpr<TransformMemorizerT>. + // This is used to convert the normal Expr to LayoutAlternatedExpr<TransformMemorizerT>. + if (const LayoutAlternatedExprNode<TransformMemorizerT>* inp = + arg.as<LayoutAlternatedExprNode<TransformMemorizerT>>()) { + inputs.push_back(GetRef<LayoutAlternatedExpr<TransformMemorizerT>>(inp)); + return inp->value; + } else { + auto inode = make_node<LayoutAlternatedExprNode<TransformMemorizerT>>(); + inode->value = arg; + inode->memorizer = memorizer; + inputs.push_back(LayoutAlternatedExpr<TransformMemorizerT>(inode)); + return arg; + } + }; + + for (auto new_arg : new_args) { + // NOTE: do not support nested tuple + if (new_arg->IsInstance<TupleNode>()) { + Tuple tuple_new_arg = Downcast<Tuple>(new_arg); + std::vector<Expr> fields; + for (auto x : tuple_new_arg->fields) { + Expr tmp = push_back_one_arg(x); + fields.push_back(tmp); + } + normal_new_args.push_back(TupleNode::make(fields)); + } else { + Expr tmp = push_back_one_arg(new_arg); + normal_new_args.push_back(tmp); + } + } + + // old_in, new_in = state[inputs] + Array<Layout> old_in, old_out, new_in, new_out, new_in2; + for (auto inp : inputs) { + old_in.push_back(inp->old_layout); + new_in.push_back(inp->new_layout); + } + + for (auto arg : ref_call->args) { + if (arg->IsInstance<TupleNode>()) { // flatten tuple + Tuple tuple_arg = Downcast<Tuple>(arg); + for (auto x : tuple_arg->fields) { + input_shapes.push_back(x->type_as<TensorTypeNode>()->shape); + } + } else { + input_shapes.push_back(arg->type_as<TensorTypeNode>()->shape); + } + } + + // old_in, old_out = op.infer(old_in) + bool success = false; + std::tie(old_in, old_out, success) = + InferCorrectLayouts(ref_call, Array<Layout>(nullptr), old_in, input_shapes); + if (!success) { + return Expr(nullptr); + } + CHECK_EQ(old_in.size(), new_in.size()); + + // if new_in == 'undef': new_in = old_in + for (size_t i = 0; i < new_in.size(); ++i) { + if (!new_in[i].defined()) { + new_in.Set(i, old_in[i]); + } + } + + // new_op = alter(op) + Call new_call = memorizer.CallWithNewLayouts(ref_call, normal_new_args); + + // new_in2, new_out = op.infer(new_in) + if (new_call->op->IsInstance<OpNode>()) { + success = false; + std::tie(new_in2, new_out, success) = + InferCorrectLayouts(new_call, new_in, old_in, input_shapes); + if (!success) { + return Expr(nullptr); + } + } else { + return Expr(nullptr); + } + + CHECK_EQ(new_out.size(), old_out.size()) + << "The number of output nodes should keep the same during alter_op_layout"; + CHECK_EQ(new_in.size(), new_in2.size()) + << "The number of input nodes should keep the same during alter_op_layout"; + + // if (new_in != new_in2): insert transform (new_in -> new_in2) + Array<Expr> transformed_args; + size_t pt = 0; + for (auto arg : new_call->args) { + if (arg->IsInstance<TupleNode>()) { // unflatten tuple + Tuple tuple_arg = Downcast<Tuple>(arg); + std::vector<Expr> transformed_tuple_arg; + for (auto arg_item : tuple_arg->fields) { + transformed_tuple_arg.push_back(memorizer.Transform(arg_item, new_in[pt], new_in2[pt])); + pt++; + } + transformed_args.push_back(TupleNode::make(transformed_tuple_arg)); + } else { + transformed_args.push_back(memorizer.Transform(arg, new_in[pt], new_in2[pt])); + pt++; + } + } + CHECK_EQ(pt, inputs.size()); + + // state[node] = (old_out, new_out) + // (handle tuple output) + if (ref_call->checked_type()->IsInstance<TupleTypeNode>()) { + Expr tuple_output = CallNode::make(new_call->op, transformed_args, new_call->attrs); + Array<Expr> fields; + for (size_t i = 0; i < new_out.size(); ++i) { + auto rnode = make_node<LayoutAlternatedExprNode<TransformMemorizerT>>(); + rnode->value = TupleGetItemNode::make(tuple_output, i); + rnode->old_layout = old_out[i]; + rnode->new_layout = new_out[i]; + rnode->memorizer = memorizer; + fields.push_back(Expr(rnode)); + } + return TupleNode::make(fields); + } else { + auto rnode = make_node<LayoutAlternatedExprNode<TransformMemorizerT>>(); + CHECK_EQ(new_out.size(), 1); + rnode->value = CallNode::make(new_call->op, transformed_args, new_call->attrs); + rnode->old_layout = old_out[0]; + rnode->new_layout = new_out[0]; + rnode->memorizer = memorizer; + return Expr(rnode); + } +} + +} // namespace relay +} // namespace tvm + +#endif // TVM_RELAY_PASS_TRANSFORM_LAYOUT_H_ diff --git a/tests/python/relay/test_pass_convert_op_layout.py b/tests/python/relay/test_pass_convert_op_layout.py new file mode 100644 index 0000000..9544525 --- /dev/null +++ b/tests/python/relay/test_pass_convert_op_layout.py @@ -0,0 +1,360 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Test alter op layout pass""" +import tvm + +from tvm import relay +from tvm.relay.op import register_alter_op_layout +from tvm.relay import transform, analysis + + +def run_opt_pass(expr, passes): + passes = passes if isinstance(passes, list) else [passes] + mod = relay.Module.from_expr(expr) + seq = transform.Sequential(passes) + with transform.PassContext(opt_level=3): + mod = seq(mod) + entry = mod["main"] + return entry if isinstance(expr, relay.Function) else entry.body + + +def test_no_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 64, 56, 56)) + weight = relay.var('weight', shape=(64, 64, 3, 3)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + return before() + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y = relay.Function([x, weight], y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var('weight', shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + weight = relay.layout_transform(weight, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, weight, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y = relay.layout_transform(y, 'NCHW', 'NHWC') + y = relay.Function(relay.analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_bias_pool_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + y = relay.nn.bias_add(y, bias, axis=3) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2), layout='NHWC') + y = relay.cast(y, 'int32') + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + bias = relay.var("bias", shape=(64,)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + weight = relay.layout_transform(weight, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1)) + + bias = relay.expand_dims(bias, axis=0, num_newaxis=3) + bias = relay.layout_transform(bias, 'NHWC', 'NCHW') + y = relay.add(y, bias) + # a useless tuple, which will be eliminated + y = relay.Tuple([y])[0] + y = relay.nn.relu(y) + y = relay.nn.max_pool2d(y, pool_size=(2, 2)) + y = relay.cast(y, 'int32') + y = relay.layout_transform(y, 'NCHW', 'NHWC') + y = relay.nn.batch_flatten(y) + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_conv_concat_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y1 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + ret = relay.concatenate([y, y1], axis=3) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 64)) + weight2 = relay.var('weight2', shape=(3, 3, 64, 64)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + y = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(y, weight1, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.conv2d(y, weight2, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + ret = relay.concatenate([y, y1], axis=1) + ret = relay.layout_transform(ret, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_dual_path_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(3, 3, 32, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y1 = relay.nn.relu(y1) + y2 = relay.nn.batch_flatten(y) + ret = relay.Tuple([y1, y2]) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(3, 3, 32, 32)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + y = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(y, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y1 = relay.nn.conv2d(y, weight2, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y1 = relay.nn.relu(y1) + y1 = relay.layout_transform(y1, "NCHW", "NHWC") + y2 = relay.layout_transform(y, "NCHW", "NHWC") + y2 = relay.nn.batch_flatten(y2) + ret = relay.Tuple([y1, y2]) + y = relay.Function(analysis.free_vars(ret), ret) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_bn_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + gamma = relay.var("gamma") + beta = relay.var("beta") + mean = relay.var("mean") + variance = relay.var("variance") + y, _, _ = relay.nn.batch_norm(y , gamma, beta, mean, variance, axis=3) + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + + # Check that there is only 1 NHWC to NCHW transform. + has_lt = list() + find_op = lambda x : \ + has_lt.append(isinstance(x, tvm.relay.expr.Call) and x.op.name == "layout_transform" \ + and x.attrs.src_layout == 'NCHW' and x.attrs.dst_layout == 'NHWC') + relay.analysis.post_order_visit(a, find_op) + has_lt = list(filter(lambda x: x, has_lt)) + assert len(has_lt) == 1 + + +def test_resnet_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(1, 1, 64, 32)) + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1), + data_layout='NHWC', + kernel_layout='HWIO') + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y, layout='NHWC') + return relay.Function(analysis.free_vars(y), y) + + def expected(): + x = relay.var("x", shape=(1,56, 56, 64)) + weight1 = relay.var('weight1', shape=(3, 3, 64, 32)) + weight2 = relay.var('weight2', shape=(1, 1, 64, 32)) + weight1 = relay.layout_transform(weight1, 'HWIO', 'OIHW') + weight2 = relay.layout_transform(weight2, 'HWIO', 'OIHW') + x = relay.layout_transform(x, "NHWC", "NCHW") + y = relay.nn.conv2d(x, weight1, + channels=32, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.nn.relu(y) + y2 = relay.nn.conv2d(x, weight2, + channels=32, + kernel_size=(1, 1)) + y2 = relay.nn.relu(y2) + y = y + y2 + y = relay.nn.global_max_pool2d(y) + y = relay.layout_transform(y, "NCHW", "NHWC") + return relay.Function(analysis.free_vars(y), y) + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +def test_scalar_convert_layout(): + def before(): + x = relay.var("x", shape=(1, 56, 56, 64)) + weight = relay.var("weight", shape=(3, 3, 64, 64)) + y = relay.nn.conv2d(x, weight, channels=64, kernel_size=(3, 3), padding=(1, 1), + data_layout='NHWC', kernel_layout='HWIO') + y = relay.add(y, relay.const(1, "float32")) + y = relay.Function(analysis.free_vars(y), y) + return y + + def expected(): + x = relay.var("x", shape=(1, 56, 56, 64)) + w = relay.var("weight", shape=(3, 3, 64, 64)) + x = relay.layout_transform(x, 'NHWC', 'NCHW') + w = relay.layout_transform(w, 'HWIO', 'OIHW') + y = relay.nn.conv2d(x, w, + channels=64, + kernel_size=(3, 3), + padding=(1, 1)) + y = relay.add(y, relay.const(1.0, "float32")) + + y = relay.layout_transform(y, "NCHW", "NHWC") + y = relay.Function(analysis.free_vars(y), y) + return y + + a = before() + a = run_opt_pass(a, transform.ConvertLayout('NCHW')) + b = run_opt_pass(expected(), transform.InferType()) + + assert analysis.alpha_equal(a, b), "Actual = \n" + str(a) + + +if __name__ == "__main__": + test_no_convert_layout() + test_conv_convert_layout() + test_conv_bias_pool_convert_layout() + test_conv_concat_convert_layout() + test_dual_path_convert_layout() + test_bn_convert_layout() + test_resnet_convert_layout() + test_scalar_convert_layout()