I am writing a custom op to reproduce CurricularFace paper, belove is my
implementation for computing moving average t.
import numpy as np
import mxnet as mx
class CurricularFace(mx.operator.CustomOp):
def __init__(self, momentum=0.99):
super(CurricularFace, self).__init__()
self.momentum = momentum
def forward(self, is_train, req, in_data, out_data, aux):
target_logits = in_data[0]
t = aux[0]
batch = target_logits.shape[0]
aux[0][:] = self.momentum * t + (1 - self.momentum) *
target_logits.mean()
self.assign(out_data[0], req[0],
aux[0].repeat(repeats=batch).expand_dims(axis=-1))
@mx.operator.register("CurricularFaceT")
class CurricularFaceProp(mx.operator.CustomOpProp):
def __init__(self, momentum=0.99):
super(CurricularFaceProp, self).__init__(need_top_grad=False)
self.momentum = float(momentum)
def list_arguments(self):
return ['data']
def list_outputs(self):
return ['CurricularT']
def list_auxiliary_states(self):
return ['coef_t_bias']
def infer_shape(self, in_shapes):
data_shape = in_shapes[0]
batch = data_shape[0]
return [data_shape], [(batch, 1)], [(1,)]
def infer_type(self, in_type):
return [np.float32], [np.float32], [np.float32]
def create_operator(self, ctx, in_shapes, in_dtypes):
return CurricularFace(momentum=self.momentum)
def declare_backward_dependency(self, out_grad, in_data, out_data):
return []
The above code will cause the following error:
Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb)
[0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc)
[0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd)
[0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Traceback (most recent call last):
.......................
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/ndarray.py", line
1819, in wait_to_read
check_call(_LIB.MXNDArrayWaitToRead(self.handle))
File "/usr/local/lib/python3.6/dist-packages/mxnet/base.py", line 253, in
check_call
raise MXNetError(py_str(_LIB.MXGetLastError()))
mxnet.base.MXNetError: [22:49:32] src/operator/custom/custom.cc:417: Check
failed: reinterpret_cast<CustomOpFBFunc>(params.info->call
backs[kCustomOpBackward])( ptrs.size(), const_cast<void**>(ptrs.data()),
const_cast<int*>(tags.data()), reinterpret_cast<const int*>(req.data()),
static_cast<int>(ctx.is_train), params
.info->contexts[kCustomOpBackward]):
Stack trace:
[bt] (0) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x4b04cb)
[0x7f3a4764b4cb]
[bt] (1) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7fddcc)
[0x7f3a47998dcc]
[bt] (2) /usr/local/lib/python3.6/dist-packages/mxnet/libmxnet.so(+0x7e16bd)
[0x7f3a4797c6bd]
[bt] (3) /usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xb8c80) [0x7f3ab224bc80]
[bt] (4) /lib/x86_64-linux-gnu/libpthread.so.0(+0x76ba) [0x7f3b91c166ba]
[bt] (5) /lib/x86_64-linux-gnu/libc.so.6(clone+0x6d) [0x7f3b90df941d]
Error in CustomOp.backward: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020,
in backward_entry
stype=stype))
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line
1187, in _ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1
Error in CustomOp.backward: Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/mxnet/operator.py", line 1020,
in backward_entry
stype=stype))
File "/usr/local/lib/python3.6/dist-packages/mxnet/ndarray/sparse.py", line
1187, in _ndarray_cls
raise Exception("unknown storage type: %s"%stype)
Exception: unknown storage type: -1
---
[Visit
Topic](https://discuss.mxnet.apache.org/t/create-custom-op-with-auxiliary-states/6597/1)
or reply to this email to respond.
You are receiving this because you enabled mailing list mode.
To unsubscribe from these emails, [click
here](https://discuss.mxnet.apache.org/email/unsubscribe/f9038312333a4e61a818b318d3eaaf6d9bee196c814fd6d5e9d63ddec8a0cb5b).