This is an automated email from the ASF dual-hosted git repository.

iblis pushed a commit to branch ib/autograd-custom-func
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git

commit 9a9d5663d2c586df6d4a9190cbee668813e46b9d
Author: Iblis Lin <ib...@hs.ntnu.edu.tw>
AuthorDate: Wed Nov 27 10:34:45 2019 +0800

    wip
---
 julia/src/MXNet.jl    |  14 +++
 julia/src/autograd.jl | 236 ++++++++++++++++++++++++++++++++++++++++++++++++--
 julia/src/base.jl     |   1 +
 3 files changed, 243 insertions(+), 8 deletions(-)

diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl
index 89ec88b..6a01fb0 100644
--- a/julia/src/MXNet.jl
+++ b/julia/src/MXNet.jl
@@ -64,6 +64,20 @@ export NDArray,
        broadcast_axis,
        broadcast_axes
 
+# autograd.jl
+export attach_grad!,
+       backward!,
+       getgrad,
+       is_recording,
+       is_training,
+       mark_variables,
+       pause,
+       predict_mode,
+       record,
+       symbol,
+       train_mode,
+       @custom
+
 # executor.jl
 export Executor,
        bind,
diff --git a/julia/src/autograd.jl b/julia/src/autograd.jl
index 8b5edae..3a32c08 100644
--- a/julia/src/autograd.jl
+++ b/julia/src/autograd.jl
@@ -19,6 +19,9 @@
 # this is a port of Python's autograd module
 # 
https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py
 
+using Base.Meta: isexpr
+using Base.GC  # FIXME
+
 ###############################################################################
 #  Private util functions
 ###############################################################################
@@ -211,7 +214,7 @@ Compute the gradients of heads w.r.t previously marked 
variables.
 
 - `head::NDArray`: output NDArray
 
-- `head_grad::NDArray` or `Cvoid`: gradient coefficient with respect to head.
+- `head_grad::NDArray` or `Nothing`: gradient coefficient with respect to head.
 
 - `heads::Vector{NDArray}`: a list of output NDArray
 
@@ -227,11 +230,14 @@ Compute the gradients of heads w.r.t previously marked 
variables.
 backward!(head::NDArray, head_grad::NDArray; kws...) =
   backward!([head], [head_grad]; kws...)
 
-backward!(head::NDArray, head_grad::Cvoid = nothing; kws...) =
+backward!(head::NDArray, head_grad::Nothing = nothing; kws...) =
   backward!([head], head_grad; kws...)
 
-function backward!(heads::VecOfNDArray, head_grad::Cvoid;
+function backward!(heads::VecOfNDArray, ::Nothing;
                    retain_graph::Bool = false, train_mode::Bool = true)
+  cblist_ref = first(keys(_cblists))
+
+  # TODO check MXAutogradBackwardEx usage
   @mxcall(
     :MXAutogradBackwardEx,
     (MX_uint,
@@ -242,8 +248,8 @@ function backward!(heads::VecOfNDArray, head_grad::Cvoid;
      Cint,
      Cint,
      Cint,
-     Ptr{MX_handle},
-     Ptr{MX_handle}),
+     Ptr{Ptr{MX_handle}},
+     Ptr{Ptr{Cint}}),
     length(heads),
     map(x -> x.handle, heads),
     C_NULL,
@@ -279,8 +285,8 @@ function backward!(heads::VecOfNDArray, head_grads::Vector;
      Cint,
      Cint,
      Cint,
-     Ptr{MX_handle},
-     Ptr{MX_handle}),
+     Ptr{Ptr{MX_handle}},
+     Ptr{Ptr{Cint}}),
     length(output_handles),
     output_handles,
     ograd_handles,
@@ -400,5 +406,219 @@ function symbol(x::NDArray)
 end
 
 ###############################################################################
-#  TODO: User-defined differentiable function
+#  User-defined differentiable function
 ###############################################################################
+
+
+# gc-free holder
+const _cbs_r  = [Ref{Ptr{Cvoid}}(C_NULL), Ref{Ptr{Cvoid}}(C_NULL)]
+const _cbs    = [Ptr{Cvoid}(C_NULL), Ptr{Cvoid}(C_NULL)]
+const _cbsref = Ref{Ptr{Ptr{Cvoid}}}(C_NULL)
+const _frefs  = Dict()  # hold custom function instance and its args
+const _conds  = []
+
+function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, 
fptr::Ptr{Cvoid})
+  # @info "_back_wrapper"
+  # hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads)
+  # @info "_back_wrapper" hdls
+  # ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds])
+  # @info "_back_wrapper" ograds
+  # igrads = map(NDArray ∘ MX_NDArrayHandle, 
hdls[num_ograds+1:num_ograds+num_igrads])
+  # @info "_back_wrapper" igrads
+  # reqs = unsafe_wrap(Array, reqs, num_igrads)
+  # @info "_back_wrapper" reqs
+  #
+  # # passing closure via raw pointer
+  # f = unsafe_pointer_to_objref(fptr)
+  #
+  # Δs = backward!(f, ograds...)
+  # Δs = Δs isa NDArray ? [Δs] : Δs
+  #
+  # # update gradient
+  # for (i, Δ, req) ∈ zip(igrads, Δs, reqs)
+  #   req = GRAD_REQ(req)
+  #   if req == GRAD_NOP
+  #     continue
+  #   elseif req ∈ (GRAD_WRITE, GRAD_INPLACE)
+  #     i[:] = Δ
+  #   elseif req == GRAD_ADD
+  #     i[:] += Δ
+  #   end
+  # end
+  #
+  # # release ref for gc
+  # delete!(_frefs, f)
+
+  Cint(true)
+end
+
+function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, handle)
+  ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _del_wrapper(handle)
+  ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _wtf_wrapper(handle)
+  ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle)
+end
+
+function _init_customfunc()  # will be invoked in __init__
+  global _cbs_r
+  global _cbs
+  global _cbsref
+
+  # the callback function prototype:
+  # 
https://github.com/apache/incubator-mxnet/blob/ca565a00285d4fb0ca77ba9dc651a07ce1f01b24/include/mxnet/c_api.h#L209-L212
+  _cbs_r[1][] = _cbs[1] = @cfunction(_back_wrapper, Cint,
+                                     (Cint, Cint, Ptr{Ptr{Cvoid}}, Ptr{Cint},
+                                      Cint, Ptr{Cvoid}))
+  # _cbs_r[1][] = _cbs[1] = @cfunction(_wtf_wrapper, Cvoid, (Ptr{Cvoid},))
+
+  _cbs_r[2][] = _cbs[2] = @cfunction(_del_wrapper, Cint, (Ptr{Cvoid},))
+  _cbsref[] = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, _cbs)
+  @info "_init_customfunc" _cbs _cbsref[]
+end
+
+struct MXCallbackList
+  n::Cint                # int num_callbacks;
+  cbs::Ptr{Ptr{Cvoid}}   # int (**callbacks)(Cvoid);
+  ctxs::Ptr{Ptr{Cvoid}}  # void **contexts;
+
+  # we must provide two callback functions
+  # the first is backward function `_back_wrapper`
+  # the second is delete callback `_del_wrapper`
+  # 
https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182
+
+  # `ctxs` is a array which is same size as `cbs`
+  # its elements will be passed as `state` for callback functions,
+  # usually the last argument.
+  # In our case, we will push the pointer of custom func instance as
+  # first element of `ctxs`; the pointer of MXCallbackList instance as
+  # the second element.
+  # The purpose of first pointer is to pass closure into `cfunction`.
+  # The second pointer is to free the reference of MXCallbackList,
+  # and let the function instance be GC-ed properly.
+
+  function MXCallbackList(f)  # where all args are Refs
+    fr = Ref(f)
+    push!(_fholder, fr)
+    @info "f ref" Base.unsafe_convert(Ptr{Cvoid}, fr)
+    cond = Base.AsyncCondition() do cond
+      @info "real back callback"
+      A = ones(10000000)
+      for i ∈ 1:10000
+        B = A * A
+      end
+      @info "long run op end"
+    end
+    cond2 = Base.AsyncCondition() do cond
+      @info "real del callback"
+    end
+    push!(_conds, cond)
+    push!(_conds, cond2)
+    @info "conds" cond.handle cond2.handle
+    ctxs = [
+      cond.handle,
+      cond2.handle,
+    ]
+    ctxsptr = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, ctxs)
+    cblist = new(length(ctxs), _cbsref[], ctxsptr)
+    # get the reference, and make a self-reference in ctxs[2]
+    cblist_ref = Ref{MXCallbackList}(cblist)
+    ctxs[2] = Base.unsafe_convert(Ptr{Cvoid}, cblist_ref)
+    # insert ref into a holder to prevent from being GC-ed.
+    # hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`.
+    _cblists[cblist_ref] = Ref(ctxs)
+    cblist_ref
+  end
+end
+
+# hold MXCallbackList to prevent from gc
+const _cblists = Dict{Ref{MXCallbackList},Ref}()
+const _fholder = []
+
+"""
+    @custom
+Create callable custom function.
+All the position-arguments should be `NDArray`.
+The return value should be a instance of your custom type.
+Please checkout `examples/autograd/customfunc.jl` for example.
+"""
+macro custom(ex::Expr)
+  fdef = splitdef(ex)  # by MacroTools
+  sig = ex.args[1]
+  body = esc(Expr(:let, Expr(:block), ex.args[2]))  # create a new scope via 
`let`
+
+  # only extract symbols, get rid of all annotations and default values
+  args = map(x -> esc(splitarg(x)[1]), fdef[:args])
+  # forward(f, xs...)
+  forward_expr = Expr(:call, :forward, :f, args...)
+  # insert keyword args
+  if !isempty(fdef[:kwargs])
+    # only extract symbols, get rid of all annotations and default values
+    kwargs = map(fdef[:kwargs]) do x
+      sym = splitarg(x)[1]
+      Expr(:kw, sym, esc(sym))
+    end
+    append!(forward_expr.args, kwargs)
+  end
+
+  # xs, FIXME: a list of NDArray from positional argument
+  xs_len = length(args)
+  xs_expr = Expr(:vect, args...)
+
+  body′ = quote
+    f, ys = _record(false, nothing) do
+      f = $body  # f is the object instance
+      ys = $forward_expr
+      f, ys
+    end
+
+    !is_recording() && return ys
+
+    xs = $xs_expr
+    ys′ = ys isa NDArray ? [ys] : ys
+
+    # struct MXCallbackList
+    cblist_ref = MXCallbackList(f)
+
+    # gc-free
+    xsr, ysr = Ref(xs), Ref(ys′)
+    _frefs[f] = (xsr, ysr)
+    # @info _frefs
+
+    @mxcall(
+      :MXCustomFunctionRecord,
+      (Cint,            # num_inputs
+       Ref{MX_handle},  # inputs
+
+       Cint,            # num_outputs
+       Ref{MX_handle},  # outputs
+
+       Ref{MXCallbackList}),  # callbacks
+      $xs_len,
+      xs,
+
+      length(ys′),
+      ys′,
+
+      cblist_ref)
+
+    @info "inputs xs"  Base.unsafe_convert(Ref{MX_handle}, xs)
+    @info "outputs ys" Base.unsafe_convert(Ref{MX_handle}, ys′)
+
+
+    ys
+  end
+
+  GC.enable(false)  # FIXME
+
+  Expr(:function, esc(sig), body′)
+end
+
+# custom function should overload these functions.
+# the # of forward return values is the inputs of backward!.
+function forward end
+function backward! end
diff --git a/julia/src/base.jl b/julia/src/base.jl
index 6831464..d10be39 100644
--- a/julia/src/base.jl
+++ b/julia/src/base.jl
@@ -69,6 +69,7 @@ function __init__()
   _get_libmx_op_names()
   _populate_iter_creator_cache!()
   _get_lib_version!()
+  _init_customfunc()
 
   atexit() do
     # notify libmxnet we are shutting down

Reply via email to