This is an automated email from the ASF dual-hosted git repository.

jroesch pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e35b7fc  [Relay][Training] Make AutoDiff thread through global 
function. (#6336)
e35b7fc is described below

commit e35b7fc4bcdcfe008c5dfea60c2297b93dbff99e
Author: 雾雨魔理沙 <lol...@marisa.moe>
AuthorDate: Thu Aug 27 11:32:40 2020 -0700

    [Relay][Training] Make AutoDiff thread through global function. (#6336)
    
    * save
    
    * lint
    
    * lint
    
    * fix warning
    
    * fix test
    
    * save
---
 src/printer/doc.cc                       |   2 +-
 src/relay/transforms/gradient.cc         | 106 ++++++++++++++++++++++++-------
 tests/python/relay/test_pass_gradient.py |  41 +++++++++++-
 3 files changed, 124 insertions(+), 25 deletions(-)

diff --git a/src/printer/doc.cc b/src/printer/doc.cc
index d487e3e..ab1eddb 100644
--- a/src/printer/doc.cc
+++ b/src/printer/doc.cc
@@ -129,7 +129,7 @@ Doc Doc::Indent(int indent, Doc doc) {
 }
 
 Doc Doc::StrLiteral(const std::string& value, std::string quote) {
-  // TODO(M.K.): add escape.
+  // TODO(@M.K.): add escape.
   Doc doc;
   return doc << quote << value << quote;
 }
diff --git a/src/relay/transforms/gradient.cc b/src/relay/transforms/gradient.cc
index 7894c34..9c47254 100644
--- a/src/relay/transforms/gradient.cc
+++ b/src/relay/transforms/gradient.cc
@@ -72,7 +72,7 @@ Type WithGradientType(const Type&);
 Expr FirstOrderGradient(const Expr& e, const Optional<IRModule>& mod);
 
 Type WithGradientType(const Type& t) {
-  // TODO(M.K.): stricter checking
+  // TODO(@M.K.): stricter checking
   auto ty = t.as<FuncTypeNode>();
   CHECK(ty) << "input should be a function";
   return FuncType(ty->arg_types, TupleType({ty->ret_type, 
TupleType(ty->arg_types)}), {}, {});
@@ -85,7 +85,7 @@ Expr DeGlobal(const Optional<IRModule>& mod, const Expr& e) {
   if (mod.defined() && x) {
     BaseFunc base_func = mod.value()->Lookup(GetRef<GlobalVar>(x));
     if (auto* n = base_func.as<FunctionNode>()) {
-      return n->body;
+      return GetRef<Function>(n);
     } else {
       return e;
     }
@@ -338,11 +338,22 @@ Expr FirstOrderGradient(const Expr& re, const 
Optional<IRModule>& mod) {
 
 
TVM_REGISTER_GLOBAL("relay._transform.first_order_gradient").set_body_typed(FirstOrderGradient);
 
+Type bpt = RelayRefType(FuncType({}, TupleType(Array<Type>()), {}, {}));
+
 struct ReverseADType : TypeMutator {
   Type VisitType_(const TensorTypeNode* ttn) final {
     Type t = GetRef<Type>(ttn);
     return TupleType({t, RelayRefType(t)});
   }
+
+  Type VisitType_(const FuncTypeNode* ftn) final {
+    std::vector<Type> arg_types;
+    for (const auto& t : ftn->arg_types) {
+      arg_types.push_back(VisitType(t));
+    }
+    arg_types.push_back(bpt);
+    return FuncType(arg_types, ftn->ret_type, ftn->type_params, 
ftn->type_constraints);
+  }
 };
 
 Type ReverseType(const Type& t) { return ReverseADType()(t); }
@@ -438,12 +449,18 @@ Expr BPEmpty() {
 
 struct ReverseAD : ExprMutator {
   using ADVarMap = std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual>;
-
+  using ADGlobalVarMap = std::unordered_map<GlobalVar, GlobalVar, 
ObjectPtrHash, ObjectPtrEqual>;
+  Optional<IRModule> mod;
+  // TODO(@M.K.) refactor AD to always use mod.
   Var bp;
   std::shared_ptr<ADVarMap> ad_vars;
+  std::shared_ptr<ADGlobalVarMap> ad_gvars;
   const OpAttrMap<FPrimalGradient> rev_map = 
Op::GetAttrMap<FPrimalGradient>("FPrimalGradient");
 
-  explicit ReverseAD(const Var& bp, std::shared_ptr<ADVarMap> ad_vars) : 
bp(bp), ad_vars(ad_vars) {}
+  explicit ReverseAD(const Optional<IRModule>& mod, const Var& bp,
+                     const std::shared_ptr<ADVarMap>& ad_vars,
+                     const std::shared_ptr<ADGlobalVarMap>& ad_gvars)
+      : mod(mod), bp(bp), ad_vars(ad_vars), ad_gvars(ad_gvars) {}
 
   Expr VisitExpr_(const OpNode* op) final {
     LOG(FATAL) << "op should only be inside call";
@@ -481,9 +498,8 @@ struct ReverseAD : ExprMutator {
       Expr nbp = Function({}, LetList::With([&](LetList* ll) {
                             // we need a new ReverseAD visitor to avoid 
clobbering the bp local var
                             auto dup_bp = ll->Push(BPEmpty());
-                            ReverseAD dup_diff(dup_bp, ad_vars);
-                            auto dup_ad = 
ll->Push(dup_diff.VisitExpr(DeDup(x)));
-
+                            auto dup_ad =
+                                ll->Push(ReverseAD(mod, dup_bp, ad_vars, 
ad_gvars)(DeDup(x)));
                             TransferGrads(call->checked_type(), ret, dup_ad, 
ll);
                             ll->Push(Call(RefRead(dup_bp), {}));
                             return Call(bpv, {});
@@ -518,22 +534,29 @@ struct ReverseAD : ExprMutator {
         orig_var->checked_type_ = call->checked_type();
         auto ret = ll->Push(GetRev(call->checked_type(), orig_var, ll));
         auto bpv = ll->Push(RefRead(bp));
-        Expr nbp = Function({}, LetList::With([&](LetList* ll) {
-                              tvm::Array<Expr> rev =
-                                  rev_map[op_ref](orig, 
GetGrad(call->checked_type(), ret, ll));
-                              CHECK(args.size() == rev.size());
-                              for (size_t i = 0; i < args.size(); ++i) {
-                                UpdateGrad(call->args[i]->checked_type(), 
args[i], rev[i], ll);
-                              }
-                              return Call(bpv, {});
-                            }),
-                            TupleType::Empty(), {});
+        Expr nbp_body = LetList::With([&](LetList* ll) {
+          tvm::Array<Expr> rev = rev_map[op_ref](orig, 
GetGrad(call->checked_type(), ret, ll));
+          CHECK(args.size() == rev.size());
+          for (size_t i = 0; i < args.size(); ++i) {
+            UpdateGrad(call->args[i]->checked_type(), args[i], rev[i], ll);
+          }
+          return Call(bpv, {});
+        });
+        Expr nbp = Function({}, nbp_body, TupleType::Empty(), {});
         ll->Push(RefWrite(bp, transform::ToANormalForm(nbp)));
         // TODO(@M.K.): ToANF should be called on rev. Enhance ToANF for that.
         return ret;
       });
+    } else if (call->op.as<ConstructorNode>()) {
+      return ExprMutator::VisitExpr_(call);
+    } else {
+      std::vector<Expr> args;
+      for (const auto& arg : call->args) {
+        args.push_back(VisitExpr(arg));
+      }
+      args.push_back(bp);
+      return Call(VisitExpr(call->op), args);
     }
-    return ExprMutator::VisitExpr_(call);
   }
 
   Expr VisitExpr_(const ConstantNode* op) final {
@@ -559,6 +582,39 @@ struct ReverseAD : ExprMutator {
     return ad_vars->at(var_ref);
   }
 
+  Expr VisitExpr_(const GlobalVarNode* op) final {
+    // todo: concatenating string to add attribute seems like a brittle hack.
+    // maybe get module indexed by a rose tree of string?
+    CHECK(mod.defined());
+    auto orig_gv = GetRef<GlobalVar>(op);
+    if (ad_gvars->count(orig_gv) == 0) {
+      GlobalVar gv(op->name_hint + "_grad");
+      (*ad_gvars)[orig_gv] = gv;
+      Function orig_f = 
Downcast<Function>(DeDup(mod.value()->Lookup(orig_gv)));
+      std::vector<Var> params;
+      for (const auto& p : orig_f->params) {
+        params.push_back(Downcast<Var>(VisitExpr(p)));
+      }
+      params.push_back(bp);
+      Expr body = VisitExpr(orig_f->body);
+      Function f(params, body, VisitType(orig_f->ret_type), 
orig_f->type_params, orig_f->attrs);
+      std::cout << "gv " << op->name_hint << ": " << AsText(f, false) << 
std::endl;
+      mod.value()->Add(gv, f);
+    }
+    return ad_gvars->at(orig_gv);
+  }
+
+  Expr VisitExpr_(const FunctionNode* op) final {
+    std::vector<Var> params;
+    for (const auto& var : op->params) {
+      params.push_back(Downcast<Var>(VisitExpr(var)));
+    }
+    auto new_bp = Var("bp", bpt);
+    params.push_back(new_bp);
+    return Function(params, ReverseAD(mod, new_bp, ad_vars, 
ad_gvars)(op->body),
+                    VisitType(op->ret_type), op->type_params, op->attrs);
+  }
+
   Type VisitType(const Type& t) final { return t.defined() ? ReverseType(t) : 
t; }
 };
 
@@ -604,12 +660,16 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& 
mod) {
   }
   CHECK(!MissingGrad(e)) << "input has operators with missing gradients";
   Expr body = LetList::With([&](LetList* ll) {
-    Var bp = ll->Push(BPEmpty());
-    Expr rev = ReverseAD(bp, std::make_shared<ReverseAD::ADVarMap>())(e);
-    std::vector<Expr> args;
+    Var bp = ll->Push(BPEmpty(), bpt);
+    Expr rev = ReverseAD(mod, bp, std::make_shared<ReverseAD::ADVarMap>(),
+                         std::make_shared<ReverseAD::ADGlobalVarMap>())(e);
+    std::vector<Expr> normal_args, args;
     for (const auto& p : f->params) {
-      args.push_back(ll->Push(Pair(p, RefCreate(ZerosLike(p)))));
+      auto x = ll->Push(Pair(p, RefCreate(ZerosLike(p))));
+      normal_args.push_back(x);
+      args.push_back(x);
     }
+    args.push_back(bp);
     auto c = ll->Push(Call(rev, args));
     std::function<void(const Expr&, const Type&)> init_grad;
     init_grad = [&](const Expr& e, const Type& t) {
@@ -626,7 +686,7 @@ Expr Gradient(const Expr& re, const Optional<IRModule>& 
mod) {
     init_grad(c, f->body->checked_type());
     ll->Push(Call(RefRead(bp), {}));
     std::vector<Expr> ret;
-    for (const auto& a : args) {
+    for (const auto& a : normal_args) {
       ret.push_back(RefRead(GetField(a, 1)));
     }
     std::function<Expr(const Expr&, const Type&)> get_final_result;
diff --git a/tests/python/relay/test_pass_gradient.py 
b/tests/python/relay/test_pass_gradient.py
index 296d3e5..b239ef4 100644
--- a/tests/python/relay/test_pass_gradient.py
+++ b/tests/python/relay/test_pass_gradient.py
@@ -21,6 +21,7 @@ import pytest
 import tvm
 from tvm import te
 from tvm import relay
+from tvm.relay import GlobalVar
 from tvm.relay.analysis import free_vars, free_type_vars
 from tvm.relay import create_executor, transform
 from tvm.relay.transform import gradient
@@ -29,7 +30,7 @@ from tvm.relay.testing import add_nat_definitions, 
make_nat_expr, run_infer_type
 import tvm.relay.op as op
 
 
-def test_id():
+def test_fo_id():
     shape = (10, 10)
     dtype = 'float32'
     t = relay.TensorType(shape, dtype)
@@ -44,6 +45,21 @@ def test_id():
     tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
     tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
 
+def test_id():
+    shape = (10, 10)
+    dtype = 'float32'
+    t = relay.TensorType(shape, dtype)
+    x = relay.var("x", t)
+    func = relay.Function([x], x)
+    func = run_infer_type(func)
+    back_func = run_infer_type(gradient(func))
+    assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, 
relay.TupleType([t])]))
+    ex = create_executor()
+    x = rand(dtype, *shape)
+    forward, (grad,) = ex.evaluate(back_func)(x)
+    tvm.testing.assert_allclose(forward.asnumpy(), x.asnumpy())
+    tvm.testing.assert_allclose(grad.asnumpy(), np.ones_like(x.asnumpy()))
+
 
 def test_relu():
     shape = (10, 10)
@@ -341,5 +357,28 @@ def test_no_duplication():
     counts = count_ops(gr)
     assert counts['nn.dense'] == 3, "We expect 3 dense (1 forward, two 
backward)"
 
+
+def test_global_function():
+    m = tvm.IRModule()
+    shape = (10, 10)
+    dtype = 'float32'
+    t = relay.TensorType(shape, dtype)
+    x = relay.Var('x', t)
+    d = GlobalVar('double')
+    m[d] = relay.Function([x], x + x)
+    y = relay.Var('y', t)
+    q = GlobalVar('q')
+    m[q] = relay.Function([y], d(d(y)))
+    g = GlobalVar('grad')
+    m[g] = tvm.relay.transform.gradient(q, m)
+    back_func = m[g]
+    assert back_func.checked_type == relay.FuncType([t], relay.TupleType([t, 
relay.TupleType([t])]))
+    ex = create_executor(mod=m)
+    x = rand(dtype, *shape)
+    forward, (grad,) = ex.evaluate(back_func)(x)
+    tvm.testing.assert_allclose(forward.asnumpy(), 4 * x.asnumpy())
+    tvm.testing.assert_allclose(grad.asnumpy(), 4 * np.ones_like(x.asnumpy()))
+
+
 if __name__ == "__main__":
     pytest.main([__file__])

Reply via email to