Caenorst commented on a change in pull request #16408: Add MXNet Ops for fast multihead attention URL: https://github.com/apache/incubator-mxnet/pull/16408#discussion_r336220153
########## File path: src/operator/contrib/transformer.cc ########## @@ -29,6 +29,231 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(InterleavedMatMulParam); + +static bool InterleavedMatMulSelfAttQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); + CHECK_EQ(in_shape->size(), 1U) << "Input:[queries_keys_values] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + out_shape->resize(1); + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({params.heads * qkv_shape[1], qkv_shape[0], qkv_shape[0]})); + return true; +} + +static bool InterleavedMatMulSelfAttValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries_keys_values, attention] currently have, " + << in_shape->size() << " inputs"; + auto qkv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(qkv_shape.ndim(), 3U) + << "Input queries_keys_values should be 3D in seq_length-batch-3*proj_dim, " + << "currently is: " << qkv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) + << "Input attention should be 3D in batch-seq_length-seq_length, " + << "currently is: " << att_shape.ndim() << "D"; + CHECK_EQ(qkv_shape[0], att_shape[1]) + << "queries_keys_values.shape[0] and attention.shape[1] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[1]; + CHECK_EQ(qkv_shape[0], att_shape[2]) + << "queries_keys_values.shape[0] and attention.shape[2] should be the same, " + << "currently are " << qkv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(qkv_shape[2] % 3, 0) + << "queries_keys_values.shape[2] should be a multiple of 3, " + << "currently is " << qkv_shape[2]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({qkv_shape[0], qkv_shape[1], qkv_shape[2] / 3})); + return true; +} + +static bool InterleavedMatMulEncDecQKShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input:[queries, keys_values], currently have " + << in_shape->size() << " inputs"; + auto q_shape = in_shape->at(0); + auto kv_shape = in_shape->at(1); + CHECK_EQ(q_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-proj_dim, " + << "currently is " << q_shape.ndim() << "D"; + CHECK_EQ(kv_shape.ndim(), 3U) << "Input queries should be 3D in seq_length-batch-2*proj_dim, " + << "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(q_shape[2] * 2, kv_shape[2]) + << "keys_values.shape[2] should be equal to queries.shape[2] * 2, " + << "currently are: " << kv_shape[2] << " and " << q_shape[2]; + CHECK_EQ(q_shape[1], kv_shape[1]) + << "queries.shape[1] should be equal to keys_values.shape[1], " + << "currently are: " << q_shape[1] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({q_shape[1] * params.heads, q_shape[0], kv_shape[0]})); + return true; +} + +static bool InterleavedMatMulEncDecValAttShape(const NodeAttrs& attrs, + mxnet::ShapeVector* in_shape, + mxnet::ShapeVector* out_shape) { + const auto& params = nnvm::get<InterleavedMatMulParam>(attrs.parsed); + CHECK_EQ(in_shape->size(), 2U) << "Input: [keys_values, attention], currently have " + << in_shape->size() << " inputs"; + auto kv_shape = in_shape->at(0); + auto att_shape = in_shape->at(1); + CHECK_EQ(kv_shape.ndim(), 3U) + << "Input keys_values should be 3D in seq_length-batch-2*proj_dim, " + << "currently is " << kv_shape.ndim() << "D"; + CHECK_EQ(att_shape.ndim(), 3U) + << "Input attention should be 3D in batch-seq_length-seq_length, " + << "currently is " << att_shape.ndim() << "D"; + CHECK_EQ(kv_shape[0], att_shape[2]) + << "keys_values.shape[0] should be equal to attention.shape[2], currently are " + << kv_shape[0] << " and " << att_shape[2]; + CHECK_EQ(kv_shape[1] * params.heads, att_shape[0]) << "attention.shape[0] " + << "should be equal to keys_values.shape[1] * heads, currently are: " + << att_shape[2] << " and " << kv_shape[1]; + SHAPE_ASSIGN_CHECK(*out_shape, 0, + mxnet::TShape({att_shape[1], kv_shape[1], kv_shape[2] / 2})); + return true; +} + +NNVM_REGISTER_OP(interleaved_matmul_selfatt_qk) +.describe(R"code(Compute the matrix multiplication between the projections of +queries and keys in multihead attention use as self attention. + +the input must be a single tensor of interleaved projections +of queries, keys and values following the layout: +(seq_length, batch_size, num_heads * head_dim * 3) Review comment: Adding the concatenation reduce by about 20% the speedup due to multihead attention. I think we can think about an improvement but meanwhile that is still a speedup. I would encourage to make an analysis of LAMB coefficients difference within multihead attention blocks, maybe directly applying it on the concatenation of weights would be fine :man_shrugging: . ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services