[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r341646687 ## File path: src/operator/contrib/transformer-inl.h ## @@ -34,6 +34,19 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { +DMLC_DECLARE_FIELD(heads) +.describe("Set number of heads"); +DMLC_DECLARE_FIELD(bwd_ignore_zero_init) +.describe("Make backward pass ignore AddTo and not init to 0. " + " /!\\ Only enable with MXNET_ENABLE_EXEC_ADDTO fonctionality") Review comment: I found a way to raise an exception while covering my edge case. By checking with a custom ParamParser, as it's called during bindind. Hope the solution suit you :) This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r341646687 ## File path: src/operator/contrib/transformer-inl.h ## @@ -34,6 +34,19 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { +DMLC_DECLARE_FIELD(heads) +.describe("Set number of heads"); +DMLC_DECLARE_FIELD(bwd_ignore_zero_init) +.describe("Make backward pass ignore AddTo and not init to 0. " + " /!\\ Only enable with MXNET_ENABLE_EXEC_ADDTO fonctionality") Review comment: I found a way to raise an exception while cover my edge case. By checking with a custom ParamParser, as it's called during bindind. Hope the solution suit you :) This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r341267939 ## File path: src/operator/contrib/transformer-inl.h ## @@ -34,6 +34,19 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { +DMLC_DECLARE_FIELD(heads) +.describe("Set number of heads"); +DMLC_DECLARE_FIELD(bwd_ignore_zero_init) +.describe("Make backward pass ignore AddTo and not init to 0. " + " /!\\ Only enable with MXNET_ENABLE_EXEC_ADDTO fonctionality") Review comment: It can happen, but probably not frequently. I can add a warning if you are fine with it. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r341182985 ## File path: src/operator/contrib/transformer-inl.h ## @@ -34,6 +34,19 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { +DMLC_DECLARE_FIELD(heads) +.describe("Set number of heads"); +DMLC_DECLARE_FIELD(bwd_ignore_zero_init) +.describe("Make backward pass ignore AddTo and not init to 0. " + " /!\\ Only enable with MXNET_ENABLE_EXEC_ADDTO fonctionality") Review comment: The problem is that MXNET_ENABLE_EXEC_ADDTO is used only during binding. So the user could potentially bind one symbol with MXNET_ENABLE_EXEC_ADDTO=0 then swap it to MXNET_ENABLE_EXEC_ADDTO=1 (for instance to bind another symbol). If I check for the flag both network will see it at 1, while only the 2nd network will actually use the functionality. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r341173650 ## File path: tests/python/gpu/test_operator_gpu.py ## @@ -2493,13 +2493,334 @@ def test_arange_like_dtype(): x = mx.sym.Variable('x', dtype=t) y = mx.sym.reshape(x, shape=(0, 0, -1)) z = mx.sym.contrib.arange_like(y, axis=-1) - + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) out = mod.forward(is_train=False) for v in out: assert v.dtype == t +@with_seed() +def check_multihead_attention_selfatt(bwd_ignore_zero_init): +def convert_weight(F, q_weight, k_weight, v_weight, num_heads): +q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) +k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) +v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) +all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) +all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) +return all_weights + +def convert_bias(F, q_bias, k_bias, v_bias, num_heads): +q_bias = F.reshape(q_bias, shape=(num_heads, -1)) +k_bias = F.reshape(k_bias, shape=(num_heads, -1)) +v_bias = F.reshape(v_bias, shape=(num_heads, -1)) +all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) +all_bias = F.reshape(all_bias, shape=(-1,)) +return all_bias + +dtype='float16' Review comment: I just added test for fp32 (the ops support fp16 and fp32), and a warning on bwd_ignore_zero_init to avoid usage without kAddTo functionality This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r338185540 ## File path: src/operator/contrib/transformer-inl.h ## @@ -34,6 +34,18 @@ namespace mxnet { namespace op { +struct InterleavedMatMulParam : public dmlc::Parameter { + int heads; + bool bwd_ignore_zero_init; + DMLC_DECLARE_PARAMETER(InterleavedMatMulParam) { +DMLC_DECLARE_FIELD(heads) +.describe("Set number of heads"); +DMLC_DECLARE_FIELD(bwd_ignore_zero_init) +.describe("Make backward pass ignore AddTo and not init to 0.") Review comment: Sure, it's kind of tricky to explain. Until we don't have the gradient accumulation feature working the user should not use this flag, it's cheating on purpose relying on the fact that the two ops of self attention are using complementary parts of the input tensor (not overlapping and the sum of all is the full tensor). If you want to use gradient accumulation with another Op you should not enable this flag. Any indication on wording would be appreciated. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r337828418 ## File path: tests/python/gpu/test_operator_gpu.py ## @@ -2493,13 +2493,334 @@ def test_arange_like_dtype(): x = mx.sym.Variable('x', dtype=t) y = mx.sym.reshape(x, shape=(0, 0, -1)) z = mx.sym.contrib.arange_like(y, axis=-1) - + mod = z.simple_bind(ctx=mx.gpu(0), x=(3, 4, 5, 6), grad_req='null') mod.arg_arrays[0][:] = np.random.normal(size=mod.arg_arrays[0].shape).astype(t) out = mod.forward(is_train=False) for v in out: assert v.dtype == t +@with_seed() +def check_multihead_attention_selfatt(bwd_ignore_zero_init): +def convert_weight(F, q_weight, k_weight, v_weight, num_heads): +q_weight = F.reshape(q_weight, shape=(num_heads, -1, 0), reverse=True) +k_weight = F.reshape(k_weight, shape=(num_heads, -1, 0), reverse=True) +v_weight = F.reshape(v_weight, shape=(num_heads, -1, 0), reverse=True) +all_weights = F.concat(q_weight, k_weight, v_weight, dim=-2) +all_weights = F.reshape(all_weights, shape=(-1, 0), reverse=True) +return all_weights + +def convert_bias(F, q_bias, k_bias, v_bias, num_heads): +q_bias = F.reshape(q_bias, shape=(num_heads, -1)) +k_bias = F.reshape(k_bias, shape=(num_heads, -1)) +v_bias = F.reshape(v_bias, shape=(num_heads, -1)) +all_bias = F.stack(q_bias, k_bias, v_bias, axis=1) +all_bias = F.reshape(all_bias, shape=(-1,)) +return all_bias + +dtype='float16' +batch_size = 2 +qkv_length = 7 # length of a sequence +qkv_dim = 9 # dimension of encoding +num_heads = 3 # number of attention head +head_dim = 5# head size +out_dim = 13 * num_heads +qkv_units = num_heads * head_dim + +arg_params = { +'qkv': mx.nd.array(np.random.rand(*(batch_size, qkv_length, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), +'q_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), +'k_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), +'v_weight': mx.nd.array(np.random.rand(*(qkv_units, qkv_dim)).astype(dtype) * 0.1, dtype=dtype), +'q_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), +'k_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), +'v_bias': mx.nd.array(np.random.rand(*(qkv_units,)).astype(dtype) * 0.1, dtype=dtype), +'out_weight': mx.nd.array(np.random.rand(*(out_dim, qkv_units)).astype(dtype) * 0.1, dtype=dtype), +'out_bias': mx.nd.array(np.random.rand(*(out_dim,)).astype(dtype) * 0.1, dtype=dtype), +} + +qkv = mx.sym.Variable('qkv') +sonde = mx.sym.Variable('sonde') +q_weight = mx.sym.Variable('q_weight') +k_weight = mx.sym.Variable('k_weight') +v_weight = mx.sym.Variable('v_weight') +q_bias = mx.sym.Variable('q_bias') +k_bias = mx.sym.Variable('k_bias') +v_bias = mx.sym.Variable('v_bias') +out_weight = mx.sym.Variable('out_weight') +out_bias = mx.sym.Variable('out_bias') +qkv_weight = convert_weight(mx.sym, q_weight, k_weight, v_weight, num_heads) +qkv_bias = convert_bias(mx.sym, q_bias, k_bias, v_bias, num_heads) +qkv = mx.sym.transpose(qkv, axes=(1, 0, 2)) +qkv_proj = mx.sym.FullyConnected(qkv, weight=qkv_weight, bias=qkv_bias, flatten=False, + num_hidden=qkv_units * 3, no_bias=False) +att_score = mx.sym.contrib.interleaved_matmul_selfatt_qk( +qkv_proj, heads=num_heads, bwd_ignore_zero_init=bwd_ignore_zero_init) +att_score = att_score + sonde +weighted_value = mx.sym.contrib.interleaved_matmul_selfatt_valatt( +qkv_proj, att_score, heads=num_heads, bwd_ignore_zero_init=bwd_ignore_zero_init) +output = mx.sym.FullyConnected(weighted_value, weight=out_weight, bias=out_bias, flatten=False, + num_hidden=out_dim, no_bias=False) +output = mx.sym.transpose(output, axes=(1, 0, 2)) +output = mx.sym.Group([output, att_score]) +executor = output.simple_bind(ctx=mx.gpu(0), + qkv=(batch_size, qkv_length, qkv_dim), + q_weight=(qkv_units, qkv_dim), + q_bias=(qkv_units,), + k_weight=(qkv_units, qkv_dim), + k_bias=(qkv_units,), + v_weight=(qkv_units, qkv_dim), + v_bias=(qkv_units,), + type_dict={'qkv': dtype, + 'q_weight': dtype, +
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r337079894 ## File path: src/operator/contrib/transformer.cc ## @@ -29,6 +29,231 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) +<< "queries_keys_values.shape[0] and attention.shape[1] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) +<< "queries_keys_values.shape[0] and attention.shape[2] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) +<< "queries_keys_values.shape[2] should be a multiple of 3, " +<< "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " + << in_shape->size() << " inputs"; + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " + << "currently is " << q_shape.ndim() << "D"; + CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(q_shape[2] * 2, kv_shape[2]) +<< "keys_values.shape[2] should be equal to queries.shape[2] * 2, " +<< "currently are: " << kv_shape[2] << " and " << q_shape[2]; + CHECK_EQ(q_shape[1], kv_shape[1]) +<< "queries.shape[1] should be equal to keys_values.shape[1], " +<< "currently are: " << q_shape[1] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " + << in_shape->size() << " inputs"; + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape.ndim(), 3U) +<< "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is " << att_shape.ndim() << "D"; + CHECK_EQ(kv_shape[0], att_shape[2]) +<< "keys_values.shape[0] should be equal to attention.shape[2], currently are " +<< kv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0])
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r336220153 ## File path: src/operator/contrib/transformer.cc ## @@ -29,6 +29,231 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) +<< "queries_keys_values.shape[0] and attention.shape[1] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) +<< "queries_keys_values.shape[0] and attention.shape[2] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) +<< "queries_keys_values.shape[2] should be a multiple of 3, " +<< "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " + << in_shape->size() << " inputs"; + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " + << "currently is " << q_shape.ndim() << "D"; + CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(q_shape[2] * 2, kv_shape[2]) +<< "keys_values.shape[2] should be equal to queries.shape[2] * 2, " +<< "currently are: " << kv_shape[2] << " and " << q_shape[2]; + CHECK_EQ(q_shape[1], kv_shape[1]) +<< "queries.shape[1] should be equal to keys_values.shape[1], " +<< "currently are: " << q_shape[1] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " + << in_shape->size() << " inputs"; + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape.ndim(), 3U) +<< "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is " << att_shape.ndim() << "D"; + CHECK_EQ(kv_shape[0], att_shape[2]) +<< "keys_values.shape[0] should be equal to attention.shape[2], currently are " +<< kv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0])
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r336058241 ## File path: src/operator/contrib/transformer.cc ## @@ -29,6 +29,231 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, +mxnet::ShapeVector* in_shape, +mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) +<< "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " +<< "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) +<< "queries_keys_values.shape[0] and attention.shape[1] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) +<< "queries_keys_values.shape[0] and attention.shape[2] should be the same, " +<< "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) +<< "queries_keys_values.shape[2] should be a multiple of 3, " +<< "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, +mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " + << in_shape->size() << " inputs"; + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " + << "currently is " << q_shape.ndim() << "D"; + CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(q_shape[2] * 2, kv_shape[2]) +<< "keys_values.shape[2] should be equal to queries.shape[2] * 2, " +<< "currently are: " << kv_shape[2] << " and " << q_shape[2]; + CHECK_EQ(q_shape[1], kv_shape[1]) +<< "queries.shape[1] should be equal to keys_values.shape[1], " +<< "currently are: " << q_shape[1] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " + << in_shape->size() << " inputs"; + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape.ndim(), 3U) +<< "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " +<< "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) +<< "Input attention should be 3D in batch-seq_length-seq_length, " +<< "currently is " << att_shape.ndim() << "D"; + CHECK_EQ(kv_shape[0], att_shape[2]) +<< "keys_values.shape[0] should be equal to attention.shape[2], currently are " +<< kv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0])
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r335057869 ## File path: .gitmodules ## @@ -26,3 +26,6 @@ [submodule "3rdparty/nvidia_cub"] path = 3rdparty/nvidia_cub url = https://github.com/NVlabs/cub.git +[submodule "3rdparty/cutlass"] Review comment: Update: We decided to drop Cutlass from now, there are some cases on which Cublas is actually working better and it make the PR more simple This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r333262363 ## File path: .gitmodules ## @@ -26,3 +26,6 @@ [submodule "3rdparty/nvidia_cub"] path = 3rdparty/nvidia_cub url = https://github.com/NVlabs/cub.git +[submodule "3rdparty/cutlass"] Review comment: Sure you can start one. On my side if there is no Cutlass then those will only work under certain dimensions. This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r333172142 ## File path: src/operator/contrib/transformer.cu ## @@ -22,12 +22,898 @@ * \file transformer.cu * \brief GPU implementation of the operators used in Transformer */ + +#include +#include +#include +#include + #include #include "./transformer-inl.h" +#include "../../common/cuda_utils.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/wmma_matrix.h" +#ifdef CUTLASS_USE_WMMA_API +#include "cutlass/gemm/wmma_gemm_traits.h" namespace mxnet { namespace op { +// gemm_switch_fp32accum and the functions called are almost fully copied from: +// MLPerf v0.6 submission repository from NVIDIA by https://github.com/kevinstephano +template +void CublasStridedBatchedGemm(mshadow::Stream* s, bool transA, bool transB, + int32_t m, int32_t n, int32_t k, + float alpha, const DType* a, int32_t lda, int32_t strideA, + const DType *b, int32_t ldb, int32_t strideB, float beta, + DType *c, int32_t ldc, int32_t strideC, int32_t batchCount, + cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP) { + using namespace mxnet::common::cuda; + CHECK_EQ(s->blas_handle_ownership_, mshadow::Stream::OwnHandle) + << "Must init CuBLAS handle in stream"; + + cublasHandle_t blas_handle = mshadow::Stream::GetBlasHandle(s); + auto err = CUBLAS_STATUS_SUCCESS; + // TODO(cfujitsang): handle computation_precision + err = cublasGemmStridedBatchedEx( + blas_handle, CublasTransposeOp(transA), CublasTransposeOp(transB), + static_cast(m), static_cast(n), static_cast(k), + reinterpret_cast(), + a, CublasType::kCudaFlag, static_cast(lda), strideA, + b, CublasType::kCudaFlag, static_cast(ldb), strideB, + reinterpret_cast(), + c, CublasType::kCudaFlag, static_cast(ldc), strideC, + static_cast(batchCount), CUDA_R_32F, algo); + CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas gemmEx fail."; +} + +template<::cutlass::MatrixLayout::Kind A_LAYOUT, + ::cutlass::MatrixLayout::Kind B_LAYOUT, + int SRC_A, int SRC_B, int DST_C, typename DType> +void CutlassGemm_FP32Accum(cudaStream_t, int32_t m, int32_t n, int32_t k, + float alpha, const DType *a, int32_t lda, + int32_t strideA, const DType *b, int32_t ldb, + int32_t strideB, float beta, DType *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + LOG(FATAL) << "Not implemented with this DType and shape (Cutlass)"; +} + + +template<::cutlass::MatrixLayout::Kind A_LAYOUT, + ::cutlass::MatrixLayout::Kind B_LAYOUT, int SRC_A, int SRC_B, int DST_C> +void CutlassGemm_FP32Accum(cudaStream_t stream, int32_t m, int32_t n, int32_t k, + float alpha, const mshadow::half::half_t *a, int32_t lda, + int32_t strideA, const mshadow::half::half_t *b, int32_t ldb, + int32_t strideB, float beta, mshadow::half::half_t *c, int32_t ldc, + int32_t strideC, int32_t batchCount) { + typedef cutlass::gemm::WmmaGemmTraits< +A_LAYOUT, +B_LAYOUT, +cutlass::Shape<32, 16, 16>, +half, +half, +half, +cutlass::gemm::LinearScaling, +float, +typename cutlass::gemm::WmmaGemmAccumulatorsPerWarp< + typename cutlass::Shape<32, 16, 16> >::Shape, + typename cutlass::Shape<16, 16, 16>, + SRC_A, // kScalarsPerLdgA_ + SRC_B, // kScalarsPerLdgB_ + SRC_A, // KScalarsPerLdsA_ + SRC_B, // KScalarsPerLdsB_ + DST_C, // kScalarsPerLdgCAndStgD_ + DST_C/2, // kScalarsPerStsD_ + DST_C/2 // kScalarsPerLdsD_ +> +WmmaGemmTraits; + + typedef cutlass::gemm::Gemm Gemm; + typename Gemm::Params params; + + + int result = params.initialize( +m, // M dimension for each batch +n, // N dimension for each batch +k, // K dimension for each batch +alpha, // scalar alpha +reinterpret_cast(a), +lda, +strideA, // distance in memory between the first element of neighboring batch +reinterpret_cast(b), +ldb, +strideB, // distance in memory between the first element of neighboring batch +beta, // scalar beta +reinterpret_cast<__half*>(c), // source matrix C +ldc, +strideC, // distance in memory between the first element of neighboring batch +reinterpret_cast<__half*>(c), // destination matrix C (may be different memory than C) +ldc, +strideC, // distance in memory between the first element of neighboring batch +batchCount); + + CHECK_EQ(result, 0) << "Failed to initialize CUTLASS Gemm::Params object."; + + //
[GitHub] [incubator-mxnet] Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention
Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r333172227 ## File path: src/operator/contrib/transformer.cu ## @@ -22,12 +22,898 @@ * \file transformer.cu * \brief GPU implementation of the operators used in Transformer */ + +#include +#include +#include +#include + #include #include "./transformer-inl.h" +#include "../../common/cuda_utils.h" + +#include "cutlass/cutlass.h" Review comment: Nice catch, gonna fix that This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services