sxjscience opened a new pull request #6696:
URL: https://github.com/apache/incubator-tvm/pull/6696


   Fix the MXNet 2.0 integration in relay. Tested the BERT and ALBERT model in 
the new GluonNLP v1 and has passed the test. Will later on add unittests in 
GluonNLP to ensure that most backbones can be run with the graph runtime.
   
   ```python
   import mxnet as mx
   import numpy as np
   import gluonnlp
   from gluonnlp.models import get_backbone
   import numpy.testing as npt
   
   mx.npx.set_np()
   
   model_cls, cfg, tokenizer, backbone_param_path, _ = 
get_backbone('google_albert_base_v2')
   
   model = model_cls.from_cfg(cfg)
   model.load_parameters(backbone_param_path)
   model.hybridize()
   
   
   batch_size = 1
   seq_length = 128
   token_ids = mx.np.random.randint(0, cfg.MODEL.vocab_size, (batch_size, 
seq_length), dtype=np.int32)
   token_types = mx.np.random.randint(0, 2, (batch_size, seq_length), 
dtype=np.int32)
   valid_length = mx.np.random.randint(seq_length // 2, seq_length, 
(batch_size,), dtype=np.int32)
   mx_out = model(token_ids, token_types, valid_length)
   
   import tvm
   from tvm import relay
   import tvm.contrib.graph_runtime as runtime
   
   shape_dict = {
       'data0': (batch_size, seq_length),
       'data1': (batch_size, seq_length),
       'data2': (batch_size,)
   }
   
   dtype_dict = {
       'data0': 'int32',
       'data1': 'int32',
       'data2': 'int32'
   }
   
   sym = model._cached_graph[1]
   
   params = {}
   for k, v in model.collect_params().items():
       params[v._var_name] = tvm.nd.array(v.data().asnumpy())
   mod, params = relay.frontend.from_mxnet(sym, shape=shape_dict, 
dtype=dtype_dict, arg_params=params)
   print(mod)
   # G4
   target = "cuda -model=t4"
   
   with relay.build_config(opt_level=3, required_pass=["FastMath"]):
       graph, lib, cparams = relay.build(mod, target, params=params)
   
   ctx = tvm.gpu()
   rt = runtime.create(graph, lib, ctx)
   rt.set_input(**cparams)
   rt.set_input(data0=token_ids, data1=token_types, data2=valid_length)
   rt.run()
   for i in range(rt.get_num_outputs()):
       out = rt.get_output(i)
       print(out.asnumpy())# verify the correctness
       npt.assert_allclose(out.asnumpy(), mx_out[i].asnumpy(), rtol=1e-3, 
atol=1e-2)
   ```


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to