tinywisdom opened a new issue, #18651:
URL: https://github.com/apache/tvm/issues/18651

   ### Summary
   
   When TVM is used in the same Python process as a PyTorch Lightning model 
that depends on `torchmetrics.Accuracy`, the process consistently segfaults 
inside LLVM initialization code.
   
   The minimal pattern is:
   
   1. Import TVM and create a CUDA target (which initializes TVM’s LLVM/CUDA 
stack).
   2. Then import `torchmetrics` and `pytorch_lightning`, define a simple 
`LightningModule`, and run a single forward pass.
   3. The process crashes with a segmentation fault during shared library 
loading, with the top of the native backtrace pointing at 
`llvm::opt::OptTable::buildPrefixChars()` and `COFFOptTable::COFFOptTable()`.
   
   Without step (1), the same PyTorch Lightning + torchmetrics code runs 
normally on this environment.
   
   ### Environment
   
   From the script output:
   
   * OS: Linux x86_64 (glibc-based, from backtrace paths such as 
`./elf/dl-open.c`)
   * Python: `3.10.16 | packaged by conda-forge | (main, Apr  8 2025, 20:53:32) 
[GCC 13.3.0]`
   * NumPy: `2.2.6`
   * PyTorch: `2.9.0+cu128`
   * TVM:
   
     * Version: `0.22.0`
     * LLVM version (reported by `tvm.support.libinfo()`): `17.0.6`
     * GIT_COMMIT_HASH: `9dbf3f22ff6f44962472f9af310fda368ca85ef2`
   * GPU / CUDA:
   
     * TVM target: `cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 
-thread_warp_size=32`
     * CUDA toolkit likely 12.8 (from PyTorch build tag `+cu128`)
   
   Installed `torchmetrics` and `pytorch_lightning` are standard pip/conda 
versions compatible with the above PyTorch release.
   
   ### Minimal Reproduction
   
   ```python
   #!/usr/bin/env python3
   # -*- coding: utf-8 -*-
   
   """
   Minimal reproducer for a segfault when using TVM together with
   PyTorch Lightning + torchmetrics (issue 70426 model).
   
   Reproduction pattern:
     1) Import TVM and create a CUDA target (so LLVM/CUDA libraries are loaded).
     2) Import torchmetrics / pytorch_lightning.
     3) Define the following LightningModule (MyModel) and run a forward pass
        with input torch.rand(B, 1, 28, 28, dtype=torch.float32).
   
   On my environment this script crashes with a segmentation fault
   (before printing the final success message).
   """
   
   import sys
   import numpy as np
   import torch
   
   import tvm
   from tvm import relax, tir  # keep the same style as in my TVM-based tools
   
   
   def print_env_info():
       print("==== Environment ====")
       print("Python:", sys.version)
       print("NumPy version:", np.__version__)
       print("Torch version:", torch.__version__)
       print("TVM version:", getattr(tvm, "__version__", "unknown"))
       try:
           from tvm import support
           info = support.libinfo()
           print("TVM LLVM version:", info.get("LLVM_VERSION", "unknown"))
           print("TVM GIT_COMMIT_HASH:", info.get("GIT_COMMIT_HASH", "unknown"))
       except Exception as e:
           print("TVM libinfo not available:", repr(e))
       print("=====================\n")
   
   
   def main():
       print_env_info()
   
       # 1) Force TVM to initialize CUDA / LLVM stack in this process
       print("[REPRO] Creating TVM cuda target to load LLVM/CUDA libraries ...")
       try:
           target = tvm.target.Target("cuda")
           print("[REPRO] TVM Target:", target)
       except Exception as e:
           print("[REPRO] Failed to create cuda target:", repr(e))
           # Even if this fails, we still continue to import the model stack.
   
       # 2) Now import torchmetrics / pytorch_lightning and define the model
       print("[REPRO] Importing torchmetrics / pytorch_lightning and defining 
MyModel ...")
       from torch import nn
       from torchmetrics import Accuracy
       import pytorch_lightning as pl
   
       class MyModel(pl.LightningModule):
           def __init__(self):
               super().__init__()
               self.encoder = nn.Sequential(
                   nn.Linear(28 * 28, 64),
                   nn.ReLU(),
                   nn.Linear(64, 3),
               )
               self.decoder = nn.Sequential(
                   nn.Linear(3, 64),
                   nn.ReLU(),
                   nn.Linear(64, 28 * 28),
               )
   
           def forward(self, x):
               # original forward from issue 70426: return embedding
               embedding = self.encoder(x.view(x.size(0), -1))
               return embedding
   
           def training_step(self, batch, batch_idx):
               # problematic code from the original Lightning training snippet
               device = self.device
               num_samples = 1000
               num_classes = 34
               Y = torch.ones(num_samples, dtype=torch.long, device=device)
               X = torch.zeros(num_samples, num_classes, device=device)
               accuracy = Accuracy(average="none", 
num_classes=num_classes).to(device)
               accuracy(X, Y)  # triggers computation during step
   
               # Original autoencoder training logic
               x, y = batch
               x = x.view(x.size(0), -1)
               z = self.encoder(x)
               x_hat = self.decoder(z)
               loss = nn.MSELoss()(x_hat, x)
               self.log("train_loss", loss)
               return loss
   
           def configure_optimizers(self):
               return torch.optim.Adam(self.parameters(), lr=1e-3)
   
       def GetInput():
           # same input shape as in the original issue:
           # torch.rand(B, 1, 28, 28, dtype=torch.float32)
           return torch.rand(32, 1, 28, 28, dtype=torch.float32)
   
       # 3) Instantiate the model and run a simple forward pass
       print("[REPRO] Instantiating MyModel and running a forward pass ...")
       model = MyModel()
       x = GetInput()
       with torch.no_grad():
           out = model(x)
       print("[REPRO] Forward output shape:", tuple(out.shape))
   
       print("[REPRO] Script finished without segfault.")
   
   
   if __name__ == "__main__":
       main()
   ```
   
   On my machine, the script crashes with a segmentation fault right after:
   
   ```text
   [REPRO] Importing torchmetrics / pytorch_lightning and defining MyModel ...
   ```
   
   The final `[REPRO] Script finished without segfault.` line is never printed.
   
   ### Actual Behavior
   
   Console output (truncated):
   
   ```text
   ==== Environment ====
   Python: 3.10.16 | packaged by conda-forge | (main, Apr  8 2025, 20:53:32) 
[GCC 13.3.0]
   NumPy version: 2.2.6
   Torch version: 2.9.0+cu128
   TVM version: 0.22.0
   TVM LLVM version: 17.0.6
   TVM GIT_COMMIT_HASH: 9dbf3f22ff6f44962472f9af310fda368ca85ef2
   =====================
   
   [REPRO] Creating TVM cuda target to load LLVM/CUDA libraries ...
   [REPRO] TVM Target: cuda -keys=cuda,gpu -arch=sm_86 -max_num_threads=1024 
-thread_warp_size=32
   [REPRO] Importing torchmetrics / pytorch_lightning and defining MyModel ...
   !!!!!!! Segfault encountered !!!!!!!
     File "./signal/../sysdeps/unix/sysv/linux/x86_64/libc_sigaction.c", line 
0, in 0x00007998d2c4251f
     File "<unknown>", line 0, in llvm::opt::OptTable::buildPrefixChars()
     File "<unknown>", line 0, in COFFOptTable::COFFOptTable()
     File "<unknown>", line 0, in _GLOBAL__sub_I_COFFDirectiveParser.cpp
     File "./elf/dl-init.c", line 70, in call_init
     File "./elf/dl-init.c", line 33, in call_init
     File "./elf/dl-init.c", line 117, in _dl_init
     File "./elf/dl-error-skeleton.c", line 182, in __GI__dl_catch_exception
     File "./elf/dl-open.c", line 808, in dl_open_worker
     File "./elf/dl-open.c", line 771, in dl_open_worker
     File "./elf/dl-error-skeleton.c", line 208, in __GI__dl_catch_exception
     File "./elf/dl-open.c", line 883, in _dl_open
     File "./dlfcn/dlopen.c", line 56, in dlopen_doit
     File "./elf/dl-error-skeleton.c", line 208, in __GI__dl_catch_exception
     File "./elf/dl-error-skeleton.c", line 227, in __GI__dl_catch_error
     File "./dlfcn/dlerror.c", line 138, in _dlerror_run
     File "./dlfcn/dlopen.c", line 71, in dlopen_implementation
     File "./dlfcn/dlopen.c", line 81, in ___dlopen
     File "/usr/local/src/conda/python-3.10.16/Python/dynload_shlib.c", line 
100, in _PyImport_FindSharedFuncptr
     File "/usr/local/src/conda/python-3.10.16/Python/importdl.c", line 137, in 
_PyImport_LoadDynamicModuleWithSpec
     ...
   
   Segmentation fault (core dumped)
   ```
   
   The key part of the native backtrace is the LLVM initialization:
   
   ```text
     in llvm::opt::OptTable::buildPrefixChars()
     in COFFOptTable::COFFOptTable()
     in _GLOBAL__sub_I_COFFDirectiveParser.cpp
   ```
   
   The segfault happens during dynamic loading of a shared library triggered by 
module import, after TVM has already created a CUDA `Target`.
   
   
   ### Triage
   
   Please refer to the list of label tags 
[here](https://github.com/apache/tvm/wiki/Issue-Triage-Labels) to find the 
relevant tags and add them below in a bullet format (example below).
   
   * needs-triage
   * bug
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to