This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 6887892 chore: Use `importlib.metadata` to locate DSOs (#306)
6887892 is described below
commit 6887892d888e0f69df8bd0a8167c5ebe98873a0b
Author: Junru Shao <[email protected]>
AuthorDate: Sat Dec 6 18:03:19 2025 -0800
chore: Use `importlib.metadata` to locate DSOs (#306)
Should fix #250.
This PR introduces `tvm_ffi.lib_info.load_lib_ctypes`, which can be
reused by downstream libraries to locate their DSOs if they are linked
with tvm_ffi package.
For packages that are shipped with tvm-ffi, the one-liner below could
load it into ctypes:
```
LIB = tvm_ffi.libinfo.load_lib_ctypes("apache-tvm-ffi", "tvm_ffi",
"RTLD_LOCAL")
```
Here
- `apache-tvm-ffi` is the name of the package
- `tvm-ffi` is the target name of the shared library
- `mode` is ctypes opening mode
---
python/tvm_ffi/__init__.py | 9 +-
python/tvm_ffi/base.py | 56 ----
python/tvm_ffi/config.py | 19 +-
python/tvm_ffi/cpp/extension.py | 7 +-
python/tvm_ffi/libinfo.py | 365 +++++++++++++--------
.../utils/_build_optional_torch_c_dlpack.py | 2 +-
tests/python/test_dlpack_exchange_api.py | 4 +-
tests/python/test_libinfo.py | 119 +++++++
tests/python/test_optional_torch_c_dlpack.py | 2 +-
9 files changed, 362 insertions(+), 221 deletions(-)
diff --git a/python/tvm_ffi/__init__.py b/python/tvm_ffi/__init__.py
index 95ca145..8d68b5d 100644
--- a/python/tvm_ffi/__init__.py
+++ b/python/tvm_ffi/__init__.py
@@ -30,11 +30,13 @@ try:
except ImportError:
pass
-# base always go first to load the libtvm_ffi
-from . import base
+# Always load base libtvm_ffi before any other imports
from . import libinfo
-# package init part
+LIB = libinfo.load_lib_ctypes("apache-tvm-ffi", "tvm_ffi", "RTLD_GLOBAL")
+
+
+# Enable package initialization
from .registry import (
register_object,
register_global_func,
@@ -90,6 +92,7 @@ except ImportError:
__version_tuple__ = (0, 0, 0, "dev0", "7d34eb8ab.d20250913")
__all__ = [
+ "LIB",
"Array",
"DLDeviceType",
"Device",
diff --git a/python/tvm_ffi/base.py b/python/tvm_ffi/base.py
deleted file mode 100644
index 2e2ece3..0000000
--- a/python/tvm_ffi/base.py
+++ /dev/null
@@ -1,56 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# coding: utf-8
-"""Base library for TVM FFI."""
-
-import ctypes
-import logging
-import os
-import sys
-
-from . import libinfo
-
-logger = logging.getLogger(__name__)
-
-# ----------------------------
-# Python3 version.
-# ----------------------------
-if sys.version_info[:2] < (3, 8): # noqa: UP036
- # Disables ruff(UP036): Version block is outdated for minimum Python
version
- # This is to ensure that the error message is sufficiently user friendly
- PY3STATEMENT = "The minimal Python requirement is Python 3.8"
- raise Exception(PY3STATEMENT)
-
-# ----------------------------
-# library loading
-# ----------------------------
-
-
-def _load_lib() -> ctypes.CDLL:
- """Load the tvm_ffi shared library by searching likely paths."""
- lib_path = libinfo.find_libtvm_ffi()
- # The dll search path need to be added explicitly in windows
- if sys.platform.startswith("win32"):
- for path in libinfo.get_dll_directories():
- os.add_dll_directory(path)
-
- lib = ctypes.CDLL(lib_path, ctypes.RTLD_GLOBAL)
- return lib
-
-
-# library instance
-_LIB = _load_lib()
diff --git a/python/tvm_ffi/config.py b/python/tvm_ffi/config.py
index 9a418b5..15a53d3 100644
--- a/python/tvm_ffi/config.py
+++ b/python/tvm_ffi/config.py
@@ -23,13 +23,10 @@ from pathlib import Path
from . import libinfo
-def find_windows_implib() -> str:
- """Find and return the Windows import library path for tvm_ffi.lib."""
- libdir = Path(libinfo.find_libtvm_ffi()).parent
- implib = libdir / "tvm_ffi.lib"
- if not implib.is_file():
- raise RuntimeError(f"Cannot find imp lib {implib}")
- return str(implib)
+def _find_libdir() -> str:
+ """Find the library directory for tvm-ffi."""
+ libtvm_ffi = Path(libinfo.find_libtvm_ffi())
+ return str(libtvm_ffi.parent)
def __main__() -> None: # noqa: PLR0912
@@ -68,10 +65,10 @@ def __main__() -> None: # noqa: PLR0912
if args.cmakedir:
print(libinfo.find_cmake_path())
if args.libdir:
- print(Path(libinfo.find_libtvm_ffi()).parent)
+ print(_find_libdir())
if args.libfiles:
if sys.platform.startswith("win32"):
- print(find_windows_implib())
+ print(libinfo.find_windows_implib())
else:
print(libinfo.find_libtvm_ffi())
if args.sourcedir:
@@ -88,12 +85,12 @@ def __main__() -> None: # noqa: PLR0912
print(f"-I{include_dir} -I{dlpack_include_dir}")
if args.libs:
if sys.platform.startswith("win32"):
- print(find_windows_implib())
+ print(libinfo.find_windows_implib())
else:
print("-ltvm_ffi")
if args.ldflags:
if not sys.platform.startswith("win32"):
- print(f"-L{Path(libinfo.find_libtvm_ffi()).parent}")
+ print(f"-L{_find_libdir()}")
if __name__ == "__main__":
diff --git a/python/tvm_ffi/cpp/extension.py b/python/tvm_ffi/cpp/extension.py
index d06caf3..b03ce67 100644
--- a/python/tvm_ffi/cpp/extension.py
+++ b/python/tvm_ffi/cpp/extension.py
@@ -234,10 +234,9 @@ def _generate_ninja_build( # noqa: PLR0915, PLR0912
) -> str:
"""Generate the content of build.ninja for building the module."""
default_include_paths = [find_include_path(), find_dlpack_include_path()]
-
- tvm_ffi_lib = find_libtvm_ffi()
- tvm_ffi_lib_path = str(Path(tvm_ffi_lib).parent)
- tvm_ffi_lib_name = Path(tvm_ffi_lib).stem
+ tvm_ffi_lib = Path(find_libtvm_ffi())
+ tvm_ffi_lib_path = str(tvm_ffi_lib.parent)
+ tvm_ffi_lib_name = tvm_ffi_lib.stem
if IS_WINDOWS:
default_cflags = [
"/std:c++17",
diff --git a/python/tvm_ffi/libinfo.py b/python/tvm_ffi/libinfo.py
index dcbfda4..35abd2c 100644
--- a/python/tvm_ffi/libinfo.py
+++ b/python/tvm_ffi/libinfo.py
@@ -17,84 +17,166 @@
"""Utilities to locate tvm_ffi libraries, headers, and helper include paths.
This module also provides helpers to locate and load platform-specific shared
-libraries by a base name (e.g., ``tvm_ffi`` -> ``libtvm_ffi.so`` on Linux).
+libraries by a target_name (e.g., ``tvm_ffi`` -> ``libtvm_ffi.so`` on Linux).
"""
from __future__ import annotations
+import ctypes
+import importlib.metadata as im
import os
import sys
from pathlib import Path
+from typing import Callable
-def split_env_var(env_var: str, split: str) -> list[str]:
- """Split an environment variable string.
-
- Parameters
- ----------
- env_var
- Name of environment variable.
-
- split
- String to split env_var on.
+def find_libtvm_ffi() -> str:
+ """Find libtvm_ffi.
Returns
-------
- splits
- If env_var exists, split env_var. Otherwise, empty list.
+ path
+ The full path to the located library.
"""
- if os.environ.get(env_var, None):
- return [p.strip() for p in os.environ[env_var].split(split)]
- return []
+ candidate = _find_library_by_basename("apache-tvm-ffi", "tvm_ffi")
+ if ret := _resolve_and_validate([candidate], cond=lambda _: True):
+ return ret
+ raise RuntimeError("Cannot find libtvm_ffi")
-def get_dll_directories() -> list[str]:
- """Get the possible dll directories."""
- ffi_dir = Path(__file__).expanduser().resolve().parent
- dll_path: list[Path] = [ffi_dir / "lib"]
- dll_path.append(ffi_dir / ".." / ".." / "build" / "lib")
- # in source build from parent if needed
- dll_path.append(ffi_dir / ".." / ".." / ".." / "build" / "lib")
- if sys.platform.startswith("linux") or sys.platform.startswith("freebsd"):
- dll_path.extend(Path(p) for p in split_env_var("LD_LIBRARY_PATH", ":"))
- dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
- elif sys.platform.startswith("darwin"):
- dll_path.extend(Path(p) for p in split_env_var("DYLD_LIBRARY_PATH",
":"))
- dll_path.extend(Path(p) for p in split_env_var("PATH", ":"))
- elif sys.platform.startswith("win32"):
- dll_path.extend(Path(p) for p in split_env_var("PATH", ";"))
+def find_windows_implib() -> str:
+ """Find and return the Windows import library path for tvm_ffi.lib."""
+ # implib = _find_library_by_basename("apache-tvm-ffi", "tvm_ffi").parent /
"tvm_ffi.lib"
+ # ret = _resolve_to_str(implib)
+ candidate = _find_library_by_basename("apache-tvm-ffi", "tvm_ffi").parent
/ "tvm_ffi.lib"
+ if ret := _resolve_and_validate([candidate], cond=lambda _: True):
+ return ret
+ raise RuntimeError("Cannot find implib tvm_ffi.lib")
- valid_paths = []
- for path in dll_path:
- try:
- if path.is_dir():
- valid_paths.append(str(path.resolve()))
- except OSError:
- # need to ignore as resolve may fail if
- # we don't have permission to access it
- pass
- return valid_paths
+
+def find_source_path() -> str:
+ """Find packaged source home path."""
+ if ret := _resolve_and_validate(
+ paths=[
+ _rel_top_directory(),
+ _dev_top_directory(),
+ ],
+ cond=lambda p: (p / "cmake").is_dir(),
+ ):
+ return ret
+ raise RuntimeError("Cannot find home path.")
-def find_libtvm_ffi() -> str:
- """Find libtvm_ffi.
+def find_cmake_path() -> str:
+ """Find the preferred cmake path."""
+ if ret := _resolve_and_validate(
+ paths=[
+ _rel_top_directory() / "share" / "cmake" / "tvm_ffi", # Standard
install
+ _dev_top_directory() / "cmake", # Development mode
+ ],
+ cond=lambda p: p.is_dir(),
+ ):
+ return ret
+ raise RuntimeError("Cannot find cmake path.")
+
+
+def find_include_path() -> str:
+ """Find header files for C compilation."""
+ if ret := _resolve_and_validate(
+ paths=[
+ _rel_top_directory() / "include",
+ _dev_top_directory() / "include",
+ ],
+ cond=lambda p: p.is_dir(),
+ ):
+ return ret
+ raise RuntimeError("Cannot find include path.")
+
+
+def find_dlpack_include_path() -> str:
+ """Find dlpack header files for C compilation."""
+ if ret := _resolve_and_validate(
+ paths=[
+ _rel_top_directory() / "include",
+ _dev_top_directory() / "3rdparty" / "dlpack" / "include",
+ ],
+ cond=lambda p: (p / "dlpack").is_dir(),
+ ):
+ return ret
+ raise RuntimeError("Cannot find dlpack include path.")
+
+
+def find_cython_lib() -> str:
+ """Find the path to tvm cython."""
+ from tvm_ffi import core # noqa: PLC0415
+
+ try:
+ return str(Path(core.__file__).resolve())
+ except OSError:
+ pass
+ raise RuntimeError("Cannot find tvm cython path.")
+
+
+def find_python_helper_include_path() -> str:
+ """Find header files for C compilation."""
+ if ret := _resolve_and_validate(
+ paths=[
+ _rel_top_directory() / "include",
+ _dev_top_directory() / "python" / "tvm_ffi" / "cython",
+ ],
+ cond=lambda p: (p / "tvm_ffi_python_helpers.h").is_file(),
+ ):
+ return ret
+ raise RuntimeError("Cannot find python helper include path.")
+
+
+def include_paths() -> list[str]:
+ """Find all include paths needed for FFI related compilation."""
+ return sorted(
+ {
+ find_include_path(),
+ find_dlpack_include_path(),
+ find_python_helper_include_path(),
+ }
+ )
+
+
+def load_lib_ctypes(package: str, target_name: str, mode: str) -> ctypes.CDLL:
+ """Load the tvm_ffi shared library by searching likely paths.
+
+ Parameters
+ ----------
+ package
+ The package name where the library is expected to be found. For
example,
+ ``"apache-tvm-ffi"`` is the package name of `tvm-ffi`.
+ target_name
+ Name of the CMake target, e.g., ``"tvm_ffi"``. It is used to derive
the platform-specific
+ shared library name, e.g., ``"libtvm_ffi.so"`` on Linux,
``"tvm_ffi.dll"`` on Windows.
+ mode
+ The mode to load the shared library. See `ctypes.${MODE}` for details.
+ Usually it is either ``"RTLD_LOCAL"`` or ``"RTLD_GLOBAL"``.
Returns
-------
- path
- The full path to the located library.
+ The loaded shared library.
"""
- return find_library_by_basename("tvm_ffi")
+ lib_path: Path = _find_library_by_basename(package, target_name)
+ # The dll search path need to be added explicitly in windows
+ if sys.platform.startswith("win32"):
+ os.add_dll_directory(str(lib_path.parent))
+ return ctypes.CDLL(str(lib_path), getattr(ctypes, mode))
-def find_library_by_basename(base: str) -> str:
- """Find a shared library by base name across known directories.
+def _find_library_by_basename(package: str, target_name: str) -> Path: #
noqa: PLR0912
+ """Find a shared library by target_name name across known directories.
Parameters
----------
- base
+ package
+ The package name where the library is expected to be found.
+ target_name
Base name (e.g., ``"tvm_ffi"`` or ``"tvm_ffi_testing"``).
Returns
@@ -108,111 +190,108 @@ def find_library_by_basename(base: str) -> str:
If the library cannot be found in any of the candidate directories.
"""
- dll_path = [Path(p) for p in get_dll_directories()]
if sys.platform.startswith("win32"):
- lib_dll_names = [f"{base}.dll"]
+ lib_dll_names = (f"{target_name}.dll",)
elif sys.platform.startswith("darwin"):
- lib_dll_names = [ # Prefer dylib, also allow .so for some toolchains
- f"lib{base}.dylib",
- f"lib{base}.so",
- ]
+ # Prefer dylib, also allow .so for some toolchains
+ lib_dll_names = (f"lib{target_name}.dylib", f"lib{target_name}.so")
else: # Linux, FreeBSD, etc
- lib_dll_names = [f"lib{base}.so"]
-
- lib_dll_path = [p / name for name in lib_dll_names for p in dll_path]
- lib_found = [p for p in lib_dll_path if p.exists() and p.is_file()]
-
- if not lib_found:
- candidate_list = "\n".join(str(p) for p in lib_dll_path)
- raise RuntimeError(
- f"Cannot find library: {', '.join(lib_dll_names)}\nList of
candidates:\n{candidate_list}"
- )
-
- return str(lib_found[0])
-
-
-def find_source_path() -> str:
- """Find packaged source home path."""
- candidates = [
- str(Path(__file__).resolve().parent),
- str(Path(__file__).resolve().parent / ".." / ".."),
- ]
- for candidate in candidates:
- if Path(candidate, "cmake").is_dir():
- return candidate
- raise RuntimeError("Cannot find home path.")
-
-
-def find_cmake_path() -> str:
- """Find the preferred cmake path."""
- candidates = [
- str(Path(__file__).resolve().parent / "share" / "cmake" / "tvm_ffi"),
# Standard install
- str(Path(__file__).resolve().parent / ".." / ".." / "cmake"), #
Development mode
- ]
- for candidate in candidates:
- if Path(candidate).is_dir():
- return candidate
- raise RuntimeError("Cannot find cmake path.")
-
+ lib_dll_names = (f"lib{target_name}.so",)
+
+ # Use `importlib.metadata` is the most reliable way to find package data
files
+ dist: im.PathDistribution = im.distribution(package) # type:
ignore[assignment]
+ record = dist.read_text("RECORD") or ""
+ for line in record.splitlines():
+ partial_path, *_ = line.split(",")
+ if partial_path.endswith(lib_dll_names):
+ try:
+ path = (dist._path.parent / partial_path).resolve()
+ except OSError:
+ continue
+ if path.name in lib_dll_names:
+ return path
+
+ # **Fallback**. it's possible that the library is not built as part of
Python ecosystem,
+ # e.g. Use PYTHONPATH to point to dev package, and CMake + Makefiles to
build the shared library.
+ dll_paths: list[Path] = []
+
+ # Case 1. It is under $PROJECT_ROOT/build/lib/ or $PROJECT_ROOT/lib/
+ dll_paths.append(_rel_top_directory() / "build" / "lib")
+ dll_paths.append(_rel_top_directory() / "lib")
+ dll_paths.append(_dev_top_directory() / "build" / "lib")
+ dll_paths.append(_dev_top_directory() / "lib")
+
+ # Case 2. It is specified in PATH-related environment variables
+ if sys.platform.startswith("win32"):
+ dll_paths.extend(Path(p) for p in _split_env_var("PATH", ";"))
+ elif sys.platform.startswith("darwin"):
+ dll_paths.extend(Path(p) for p in _split_env_var("DYLD_LIBRARY_PATH",
":"))
+ dll_paths.extend(Path(p) for p in _split_env_var("PATH", ":"))
+ else:
+ dll_paths.extend(Path(p) for p in _split_env_var("LD_LIBRARY_PATH",
":"))
+ dll_paths.extend(Path(p) for p in _split_env_var("PATH", ":"))
+
+ # Search for the library in candidate directories
+ for dll_dir in dll_paths:
+ for lib_dll_name in lib_dll_names:
+ try:
+ path = (dll_dir / lib_dll_name).resolve()
+ if path.is_file():
+ return path
+ except OSError:
+ continue
+ raise RuntimeError(f"Cannot find library: {', '.join(lib_dll_names)}")
+
+
+def _split_env_var(env_var: str, split: str) -> list[str]:
+ """Split an environment variable string.
-def find_include_path() -> str:
- """Find header files for C compilation."""
- candidates = [
- str(Path(__file__).resolve().parent / "include"),
- str(Path(__file__).resolve().parent / ".." / ".." / "include"),
- ]
- for candidate in candidates:
- if Path(candidate).is_dir():
- return candidate
- raise RuntimeError("Cannot find include path.")
+ Parameters
+ ----------
+ env_var
+ Name of environment variable.
+ split
+ String to split env_var on.
-def find_python_helper_include_path() -> str:
- """Find header files for C compilation."""
- candidates = [
- str(Path(__file__).resolve().parent / "include"),
- str(Path(__file__).resolve().parent / "cython"),
- ]
- for candidate in candidates:
- if Path(candidate, "tvm_ffi_python_helpers.h").is_file():
- return candidate
- raise RuntimeError("Cannot find python helper include path.")
+ Returns
+ -------
+ splits
+ If env_var exists, split env_var. Otherwise, empty list.
+ """
+ if os.environ.get(env_var, None):
+ return [p.strip() for p in os.environ[env_var].split(split)]
+ return []
-def find_dlpack_include_path() -> str:
- """Find dlpack header files for C compilation."""
- install_include_path = Path(__file__).resolve().parent / "include"
- if (install_include_path / "dlpack").is_dir():
- return str(install_include_path)
- source_include_path = (
- Path(__file__).resolve().parent / ".." / ".." / "3rdparty" / "dlpack"
/ "include"
- )
- if source_include_path.is_dir():
- return str(source_include_path)
+def _rel_top_directory() -> Path:
+ """Get the current directory of this file."""
+ return Path(__file__).parent
- raise RuntimeError("Cannot find include path.")
+def _dev_top_directory() -> Path:
+ """Get the top-level development directory."""
+ return _rel_top_directory() / ".." / ".."
-def find_cython_lib() -> str:
- """Find the path to tvm cython."""
- path_candidates = [
- Path(__file__).resolve().parent,
- Path(__file__).resolve().parent / ".." / ".." / "build",
- ]
- suffixes = "pyd" if sys.platform.startswith("win32") else "so"
- for candidate in path_candidates:
- for path in Path(candidate).glob(f"core*.{suffixes}"):
- return str(Path(path).resolve())
- raise RuntimeError("Cannot find tvm cython path.")
+def _resolve_and_validate(
+ paths: list[Path],
+ cond: Callable[[Path], bool | Path],
+) -> str | None:
+ """For all paths that resolve properly, find the 1st one that meets the
specified condition.
-def include_paths() -> list[str]:
- """Find all include paths needed for FFI related compilation."""
- include_path = find_include_path()
- python_helper_include_path = find_python_helper_include_path()
- dlpack_include_path = find_dlpack_include_path()
- result = [include_path, dlpack_include_path]
- if python_helper_include_path != include_path:
- result.append(python_helper_include_path)
- return result
+ M. B. This code path gracefully handles broken paths, symlinks, or
permission issues,
+ and is required for robust library discovery in all public APIs in this
file.
+ """
+ for path in paths:
+ try:
+ resolved = path.resolve()
+ ret = cond(resolved)
+ except (OSError, AssertionError):
+ continue
+ if isinstance(ret, Path):
+ return str(ret)
+ elif ret is True:
+ return str(resolved)
+ return None
diff --git a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
index 99f76d4..4512caf 100644
--- a/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
+++ b/python/tvm_ffi/utils/_build_optional_torch_c_dlpack.py
@@ -726,7 +726,7 @@ def get_torch_include_paths(build_with_cuda: bool) ->
Sequence[str]:
device_type="cuda" if build_with_cuda else "cpu"
)
else:
- return torch.utils.cpp_extension.include_paths(cuda=build_with_cuda)
+ return torch.utils.cpp_extension.include_paths(cuda=build_with_cuda)
# type: ignore[call-arg]
def main() -> None: # noqa: PLR0912, PLR0915
diff --git a/tests/python/test_dlpack_exchange_api.py
b/tests/python/test_dlpack_exchange_api.py
index 70e7586..9d1df21 100644
--- a/tests/python/test_dlpack_exchange_api.py
+++ b/tests/python/test_dlpack_exchange_api.py
@@ -32,7 +32,7 @@ try:
from torch.utils import cpp_extension # type: ignore
from tvm_ffi import libinfo
except ImportError:
- torch = None
+ torch = None # type: ignore[assignment]
# Check if DLPack Exchange API is available
_has_dlpack_api = torch is not None and hasattr(torch.Tensor,
"__dlpack_c_exchange_api__")
@@ -46,7 +46,7 @@ def test_dlpack_exchange_api() -> None:
assert torch is not None
assert hasattr(torch.Tensor, "__dlpack_c_exchange_api__")
- api_attr = torch.Tensor.__dlpack_c_exchange_api__
+ api_attr = torch.Tensor.__dlpack_c_exchange_api__ # type:
ignore[attr-defined]
# PyCapsule - extract the pointer as integer
pythonapi = ctypes.pythonapi
# Set restype to c_size_t to get integer directly (avoids c_void_p quirks)
diff --git a/tests/python/test_libinfo.py b/tests/python/test_libinfo.py
new file mode 100644
index 0000000..c15b130
--- /dev/null
+++ b/tests/python/test_libinfo.py
@@ -0,0 +1,119 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Tests for the tvm-ffi-config command line utility."""
+
+import subprocess
+import sys
+from pathlib import Path
+from typing import Callable
+
+import pytest
+from tvm_ffi import libinfo
+
+
+def _stdout_for(*args: str) -> str:
+ """Invoke tvm-ffi-config with the provided arguments and return stdout
with trailing whitespace removed."""
+ result = subprocess.run(
+ [
+ sys.executable,
+ "-m",
+ "tvm_ffi.config",
+ *args,
+ ],
+ check=True,
+ capture_output=True,
+ text=True,
+ )
+ assert result.stderr == ""
+ return result.stdout.strip()
+
+
[email protected](
+ ("flag", "expected_fn", "is_dir"),
+ [
+ ("--includedir", libinfo.find_include_path, True),
+ ("--dlpack-includedir", libinfo.find_dlpack_include_path, True),
+ ("--cmakedir", libinfo.find_cmake_path, True),
+ ("--sourcedir", libinfo.find_source_path, True),
+ ("--cython-lib-path", libinfo.find_cython_lib, False),
+ ],
+)
+def test_basic_path_flags(flag: str, expected_fn: Callable[[], str], is_dir:
bool) -> None:
+ output = _stdout_for(flag)
+ assert output == expected_fn()
+ path = Path(output)
+ assert path.exists()
+ assert path.is_dir() if is_dir else path.is_file()
+
+
+def test_libdir_matches_library_parent() -> None:
+ expected_dir = Path(libinfo.find_libtvm_ffi()).parent
+ output = _stdout_for("--libdir")
+ assert output == str(expected_dir)
+ assert Path(output).is_dir()
+ assert Path(libinfo.find_libtvm_ffi()).is_file()
+
+
+def test_libfiles_reports_platform_library() -> None:
+ output = _stdout_for("--libfiles")
+ if sys.platform.startswith("win32"):
+ expected = libinfo.find_windows_implib()
+ else:
+ expected = libinfo.find_libtvm_ffi()
+ assert output == expected
+ assert Path(output).is_file()
+
+
+def test_libs_reports_link_target() -> None:
+ output = _stdout_for("--libs")
+ if sys.platform.startswith("win32"):
+ assert output == libinfo.find_windows_implib()
+ else:
+ assert output == "-ltvm_ffi"
+
+
+def test_cxxflags_include_paths_and_standard() -> None:
+ include_dir = libinfo.find_include_path()
+ dlpack_dir = libinfo.find_dlpack_include_path()
+ assert _stdout_for("--cxxflags") == f"-I{include_dir} -I{dlpack_dir}
-std=c++17"
+
+
+def test_cflags_include_paths() -> None:
+ include_dir = libinfo.find_include_path()
+ dlpack_dir = libinfo.find_dlpack_include_path()
+ assert _stdout_for("--cflags") == f"-I{include_dir} -I{dlpack_dir}"
+
+
+def test_ldflags_only_on_unix() -> None:
+ output = _stdout_for("--ldflags")
+ if sys.platform.startswith("win32"):
+ assert output == ""
+ else:
+ libdir = Path(libinfo.find_libtvm_ffi()).parent
+ assert output == f"-L{libdir}"
+ assert libdir.is_dir()
+
+
+def test_cmakedir_contains_config_file() -> None:
+ cmake_dir = Path(_stdout_for("--cmakedir"))
+ assert (cmake_dir / "tvm_ffi-config.cmake").is_file()
+
+
+def test_find_python_helper_include_path() -> None:
+ path = libinfo.find_python_helper_include_path()
+ assert Path(path).is_dir()
+ assert (Path(path) / "tvm_ffi_python_helpers.h").is_file()
diff --git a/tests/python/test_optional_torch_c_dlpack.py
b/tests/python/test_optional_torch_c_dlpack.py
index 4666d5c..e8e4100 100644
--- a/tests/python/test_optional_torch_c_dlpack.py
+++ b/tests/python/test_optional_torch_c_dlpack.py
@@ -25,7 +25,7 @@ import pytest
try:
import torch
except ImportError:
- torch = None
+ torch = None # type: ignore[assignment]
import tvm_ffi