I am writing a custom op to reproduce CurricularFace paper, belove is my 
implementation for computing moving average t.

    import numpy as np
    import mxnet as mx


    class CurricularFace(mx.operator.CustomOp):
        def __init__(self, momentum=0.99):
            super(CurricularFace, self).__init__()
            self.momentum = momentum

        def forward(self, is_train, req, in_data, out_data, aux):
            target_logits = in_data[0]
            t = aux[0]
            batch = target_logits.shape[0]
            aux[0][:] = self.momentum * t + (1 - self.momentum) * 
target_logits.mean()
            self.assign(out_data[0], req[0], 
aux[0].repeat(repeats=batch).expand_dims(axis=-1))


    @mx.operator.register("CurricularFaceT")
    class CurricularFaceProp(mx.operator.CustomOpProp):
        def __init__(self, momentum=0.99):
            super(CurricularFaceProp, self).__init__(need_top_grad=False)
            self.momentum = float(momentum)

        def list_arguments(self):
            return ['data']

        def list_outputs(self):
            return ['CurricularT']

        def list_auxiliary_states(self):
            return ['coef_t_bias']

        def infer_shape(self, in_shapes):
            data_shape = in_shapes[0]
            batch = data_shape[0]
            return [data_shape], [(batch, 1)], [(1,)]

        def infer_type(self, in_type):
            return [np.float32], [np.float32], [np.float32]

        def create_operator(self, ctx, in_shapes, in_dtypes):
            return CurricularFace(momentum=self.momentum)
        
        def declare_backward_dependency(self, out_grad, in_data, out_data):
            return []

The above code will cause the following error:

Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) 
[0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) 
[0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) 
[0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Traceback (most recent call last):
.......................
 File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py", line 
1819, in wait_to_read
   check_call(_LIB.MXNDArrayWaitToRead(self.handle))
 File "/usr/local/lib/python3.6/dist-packages/mxnet/base.py", line 253, in 
check_call
   raise MXNetError(py_str(_LIB.MXGetLastError()))
 mxnet.base.MXNetError: [22:49:32] src/operator/custom/custom.cc:417: Check 
failed: reinterpret_cast<CustomOpFBFunc>(params.info->call
backs[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()), 
const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()), 
static_cast<int>(ctx.is_train), params
.info->contexts[kCustomOpBackward]): 
 Stack trace:
   [bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb) 
[0x7f3a4764b4cb]
   [bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc) 
[0x7f3a47998dcc]
   [bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd) 
[0x7f3a4797c6bd]
   [bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
   [bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
   [bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
 Error in CustomOp.backward: Traceback (most recent call last):
   File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, 
in backward_entry
     stype=stype))
   File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 
1187, in _ndarray_cls
     raise Exception("unknown storage type: %s"%stype)
 Exception: unknown storage type: -1
 Error in CustomOp.backward: Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020, 
in backward_entry
     stype=stype))
   File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line 
1187, in _ndarray_cls
    raise Exception("unknown storage type: %s"%stype)
     Exception: unknown storage type: -1





---
[Visit 
Topic](https://discuss.mxnet.apache.org/t/create-custom-op-with-auxiliary-states/6597/1)
 or reply to this email to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.mxnet.apache.org/email/unsubscribe/f9038312333a4e61a818b318d3eaaf6d9bee196c814fd6d5e9d63ddec8a0cb5b).

Reply via email to