crawlingcub opened a new issue, #12627: URL: https://github.com/apache/tvm/issues/12627
The outputs of Linear layer differ when there are very large/small values in weights/inputs. PyTorch seems to be casting some extreme values to -inf, but tvm is not. Is that expected? Is there subtle way how float32/64 is being applied in Tvm vs PyTorch? Let me know if you need more info. Find the files [here](https://drive.google.com/drive/folders/1VN9xWk773GRBxEqkL6xEMpWbpCsbUid9?usp=sharing) The model here is just a linear layer: `Linear(in_features=120, out_features=84, bias=True)` Repro code: ```python import torch import tvm from tvm import relay from tvm.contrib.download import download_testdata from tvm.contrib import graph_executor import os import pickle as pkl import numpy as np import sys DEVICE='cuda' model=torch.load(os.path.join(sys.argv[1], "model.pt")) pt_inp=torch.Tensor(pkl.load(open(os.path.join(sys.argv[1], "data.pkl"), "rb"))) pt_inp=pt_inp.to(DEVICE) model.to(DEVICE) with torch.no_grad(): pt_out = model(pt_inp).cpu().numpy() model.to('cpu') pt_inp=pt_inp.to('cpu') input_name = "input0" target = tvm.target.cuda() dev = tvm.gpu(0) scripted_model = torch.jit.trace(model, pt_inp).eval() input_data = pt_inp.numpy() shape_list = [(input_name, input_data.shape)] mod, params = relay.frontend.from_pytorch(scripted_model, shape_list) # print(mod) with tvm.transform.PassContext(opt_level=3): lib = relay.build(mod, target=target, params=params) m = graph_executor.GraphModule(lib["default"](dev)) m.set_input(input_name, tvm.nd.array(pt_inp.numpy())) m.run() output = torch.tensor(m.get_output(0).asnumpy()) tvm_out = m.get_output(0).asnumpy() print(np.max(np.abs(tvm_out - pt_out)), np.mean(np.abs(tvm_out - pt_out))) print(pt_out[0][50:60]) print(tvm_out[0][50:60]) #print(pt_out) np.testing.assert_allclose(tvm_out[0][50:60], pt_out[0][50:60], rtol=1e-5, atol=1e-5) ``` On my machine, indices 50-60 show the difference. Could be different on your machine. ### Expected behavior Outputs should be almost same. ### Actual behavior Outputs are different beyond typical threshold. ```python [-5.2627316e+37 -8.1805114e+37 1.8148990e+38 -inf 3.4717639e+37 -4.2125961e+37 1.1323989e+37 -8.9163937e+37 4.2243163e+37 8.5123610e+37] [-5.2627316e+37 -8.1805114e+37 1.8148992e+38 -3.0406243e+38 3.4717624e+37 -4.2125966e+37 1.1323974e+37 -8.9163927e+37 4.2243173e+37 8.5123630e+37] Traceback (most recent call last): np.testing.assert_allclose(tvm_out[0][50:60], pt_out[0][50:60], rtol=1e-5, atol=1e-5) File " pt1121/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 1527, in assert_allclose assert_array_compare(compare, actual, desired, err_msg=str(err_msg), File "pt1121/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 774, in assert_array_compare flagged |= func_assert_same_pos(x, y, File "pt1121/lib/python3.8/site-packages/numpy/testing/_private/utils.py", line 745, in func_assert_same_pos raise AssertionError(msg) AssertionError: Not equal to tolerance rtol=1e-05, atol=1e-05 x and y -inf location mismatch: x: array([-5.262732e+37, -8.180511e+37, 1.814899e+38, -3.040624e+38, 3.471762e+37, -4.212597e+37, 1.132397e+37, -8.916393e+37, 4.224317e+37, 8.512363e+37], dtype=float32) y: array([-5.262732e+37, -8.180511e+37, 1.814899e+38, -inf, 3.471764e+37, -4.212596e+37, 1.132399e+37, -8.916394e+37, 4.224316e+37, 8.512361e+37], dtype=float32) ``` ### Environment ``` torch==1.12.1 torchvision==0.13.1 python 3.8.13 Ubuntu 18.04 ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
