This is an automated email from the ASF dual-hosted git repository. mbrookhart pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push: new aefa0c8 [Relay][dismantler] Added handling of packed func (#8004) aefa0c8 is described below commit aefa0c85e46fc5ed15e71805f52bf7be6e238e33 Author: Dmitriy Smirnov <dmitriy.smir...@arm.com> AuthorDate: Tue May 25 18:47:20 2021 +0100 [Relay][dismantler] Added handling of packed func (#8004) Added handling of CallNode objects created via packed functions invocation + test cases. Change-Id: I5374abc59a3b0f79f27364c45f1a5789536df940 --- include/tvm/relay/expr.h | 6 +++ src/relay/ir/expr.cc | 34 ++++++++++--- tests/cpp/relay_dismantler_test.cc | 77 +++++++++++++++++++++++++++++- tests/python/relay/test_ir_text_printer.py | 12 +++++ 4 files changed, 121 insertions(+), 8 deletions(-) diff --git a/include/tvm/relay/expr.h b/include/tvm/relay/expr.h index 17718d1..daad851 100644 --- a/include/tvm/relay/expr.h +++ b/include/tvm/relay/expr.h @@ -227,6 +227,11 @@ class Var : public Expr { class Call; /*! \brief Call container. */ class CallNode : public ExprNode { + protected: + // CallNode uses own deleter to indirectly call non-recursive destructor + Object::FDeleter saved_deleter_; + static void Deleter_(Object* ptr); + public: /*! * \brief The operator(function) being invoked @@ -290,6 +295,7 @@ class CallNode : public ExprNode { static constexpr const char* _type_key = "relay.Call"; TVM_DECLARE_FINAL_OBJECT_INFO(CallNode, ExprNode); + friend class Call; }; class Call : public Expr { diff --git a/src/relay/ir/expr.cc b/src/relay/ir/expr.cc index 62ff0b1..3b3c879 100644 --- a/src/relay/ir/expr.cc +++ b/src/relay/ir/expr.cc @@ -115,6 +115,8 @@ Call::Call(Expr op, Array<Expr> args, Attrs attrs, Array<Type> type_args, Span s n->attrs = std::move(attrs); n->type_args = std::move(type_args); n->span = std::move(span); + n->saved_deleter_ = n->deleter_; + n->deleter_ = CallNode::Deleter_; data_ = std::move(n); } @@ -288,16 +290,24 @@ inline void Dismantle(const Expr& expr) { // special handling if (const CallNode* op = node.as<CallNode>()) { - for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { - fpush_to_stack(*it); + // do not process args if used elsewhere + if (op->args.use_count() < 2) { + for (auto it = op->args.rbegin(); it != op->args.rend(); ++it) { + fpush_to_stack(*it); + } } - fpush_to_stack(op->op); } else if (const TupleNode* op = node.as<TupleNode>()) { - for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { - fpush_to_stack(*it); + // do not process fields if used elsewhere + if (op->fields.use_count() < 2) { + for (auto it = op->fields.rbegin(); it != op->fields.rend(); ++it) { + fpush_to_stack(*it); + } } } else if (const TupleGetItemNode* op = node.as<TupleGetItemNode>()) { - fpush_to_stack(op->tuple); + // do not process tuple if used elsewhere + if (op->tuple.use_count() < 2) { + fpush_to_stack(op->tuple); + } } } } @@ -306,7 +316,6 @@ inline void Dismantle(const Expr& expr) { /* * Non-recursive destructor */ - Call::~Call() { // attempt to dismantle if referenced one or zero times if (this->use_count() < 2) { @@ -316,5 +325,16 @@ Call::~Call() { } } +/* + * CallNode's deleter + */ +void CallNode::Deleter_(Object* ptr) { + auto p = reinterpret_cast<CallNode*>(ptr); + // resore original deleter + p->deleter_ = p->saved_deleter_; + // create Call reference in order to invoke ~Call + auto c = GetRef<Call>(p); +} + } // namespace relay } // namespace tvm diff --git a/tests/cpp/relay_dismantler_test.cc b/tests/cpp/relay_dismantler_test.cc index d5c089b..8c74d41 100644 --- a/tests/cpp/relay_dismantler_test.cc +++ b/tests/cpp/relay_dismantler_test.cc @@ -16,7 +16,6 @@ * specific language governing permissions and limitations * under the License. */ - #include <gtest/gtest.h> #include <tvm/ir/expr.h> #include <tvm/ir/type_functor.h> @@ -38,6 +37,8 @@ #include <tvm/topi/broadcast.h> #include <tvm/topi/generic/injective.h> +#include <memory> + using namespace tvm; using namespace tvm::relay; @@ -69,6 +70,80 @@ TEST(Relay, OutOfStack_cast) { ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); } +TEST(Relay, OutOfStack_packed_func) { + constexpr int len = 1e6; + auto foo = [] { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto add_func = tvm::runtime::Registry::Get("relay.op._make.add"); + auto y = (*add_func)(x, one); + for (int i = 0; i < len; ++i) { + y = (*add_func)(y, one); + } + + // check if still reachable + int k = 0; + Expr e = y; + while (e.defined() && e.as<CallNode>() != nullptr) { + e = e.as<CallNode>()->args[0]; + ++k; + } + ASSERT_EQ(len + 1, k); + }; + ASSERT_EXIT((foo(), exit(0)), ::testing::ExitedWithCode(0), ".*"); +} + +TEST(Relay, CallNodeSharedArgs) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Call y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Call(relu_op, {y}, Attrs(), {}); + ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size()); + y = relay::Call(y.get()->op, y.get()->args, y.get()->attrs, y.get()->type_args); + ASSERT_EQ(1, y.get()->args[0].as<CallNode>()->args.size()); +} + +TEST(Relay, TupleSharedFields) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Call(relu_op, {y}, Attrs(), {}); + { + Expr y1 = relay::Tuple(y.as<CallNode>()->args); + Expr y2 = relay::Tuple(y.as<CallNode>()->args); + + y1 = relay::Call(relu_op, {y1}); + y2 = relay::Call(relu_op, {y2}); + y = y1; + } + ASSERT_EQ(1, y.as<CallNode>()->args[0].as<TupleNode>()->fields[0].as<CallNode>()->args.size()); +} + +TEST(Relay, TupleiGetItemSharedTuple) { + auto x = relay::Var("x", relay::TensorType({3, 2}, DataType::Float(32))); + auto one = relay::Constant(tvm::runtime::NDArray::Empty({1}, {kDLFloat, 32, 1}, {kDLCPU, 0})); + auto relu_op = relay::Op::Get("nn.relu"); + Expr y = relay::Call(relu_op, {x}, Attrs(), {}); + y = relay::Tuple({y}); + { + Expr y1 = relay::TupleGetItem(y, 0); + Expr y2 = relay::TupleGetItem(y, 0); + + y1 = relay::Call(relu_op, {y1}); + y2 = relay::Call(relu_op, {y2}); + y = y1; + } + ASSERT_EQ(1, y.as<CallNode>() + ->args[0] + .as<TupleGetItemNode>() + ->tuple.as<TupleNode>() + ->fields[0] + .as<CallNode>() + ->args.size()); +} + int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); testing::FLAGS_gtest_death_test_style = "threadsafe"; diff --git a/tests/python/relay/test_ir_text_printer.py b/tests/python/relay/test_ir_text_printer.py index b2ae286..4968660 100644 --- a/tests/python/relay/test_ir_text_printer.py +++ b/tests/python/relay/test_ir_text_printer.py @@ -30,6 +30,7 @@ SEMVER = '#[version = "0.0.5"]\n' def astext(program, unify_free_vars=False): text = program.astext() + print(text) if isinstance(program, Expr): roundtrip_program = tvm.parser.parse_expr(text) @@ -47,6 +48,17 @@ def show(text): print(text) +def test_large_graph(): + x = relay.var("x", shape=(3, 2)) + y = relay.var("y") + one = relay.const(10e10, dtype="float32") + z = relay.add(x, one) + for i in range(int(1e6)): + z = relay.add(z, one) + f = relay.Function([x, y], z) + show(astext(f)) + + def test_func(): x = relay.var("x", shape=(3, 2)) y = relay.var("y")