Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r337828418
########## File path: tests/python/gpu/test_operator_gpu.py ########## @@ -2493,13 +2493,334 @@ def test_arange_like_dtype(): x = mx.sym.Variable('x', dtype=t) y = mx.sym.reshape(x, shape=(0, 0, -1)) z = mx.sym.contrib.arange_like(y, axis=-1) - + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) out = mod.forward(is_train=False) for v in out: assert v.dtype == t +@with_seed() +def check_multihead_attention_selfatt(bwd_ignore_zero_init): + def convert_weight(F, q_weight, k_weight, v_weight, num_heads): + q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) + k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) + v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) + all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) + all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) + return all_weights + + def convert_bias(F, q_bias, k_bias, v_bias, num_heads): + q_bias = F.reshape(q_bias, shape=(num_heads, -1)) + k_bias = F.reshape(k_bias, shape=(num_heads, -1)) + v_bias = F.reshape(v_bias, shape=(num_heads, -1)) + all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) + all_bias = F.reshape(all_bias, shape=(-1,)) + return all_bias + + dtype='float16' + batch_size = 2 + qkv_length = 7 # length of a sequence + qkv_dim = 9 # dimension of encoding + num_heads = 3 # number of attention head + head_dim = 5 # head size + out_dim = 13 * num_heads + qkv_units = num_heads * head_dim + + arg_params = { + 'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), + 'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), + 'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), + 'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), + } + + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) + qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) + qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) + qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) + att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( + qkv_proj, heads=num_heads, bwd_ignore_zero_init=bwd_ignore_zero_init) + att_score = att_score + sonde + weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( + qkv_proj, att_score, heads=num_heads, bwd_ignore_zero_init=bwd_ignore_zero_init) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.transpose(output, axes=(1, 0, 2)) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, + 'k_weight': dtype, + 'v_weight': dtype, + 'q_bias': dtype, + 'k_bias': dtype, + 'v_bias': dtype, + 'sonde': dtype}, + grad_req='write', force_rebind=True) + output_shape = executor.outputs[0].shape + output_grads = np.random.rand(*output_shape).astype(dtype) * 0.1 + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_opti = executor.outputs[0].asnumpy() + att_score_opti = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_opti.shape, dtype=dtype)]) + grads_opti = {k: v.asnumpy() for k, v in executor.grad_dict.items()} + qkv = mx.sym.Variable('qkv') + sonde = mx.sym.Variable('sonde') + q_weight = mx.sym.Variable('q_weight') + k_weight = mx.sym.Variable('k_weight') + v_weight = mx.sym.Variable('v_weight') + q_bias = mx.sym.Variable('q_bias') + k_bias = mx.sym.Variable('k_bias') + v_bias = mx.sym.Variable('v_bias') + out_weight = mx.sym.Variable('out_weight') + out_bias = mx.sym.Variable('out_bias') + + q = mx.sym.FullyConnected(qkv, weight=q_weight, bias=q_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + k = mx.sym.FullyConnected(qkv, weight=k_weight, bias=k_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + v = mx.sym.FullyConnected(qkv, weight=v_weight, bias=v_bias, flatten=False, + num_hidden=qkv_units, no_bias=False) + q = mx.sym.reshape(q, shape=(0, 0, num_heads, -1)) + q = mx.sym.transpose(q, axes=(0, 2, 1, 3)) + q = mx.sym.reshape(q, shape=(-1, 0, 0), reverse=True) + k = mx.sym.reshape(k, shape=(0, 0, num_heads, -1)) + k = mx.sym.transpose(k, axes=(0, 2, 1, 3)) + k = mx.sym.reshape(k, shape=(-1, 0, 0), reverse=True) + q = mx.sym.contrib.div_sqrt_dim(q) + att_score = mx.sym.batch_dot(q, k, transpose_b=True) + att_score = att_score + sonde + v = mx.sym.reshape(v, shape=(0, 0, num_heads, -1)) + v = mx.sym.transpose(v, axes=(0, 2, 1, 3)) + v = mx.sym.reshape(v, shape=(-1, 0, 0), reverse=True) + weighted_value = mx.sym.batch_dot(att_score, v) + weighted_value = mx.sym.reshape(weighted_value, shape=(-1, num_heads, 0, 0), + reverse=True) + weighted_value = mx.sym.transpose(weighted_value, axes=(0, 2, 1, 3)) + weighted_value = mx.sym.reshape(weighted_value, shape=(0, 0, -1)) + output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) + output = mx.sym.Group([output, att_score]) + executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + type_dict={'qkv': dtype}, + grad_req='write', force_rebind=True) + executor.copy_params_from(arg_params, {}) + executor.arg_dict['sonde'][:] = 0. + executor.arg_dict['sonde'].wait_to_read() + executor.forward(is_train=True) + output_orig = executor.outputs[0].asnumpy() + att_score_orig = executor.outputs[1].asnumpy() + executor.backward([mx.nd.array(output_grads, dtype=dtype), + mx.nd.zeros(att_score_orig.shape, dtype=dtype)]) + grads_orig = {k : v.asnumpy() for k, v in executor.grad_dict.items()} + assert_allclose(att_score_orig, att_score_opti, rtol=1e-2, atol=1e-3) + assert_allclose(output_orig, output_opti, rtol=1e-2, atol=1e-3) + + for k in grads_opti.keys(): + assert(grads_orig[k].dtype == grads_opti[k].dtype) + assert(grads_orig[k].shape == grads_opti[k].shape) + assert_allclose(grads_orig[k], grads_opti[k], rtol=1e-2, atol=1e-3) + +def test_multihead_attention_selfatt(): + #os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '0' + check_multihead_attention_selfatt(bwd_ignore_zero_init=False) + #os.environ['MXNET_EXEC_ENABLE_ADDTO'] = '1' + #check_multihead_attention_selfatt(bwd_ignore_zero_init=False) + #check_multihead_attention_selfatt(bwd_ignore_zero_init=True) Review comment: My comment : https://github.com/apache/incubator-mxnet/pull/16408#issuecomment-543947555 refer to this modification Apparently the usage of kAddTo is not working yet on this repository. and the behavior with `bwd_ignore_zero_init=True` is only correct if you enable it. I can add a little TODO comment like "to uncomment once kAddTo feature is working" ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services