sxjscience opened a new pull request #6699: URL: https://github.com/apache/incubator-tvm/pull/6699
Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in the new [GluonNLP 1.0](https://github.com/dmlc/gluon-nlp/tree/master) and has passed the test. I will later add unittests in GluonNLP side to ensure that the backbones can be run with the graph runtime. ```python import mxnet as mx import numpy as np import gluonnlp from gluonnlp.models import get_backbone import numpy.testing as npt mx.npx.set_np() model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone('google_albert_base_v2') model = model_cls.from_cfg(cfg) model.load_parameters(backbone_param_path) model.hybridize() batch_size = 1 seq_length = 128 token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32) token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32) mx_out = model(token_ids, token_types, valid_length) import tvm from tvm import relay import tvm.contrib.graph_runtime as runtime shape_dict = { 'data0': (batch_size, seq_length), 'data1': (batch_size, seq_length), 'data2': (batch_size,) } dtype_dict = { 'data0': 'int32', 'data1': 'int32', 'data2': 'int32' } sym = model._cached_graph[1] params = {} for k, v in model.collect_params().items(): params[v._var_name] = tvm.nd.array(v.data().asnumpy()) mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params) print(mod) # G4 target = "cuda -model=t4" with relay.build_config(opt_level=3, required_pass=["FastMath"]): graph, lib, cparams = relay.build(mod, target, params=params) ctx = tvm.gpu() rt = runtime.create(graph, lib, ctx) rt.set_input(**cparams) rt.set_input(data0=token_ids, data1=token_types, data2=valid_length) rt.run() for i in range(rt.get_num_outputs()): out = rt.get_output(i) print(out.asnumpy())# verify the correctness npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3, atol=1e-2) ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org