gemini-code-assist[bot] commented on code in PR #440:
URL: https://github.com/apache/tvm-ffi/pull/440#discussion_r2792304116
##########
python/tvm_ffi/cpp/extension.py:
##########
@@ -154,6 +154,23 @@ def _get_cuda_target() -> str:
major, minor = compute_cap.split(".")
return
f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"
except Exception:
+ try:
+ # For old drivers, there is no compute_cap, but we can use the
GPU name to determine the architecture.
+ status = subprocess.run(
+ args=["nvidia-smi", "--query-gpu=name",
"--format=csv,noheader"],
+ capture_output=True,
+ check=True,
+ )
+ gpu_name = status.stdout.decode("utf-8").strip().split("\n")[0]
+ if "A100" in gpu_name:
+ major, minor = "8", "0"
+ elif "A10" in gpu_name:
+ major, minor = "8", "6"
+ else:
+ raise Exception(f"Unsupported GPU: {gpu_name}")
+ return
f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"
+ except Exception:
+ pass
Review Comment:

The fallback logic for Ampere GPUs is a great addition. However, the
implementation can be improved for better maintainability and robustness.
1. **Hardcoded GPU checks**: The `if/elif` chain for GPU models is not
easily extensible. Using a dictionary to map GPU names to their compute
capabilities would be more maintainable.
2. **Broad exception handling**: `except Exception: pass` can hide
unexpected errors. It's better to catch specific exceptions like
`subprocess.CalledProcessError` and `FileNotFoundError`.
3. **`subprocess.run` usage**: You can use `text=True` to have `subprocess`
handle the decoding of stdout automatically.
Here is a suggested refactoring that incorporates these points. This version
is cleaner, more robust, and easier to extend with more Ampere GPUs in the
future. If no matching GPU is found in the map, it will naturally fall through
to the default `sm_70` architecture, which preserves the intended logic.
```suggestion
try:
# For old drivers, there is no compute_cap, but we can use
the GPU name to determine the architecture.
ampere_arch_map = {
"A100": ("8", "0"),
"A10": ("8", "6"),
}
status = subprocess.run(
args=["nvidia-smi", "--query-gpu=name",
"--format=csv,noheader"],
capture_output=True,
check=True,
text=True,
)
gpu_name = status.stdout.strip().split("\n")[0]
for gpu_key, (major, minor) in ampere_arch_map.items():
if gpu_key in gpu_name:
return
f"-gencode=arch=compute_{major}{minor},code=sm_{major}{minor}"
except (subprocess.CalledProcessError, FileNotFoundError):
# If nvidia-smi fails or the GPU is not in our map,
# we proceed to the default fallback.
pass
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]