This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new c5963d7 [REFACTOR][STUBGEN] refactor the stubgen logic. (#608)
c5963d7 is described below
commit c5963d7bfb9c7c30d992d9d95d736cfbb15368a5
Author: Linzhang Li <[email protected]>
AuthorDate: Fri Jun 12 10:17:41 2026 -0400
[REFACTOR][STUBGEN] refactor the stubgen logic. (#608)
## Summary
This PR refactors the Python stubgen implementation to improve
extensibility and separation of concerns. No functional changes are
intended, and the generated output should remain unchanged.
## Design
The current stubgen implementation contains both language-independent
logic (e.g., file processing and library analysis) and language-specific
code generation. This PR separates these responsibilities by introducing
a generator abstraction for the code generation layer, making it easier
to support additional target languages in the future.
The reorganized file structure is:
```text
python/tvm_ffi/stub/
├── __init__.py
├── cli.py
├── consts.py
├── file_utils.py
├── lib_state.py
├── utils.py
├── generator.py
└── python_generator/
├── __init__.py
├── generator.py
├── codegen.py
├── consts.py
└── utils.py
```
This refactoring lays the groundwork for future support of Rust stub
generation and other language backends.
## Testing
This PR is a pure refactor and is not expected to change existing
behavior. Therefore, no new tests are added.
Signed-off-by: yuchuan <[email protected]>
---
python/tvm_ffi/stub/cli.py | 100 +++++---
python/tvm_ffi/stub/consts.py | 115 ++++-----
python/tvm_ffi/stub/file_utils.py | 71 +++---
python/tvm_ffi/stub/generator.py | 190 ++++++++++++++
python/tvm_ffi/stub/python_generator/__init__.py | 23 ++
.../tvm_ffi/stub/{ => python_generator}/codegen.py | 118 ++++++---
python/tvm_ffi/stub/python_generator/consts.py | 46 ++++
python/tvm_ffi/stub/python_generator/generator.py | 155 ++++++++++++
python/tvm_ffi/stub/python_generator/utils.py | 274 +++++++++++++++++++++
python/tvm_ffi/stub/utils.py | 205 +--------------
tests/python/test_stubgen.py | 210 +++++++++-------
11 files changed, 1067 insertions(+), 440 deletions(-)
diff --git a/python/tvm_ffi/stub/cli.py b/python/tvm_ffi/stub/cli.py
index 0cfd839..66688de 100644
--- a/python/tvm_ffi/stub/cli.py
+++ b/python/tvm_ffi/stub/cli.py
@@ -24,17 +24,21 @@ import importlib
import sys
import traceback
from pathlib import Path
+from typing import TYPE_CHECKING
-from . import codegen as G
from . import consts as C
-from .file_utils import FileInfo, collect_files
+from .file_utils import FileInfo, collect_files, syntax_for
+from .generator import get_generator
from .lib_state import (
collect_global_funcs,
collect_type_keys,
object_info_from_type_key,
toposort_objects,
)
-from .utils import FuncInfo, ImportItem, InitConfig, Options
+from .utils import FuncInfo, InitConfig, Options
+
+if TYPE_CHECKING:
+ from .generator import Generator
def __main__() -> int:
@@ -45,6 +49,7 @@ def __main__() -> int:
overview and examples of the block syntax.
"""
opt = _parse_args()
+ generator = get_generator(opt.target)
for imp in opt.imports or []:
importlib.import_module(imp)
dlls = [ctypes.CDLL(lib) for lib in opt.dlls]
@@ -60,7 +65,7 @@ def __main__() -> int:
# - type maps: `tvm-ffi-stubgen(ty-map)`
# - defined global functions: `tvm-ffi-stubgen(begin): global/...`
# - defined object types: `tvm-ffi-stubgen(begin): object/...`
- ty_map: dict[str, str] = C.TY_MAP_DEFAULTS.copy()
+ ty_map: dict[str, str] = generator.default_ty_map()
for file in files:
try:
_stage_1(file, ty_map)
@@ -70,14 +75,16 @@ def __main__() -> int:
)
# Stage 2. Generate stubs if they are not defined on the file.
+ generated_prefixes: set[str] = set()
if opt.init:
assert init_path is not None, "init-path could not be determined"
- _stage_2(
+ generated_prefixes = _stage_2(
files,
ty_map,
init_cfg=opt.init,
init_path=init_path,
global_funcs=global_funcs,
+ generator=generator,
)
# Stage 3: Process
@@ -87,11 +94,23 @@ def __main__() -> int:
if opt.verbose:
print(f"{C.TERM_CYAN}[File] {file.path}{C.TERM_RESET}")
try:
- _stage_3(file, opt, ty_map, global_funcs)
+ _stage_3(
+ file,
+ opt,
+ ty_map,
+ global_funcs,
+ generator=generator,
+ )
except Exception:
print(
f'{C.TERM_RED}[Failed] File "{file.path}":
{traceback.format_exc()}{C.TERM_RESET}'
)
+
+ # Stage 4. Let the generator stitch the generated tree together (runs
after the
+ # files are fully written, so language-specific wiring isn't clobbered).
+ if opt.init and generated_prefixes:
+ assert init_path is not None
+ generator.finalize_init(init_path, generated_prefixes)
del dlls
return 0
@@ -118,11 +137,12 @@ def _stage_2(
init_cfg: InitConfig,
init_path: Path,
global_funcs: dict[str, list[FuncInfo]],
-) -> None:
+ generator: Generator,
+) -> set[str]:
def _find_or_insert_file(path: Path) -> FileInfo:
ret: FileInfo | None
if not path.exists():
- ret = FileInfo(path=path, lines=(), code_blocks=[])
+ ret = FileInfo(path=path, lines=(), code_blocks=[],
syntax=syntax_for(path))
else:
for file in files:
if path.samefile(file.path):
@@ -148,6 +168,7 @@ def _stage_2(
prefixes: dict[str, list[str]] = collect_type_keys()
for prefix in global_funcs:
prefixes.setdefault(prefix, [])
+ generated_prefixes: set[str] = set()
for prefix, obj_names in prefixes.items():
if not (prefix == root_prefix or prefix.startswith(prefix_filter)):
continue
@@ -159,15 +180,17 @@ def _stage_2(
object_infos = toposort_objects(objs)
if not funcs and not object_infos:
continue
+ generated_prefixes.add(prefix)
# Step 1. Create target directory if not exists
directory = init_path / prefix.replace(".", "/")
directory.mkdir(parents=True, exist_ok=True)
- # Step 2. Generate `_ffi_api.py`
- target_path = directory / "_ffi_api.py"
+ # Step 2. Generate the API file.
+ api_filename = generator.api_filename()
+ target_path = directory / api_filename
target_file = _find_or_insert_file(target_path)
with target_path.open("a", encoding="utf-8") as f:
f.write(
- G.generate_ffi_api(
+ generator.generate_api_file(
target_file.code_blocks,
ty_map,
prefix,
@@ -177,12 +200,15 @@ def _stage_2(
)
)
target_file.reload()
- # Step 3. Generate `__init__.py`
- target_path = directory / "__init__.py"
+ # Step 3. Generate the package entry (Python `__init__.py`; re-exports
the
+ # API submodule). `submodule` is the API file's stem.
+ submodule = api_filename.rsplit(".", 1)[0]
+ target_path = directory / generator.init_filename()
target_file = _find_or_insert_file(target_path)
with target_path.open("a", encoding="utf-8") as f:
- f.write(G.generate_init(target_file.code_blocks, prefix,
submodule="_ffi_api"))
+ f.write(generator.generate_init_file(target_file.code_blocks,
prefix, submodule))
target_file.reload()
+ return generated_prefixes
def _stage_3( # noqa: PLR0912
@@ -190,35 +216,23 @@ def _stage_3( # noqa: PLR0912
opt: Options,
ty_map: dict[str, str],
global_funcs: dict[str, list[FuncInfo]],
+ generator: Generator,
) -> None:
defined_funcs: set[str] = set()
defined_types: set[str] = set()
- imports: list[ImportItem] = []
- ffi_load_lib_imported = False
+ imports = generator.new_imports()
# Stage 1. Collect `tvm-ffi-stubgen(import-object): ...`
for code in file.code_blocks:
if code.kind == "import-object":
name, type_checking_only, alias = code.param
- imports.append(
- ImportItem(
- name,
- type_checking_only=(
- bool(type_checking_only)
- and isinstance(type_checking_only, str)
- and type_checking_only.lower() == "true"
- ),
- alias=alias if alias else None,
- )
- )
- if (alias and alias == "_FFI_LOAD_LIB") or
name.endswith("libinfo.load_lib_module"):
- ffi_load_lib_imported = True
+ generator.add_imported_object(imports, name, type_checking_only,
alias)
# Stage 2. Process `tvm-ffi-stubgen(begin): global/...`
for code in file.code_blocks:
if code.kind == "global":
funcs = global_funcs.get(code.param[0], [])
for func in funcs:
defined_funcs.add(func.schema.name)
- G.generate_global_funcs(code, funcs, ty_map, imports, opt)
+ generator.generate_global_funcs_block(code, funcs, ty_map,
imports, opt)
# Stage 3. Process `tvm-ffi-stubgen(begin): object/...`
for code in file.code_blocks:
if code.kind == "object":
@@ -226,27 +240,23 @@ def _stage_3( # noqa: PLR0912
assert isinstance(type_key, str)
obj_info = object_info_from_type_key(type_key)
type_key = ty_map.get(type_key, type_key)
- full_name = ImportItem(type_key).full_name
- defined_types.add(full_name)
- G.generate_object(code, ty_map, imports, opt, obj_info)
+ defined_types.add(generator.canonical_type_name(type_key))
+ generator.generate_object_block(code, ty_map, imports, opt,
obj_info)
# Stage 4. Add imports for used types.
- imports = [i for i in imports if i.full_name not in defined_types]
for code in file.code_blocks:
if code.kind == "import-section":
- G.generate_import_section(code, imports, opt)
+ generator.generate_import_section_block(code, imports, opt,
defined_types)
break # Only one import block per file is supported for now.
# Stage 5. Add `__all__` for defined classes and functions.
for code in file.code_blocks:
if code.kind == "__all__":
- export_names = defined_funcs | defined_types
- if ffi_load_lib_imported:
- export_names = export_names | {"LIB"}
- G.generate_all(code, export_names, opt)
+ export_names = defined_funcs | defined_types |
generator.extra_export_names(imports)
+ generator.generate_all_block(code, export_names, opt)
break # Only one __all__ block per file is supported for now.
# Stage 6. Process `tvm-ffi-stubgen(begin): export/...`
for code in file.code_blocks:
if code.kind == "export":
- G.generate_export(code)
+ generator.generate_export_block(code)
# Finalize: write back to file
file.update(verbose=opt.verbose, dry_run=opt.dry_run)
@@ -328,7 +338,7 @@ def _parse_args() -> Options:
default=4,
help=(
"Extra spaces added inside each generated block, relative to the "
- f"indentation of the corresponding '{C.STUB_BEGIN}' line."
+ "indentation of the corresponding stub 'begin' marker line."
),
)
parser.add_argument(
@@ -341,6 +351,13 @@ def _parse_args() -> Options:
"select where stubs are generated."
),
)
+ parser.add_argument(
+ "--target",
+ type=str,
+ default="python",
+ choices=["python"],
+ help="Code generator target.",
+ )
parser.add_argument(
"--verbose",
action="store_true",
@@ -382,6 +399,7 @@ def _parse_args() -> Options:
files=args.files,
verbose=args.verbose,
dry_run=args.dry_run,
+ target=args.target,
)
diff --git a/python/tvm_ffi/stub/consts.py b/python/tvm_ffi/stub/consts.py
index 94b5545..46ba970 100644
--- a/python/tvm_ffi/stub/consts.py
+++ b/python/tvm_ffi/stub/consts.py
@@ -18,16 +18,65 @@
from __future__ import annotations
+import dataclasses
from typing import Literal
from typing_extensions import TypeAlias
-STUB_PREFIX = "# tvm-ffi-stubgen("
-STUB_BEGIN = f"{STUB_PREFIX}begin):"
-STUB_END = f"{STUB_PREFIX}end)"
-STUB_TY_MAP = f"{STUB_PREFIX}ty-map):"
-STUB_IMPORT_OBJECT = f"{STUB_PREFIX}import-object):"
-STUB_SKIP_FILE = f"{STUB_PREFIX}skip-file)"
+
[email protected](frozen=True)
+class MarkerSyntax:
+ """Comment-syntax-specific stub directive markers.
+
+ All stub directives are embedded inside single-line comments. The comment
+ token (currently ``#`` for Python sources) parameterizes the marker set,
+ while the directive grammar (``tvm-ffi-stubgen(begin): ...`` etc.) stays
+ uniform.
+ """
+
+ comment: str
+ """The line-comment token for the target language."""
+
+ @property
+ def prefix(self) -> str:
+ """Common prefix shared by every stub directive on a line."""
+ return f"{self.comment} tvm-ffi-stubgen("
+
+ @property
+ def begin(self) -> str:
+ """Marker that opens a generated block: ``<comment>
tvm-ffi-stubgen(begin):``."""
+ return f"{self.prefix}begin):"
+
+ @property
+ def end(self) -> str:
+ """Marker that closes a generated block: ``<comment>
tvm-ffi-stubgen(end)``."""
+ return f"{self.prefix}end)"
+
+ @property
+ def ty_map(self) -> str:
+ """One-line type-map directive: ``<comment>
tvm-ffi-stubgen(ty-map):``."""
+ return f"{self.prefix}ty-map):"
+
+ @property
+ def import_object(self) -> str:
+ """One-line import-object directive: ``<comment>
tvm-ffi-stubgen(import-object):``."""
+ return f"{self.prefix}import-object):"
+
+ @property
+ def skip_file(self) -> str:
+ """Whole-file opt-out directive: ``<comment>
tvm-ffi-stubgen(skip-file)``."""
+ return f"{self.prefix}skip-file)"
+
+
+PYTHON_SYNTAX = MarkerSyntax(comment="#")
+
+#: Map a source-file extension to the marker syntax used inside it. The block
+#: parser selects the syntax per file.
+SYNTAX_BY_EXT: dict[str, MarkerSyntax] = {
+ ".py": PYTHON_SYNTAX,
+ ".pyi": PYTHON_SYNTAX,
+}
+
STUB_BLOCK_KINDS: TypeAlias = Literal[
"global",
"object",
@@ -51,26 +100,10 @@ TERM_CYAN = "\033[36m"
TERM_WHITE = "\033[37m"
DOC_URL = "https://tvm.apache.org/ffi/packaging/stubgen.html"
-DEFAULT_SOURCE_EXTS = {".py", ".pyi"}
-TY_MAP_DEFAULTS = {
- "Any": "typing.Any",
- "Callable": "typing.Callable",
- "Array": "collections.abc.Sequence",
- "List": "collections.abc.MutableSequence",
- "Map": "collections.abc.Mapping",
- "Dict": "collections.abc.MutableMapping",
- "Object": "ffi.Object",
- "Tensor": "ffi.Tensor",
- "dtype": "ffi.dtype",
- "Device": "ffi.Device",
-}
-
-# TODO(@junrushao): Make it configurable
-MOD_MAP = {
- "testing": "tvm_ffi.testing",
- "ffi": "tvm_ffi",
-}
+DEFAULT_SOURCE_EXTS = set(SYNTAX_BY_EXT)
+# Language-neutral metadata transform applied while building `ObjectInfo` from
+# the FFI reflection registry (see `utils.ObjectInfo.from_type_info`).
FN_NAME_MAP: dict[str, str] = {}
BUILTIN_TYPE_KEYS = {
@@ -84,35 +117,3 @@ BUILTIN_TYPE_KEYS = {
"ffi.String",
"ffi.Tensor",
}
-
-
-def _prompt_globals(mod: str) -> str:
- return f"""{STUB_BEGIN} global/{mod}
-{STUB_END}
-"""
-
-
-def _prompt_class_def(type_name: str, type_key: str, parent_type_name: str) ->
str:
- return f'''@_FFI_REG_OBJ("{type_key}")
-class {type_name}({parent_type_name}):
- """FFI binding for `{type_key}`."""
-
- {STUB_BEGIN} object/{type_key}
- {STUB_END}\n\n'''
-
-
-def _prompt_import_object(type_key: str, type_name: str) -> str:
- return f"""{STUB_IMPORT_OBJECT} {type_key};False;{type_name}\n"""
-
-
-PROMPT_IMPORT_SECTION = f"""
-{STUB_BEGIN} import-section
-{STUB_END}
-"""
-
-PROMPT_ALL_SECTION = f"""
-__all__ = [
- {STUB_BEGIN} __all__
- {STUB_END}
-]
-"""
diff --git a/python/tvm_ffi/stub/file_utils.py
b/python/tvm_ffi/stub/file_utils.py
index e7ce609..b409f61 100644
--- a/python/tvm_ffi/stub/file_utils.py
+++ b/python/tvm_ffi/stub/file_utils.py
@@ -28,6 +28,11 @@ from typing import Callable, Generator, Iterable
from . import consts as C
+def syntax_for(path: Path) -> C.MarkerSyntax:
+ """Pick the comment-marker syntax for a file based on its extension."""
+ return C.SYNTAX_BY_EXT.get(path.suffix.lower(), C.PYTHON_SYNTAX)
+
+
@dataclasses.dataclass
class CodeBlock:
"""A block of code to be generated in a stub file."""
@@ -60,10 +65,10 @@ class CodeBlock:
return len(first_line) - len(first_line.lstrip(" "))
@staticmethod
- def from_begin_line(lineo: int, line: str) -> CodeBlock:
+ def from_begin_line(lineo: int, line: str, syntax: C.MarkerSyntax) ->
CodeBlock:
"""Parse a line to create a CodeBlock if it contains a stub begin
marker."""
- if line.startswith(C.STUB_TY_MAP):
- line = line[len(C.STUB_TY_MAP) :].strip()
+ if line.startswith(syntax.ty_map):
+ line = line[len(syntax.ty_map) :].strip()
return CodeBlock(
kind="ty-map",
param=line,
@@ -71,8 +76,8 @@ class CodeBlock:
lineno_end=lineo,
lines=[],
)
- elif line.startswith(C.STUB_IMPORT_OBJECT):
- line = line[len(C.STUB_IMPORT_OBJECT) :].strip()
+ elif line.startswith(syntax.import_object):
+ line = line[len(syntax.import_object) :].strip()
splits = [p.strip() for p in line.split(";")]
if len(splits) < 3:
splits += [""] * (3 - len(splits))
@@ -83,16 +88,13 @@ class CodeBlock:
lineno_end=lineo,
lines=[],
)
- assert line.startswith(C.STUB_BEGIN)
+ assert line.startswith(syntax.begin)
param: str | tuple[str, ...]
- stub = line[len(C.STUB_BEGIN) :].strip()
+ stub = line[len(syntax.begin) :].strip()
if stub.startswith("global/"):
kind = "global"
param = stub[len("global/") :].strip()
- if "@" in param:
- param = tuple(param.split("@"))
- else:
- param = (param, "")
+ param = tuple(param.split("@")) if "@" in param else (param, "")
elif stub.startswith("object/"):
kind = "object"
param = stub[len("object/") :].strip()
@@ -126,6 +128,7 @@ class FileInfo:
path: Path
lines: tuple[str, ...]
code_blocks: list[CodeBlock]
+ syntax: C.MarkerSyntax
def update(self, verbose: bool, dry_run: bool) -> bool:
"""Update the file's lines based on the current code blocks and
optionally show a diff."""
@@ -153,16 +156,24 @@ class FileInfo:
return True
@staticmethod
- def from_file(file: Path, include_empty: bool = False) -> FileInfo | None:
# noqa: PLR0912
- """Parse a file to extract code blocks based on stub markers."""
+ def from_file( # noqa: PLR0912
+ file: Path, include_empty: bool = False, syntax: C.MarkerSyntax | None
= None
+ ) -> FileInfo | None:
+ """Parse a file to extract code blocks based on stub markers.
+
+ The marker comment syntax is auto-detected from the file extension when
+ ``syntax`` is not given.
+ """
assert file.is_file(), f"Expected a file, but got: {file}"
file = file.resolve()
+ if syntax is None:
+ syntax = syntax_for(file)
has_marker = False
lines: list[str] = file.read_text(encoding="utf-8").splitlines()
for _, line in enumerate(lines, start=1):
- if line.strip().startswith(C.STUB_SKIP_FILE):
+ if line.strip().startswith(syntax.skip_file):
return None
- if line.strip().startswith(C.STUB_PREFIX):
+ if line.strip().startswith(syntax.prefix):
has_marker = True
if not has_marker and not include_empty:
return None
@@ -172,36 +183,36 @@ class FileInfo:
code: CodeBlock | None = None
for lineno, line in enumerate(lines, 1):
clean_line = line.strip()
- if clean_line.startswith(C.STUB_BEGIN):
- # Process "# tvm-ffi-stubgen(begin)"
+ if clean_line.startswith(syntax.begin):
+ # Process "<comment> tvm-ffi-stubgen(begin)"
if code is not None:
raise ValueError(f"Nested stub not permitted at line
{lineno}")
- code = CodeBlock.from_begin_line(lineno, clean_line)
+ code = CodeBlock.from_begin_line(lineno, clean_line, syntax)
code.lineno_start = lineno
code.lines.append(line)
- elif clean_line.startswith(C.STUB_END):
- # Process "# tvm-ffi-stubgen(end)"
+ elif clean_line.startswith(syntax.end):
+ # Process "<comment> tvm-ffi-stubgen(end)"
if code is None:
- raise ValueError(f"Unmatched `{C.STUB_END}` found at line
{lineno}")
+ raise ValueError(f"Unmatched `{syntax.end}` found at line
{lineno}")
code.lineno_end = lineno
code.lines.append(line)
codes.append(code)
code = None
- elif clean_line.startswith(C.STUB_TY_MAP):
- # Process "# tvm-ffi-stubgen(ty_map)"
- ty_code = CodeBlock.from_begin_line(lineno, clean_line)
+ elif clean_line.startswith(syntax.ty_map):
+ # Process "<comment> tvm-ffi-stubgen(ty_map)"
+ ty_code = CodeBlock.from_begin_line(lineno, clean_line, syntax)
ty_code.lineno_end = lineno
ty_code.lines.append(line)
codes.append(ty_code)
del ty_code
- elif clean_line.startswith(C.STUB_IMPORT_OBJECT):
- # Process "# tvm-ffi-stubgen(import-object)"
- imp_code = CodeBlock.from_begin_line(lineno, clean_line)
+ elif clean_line.startswith(syntax.import_object):
+ # Process "<comment> tvm-ffi-stubgen(import-object)"
+ imp_code = CodeBlock.from_begin_line(lineno, clean_line,
syntax)
imp_code.lineno_end = lineno
imp_code.lines.append(line)
codes.append(imp_code)
del imp_code
- elif clean_line.startswith(C.STUB_PREFIX):
+ elif clean_line.startswith(syntax.prefix):
raise ValueError(f"Unknown stub type at line {lineno}:
{clean_line}")
elif code is None:
# Process a plain line outside of any stub block
@@ -218,11 +229,11 @@ class FileInfo:
code.lines.append(line)
if code is not None:
raise ValueError("Unclosed stub block at end of file")
- return FileInfo(path=file, lines=tuple(lines), code_blocks=codes)
+ return FileInfo(path=file, lines=tuple(lines), code_blocks=codes,
syntax=syntax)
def reload(self) -> None:
"""Reload the code blocks from disk while preserving original
`lines`."""
- source = FileInfo.from_file(self.path)
+ source = FileInfo.from_file(self.path, syntax=self.syntax)
assert source is not None, f"File no longer exists or valid:
{self.path}"
self.code_blocks = source.code_blocks
diff --git a/python/tvm_ffi/stub/generator.py b/python/tvm_ffi/stub/generator.py
new file mode 100644
index 0000000..1651dd9
--- /dev/null
+++ b/python/tvm_ffi/stub/generator.py
@@ -0,0 +1,190 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Pluggable code generators for ``tvm-ffi-stubgen``.
+
+The stub generator separates two concerns:
+
+1. *Language-agnostic* infrastructure — reading the FFI reflection registry
+ (:mod:`.lib_state`), parsing/writing marker blocks (:mod:`.file_utils`), and
+ the abstract object/function metadata (:class:`.utils.ObjectInfo`,
+ :class:`.utils.FuncInfo`). None of this knows or cares about the target
+ language.
+2. *Language-specific* rendering — turning that metadata into concrete source
+ text, rendering a :class:`~tvm_ffi.core.TypeSchema` into a target-language
+ type expression, and modelling that language's imports.
+
+A :class:`Generator` encapsulates concern (2); ``cli.py`` drives concern (1)
and
+delegates every act of emitting text — and every act of collecting imports — to
+the active generator. The import collector is opaque to the pipeline:
``cli.py``
+asks the generator to create one, seed it from ``import-object`` directives,
and
+later render it, but never reaches inside. Adding a language is therefore
+"implement one more :class:`Generator`" rather than forking the pipeline.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Protocol
+
+from . import consts as C
+from .python_generator import PythonGenerator
+
+if TYPE_CHECKING:
+ from pathlib import Path
+
+ from .file_utils import CodeBlock
+ from .utils import FuncInfo, InitConfig, ObjectInfo, Options
+
+
+class Generator(Protocol):
+ """Language-specific rendering surface used by the stub-generation
pipeline.
+
+ Each method that ends in ``_block`` mutates ``code.lines`` in place to hold
+ the freshly generated text between the ``begin``/``end`` markers. The
+ ``*_file`` methods return whole-file scaffolding text used by ``--init``
+ mode. Implementations must be stateless with respect to a single file so
the
+ pipeline can process files in any order.
+
+ The ``imports`` parameter threaded through the ``_block`` methods is an
+ opaque import collector created by :meth:`new_imports`; only the generator
+ that created it understands its contents.
+ """
+
+ #: Short identifier, e.g. ``"python"``.
+ name: str
+
+ #: Comment-marker syntax for the files this generator emits.
+ syntax: C.MarkerSyntax
+
+ def default_ty_map(self) -> dict[str, str]:
+ """Return the default FFI-origin -> target-type name map for this
language."""
+ ...
+
+ # --- import collection (representation is generator-private) ------------
+
+ def new_imports(self) -> Any:
+ """Create a fresh, empty import collector for one file."""
+ ...
+
+ def add_imported_object(
+ self, imports: Any, name: str, type_checking_only: str, alias: str
+ ) -> None:
+ """Record an ``import-object`` directive (raw directive fields) into
``imports``."""
+ ...
+
+ def canonical_type_name(self, type_key: str) -> str:
+ """Return the canonical identifier for a locally-defined type key.
+
+ Used to suppress importing a type the file itself defines, and to feed
+ the public-export list. Must be comparable to the names produced while
+ collecting imports.
+ """
+ ...
+
+ def extra_export_names(self, imports: Any) -> set[str]:
+ """Return extra public-export names implied by the collected
imports."""
+ ...
+
+ # --- per-block generation (mutates `code.lines`) ------------------------
+
+ def generate_global_funcs_block(
+ self,
+ code: CodeBlock,
+ global_funcs: list[FuncInfo],
+ ty_map: dict[str, str],
+ imports: Any,
+ opt: Options,
+ ) -> None:
+ """Emit free function signatures for a ``global/<prefix>`` block."""
+ ...
+
+ def generate_object_block(
+ self,
+ code: CodeBlock,
+ ty_map: dict[str, str],
+ imports: Any,
+ opt: Options,
+ obj_info: ObjectInfo,
+ ) -> None:
+ """Emit a type definition (fields + methods + init) for an
``object/<key>`` block."""
+ ...
+
+ def generate_import_section_block(
+ self, code: CodeBlock, imports: Any, opt: Options, defined_types:
set[str]
+ ) -> None:
+ """Emit the import/`use` statements collected while rendering other
blocks.
+
+ ``defined_types`` holds the canonical names defined in this file so the
+ generator can drop imports that would shadow a local definition.
+ """
+ ...
+
+ def generate_all_block(self, code: CodeBlock, names: set[str], opt:
Options) -> None:
+ """Emit the public-export list for this generator."""
+ ...
+
+ def generate_export_block(self, code: CodeBlock) -> None:
+ """Emit a submodule re-export for an ``export/<submodule>`` block."""
+ ...
+
+ def generate_helpers_block(self, code: CodeBlock, opt: Options) -> None:
+ """Emit shared per-file support code for generator-specific helper
blocks."""
+ ...
+
+ # --- whole-file scaffolding (used by `--init` mode) ---------------------
+
+ def api_filename(self) -> str:
+ """File name of the scaffolded API file."""
+ ...
+
+ def init_filename(self) -> str:
+ """File name of the scaffolded package entry (Python
``__init__.py``)."""
+ ...
+
+ def generate_api_file(
+ self,
+ code_blocks: list[CodeBlock],
+ ty_map: dict[str, str],
+ module_name: str,
+ object_infos: list[ObjectInfo],
+ init_cfg: InitConfig,
+ is_root: bool,
+ ) -> str:
+ """Return text appended to a freshly scaffolded API file (Python
``_ffi_api.py``)."""
+ ...
+
+ def generate_init_file(
+ self, code_blocks: list[CodeBlock], module_name: str, submodule: str
+ ) -> str:
+ """Return text appended to a freshly scaffolded package entry (Python
``__init__.py``)."""
+ ...
+
+ def finalize_init(self, init_path: Path, generated_prefixes: set[str]) ->
None:
+ """Post-``--init`` hook to stitch the generated tree after file
creation."""
+ ...
+
+
+_GENERATORS: dict[str, Generator] = {
+ "python": PythonGenerator(),
+}
+
+
+def get_generator(target: str) -> Generator:
+ """Resolve a generator by target name."""
+ if target not in _GENERATORS:
+ known = ", ".join(sorted(_GENERATORS))
+ raise ValueError(f"Unknown stubgen generator: {target!r}. Known
generators: {known}")
+ return _GENERATORS[target]
diff --git a/python/tvm_ffi/stub/python_generator/__init__.py
b/python/tvm_ffi/stub/python_generator/__init__.py
new file mode 100644
index 0000000..eb356cc
--- /dev/null
+++ b/python/tvm_ffi/stub/python_generator/__init__.py
@@ -0,0 +1,23 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Python code generator for ``tvm-ffi-stubgen``."""
+
+from __future__ import annotations
+
+from .generator import PythonGenerator
+
+__all__ = ["PythonGenerator"]
diff --git a/python/tvm_ffi/stub/codegen.py
b/python/tvm_ffi/stub/python_generator/codegen.py
similarity index 73%
rename from python/tvm_ffi/stub/codegen.py
rename to python/tvm_ffi/stub/python_generator/codegen.py
index bc31103..c3ffe76 100644
--- a/python/tvm_ffi/stub/codegen.py
+++ b/python/tvm_ffi/stub/python_generator/codegen.py
@@ -14,15 +14,72 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Code generation logic for the `tvm-ffi-stubgen` tool."""
+"""Python code generation for the ``tvm-ffi-stubgen`` tool.
+
+This module owns the Python codegen orchestration for language-agnostic FFI
+metadata (:class:`tvm_ffi.stub.utils.FuncInfo` /
+:class:`~tvm_ffi.stub.utils.ObjectInfo`). Rendering helpers live in
+``python_generator.utils`` so the per-block generation pipeline here stays
focused
+on directive handling and source assembly.
+"""
from __future__ import annotations
from typing import Callable
-from . import consts as C
-from .file_utils import CodeBlock
-from .utils import FuncInfo, ImportItem, InitConfig, ObjectInfo, Options
+from .. import consts as C
+from ..file_utils import CodeBlock
+from ..utils import FuncInfo, InitConfig, ObjectInfo, Options
+from .utils import (
+ ImportItem,
+ render_func_signature,
+ render_object_ffi_init,
+ render_object_fields,
+ render_object_init,
+ render_object_methods,
+)
+
+# --- Python scaffolding templates (init mode) -------------------------------
+# These emit Python source plus stub-directive markers. The marker comment
token
+# comes from the supplied `MarkerSyntax`, so the directive structure stays
+# language-aware even though the surrounding code is Python-specific.
+
+
+def _prompt_globals(mod: str, syntax: C.MarkerSyntax) -> str:
+ return f"""{syntax.begin} global/{mod}
+{syntax.end}
+"""
+
+
+def _prompt_class_def(
+ type_name: str, type_key: str, parent_type_name: str, syntax:
C.MarkerSyntax
+) -> str:
+ return f'''@_FFI_REG_OBJ("{type_key}")
+class {type_name}({parent_type_name}):
+ """FFI binding for `{type_key}`."""
+
+ {syntax.begin} object/{type_key}
+ {syntax.end}\n\n'''
+
+
+def _prompt_import_object(type_key: str, type_name: str, syntax:
C.MarkerSyntax) -> str:
+ return f"""{syntax.import_object} {type_key};False;{type_name}\n"""
+
+
+def _prompt_import_section(syntax: C.MarkerSyntax) -> str:
+ return f"""
+{syntax.begin} import-section
+{syntax.end}
+"""
+
+
+def _prompt_all_section(syntax: C.MarkerSyntax) -> str:
+ return f"""
+__all__ = [
+ {syntax.begin} __all__
+ {syntax.end}
+]
+"""
def _type_suffix_and_record(
@@ -46,7 +103,7 @@ def _type_suffix_and_record(
return _run
-def generate_global_funcs(
+def generate_python_global_funcs(
code: CodeBlock,
global_funcs: list[FuncInfo],
ty_map: dict[str, str],
@@ -83,7 +140,7 @@ def generate_global_funcs(
"# fmt: off",
f'_FFI_INIT_FUNC("{prefix}", __name__)',
"if TYPE_CHECKING:",
- *[func.gen(fn_ty_map, indent=opt.indent) for func in global_funcs],
+ *[render_func_signature(func, fn_ty_map, opt.indent) for func in
global_funcs],
"# fmt: on",
]
indent = " " * code.indent
@@ -94,7 +151,7 @@ def generate_global_funcs(
]
-def generate_object(
+def generate_python_object(
code: CodeBlock,
ty_map: dict[str, str],
imports: list[ImportItem],
@@ -109,12 +166,12 @@ def generate_object(
info = obj_info
method_names = {m.schema.name.rsplit(".", 1)[-1] for m in info.methods}
fn_ty_map = _type_suffix_and_record(ty_map, imports,
func_names=method_names)
- init_lines = info.gen_init(fn_ty_map, indent=opt.indent)
- ffi_init_lines = info.gen_ffi_init(fn_ty_map, indent=opt.indent)
+ init_lines = render_object_init(info, fn_ty_map, opt.indent)
+ ffi_init_lines = render_object_ffi_init(info, fn_ty_map, opt.indent)
type_checking_lines = [
*init_lines,
*ffi_init_lines,
- *info.gen_methods(fn_ty_map, indent=opt.indent),
+ *render_object_methods(info, fn_ty_map, opt.indent),
]
if type_checking_lines:
imports.append(
@@ -125,7 +182,7 @@ def generate_object(
)
results = [
"# fmt: off",
- *info.gen_fields(fn_ty_map, indent=0),
+ *render_object_fields(info, fn_ty_map, 0),
"if TYPE_CHECKING:",
*type_checking_lines,
"# fmt: on",
@@ -133,7 +190,7 @@ def generate_object(
else:
results = [
"# fmt: off",
- *info.gen_fields(fn_ty_map, indent=0),
+ *render_object_fields(info, fn_ty_map, 0),
"# fmt: on",
]
indent = " " * code.indent
@@ -144,7 +201,7 @@ def generate_object(
]
-def generate_import_section(
+def generate_python_import_section(
code: CodeBlock,
imports: list[ImportItem],
opt: Options,
@@ -197,7 +254,7 @@ def generate_import_section(
]
-def generate_all(code: CodeBlock, names: set[str], opt: Options) -> None:
+def generate_python_all(code: CodeBlock, names: set[str], opt: Options) ->
None:
"""Generate an `__all__` variable for the given names."""
assert len(code.lines) >= 2
if not names:
@@ -220,7 +277,7 @@ def generate_all(code: CodeBlock, names: set[str], opt:
Options) -> None:
]
-def generate_export(code: CodeBlock) -> None:
+def generate_python_export(code: CodeBlock) -> None:
"""Generate an `__all__` variable for the given names."""
assert len(code.lines) >= 2
@@ -240,13 +297,14 @@ def generate_export(code: CodeBlock) -> None:
]
-def generate_ffi_api(
+def generate_python_ffi_api(
code_blocks: list[CodeBlock],
ty_map: dict[str, str],
module_name: str,
object_infos: list[ObjectInfo],
init_cfg: InitConfig,
is_root: bool,
+ syntax: C.MarkerSyntax,
) -> str:
"""Generate the initial FFI API stub code for a given module."""
# TODO(@junrus): New code is appended to the end of the file.
@@ -257,22 +315,24 @@ def generate_ffi_api(
if not code_blocks:
append += f"""\"\"\"FFI API bindings for {module_name}.\"\"\"\n"""
if not any(code.kind == "import-section" for code in code_blocks):
- append += C.PROMPT_IMPORT_SECTION
+ append += _prompt_import_section(syntax)
# Part 1. Library loading
if is_root:
- append += C._prompt_import_object("tvm_ffi.libinfo.load_lib_module",
"_FFI_LOAD_LIB")
+ append += _prompt_import_object("tvm_ffi.libinfo.load_lib_module",
"_FFI_LOAD_LIB", syntax)
append += f"""LIB = _FFI_LOAD_LIB("{init_cfg.pkg}",
"{init_cfg.shared_target}")\n"""
# Part 2. Global functions
if not any(code.kind == "global" for code in code_blocks):
- append += C._prompt_globals(module_name)
+ append += _prompt_globals(module_name, syntax)
# Part 3. Object types
if object_infos:
- append += C._prompt_import_object("tvm_ffi.register_object",
"_FFI_REG_OBJ")
+ append += _prompt_import_object("tvm_ffi.register_object",
"_FFI_REG_OBJ", syntax)
- defined_type_keys = {info.type_key for info in object_infos if
info.type_key}
+ defined_type_keys = {
+ ty_map.get(info.type_key, info.type_key) for info in object_infos if
info.type_key
+ }
for info in object_infos:
type_key = info.type_key
parent_type_key = info.parent_type_key
@@ -288,28 +348,30 @@ def generate_ffi_api(
# Import parent type keys if they are not defined in the current module
if parent_type_key and parent_type_key not in defined_type_keys:
parent_type_name = "_" + parent_type_key.replace(".", "_")
- append += C._prompt_import_object(parent_type_key,
parent_type_name)
+ append += _prompt_import_object(parent_type_key, parent_type_name,
syntax)
# Generate class definition
- append += C._prompt_class_def(
+ append += _prompt_class_def(
type_name,
type_key,
parent_type_name,
+ syntax,
)
# Part 4. __all__
if not any(code.kind == "__all__" for code in code_blocks):
- append += C.PROMPT_ALL_SECTION
+ append += _prompt_all_section(syntax)
return append
-def generate_init(
+def generate_python_init(
code_blocks: list[CodeBlock],
module_name: str,
- submodule: str = "_ffi_api",
+ submodule: str,
+ syntax: C.MarkerSyntax,
) -> str:
"""Generate the `__init__.py` file for the `tvm_ffi` package."""
code = f"""
-{C.STUB_BEGIN} export/{submodule}
-{C.STUB_END}
+{syntax.begin} export/{submodule}
+{syntax.end}
"""
if not code_blocks:
return f"""\"\"\"Package {module_name}.\"\"\"\n""" + code
diff --git a/python/tvm_ffi/stub/python_generator/consts.py
b/python/tvm_ffi/stub/python_generator/consts.py
new file mode 100644
index 0000000..a788bc6
--- /dev/null
+++ b/python/tvm_ffi/stub/python_generator/consts.py
@@ -0,0 +1,46 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Python-specific constants for the ``tvm-ffi-stubgen`` Python generator.
+
+These tables map FFI-origin names and module prefixes onto Python typing /
+import syntax. They are intentionally kept out of the language-agnostic
+:mod:`tvm_ffi.stub.consts` so that a non-Python generator never inherits Python
+typing assumptions.
+"""
+
+from __future__ import annotations
+
+#: Default FFI-origin -> Python-type name map used to seed a render.
+TY_MAP_DEFAULTS = {
+ "Any": "typing.Any",
+ "Callable": "typing.Callable",
+ "Array": "collections.abc.Sequence",
+ "List": "collections.abc.MutableSequence",
+ "Map": "collections.abc.Mapping",
+ "Dict": "collections.abc.MutableMapping",
+ "Object": "ffi.Object",
+ "Tensor": "ffi.Tensor",
+ "dtype": "ffi.dtype",
+ "Device": "ffi.Device",
+}
+
+# TODO(@junrushao): Make it configurable
+#: Module-prefix rewrites applied when constructing a Python ``import`` path.
+MOD_MAP = {
+ "testing": "tvm_ffi.testing",
+ "ffi": "tvm_ffi",
+}
diff --git a/python/tvm_ffi/stub/python_generator/generator.py
b/python/tvm_ffi/stub/python_generator/generator.py
new file mode 100644
index 0000000..da1fc16
--- /dev/null
+++ b/python/tvm_ffi/stub/python_generator/generator.py
@@ -0,0 +1,155 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""The Python code generator for ``tvm-ffi-stubgen``.
+
+:class:`PythonGenerator` implements the
:class:`tvm_ffi.stub.generator.Generator`
+protocol by delegating to :mod:`.codegen`. It owns the Python notion of an
+import (:class:`.utils.ImportItem` / :class:`.utils.PythonImports`); the
+language-agnostic pipeline only ever sees the opaque collector.
+"""
+
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from .. import consts as C
+from . import codegen as G
+from . import consts as PC
+from .utils import ImportItem, PythonImports
+
+if TYPE_CHECKING:
+ from pathlib import Path
+
+ from ..file_utils import CodeBlock
+ from ..utils import FuncInfo, InitConfig, ObjectInfo, Options
+
+
+class PythonGenerator:
+ """Generator that emits Python type stubs by delegating to
:mod:`.codegen`."""
+
+ name = "python"
+ syntax = C.PYTHON_SYNTAX
+
+ def default_ty_map(self) -> dict[str, str]:
+ """Return the default FFI-origin -> Python-type name map."""
+ return PC.TY_MAP_DEFAULTS.copy()
+
+ # --- import collection (Python representation is private) ---------------
+
+ def new_imports(self) -> PythonImports:
+ """Create an empty import collector."""
+ return PythonImports()
+
+ def add_imported_object(
+ self, imports: PythonImports, name: str, type_checking_only: str,
alias: str
+ ) -> None:
+ """Record an ``import-object`` directive into the collector."""
+ tco = type_checking_only.lower() == "true"
+ imports.items.append(ImportItem(name, type_checking_only=tco,
alias=alias or None))
+ if alias == "_FFI_LOAD_LIB" or
name.endswith("libinfo.load_lib_module"):
+ imports.has_lib_load = True
+
+ def canonical_type_name(self, type_key: str) -> str:
+ """Return the canonical (import-comparable) full name for a defined
type key."""
+ return ImportItem(type_key).full_name
+
+ def extra_export_names(self, imports: PythonImports) -> set[str]:
+ """Return extra ``__all__`` names implied by the collected imports."""
+ return {"LIB"} if imports.has_lib_load else set()
+
+ # --- per-block generation (mutates `code.lines`) ------------------------
+
+ def generate_global_funcs_block(
+ self,
+ code: CodeBlock,
+ global_funcs: list[FuncInfo],
+ ty_map: dict[str, str],
+ imports: PythonImports,
+ opt: Options,
+ ) -> None:
+ """Emit Python free-function signatures for a ``global/<prefix>``
block."""
+ G.generate_python_global_funcs(code, global_funcs, ty_map,
imports.items, opt)
+
+ def generate_object_block(
+ self,
+ code: CodeBlock,
+ ty_map: dict[str, str],
+ imports: PythonImports,
+ opt: Options,
+ obj_info: ObjectInfo,
+ ) -> None:
+ """Emit a Python class definition for an ``object/<key>`` block."""
+ G.generate_python_object(code, ty_map, imports.items, opt, obj_info)
+
+ def generate_import_section_block(
+ self,
+ code: CodeBlock,
+ imports: PythonImports,
+ opt: Options,
+ defined_types: set[str],
+ ) -> None:
+ """Emit Python ``import`` statements for the collected imports.
+
+ Imports whose full name is a type defined in this same file are dropped
+ (you don't import what you define locally).
+ """
+ filtered = [i for i in imports.items if i.full_name not in
defined_types]
+ G.generate_python_import_section(code, filtered, opt)
+
+ def generate_all_block(self, code: CodeBlock, names: set[str], opt:
Options) -> None:
+ """Emit a Python ``__all__`` list."""
+ G.generate_python_all(code, names, opt)
+
+ def generate_export_block(self, code: CodeBlock) -> None:
+ """Emit a Python submodule re-export for an ``export/<submodule>``
block."""
+ G.generate_python_export(code)
+
+ def generate_helpers_block(self, code: CodeBlock, opt: Options) -> None:
+ """No-op: Python needs no per-file support code (Python files have no
helpers block)."""
+
+ # --- whole-file scaffolding (used by `--init` mode) ---------------------
+
+ def api_filename(self) -> str:
+ """Return the Python API file name."""
+ return "_ffi_api.py"
+
+ def init_filename(self) -> str:
+ """Return the Python package entry file name."""
+ return "__init__.py"
+
+ def generate_api_file(
+ self,
+ code_blocks: list[CodeBlock],
+ ty_map: dict[str, str],
+ module_name: str,
+ object_infos: list[ObjectInfo],
+ init_cfg: InitConfig,
+ is_root: bool,
+ ) -> str:
+ """Return text appended to a scaffolded ``_ffi_api.py``."""
+ return G.generate_python_ffi_api(
+ code_blocks, ty_map, module_name, object_infos, init_cfg, is_root,
self.syntax
+ )
+
+ def generate_init_file(
+ self, code_blocks: list[CodeBlock], module_name: str, submodule: str
+ ) -> str:
+ """Return text appended to a scaffolded ``__init__.py``."""
+ return G.generate_python_init(code_blocks, module_name, submodule,
self.syntax)
+
+ def finalize_init(self, init_path: Path, generated_prefixes: set[str]) ->
None:
+ """No-op: Python packages need no parent-declares-child wiring."""
diff --git a/python/tvm_ffi/stub/python_generator/utils.py
b/python/tvm_ffi/stub/python_generator/utils.py
new file mode 100644
index 0000000..02c0f7b
--- /dev/null
+++ b/python/tvm_ffi/stub/python_generator/utils.py
@@ -0,0 +1,274 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Python generator helpers for ``tvm-ffi-stubgen``.
+
+This module groups two Python-specific concerns:
+
+- import modelling (:class:`ImportItem`, :class:`PythonImports`)
+- stub rendering helpers for function/object signatures
+"""
+
+from __future__ import annotations
+
+import dataclasses
+from io import StringIO
+from typing import Callable
+
+from ..utils import FuncInfo, ObjectInfo
+from . import consts as C
+
+
[email protected](frozen=True, eq=True)
+class ImportItem:
+ """An import statement item."""
+
+ mod: str
+ name: str
+ type_checking_only: bool = False
+ alias: str | None = None
+
+ def __init__(
+ self,
+ full_name: str,
+ type_checking_only: bool = False,
+ alias: str | None = None,
+ ) -> None:
+ """Initialize an `ImportItem` from a dotted ``module.symbol`` name and
optional alias."""
+ if "." in full_name:
+ mod, name = full_name.rsplit(".", 1)
+ for mod_prefix, mod_replacement in C.MOD_MAP.items():
+ if mod == mod_prefix or mod.startswith(mod_prefix + "."):
+ mod = mod.replace(mod_prefix, mod_replacement, 1)
+ break
+ else:
+ mod, name = "", full_name
+ object.__setattr__(self, "mod", mod)
+ object.__setattr__(self, "name", name)
+ object.__setattr__(self, "type_checking_only", type_checking_only)
+ object.__setattr__(self, "alias", alias)
+
+ @property
+ def name_with_alias(self) -> str:
+ """Generate a string of the form `name as alias` if an alias is set,
otherwise just `name`."""
+ return f"{self.name} as {self.alias}" if self.alias else self.name
+
+ @property
+ def full_name(self) -> str:
+ """Generate a string of the form `mod.name` or `name` if no module is
set."""
+ return f"{self.mod}.{self.name}" if self.mod else self.name
+
+ def __repr__(self) -> str:
+ """Generate an import statement string for this item."""
+ return str(self)
+
+ def __str__(self) -> str:
+ """Generate an import statement string for this item."""
+ if self.mod:
+ ret = f"from {self.mod} import {self.name_with_alias}"
+ else:
+ ret = f"import {self.name_with_alias}"
+ return ret
+
+
[email protected]
+class PythonImports:
+ """Opaque import collector threaded through the Python generation pipeline.
+
+ The language-agnostic ``cli`` treats this as an opaque token: it asks the
+ generator to create one, seed it from ``import-object`` directives, and
later
+ render it. Only the Python generator reaches inside.
+ """
+
+ items: list[ImportItem] = dataclasses.field(default_factory=list)
+ has_lib_load: bool = False
+ """Whether an FFI library-loading import was seen (adds ``LIB`` to
``__all__``)."""
+
+
+def render_func_signature(
+ func: FuncInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> str:
+ """Render a function signature string for ``func``."""
+ func_name = func.schema.name.rsplit(".", 1)[-1]
+ buf = StringIO()
+ buf.write(" " * indent)
+ buf.write(f"def {func_name}(")
+ if func.schema.origin != "Callable":
+ raise ValueError(f"Expected Callable type schema, but got:
{func.schema}")
+ if not func.schema.args:
+ ty_map("Any")
+ buf.write("*args: Any) -> Any: ...")
+ return buf.getvalue()
+ arg_ret = func.schema.args[0]
+ arg_args = func.schema.args[1:]
+ for i, arg in enumerate(arg_args):
+ if func.is_member and i == 0:
+ buf.write("self, ")
+ else:
+ buf.write(f"_{i}: ")
+ buf.write(arg.repr(ty_map))
+ buf.write(", ")
+ if arg_args:
+ buf.write("/")
+ buf.write(") -> ")
+ buf.write(arg_ret.repr(ty_map))
+ buf.write(": ...")
+ return buf.getvalue()
+
+
+def render_object_fields(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render field definitions for ``info``."""
+ indent_str = " " * indent
+ return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in
info.fields]
+
+
+def render_object_methods(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render method definitions for ``info``."""
+ indent_str = " " * indent
+ ret = []
+ for method in info.methods:
+ func_name = method.schema.name.rsplit(".", 1)[-1]
+ if func_name == "__ffi_init__":
+ # __ffi_init__ is installed as an instance method (self, *args,
**kwargs) -> None
+ # by _install_ffi_init_attr, regardless of the C++ static
registration.
+ ret.append(_render_ffi_init_from_method(method, ty_map, indent))
+ continue
+ if not method.is_member:
+ ret.append(f"{indent_str}@staticmethod")
+ ret.append(render_func_signature(method, ty_map, indent))
+ return ret
+
+
+def _render_ffi_init_from_method(
+ method: FuncInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> str:
+ """Render ``__ffi_init__`` TypeMethod as an instance method returning
None."""
+ indent_str = " " * indent
+ schema = method.schema
+ # Subclass __ffi_init__ signatures legitimately differ from the parent
+ # (different fields -> different constructor params), so suppress LSP.
+ ignore = " # ty: ignore[invalid-method-override]"
+ if schema.origin != "Callable" or not schema.args:
+ ty_map("Any")
+ return f"{indent_str}def __ffi_init__(self, *args: Any) -> None:
...{ignore}"
+ # schema.args[0] is return type, schema.args[1:] are param types.
+ parts: list[str] = []
+ for i, arg in enumerate(schema.args[1:]):
+ parts.append(f"_{i}: {arg.repr(ty_map)}")
+ if parts:
+ params = ", ".join(parts)
+ return f"{indent_str}def __ffi_init__(self, {params}, /) -> None:
...{ignore}"
+ return f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"
+
+
+def render_object_ffi_init(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render a ``__ffi_init__`` stub when it's not already in TypeMethod.
+
+ For types whose ``__ffi_init__`` is auto-generated by ``RegisterFFIInit``
+ (TypeAttrColumn only), synthesize a static-method stub from field metadata.
+ Types that already have ``__ffi_init__`` in TypeMethod (from explicit
+ ``refl::init<>``) get it via ``render_object_methods`` instead.
+ """
+ if not info.has_init:
+ return []
+ # If __ffi_init__ is already in methods (from TypeMethod), methods render
it.
+ if any(m.schema.name.rsplit(".", 1)[-1] == "__ffi_init__" for m in
info.methods):
+ return []
+ return _render_ffi_init_from_fields(info, ty_map, indent)
+
+
+def render_object_init(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render an ``__init__`` stub from init-eligible field metadata."""
+ if not info.has_init:
+ return []
+ return _render_init_from_fields(info, ty_map, indent)
+
+
+def _format_field_params(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+) -> str:
+ """Format init-eligible fields as a parameter string with defaults and
kw_only."""
+ positional = [f for f in info.init_fields if not f.kw_only]
+ kw_only = [f for f in info.init_fields if f.kw_only]
+
+ pos_required = [f for f in positional if not f.has_default]
+ pos_default = [f for f in positional if f.has_default]
+ kw_required = [f for f in kw_only if not f.has_default]
+ kw_default = [f for f in kw_only if f.has_default]
+
+ parts: list[str] = []
+ for f in pos_required:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ for f in pos_default:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+ if kw_required or kw_default:
+ parts.append("*")
+ for f in kw_required:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
+ for f in kw_default:
+ parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
+
+ return ", ".join(parts)
+
+
+def _render_init_from_fields(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render ``__init__`` from init-eligible field metadata (auto-generated
init)."""
+ indent_str = " " * indent
+ params = _format_field_params(info, ty_map)
+ if params:
+ return [f"{indent_str}def __init__(self, {params}) -> None: ..."]
+ return [f"{indent_str}def __init__(self) -> None: ..."]
+
+
+def _render_ffi_init_from_fields(
+ info: ObjectInfo,
+ ty_map: Callable[[str], str],
+ indent: int,
+) -> list[str]:
+ """Render ``__ffi_init__`` stub from field metadata for auto-generated
init."""
+ indent_str = " " * indent
+ # Subclass __ffi_init__ signatures legitimately differ from the parent
+ # (different fields -> different constructor params), so suppress LSP.
+ ignore = " # ty: ignore[invalid-method-override]"
+ params = _format_field_params(info, ty_map)
+ if params:
+ return [f"{indent_str}def __ffi_init__(self, {params}) -> None:
...{ignore}"]
+ return [f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"]
diff --git a/python/tvm_ffi/stub/utils.py b/python/tvm_ffi/stub/utils.py
index 5ff79d2..172fac3 100644
--- a/python/tvm_ffi/stub/utils.py
+++ b/python/tvm_ffi/stub/utils.py
@@ -14,13 +14,18 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Common utilities for the `tvm-ffi-stubgen` tool."""
+"""Language-agnostic data model for the `tvm-ffi-stubgen` tool.
+
+These dataclasses describe the FFI reflection metadata (functions, object
+fields/methods, init signatures) without committing to any target language.
+Turning this metadata into source text is the job of a target language
+generator (e.g. :mod:`tvm_ffi.stub.python_generator.codegen`).
+"""
from __future__ import annotations
import dataclasses
-from io import StringIO
-from typing import Any, Callable
+from typing import Any
from tvm_ffi.core import TypeInfo, TypeSchema, _lookup_type_attr
@@ -75,58 +80,8 @@ class Options:
files: list[str] = dataclasses.field(default_factory=list)
verbose: bool = False
dry_run: bool = False
-
-
[email protected](frozen=True, eq=True)
-class ImportItem:
- """An import statement item."""
-
- mod: str
- name: str
- type_checking_only: bool = False
- alias: str | None = None
-
- def __init__(
- self,
- name: str,
- type_checking_only: bool = False,
- alias: str | None = None,
- ) -> None:
- """Initialize an `ImportItem` with the given module name and optional
alias."""
- if "." in name:
- mod, name = name.rsplit(".", 1)
- for mod_prefix, mod_replacement in C.MOD_MAP.items():
- if mod.startswith(mod_prefix):
- mod = mod.replace(mod_prefix, mod_replacement, 1)
- break
- else:
- mod = ""
- object.__setattr__(self, "mod", mod)
- object.__setattr__(self, "name", name)
- object.__setattr__(self, "type_checking_only", type_checking_only)
- object.__setattr__(self, "alias", alias)
-
- @property
- def name_with_alias(self) -> str:
- """Generate a string of the form `name as alias` if an alias is set,
otherwise just `name`."""
- return f"{self.name} as {self.alias}" if self.alias else self.name
-
- @property
- def full_name(self) -> str:
- """Generate a string of the form `mod.name` or `name` if no module is
set."""
- return f"{self.mod}.{self.name}" if self.mod else self.name
-
- def __repr__(self) -> str:
- """Generate an import statement string for this item."""
- return str(self)
-
- def __str__(self) -> str:
- """Generate an import statement string for this item."""
- if self.mod:
- ret = f"from {self.mod} import {self.name_with_alias}"
- else:
- ret = f"import {self.name_with_alias}"
- return ret
+ target: str = "python"
+ """Code generator target to use."""
@dataclasses.dataclass(init=False)
@@ -136,7 +91,7 @@ class NamedTypeSchema(TypeSchema):
name: str
def __init__(self, name: str, schema: TypeSchema) -> None:
- """Initialize a `NamedTypeSchema` with the given name and type
schema."""
+ """Initialize a `NamedTypeSchema` with the given name and schema."""
super().__init__(origin=schema.origin, args=schema.args)
self.name = name
@@ -153,37 +108,6 @@ class FuncInfo:
"""Construct a `FuncInfo` from a name and its type schema."""
return FuncInfo(schema=NamedTypeSchema(name=name, schema=schema),
is_member=is_member)
- def gen(self, ty_map: Callable[[str], str], indent: int) -> str:
- """Generate a function signature string for this function."""
- try:
- _, func_name = self.schema.name.rsplit(".", 1)
- except ValueError:
- func_name = self.schema.name
- buf = StringIO()
- buf.write(" " * indent)
- buf.write(f"def {func_name}(")
- if self.schema.origin != "Callable":
- raise ValueError(f"Expected Callable type schema, but got:
{self.schema}")
- if not self.schema.args:
- ty_map("Any")
- buf.write("*args: Any) -> Any: ...")
- return buf.getvalue()
- arg_ret = self.schema.args[0]
- arg_args = self.schema.args[1:]
- for i, arg in enumerate(arg_args):
- if self.is_member and i == 0:
- buf.write("self, ")
- else:
- buf.write(f"_{i}: ")
- buf.write(arg.repr(ty_map))
- buf.write(", ")
- if arg_args:
- buf.write("/")
- buf.write(") -> ")
- buf.write(arg_ret.repr(ty_map))
- buf.write(": ...")
- return buf.getvalue()
-
@dataclasses.dataclass
class InitFieldInfo:
@@ -265,110 +189,3 @@ class ObjectInfo:
init_fields=init_fields,
has_init=has_init,
)
-
- def gen_fields(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
- """Generate field definitions for this object."""
- indent_str = " " * indent
- return [f"{indent_str}{field.name}: {field.repr(ty_map)}" for field in
self.fields]
-
- def gen_methods(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
- """Generate method definitions for this object."""
- indent_str = " " * indent
- ret = []
- for method in self.methods:
- func_name = method.schema.name.rsplit(".", 1)[-1]
- if func_name == "__ffi_init__":
- # __ffi_init__ is installed as an instance method (self,
*args, **kwargs) -> None
- # by _install_ffi_init_attr, regardless of the C++ static
registration.
- ret.append(self._gen_ffi_init_from_method(method, ty_map,
indent))
- continue
- if not method.is_member:
- ret.append(f"{indent_str}@staticmethod")
- ret.append(method.gen(ty_map, indent))
- return ret
-
- @staticmethod
- def _gen_ffi_init_from_method(
- method: FuncInfo, ty_map: Callable[[str], str], indent: int
- ) -> str:
- """Render ``__ffi_init__`` TypeMethod as an instance method returning
None."""
- indent_str = " " * indent
- schema = method.schema
- # Subclass __ffi_init__ signatures legitimately differ from the parent
- # (different fields → different constructor params), so suppress LSP.
- ignore = " # ty: ignore[invalid-method-override]"
- if schema.origin != "Callable" or not schema.args:
- ty_map("Any")
- return f"{indent_str}def __ffi_init__(self, *args: Any) -> None:
...{ignore}"
- # schema.args[0] is return type, schema.args[1:] are param types.
- parts: list[str] = []
- for i, arg in enumerate(schema.args[1:]):
- parts.append(f"_{i}: {arg.repr(ty_map)}")
- if parts:
- params = ", ".join(parts)
- return f"{indent_str}def __ffi_init__(self, {params}, /) -> None:
...{ignore}"
- return f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"
-
- def gen_ffi_init(self, ty_map: Callable[[str], str], indent: int) ->
list[str]:
- """Generate a ``__ffi_init__`` stub when it's not already in
TypeMethod.
-
- For types whose ``__ffi_init__`` is auto-generated by
``RegisterFFIInit``
- (TypeAttrColumn only), synthesize a static-method stub from field
metadata.
- Types that already have ``__ffi_init__`` in TypeMethod (from explicit
- ``refl::init<>``) get it via ``gen_methods`` instead.
- """
- if not self.has_init:
- return []
- # If __ffi_init__ is already in methods (from TypeMethod), gen_methods
handles it.
- if any(m.schema.name.rsplit(".", 1)[-1] == "__ffi_init__" for m in
self.methods):
- return []
- return self._gen_ffi_init_from_fields(ty_map, indent)
-
- def gen_init(self, ty_map: Callable[[str], str], indent: int) -> list[str]:
- """Generate an ``__init__`` stub from init-eligible field metadata."""
- if not self.has_init:
- return []
- return self._gen_init_from_fields(ty_map, indent)
-
- def _format_field_params(self, ty_map: Callable[[str], str]) -> str:
- """Format init-eligible fields as a parameter string with defaults and
kw_only."""
- positional = [f for f in self.init_fields if not f.kw_only]
- kw_only = [f for f in self.init_fields if f.kw_only]
-
- pos_required = [f for f in positional if not f.has_default]
- pos_default = [f for f in positional if f.has_default]
- kw_required = [f for f in kw_only if not f.has_default]
- kw_default = [f for f in kw_only if f.has_default]
-
- parts: list[str] = []
- for f in pos_required:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
- for f in pos_default:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
- if kw_required or kw_default:
- parts.append("*")
- for f in kw_required:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)}")
- for f in kw_default:
- parts.append(f"{f.name}: {f.schema.repr(ty_map)} = ...")
-
- return ", ".join(parts)
-
- def _gen_init_from_fields(self, ty_map: Callable[[str], str], indent: int)
-> list[str]:
- """Generate ``__init__`` from init-eligible field metadata
(auto-generated init)."""
- indent_str = " " * indent
- params = self._format_field_params(ty_map)
- if params:
- return [f"{indent_str}def __init__(self, {params}) -> None: ..."]
- return [f"{indent_str}def __init__(self) -> None: ..."]
-
- def _gen_ffi_init_from_fields(self, ty_map: Callable[[str], str], indent:
int) -> list[str]:
- """Generate ``__ffi_init__`` stub from field metadata for
auto-generated init."""
- indent_str = " " * indent
- # Subclass __ffi_init__ signatures legitimately differ from the parent
- # (different fields → different constructor params), so suppress LSP.
- ignore = " # ty: ignore[invalid-method-override]"
- params = self._format_field_params(ty_map)
- if params:
- return [f"{indent_str}def __ffi_init__(self, {params}) -> None:
...{ignore}"]
- return [f"{indent_str}def __ffi_init__(self) -> None: ...{ignore}"]
diff --git a/tests/python/test_stubgen.py b/tests/python/test_stubgen.py
index c32f6b2..5c00a04 100644
--- a/tests/python/test_stubgen.py
+++ b/tests/python/test_stubgen.py
@@ -23,19 +23,24 @@ import tvm_ffi.stub.cli as stub_cli
from tvm_ffi.core import TypeSchema
from tvm_ffi.stub import consts as C
from tvm_ffi.stub.cli import _stage_2, _stage_3
-from tvm_ffi.stub.codegen import (
- generate_all,
- generate_export,
- generate_ffi_api,
- generate_global_funcs,
- generate_import_section,
- generate_init,
- generate_object,
-)
from tvm_ffi.stub.file_utils import CodeBlock, FileInfo
+from tvm_ffi.stub.generator import get_generator
+from tvm_ffi.stub.python_generator import consts as PC
+from tvm_ffi.stub.python_generator.codegen import (
+ generate_python_all,
+ generate_python_export,
+ generate_python_ffi_api,
+ generate_python_global_funcs,
+ generate_python_import_section,
+ generate_python_init,
+ generate_python_object,
+ render_func_signature,
+ render_object_fields,
+ render_object_methods,
+)
+from tvm_ffi.stub.python_generator.utils import ImportItem
from tvm_ffi.stub.utils import (
FuncInfo,
- ImportItem,
InitConfig,
NamedTypeSchema,
ObjectInfo,
@@ -48,23 +53,23 @@ def _identity_ty_map(name: str) -> str:
def _default_ty_map() -> dict[str, str]:
- return C.TY_MAP_DEFAULTS.copy()
+ return PC.TY_MAP_DEFAULTS.copy()
def _type_suffix(name: str) -> str:
- return C.TY_MAP_DEFAULTS.get(name, name).rsplit(".", 1)[-1]
+ return PC.TY_MAP_DEFAULTS.get(name, name).rsplit(".", 1)[-1]
def test_codeblock_from_begin_line_variants() -> None:
cases = [
- (f"{C.STUB_BEGIN} global/demo", "global", ("demo", "")),
- (f"{C.STUB_BEGIN} global/[email protected]", "global", ("demo",
".registry")),
- (f"{C.STUB_BEGIN} object/demo.TypeBase", "object", "demo.TypeBase"),
- (f"{C.STUB_BEGIN} ty-map/custom", "ty-map", "custom"),
- (f"{C.STUB_BEGIN} import-section", "import-section", ""),
+ (f"{C.PYTHON_SYNTAX.begin} global/demo", "global", ("demo", "")),
+ (f"{C.PYTHON_SYNTAX.begin} global/[email protected]", "global", ("demo",
".registry")),
+ (f"{C.PYTHON_SYNTAX.begin} object/demo.TypeBase", "object",
"demo.TypeBase"),
+ (f"{C.PYTHON_SYNTAX.begin} ty-map/custom", "ty-map", "custom"),
+ (f"{C.PYTHON_SYNTAX.begin} import-section", "import-section", ""),
]
for lineno, (line, kind, param) in enumerate(cases, start=1):
- block = CodeBlock.from_begin_line(lineno, line)
+ block = CodeBlock.from_begin_line(lineno, line, C.PYTHON_SYNTAX)
assert block.kind == kind
assert block.param == param
assert block.lineno_start == lineno
@@ -73,20 +78,20 @@ def test_codeblock_from_begin_line_variants() -> None:
def test_codeblock_from_begin_line_ty_map_and_unknown() -> None:
- line = f"{C.STUB_TY_MAP} custom -> mapped"
- block = CodeBlock.from_begin_line(5, line)
+ line = f"{C.PYTHON_SYNTAX.ty_map} custom -> mapped"
+ block = CodeBlock.from_begin_line(5, line, C.PYTHON_SYNTAX)
assert block.kind == "ty-map"
assert block.param == "custom -> mapped"
assert block.lineno_start == 5
assert block.lineno_end == 5
with pytest.raises(ValueError):
- CodeBlock.from_begin_line(1, f"{C.STUB_BEGIN} unsupported/kind")
+ CodeBlock.from_begin_line(1, f"{C.PYTHON_SYNTAX.begin}
unsupported/kind", C.PYTHON_SYNTAX)
def test_fileinfo_from_file_skip_and_missing_markers(tmp_path: Path) -> None:
skip = tmp_path / "skip.py"
- skip.write_text(f"print('hi')\n{C.STUB_SKIP_FILE}\n", encoding="utf-8")
+ skip.write_text(f"print('hi')\n{C.PYTHON_SYNTAX.skip_file}\n",
encoding="utf-8")
assert FileInfo.from_file(skip) is None
plain = tmp_path / "plain.py"
@@ -98,10 +103,10 @@ def test_fileinfo_from_file_parses_blocks(tmp_path: Path)
-> None:
content = "\n".join(
[
"first = 1",
- f"{C.STUB_BEGIN} global/demo.func",
+ f"{C.PYTHON_SYNTAX.begin} global/demo.func",
"in_stub = True",
- C.STUB_END,
- f"{C.STUB_TY_MAP} x -> y",
+ C.PYTHON_SYNTAX.end,
+ f"{C.PYTHON_SYNTAX.ty_map} x -> y",
]
)
path = tmp_path / "demo.py"
@@ -120,15 +125,15 @@ def test_fileinfo_from_file_parses_blocks(tmp_path: Path)
-> None:
assert stub.lineno_start == 2
assert stub.lineno_end == 4
assert stub.lines == [
- f"{C.STUB_BEGIN} global/demo.func",
+ f"{C.PYTHON_SYNTAX.begin} global/demo.func",
"in_stub = True",
- C.STUB_END,
+ C.PYTHON_SYNTAX.end,
]
assert ty_map.kind == "ty-map"
assert ty_map.param == "x -> y"
assert ty_map.lineno_start == ty_map.lineno_end == 5
- assert ty_map.lines == [f"{C.STUB_TY_MAP} x -> y"]
+ assert ty_map.lines == [f"{C.PYTHON_SYNTAX.ty_map} x -> y"]
def test_fileinfo_from_file_error_paths(tmp_path: Path) -> None:
@@ -136,8 +141,8 @@ def test_fileinfo_from_file_error_paths(tmp_path: Path) ->
None:
nested.write_text(
"\n".join(
[
- f"{C.STUB_BEGIN} global/outer",
- f"{C.STUB_BEGIN} global/inner",
+ f"{C.PYTHON_SYNTAX.begin} global/outer",
+ f"{C.PYTHON_SYNTAX.begin} global/inner",
]
),
encoding="utf-8",
@@ -146,12 +151,12 @@ def test_fileinfo_from_file_error_paths(tmp_path: Path)
-> None:
FileInfo.from_file(nested)
unmatched_end = tmp_path / "unmatched.py"
- unmatched_end.write_text(C.STUB_END + "\n", encoding="utf-8")
+ unmatched_end.write_text(C.PYTHON_SYNTAX.end + "\n", encoding="utf-8")
with pytest.raises(ValueError, match="Unmatched"):
FileInfo.from_file(unmatched_end)
unclosed = tmp_path / "unclosed.py"
- unclosed.write_text(f"{C.STUB_BEGIN} global/method\n", encoding="utf-8")
+ unclosed.write_text(f"{C.PYTHON_SYNTAX.begin} global/method\n",
encoding="utf-8")
with pytest.raises(ValueError, match="Unclosed stub block"):
FileInfo.from_file(unclosed)
@@ -165,7 +170,7 @@ def test_funcinfo_gen_variants() -> None:
schema_no_args = NamedTypeSchema("demo.no_args", TypeSchema("Callable",
()))
func = FuncInfo(schema=schema_no_args, is_member=False)
- assert func.gen(ty_map, indent=2) == " def no_args(*args: Any) -> Any:
..."
+ assert render_func_signature(func, ty_map, indent=2) == " def
no_args(*args: Any) -> Any: ..."
assert called == ["Any"]
schema_member = NamedTypeSchema(
@@ -181,12 +186,15 @@ def test_funcinfo_gen_variants() -> None:
)
member_func = FuncInfo(schema=schema_member, is_member=True)
assert (
- member_func.gen(_identity_ty_map, indent=0) == "def method(self, _1:
float, /) -> str: ..."
+ render_func_signature(member_func, _identity_ty_map, indent=0)
+ == "def method(self, _1: float, /) -> str: ..."
)
schema_bad = NamedTypeSchema("bad", TypeSchema("int"))
with pytest.raises(ValueError):
- FuncInfo(schema=schema_bad, is_member=False).gen(_identity_ty_map,
indent=0)
+ render_func_signature(
+ FuncInfo(schema=schema_bad, is_member=False), _identity_ty_map,
indent=0
+ )
def test_objectinfo_gen_fields_and_methods() -> None:
@@ -218,13 +226,13 @@ def test_objectinfo_gen_fields_and_methods() -> None:
],
)
- assert info.gen_fields(ty_map, indent=2) == [
+ assert render_object_fields(info, ty_map, indent=2) == [
" field_a: Sequence[int]",
" field_b: Mapping[str, float]",
]
assert ty_calls.count("list") == 1 and ty_calls.count("dict") == 1
- methods = info.gen_methods(_identity_ty_map, indent=2)
+ methods = render_object_methods(info, _identity_ty_map, indent=2)
assert methods == [
" @staticmethod",
" def static() -> int: ...",
@@ -285,7 +293,7 @@ def test_objectinfo_gen_fields_container_types() -> None:
],
methods=[],
)
- assert info.gen_fields(_type_suffix, indent=0) == [
+ assert render_object_fields(info, _type_suffix, indent=0) == [
"arr: Sequence[int]",
"lst: MutableSequence[str]",
"mp: Mapping[str, int]",
@@ -299,7 +307,7 @@ def test_generate_global_funcs_updates_block() -> None:
param=("demo", "mockpkg"),
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} global/demo@mockpkg", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
C.PYTHON_SYNTAX.end],
)
funcs = [
FuncInfo(
@@ -312,19 +320,19 @@ def test_generate_global_funcs_updates_block() -> None:
]
opts = Options(indent=2)
imports: list[ImportItem] = []
- generate_global_funcs(code, funcs, _default_ty_map(), imports, opts)
+ generate_python_global_funcs(code, funcs, _default_ty_map(), imports, opts)
assert imports == [
ImportItem("mockpkg.init_ffi_api", alias="_FFI_INIT_FUNC"),
ImportItem("typing.TYPE_CHECKING"),
]
assert code.lines == [
- f"{C.STUB_BEGIN} global/demo@mockpkg",
+ f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
"# fmt: off",
'_FFI_INIT_FUNC("demo", __name__)',
"if TYPE_CHECKING:",
" def add_one(_0: int, /) -> int: ...",
"# fmt: on",
- C.STUB_END,
+ C.PYTHON_SYNTAX.end,
]
@@ -334,11 +342,11 @@ def test_generate_global_funcs_noop_on_empty_list() ->
None:
param=("empty", ""),
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} global/empty", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/empty", C.PYTHON_SYNTAX.end],
)
imports: list[ImportItem] = []
- generate_global_funcs(code, [], _default_ty_map(), imports, Options())
- assert code.lines == [f"{C.STUB_BEGIN} global/empty", C.STUB_END]
+ generate_python_global_funcs(code, [], _default_ty_map(), imports,
Options())
+ assert code.lines == [f"{C.PYTHON_SYNTAX.begin} global/empty",
C.PYTHON_SYNTAX.end]
assert imports == []
@@ -348,7 +356,7 @@ def
test_generate_global_funcs_respects_custom_import_from() -> None:
param=("demo", "custom.mod"),
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} global/[email protected]", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/[email protected]",
C.PYTHON_SYNTAX.end],
)
funcs = [
FuncInfo(
@@ -360,7 +368,7 @@ def
test_generate_global_funcs_respects_custom_import_from() -> None:
)
]
imports: list[ImportItem] = []
- generate_global_funcs(code, funcs, _default_ty_map(), imports,
Options(indent=0))
+ generate_python_global_funcs(code, funcs, _default_ty_map(), imports,
Options(indent=0))
assert ImportItem("custom.mod.init_ffi_api", alias="_FFI_INIT_FUNC") in
imports
@@ -371,7 +379,7 @@ def test_generate_global_funcs_aliases_colliding_type() ->
None:
param=("demo", "mockpkg"),
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} global/demo@mockpkg", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/demo@mockpkg",
C.PYTHON_SYNTAX.end],
)
# Function "demo.Foo" returns type "demo.Foo" — name collision
funcs = [
@@ -386,7 +394,7 @@ def test_generate_global_funcs_aliases_colliding_type() ->
None:
ty_map = _default_ty_map()
ty_map["demo.Foo"] = "somepkg.Foo"
imports: list[ImportItem] = []
- generate_global_funcs(code, funcs, ty_map, imports, Options(indent=4))
+ generate_python_global_funcs(code, funcs, ty_map, imports,
Options(indent=4))
# The type import should use an alias to avoid shadowing the function
assert ImportItem("somepkg.Foo", type_checking_only=True, alias="_Foo") in
imports
# The function annotation should use the alias
@@ -399,7 +407,7 @@ def test_generate_object_fields_only_block() -> None:
param="demo.TypeDerived",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} object/demo.TypeDerived", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} object/demo.TypeDerived",
C.PYTHON_SYNTAX.end],
)
opts = Options(indent=4)
imports: list[ImportItem] = []
@@ -412,7 +420,7 @@ def test_generate_object_fields_only_block() -> None:
type_key="demo.TypeDerived",
parent_type_key="demo.Parent",
)
- generate_object(
+ generate_python_object(
code,
_default_ty_map(),
imports,
@@ -422,11 +430,14 @@ def test_generate_object_fields_only_block() -> None:
assert imports == []
expected = [
- f"{C.STUB_BEGIN} object/demo.TypeDerived",
+ f"{C.PYTHON_SYNTAX.begin} object/demo.TypeDerived",
" " * code.indent + "# fmt: off",
- *[(" " * code.indent) + line for line in info.gen_fields(_type_suffix,
indent=0)],
+ *[
+ (" " * code.indent) + line
+ for line in render_object_fields(info, _type_suffix, indent=0)
+ ],
" " * code.indent + "# fmt: on",
- C.STUB_END,
+ C.PYTHON_SYNTAX.end,
]
assert code.lines == expected
@@ -437,7 +448,7 @@ def test_generate_object_with_methods() -> None:
param="demo.IntPair",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} object/demo.IntPair", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} object/demo.IntPair",
C.PYTHON_SYNTAX.end],
)
opts = Options(indent=4)
imports: list[ImportItem] = []
@@ -458,11 +469,11 @@ def test_generate_object_with_methods() -> None:
type_key="demo.IntPair",
parent_type_key="demo.Parent",
)
- generate_object(code, _default_ty_map(), imports, opts, info)
+ generate_python_object(code, _default_ty_map(), imports, opts, info)
assert set(imports) == {ImportItem("typing.TYPE_CHECKING")}
- assert code.lines[0] == f"{C.STUB_BEGIN} object/demo.IntPair"
- assert code.lines[-1] == C.STUB_END
+ assert code.lines[0] == f"{C.PYTHON_SYNTAX.begin} object/demo.IntPair"
+ assert code.lines[-1] == C.PYTHON_SYNTAX.end
assert "# fmt: off" in code.lines[1]
assert any("if TYPE_CHECKING:" in line for line in code.lines)
method_lines = [line for line in code.lines if "def __ffi_init__" in line
or "def sum" in line]
@@ -471,13 +482,23 @@ def test_generate_object_with_methods() -> None:
assert any(line.strip().startswith("def sum") for line in method_lines)
+def test_import_item_mod_map_prefix_rewrite() -> None:
+ # MOD_MAP rewrites must respect module-path boundaries.
+ assert ImportItem("ffi.Object").mod == "tvm_ffi"
+ assert ImportItem("testing.TestIntPair").mod == "tvm_ffi.testing"
+ assert ImportItem("testing.sub.Thing").mod == "tvm_ffi.testing.sub"
+ # A module that merely starts with a mapped prefix is NOT rewritten.
+ assert ImportItem("testingfoo.Thing").mod == "testingfoo"
+ assert ImportItem("ffi2.Thing").mod == "ffi2"
+
+
def test_generate_import_section_groups_modules() -> None:
code = CodeBlock(
kind="import-section",
param="",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} import", C.PYTHON_SYNTAX.end],
)
imports = [
ImportItem("typing.Any", type_checking_only=True),
@@ -486,10 +507,10 @@ def test_generate_import_section_groups_modules() -> None:
ImportItem("custom.mod.Type", type_checking_only=True),
]
opts = Options(indent=4)
- generate_import_section(code, imports, opts)
+ generate_python_import_section(code, imports, opts)
expected_prefix = [
- f"{C.STUB_BEGIN} import",
+ f"{C.PYTHON_SYNTAX.begin} import",
"# fmt: off",
"# isort: off",
"from __future__ import annotations",
@@ -501,7 +522,7 @@ def test_generate_import_section_groups_modules() -> None:
assert " from demo_pkg import Tensor" in code.lines
assert " from custom.mod import Type" in code.lines
assert " from typing import Any" in code.lines
- assert code.lines[-2:] == ["# fmt: on", C.STUB_END]
+ assert code.lines[-2:] == ["# fmt: on", C.PYTHON_SYNTAX.end]
def test_generate_import_section_no_imports_noop() -> None:
@@ -510,10 +531,10 @@ def test_generate_import_section_no_imports_noop() ->
None:
param="",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} import", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} import", C.PYTHON_SYNTAX.end],
)
before = list(code.lines)
- generate_import_section(code, [], Options())
+ generate_python_import_section(code, [], Options())
assert code.lines == before
@@ -523,19 +544,19 @@ def test_generate_all_builds_sorted_and_deduped_list() ->
None:
param="all",
lineno_start=1,
lineno_end=2,
- lines=[" " + C.STUB_BEGIN + " global/all", C.STUB_END],
+ lines=[" " + C.PYTHON_SYNTAX.begin + " global/all",
C.PYTHON_SYNTAX.end],
)
- generate_all(
+ generate_python_all(
code,
names={"tvm_ffi.foo", "bar", "pkg.baz", "bar"}, # duplicates stripped
opt=Options(indent=2),
)
assert code.lines == [
- " " + C.STUB_BEGIN + " global/all",
+ " " + C.PYTHON_SYNTAX.begin + " global/all",
' "bar",',
' "baz",',
' "foo",',
- C.STUB_END,
+ C.PYTHON_SYNTAX.end,
]
@@ -545,10 +566,10 @@ def test_generate_all_noop_on_empty_names() -> None:
param="all-empty",
lineno_start=1,
lineno_end=2,
- lines=[C.STUB_BEGIN + " global/all-empty", C.STUB_END],
+ lines=[C.PYTHON_SYNTAX.begin + " global/all-empty",
C.PYTHON_SYNTAX.end],
)
before = list(code.lines)
- generate_all(code, names=set(), opt=Options())
+ generate_python_all(code, names=set(), opt=Options())
assert code.lines == before
@@ -558,19 +579,19 @@ def test_generate_all_uses_isort_style_ordering() -> None:
param="all-mixed",
lineno_start=1,
lineno_end=2,
- lines=[C.STUB_BEGIN + " global/all-mixed", C.STUB_END],
+ lines=[C.PYTHON_SYNTAX.begin + " global/all-mixed",
C.PYTHON_SYNTAX.end],
)
names = {"foo", "Bar", "LIB", "baz", "Alpha", "CONST"}
- generate_all(code, names=names, opt=Options(indent=0))
+ generate_python_all(code, names=names, opt=Options(indent=0))
assert code.lines == [
- C.STUB_BEGIN + " global/all-mixed",
+ C.PYTHON_SYNTAX.begin + " global/all-mixed",
'"CONST",',
'"LIB",',
'"Alpha",',
'"Bar",',
'"baz",',
'"foo",',
- C.STUB_END,
+ C.PYTHON_SYNTAX.end,
]
@@ -581,21 +602,23 @@ def
test_stage_3_adds_LIB_when_load_lib_imported(tmp_path: Path) -> None:
param=("testing", ""),
lineno_start=2,
lineno_end=3,
- lines=[f"{C.STUB_BEGIN} global/testing", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} global/testing", C.PYTHON_SYNTAX.end],
)
import_obj_block = CodeBlock(
kind="import-object",
param=("tvm_ffi.libinfo.load_lib_module", "False", "_FFI_LOAD_LIB"),
lineno_start=1,
lineno_end=1,
- lines=[f"{C.STUB_IMPORT_OBJECT}
tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB"],
+ lines=[
+ f"{C.PYTHON_SYNTAX.import_object}
tvm_ffi.libinfo.load_lib_module;False;_FFI_LOAD_LIB"
+ ],
)
all_block = CodeBlock(
kind="__all__",
param="",
lineno_start=4,
lineno_end=5,
- lines=[f"{C.STUB_BEGIN} __all__", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} __all__", C.PYTHON_SYNTAX.end],
)
file_info = FileInfo(
path=path,
@@ -603,6 +626,7 @@ def test_stage_3_adds_LIB_when_load_lib_imported(tmp_path:
Path) -> None:
line for block in (import_obj_block, global_block, all_block) for
line in block.lines
),
code_blocks=[import_obj_block, global_block, all_block],
+ syntax=C.PYTHON_SYNTAX,
)
funcs = [
FuncInfo.from_schema(
@@ -615,6 +639,7 @@ def test_stage_3_adds_LIB_when_load_lib_imported(tmp_path:
Path) -> None:
Options(dry_run=True),
_default_ty_map(),
{"testing": funcs},
+ get_generator("python"),
)
lib_lines = [line for line in all_block.lines if "LIB" in line]
assert any("LIB" in line for line in lib_lines)
@@ -626,20 +651,20 @@ def test_generate_export_builds_all_extension() -> None:
param="ffi_api",
lineno_start=1,
lineno_end=2,
- lines=[f"{C.STUB_BEGIN} export/ffi_api", C.STUB_END],
+ lines=[f"{C.PYTHON_SYNTAX.begin} export/ffi_api", C.PYTHON_SYNTAX.end],
)
- generate_export(code)
+ generate_python_export(code)
full_text = "\n".join(code.lines)
assert "from .ffi_api import *" in full_text
assert "ffi_api__all__" in full_text
def test_generate_init_with_and_without_existing_export_block() -> None:
- code_no_blocks = generate_init([], "demo")
+ code_no_blocks = generate_python_init([], "demo", "_ffi_api",
C.PYTHON_SYNTAX)
assert "Package demo." in code_no_blocks
- assert f"{C.STUB_BEGIN} export/_ffi_api" in code_no_blocks
+ assert f"{C.PYTHON_SYNTAX.begin} export/_ffi_api" in code_no_blocks
- code_with_export = generate_init(
+ code_with_export = generate_python_init(
[
CodeBlock(
kind="export",
@@ -650,23 +675,26 @@ def
test_generate_init_with_and_without_existing_export_block() -> None:
)
],
"demo",
+ "_ffi_api",
+ C.PYTHON_SYNTAX,
)
assert code_with_export == ""
def test_generate_ffi_api_without_objects_includes_sections() -> None:
init_cfg = InitConfig(pkg="pkg", shared_target="pkg_shared", prefix="pkg.")
- code = generate_ffi_api(
+ code = generate_python_ffi_api(
[],
_default_ty_map(),
"demo.mod",
[],
init_cfg,
is_root=False,
+ syntax=C.PYTHON_SYNTAX,
)
- assert f"{C.STUB_BEGIN} import-section" in code
- assert f"{C.STUB_BEGIN} global/demo.mod" in code
- assert C.STUB_BEGIN + " __all__" in code
+ assert f"{C.PYTHON_SYNTAX.begin} import-section" in code
+ assert f"{C.PYTHON_SYNTAX.begin} global/demo.mod" in code
+ assert C.PYTHON_SYNTAX.begin + " __all__" in code
assert "LIB =" not in code
@@ -679,19 +707,20 @@ def test_generate_ffi_api_with_objects_imports_parents()
-> None:
parent_type_key="demo.Parent",
)
parent_key = obj_info.parent_type_key
- code = generate_ffi_api(
+ code = generate_python_ffi_api(
[],
_default_ty_map(),
"demo",
[obj_info],
init_cfg,
is_root=False,
+ syntax=C.PYTHON_SYNTAX,
)
- assert C.STUB_IMPORT_OBJECT in code # register_object prompt
- assert f"{C.STUB_BEGIN} object/{obj_info.type_key}" in code
+ assert C.PYTHON_SYNTAX.import_object in code # register_object prompt
+ assert f"{C.PYTHON_SYNTAX.begin} object/{obj_info.type_key}" in code
assert parent_key is not None
parent_import_prompt = (
- f"{C.STUB_IMPORT_OBJECT} {parent_key};False;_{parent_key.replace('.',
'_')}"
+ f"{C.PYTHON_SYNTAX.import_object}
{parent_key};False;_{parent_key.replace('.', '_')}"
)
assert parent_import_prompt in code
@@ -729,6 +758,7 @@ def test_stage_2_filters_prefix_and_marks_root(
init_cfg=InitConfig(pkg="demo-pkg", shared_target="demo_shared",
prefix="demo."),
init_path=tmp_path,
global_funcs=global_funcs,
+ generator=get_generator("python"),
)
root_api = tmp_path / "demo" / "_ffi_api.py"