sxjscience opened a new pull request #6696: URL: https://github.com/apache/incubator-tvm/pull/6696
Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in the new GluonNLP v1 and has passed the test. Will later on add unittests in GluonNLP to ensure that most backbones can be run with the graph runtime. ```python import mxnet as mx import numpy as np import gluonnlp from gluonnlp.models import get_backbone import numpy.testing as npt mx.npx.set_np() model_cls, cfg, tokenizer, backbone_param_path, _ = get_backbone('google_albert_base_v2') model = model_cls.from_cfg(cfg) model.load_parameters(backbone_param_path) model.hybridize() batch_size = 1 seq_length = 128 token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, seq_length), dtype=np.int32) token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), dtype=np.int32) valid_length = mx.np.random.randint(seq_length // 2, seq_length, (batch_size,), dtype=np.int32) mx_out = model(token_ids, token_types, valid_length) import tvm from tvm import relay import tvm.contrib.graph_runtime as runtime shape_dict = { 'data0': (batch_size, seq_length), 'data1': (batch_size, seq_length), 'data2': (batch_size,) } dtype_dict = { 'data0': 'int32', 'data1': 'int32', 'data2': 'int32' } sym = model._cached_graph[1] params = {} for k, v in model.collect_params().items(): params[v._var_name] = tvm.nd.array(v.data().asnumpy()) mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, dtype=dtype_dict, arg_params=params) print(mod) # G4 target = "cuda -model=t4" with relay.build_config(opt_level=3, required_pass=["FastMath"]): graph, lib, cparams = relay.build(mod, target, params=params) ctx = tvm.gpu() rt = runtime.create(graph, lib, ctx) rt.set_input(**cparams) rt.set_input(data0=token_ids, data1=token_types, data2=valid_length) rt.run() for i in range(rt.get_num_outputs()): out = rt.get_output(i) print(out.asnumpy())# verify the correctness npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3, atol=1e-2) ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org