This is an automated email from the ASF dual-hosted git repository. iblis pushed a commit to branch ib/autograd-custom-func in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
commit 9a9d5663d2c586df6d4a9190cbee668813e46b9d Author: Iblis Lin <ib...@hs.ntnu.edu.tw> AuthorDate: Wed Nov 27 10:34:45 2019 +0800 wip --- julia/src/MXNet.jl | 14 +++ julia/src/autograd.jl | 236 ++++++++++++++++++++++++++++++++++++++++++++++++-- julia/src/base.jl | 1 + 3 files changed, 243 insertions(+), 8 deletions(-) diff --git a/julia/src/MXNet.jl b/julia/src/MXNet.jl index 89ec88b..6a01fb0 100644 --- a/julia/src/MXNet.jl +++ b/julia/src/MXNet.jl @@ -64,6 +64,20 @@ export NDArray, broadcast_axis, broadcast_axes +# autograd.jl +export attach_grad!, + backward!, + getgrad, + is_recording, + is_training, + mark_variables, + pause, + predict_mode, + record, + symbol, + train_mode, + @custom + # executor.jl export Executor, bind, diff --git a/julia/src/autograd.jl b/julia/src/autograd.jl index 8b5edae..3a32c08 100644 --- a/julia/src/autograd.jl +++ b/julia/src/autograd.jl @@ -19,6 +19,9 @@ # this is a port of Python's autograd module # https://github.com/apache/incubator-mxnet/blob/master/python/mxnet/autograd.py +using Base.Meta: isexpr +using Base.GC # FIXME + ############################################################################### # Private util functions ############################################################################### @@ -211,7 +214,7 @@ Compute the gradients of heads w.r.t previously marked variables. - `head::NDArray`: output NDArray -- `head_grad::NDArray` or `Cvoid`: gradient coefficient with respect to head. +- `head_grad::NDArray` or `Nothing`: gradient coefficient with respect to head. - `heads::Vector{NDArray}`: a list of output NDArray @@ -227,11 +230,14 @@ Compute the gradients of heads w.r.t previously marked variables. backward!(head::NDArray, head_grad::NDArray; kws...) = backward!([head], [head_grad]; kws...) -backward!(head::NDArray, head_grad::Cvoid = nothing; kws...) = +backward!(head::NDArray, head_grad::Nothing = nothing; kws...) = backward!([head], head_grad; kws...) -function backward!(heads::VecOfNDArray, head_grad::Cvoid; +function backward!(heads::VecOfNDArray, ::Nothing; retain_graph::Bool = false, train_mode::Bool = true) + cblist_ref = first(keys(_cblists)) + + # TODO check MXAutogradBackwardEx usage @mxcall( :MXAutogradBackwardEx, (MX_uint, @@ -242,8 +248,8 @@ function backward!(heads::VecOfNDArray, head_grad::Cvoid; Cint, Cint, Cint, - Ptr{MX_handle}, - Ptr{MX_handle}), + Ptr{Ptr{MX_handle}}, + Ptr{Ptr{Cint}}), length(heads), map(x -> x.handle, heads), C_NULL, @@ -279,8 +285,8 @@ function backward!(heads::VecOfNDArray, head_grads::Vector; Cint, Cint, Cint, - Ptr{MX_handle}, - Ptr{MX_handle}), + Ptr{Ptr{MX_handle}}, + Ptr{Ptr{Cint}}), length(output_handles), output_handles, ograd_handles, @@ -400,5 +406,219 @@ function symbol(x::NDArray) end ############################################################################### -# TODO: User-defined differentiable function +# User-defined differentiable function ############################################################################### + + +# gc-free holder +const _cbs_r = [Ref{Ptr{Cvoid}}(C_NULL), Ref{Ptr{Cvoid}}(C_NULL)] +const _cbs = [Ptr{Cvoid}(C_NULL), Ptr{Cvoid}(C_NULL)] +const _cbsref = Ref{Ptr{Ptr{Cvoid}}}(C_NULL) +const _frefs = Dict() # hold custom function instance and its args +const _conds = [] + +function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, fptr::Ptr{Cvoid}) + # @info "_back_wrapper" + # hdls = unsafe_wrap(Array, ptrs, num_ograds + num_igrads) + # @info "_back_wrapper" hdls + # ograds = map(x -> NDArray(MX_NDArrayHandle(x), false), hdls[1:num_ograds]) + # @info "_back_wrapper" ograds + # igrads = map(NDArray ∘ MX_NDArrayHandle, hdls[num_ograds+1:num_ograds+num_igrads]) + # @info "_back_wrapper" igrads + # reqs = unsafe_wrap(Array, reqs, num_igrads) + # @info "_back_wrapper" reqs + # + # # passing closure via raw pointer + # f = unsafe_pointer_to_objref(fptr) + # + # Δs = backward!(f, ograds...) + # Δs = Δs isa NDArray ? [Δs] : Δs + # + # # update gradient + # for (i, Δ, req) ∈ zip(igrads, Δs, reqs) + # req = GRAD_REQ(req) + # if req == GRAD_NOP + # continue + # elseif req ∈ (GRAD_WRITE, GRAD_INPLACE) + # i[:] = Δ + # elseif req == GRAD_ADD + # i[:] += Δ + # end + # end + # + # # release ref for gc + # delete!(_frefs, f) + + Cint(true) +end + +function _back_wrapper(num_ograds, num_igrads, ptrs, reqs, is_train, handle) + ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle) +end + +function _del_wrapper(handle) + ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle) +end + +function _wtf_wrapper(handle) + ccall(:uv_async_send, Cint, (Ptr{Cvoid},), handle) +end + +function _init_customfunc() # will be invoked in __init__ + global _cbs_r + global _cbs + global _cbsref + + # the callback function prototype: + # https://github.com/apache/incubator-mxnet/blob/ca565a00285d4fb0ca77ba9dc651a07ce1f01b24/include/mxnet/c_api.h#L209-L212 + _cbs_r[1][] = _cbs[1] = @cfunction(_back_wrapper, Cint, + (Cint, Cint, Ptr{Ptr{Cvoid}}, Ptr{Cint}, + Cint, Ptr{Cvoid})) + # _cbs_r[1][] = _cbs[1] = @cfunction(_wtf_wrapper, Cvoid, (Ptr{Cvoid},)) + + _cbs_r[2][] = _cbs[2] = @cfunction(_del_wrapper, Cint, (Ptr{Cvoid},)) + _cbsref[] = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, _cbs) + @info "_init_customfunc" _cbs _cbsref[] +end + +struct MXCallbackList + n::Cint # int num_callbacks; + cbs::Ptr{Ptr{Cvoid}} # int (**callbacks)(Cvoid); + ctxs::Ptr{Ptr{Cvoid}} # void **contexts; + + # we must provide two callback functions + # the first is backward function `_back_wrapper` + # the second is delete callback `_del_wrapper` + # https://github.com/apache/incubator-mxnet/blob/2f8c1e83f94e84a25a48d2cd43136030fb3f2d1e/include/mxnet/c_api.h#L174-L182 + + # `ctxs` is a array which is same size as `cbs` + # its elements will be passed as `state` for callback functions, + # usually the last argument. + # In our case, we will push the pointer of custom func instance as + # first element of `ctxs`; the pointer of MXCallbackList instance as + # the second element. + # The purpose of first pointer is to pass closure into `cfunction`. + # The second pointer is to free the reference of MXCallbackList, + # and let the function instance be GC-ed properly. + + function MXCallbackList(f) # where all args are Refs + fr = Ref(f) + push!(_fholder, fr) + @info "f ref" Base.unsafe_convert(Ptr{Cvoid}, fr) + cond = Base.AsyncCondition() do cond + @info "real back callback" + A = ones(10000000) + for i ∈ 1:10000 + B = A * A + end + @info "long run op end" + end + cond2 = Base.AsyncCondition() do cond + @info "real del callback" + end + push!(_conds, cond) + push!(_conds, cond2) + @info "conds" cond.handle cond2.handle + ctxs = [ + cond.handle, + cond2.handle, + ] + ctxsptr = Base.unsafe_convert(Ptr{Ptr{Cvoid}}, ctxs) + cblist = new(length(ctxs), _cbsref[], ctxsptr) + # get the reference, and make a self-reference in ctxs[2] + cblist_ref = Ref{MXCallbackList}(cblist) + ctxs[2] = Base.unsafe_convert(Ptr{Cvoid}, cblist_ref) + # insert ref into a holder to prevent from being GC-ed. + # hold `xs` and `ys` which is passed into `MXCustomFunctionRecord`. + _cblists[cblist_ref] = Ref(ctxs) + cblist_ref + end +end + +# hold MXCallbackList to prevent from gc +const _cblists = Dict{Ref{MXCallbackList},Ref}() +const _fholder = [] + +""" + @custom +Create callable custom function. +All the position-arguments should be `NDArray`. +The return value should be a instance of your custom type. +Please checkout `examples/autograd/customfunc.jl` for example. +""" +macro custom(ex::Expr) + fdef = splitdef(ex) # by MacroTools + sig = ex.args[1] + body = esc(Expr(:let, Expr(:block), ex.args[2])) # create a new scope via `let` + + # only extract symbols, get rid of all annotations and default values + args = map(x -> esc(splitarg(x)[1]), fdef[:args]) + # forward(f, xs...) + forward_expr = Expr(:call, :forward, :f, args...) + # insert keyword args + if !isempty(fdef[:kwargs]) + # only extract symbols, get rid of all annotations and default values + kwargs = map(fdef[:kwargs]) do x + sym = splitarg(x)[1] + Expr(:kw, sym, esc(sym)) + end + append!(forward_expr.args, kwargs) + end + + # xs, FIXME: a list of NDArray from positional argument + xs_len = length(args) + xs_expr = Expr(:vect, args...) + + body′ = quote + f, ys = _record(false, nothing) do + f = $body # f is the object instance + ys = $forward_expr + f, ys + end + + !is_recording() && return ys + + xs = $xs_expr + ys′ = ys isa NDArray ? [ys] : ys + + # struct MXCallbackList + cblist_ref = MXCallbackList(f) + + # gc-free + xsr, ysr = Ref(xs), Ref(ys′) + _frefs[f] = (xsr, ysr) + # @info _frefs + + @mxcall( + :MXCustomFunctionRecord, + (Cint, # num_inputs + Ref{MX_handle}, # inputs + + Cint, # num_outputs + Ref{MX_handle}, # outputs + + Ref{MXCallbackList}), # callbacks + $xs_len, + xs, + + length(ys′), + ys′, + + cblist_ref) + + @info "inputs xs" Base.unsafe_convert(Ref{MX_handle}, xs) + @info "outputs ys" Base.unsafe_convert(Ref{MX_handle}, ys′) + + + ys + end + + GC.enable(false) # FIXME + + Expr(:function, esc(sig), body′) +end + +# custom function should overload these functions. +# the # of forward return values is the inputs of backward!. +function forward end +function backward! end diff --git a/julia/src/base.jl b/julia/src/base.jl index 6831464..d10be39 100644 --- a/julia/src/base.jl +++ b/julia/src/base.jl @@ -69,6 +69,7 @@ function __init__() _get_libmx_op_names() _populate_iter_creator_cache!() _get_lib_version!() + _init_customfunc() atexit() do # notify libmxnet we are shutting down