This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git


The following commit(s) were added to refs/heads/main by this push:
     new 3115b23  [FEAT] kwargs wrapping utility (#309)
3115b23 is described below

commit 3115b237d43fa2c7a24157ec88e1a9f9ec403900
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 4 00:05:25 2025 -0500

    [FEAT] kwargs wrapping utility (#309)
    
    This PR introduces a kwargs_wrapper that can wrap an existing positional
    only function to a function with optional kwargs signatures through code
    generation. We carefully limit the code generation scope to not impact
    by untrusted input so generated code is safe.
---
 python/tvm_ffi/utils/__init__.py          |   1 +
 python/tvm_ffi/utils/kwargs_wrapper.py    | 360 ++++++++++++++++++++++++++++++
 tests/python/utils/test_kwargs_wrapper.py | 321 ++++++++++++++++++++++++++
 tests/scripts/benchmark_kwargs_wrapper.py |  85 +++++++
 4 files changed, 767 insertions(+)

diff --git a/python/tvm_ffi/utils/__init__.py b/python/tvm_ffi/utils/__init__.py
index 896001e..f142067 100644
--- a/python/tvm_ffi/utils/__init__.py
+++ b/python/tvm_ffi/utils/__init__.py
@@ -16,4 +16,5 @@
 # under the License.
 """Utilities used by the tvm_ffi Python package."""
 
+from . import kwargs_wrapper
 from .lockfile import FileLock
diff --git a/python/tvm_ffi/utils/kwargs_wrapper.py 
b/python/tvm_ffi/utils/kwargs_wrapper.py
new file mode 100644
index 0000000..5f6395b
--- /dev/null
+++ b/python/tvm_ffi/utils/kwargs_wrapper.py
@@ -0,0 +1,360 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for creating high-performance keyword argument wrapper functions.
+
+This module provides tools for wrapping positional-only callables with
+keyword argument support using code generation techniques.
+"""
+
+from __future__ import annotations
+
+import functools
+import inspect
+from typing import Any, Callable
+
+# Sentinel object for missing arguments
+MISSING = object()
+
+
+def _validate_argument_names(names: list[str], arg_type: str) -> None:
+    """Validate that argument names are valid Python identifiers and unique.
+
+    Parameters
+    ----------
+    names
+        List of argument names to validate.
+    arg_type
+        Description of the argument type (e.g., "Argument", "Keyword-only 
argument").
+
+    """
+    # Check for duplicate names
+    if len(set(names)) != len(names):
+        raise ValueError(f"Duplicate {arg_type.lower()} names found in: 
{names}")
+
+    # Validate each name is a valid identifier
+    for name in names:
+        if not isinstance(name, str):
+            raise TypeError(
+                f"{arg_type} name must be a string, got {type(name).__name__}: 
{name!r}"
+            )
+        if not name.isidentifier():
+            raise ValueError(
+                f"Invalid {arg_type.lower()} name: {name!r} is not a valid 
Python identifier"
+            )
+
+
+def _validate_wrapper_args(
+    args_names: list[str],
+    args_defaults: tuple,
+    kwargsonly_names: list[str],
+    kwargsonly_defaults: dict[str, Any],
+    reserved_names: set[str],
+) -> None:
+    """Validate all input arguments for make_kwargs_wrapper.
+
+    Parameters
+    ----------
+    args_names
+        List of positional argument names.
+    args_defaults
+        Tuple of default values for positional arguments.
+    kwargsonly_names
+        List of keyword-only argument names.
+    kwargsonly_defaults
+        Dictionary of default values for keyword-only arguments.
+    reserved_names
+        Set of reserved internal names that cannot be used as argument names.
+
+    """
+    # Validate args_names are valid Python identifiers and unique
+    _validate_argument_names(args_names, "Argument")
+
+    # Validate args_defaults is a tuple
+    if not isinstance(args_defaults, tuple):
+        raise TypeError(f"args_defaults must be a tuple, got 
{type(args_defaults).__name__}")
+
+    # Validate args_defaults length doesn't exceed args_names length
+    if len(args_defaults) > len(args_names):
+        raise ValueError(
+            f"args_defaults has {len(args_defaults)} values but only "
+            f"{len(args_names)} positional arguments"
+        )
+
+    # Validate kwargsonly_names are valid identifiers and unique
+    _validate_argument_names(kwargsonly_names, "Keyword-only argument")
+
+    # Validate kwargsonly_defaults keys are in kwargsonly_names
+    kwargsonly_names_set = set(kwargsonly_names)
+    for key in kwargsonly_defaults:
+        if key not in kwargsonly_names_set:
+            raise ValueError(
+                f"Default provided for '{key}' which is not in 
kwargsonly_names: {kwargsonly_names}"
+            )
+
+    # Validate no overlap between positional and keyword-only arguments
+    args_names_set = set(args_names)
+    overlap = args_names_set & kwargsonly_names_set
+    if overlap:
+        raise ValueError(f"Arguments cannot be both positional and 
keyword-only: {overlap}")
+
+    # Validate no conflict between user argument names and internal names
+    all_user_names = args_names_set | kwargsonly_names_set
+    conflicts = all_user_names & reserved_names
+    if conflicts:
+        raise ValueError(
+            f"Argument names conflict with internal names: {conflicts}. "
+            f'Please avoid using names starting with "__i_"'
+        )
+
+
+def make_kwargs_wrapper(
+    target_func: Callable,
+    args_names: list[str],
+    args_defaults: tuple = (),
+    kwargsonly_names: list[str] | None = None,
+    kwargsonly_defaults: dict[str, Any] | None = None,
+    prototype_func: Callable | None = None,
+) -> Callable:
+    """Create a wrapper with kwargs support for a function that only accepts 
positional arguments.
+
+    This function dynamically generates a wrapper using code generation to 
minimize overhead.
+
+    Parameters
+    ----------
+    target_func
+        The underlying function to be called by the wrapper. This function 
must only
+        accept positional arguments.
+    args_names
+        A list of ALL positional argument names in order. These define the 
positional
+        parameters that the wrapper will accept. Must not overlap with 
kwargsonly_names.
+    args_defaults
+        A tuple of default values for positional arguments, right-aligned to 
args_names
+        (matching Python's __defaults__ behavior). The length of this tuple 
determines
+        how many trailing arguments have defaults.
+        Example: (10, 20) with args_names=['a', 'b', 'c', 'd'] means c=10, 
d=20.
+        Empty tuple () means no defaults.
+    kwargsonly_names
+        A list of keyword-only argument names. These arguments can only be 
passed by name,
+        not positionally, and appear after a '*' separator in the signature. 
Can include both
+        required and optional keyword-only arguments. Must not overlap with 
args_names.
+        Example: ['debug', 'timeout'] creates wrapper(..., *, debug, timeout).
+    kwargsonly_defaults
+        Optional dictionary of default values for keyword-only arguments 
(matching Python's
+        __kwdefaults__ behavior). Keys must be a subset of kwargsonly_names. 
Keyword-only
+        arguments not in this dict are required.
+        Example: {'timeout': 30} with kwargsonly_names=['debug', 'timeout'] 
means 'debug'
+        is required and 'timeout' is optional.
+    prototype_func
+        Optional prototype function to copy metadata (__name__, __doc__, 
__module__,
+        __qualname__, __annotations__) from. If None, no metadata is copied.
+
+    Returns
+    -------
+        A dynamically generated wrapper function with the specified signature
+
+    Notes
+    -----
+    The generated wrapper will directly embed default values for None and bool 
types
+    and use a MISSING sentinel object to distinguish between explicitly
+    passed arguments and those that should use default values for other types 
to ensure
+    the generated code does not contain unexpected str repr.
+
+    """
+    # Normalize inputs
+    if kwargsonly_names is None:
+        kwargsonly_names = []
+    if kwargsonly_defaults is None:
+        kwargsonly_defaults = {}
+
+    # Internal variable names used in generated code to avoid user argument 
conflicts
+    _INTERNAL_TARGET_FUNC = "__i_target_func"
+    _INTERNAL_MISSING = "__i_MISSING"
+    _INTERNAL_DEFAULTS_DICT = "__i_args_defaults"
+    _INTERNAL_NAMES = {_INTERNAL_TARGET_FUNC, _INTERNAL_MISSING, 
_INTERNAL_DEFAULTS_DICT}
+
+    # Validate all input arguments
+    _validate_wrapper_args(
+        args_names, args_defaults, kwargsonly_names, kwargsonly_defaults, 
_INTERNAL_NAMES
+    )
+
+    # Build positional defaults dictionary (right-aligned)
+    # Example: args_names=["a","b","c","d"], args_defaults=(10,20) -> {"c":10, 
"d":20}
+    args_defaults_dict = (
+        dict(zip(args_names[-len(args_defaults) :], args_defaults)) if 
args_defaults else {}
+    )
+
+    # Build wrapper signature and call arguments
+    # Note: this code must be in this function so all code generation and exec 
are self-contained
+    # We construct runtime_defaults dict for only non-safe defaults that need 
MISSING sentinel
+    arg_parts = []
+    call_parts = []
+    runtime_defaults = {}
+
+    def _add_param_with_default(name: str, default_value: Any) -> None:
+        """Add a parameter with a default value to arg_parts and call_parts."""
+        # Rationale: we directly embed default values for None and bool
+        # since they are common case and safe to be directly included in 
generated code.
+        #
+        # For other cases (including int/str), we use the MISSING sentinel to 
ensure
+        # generated code do not contain unexpected str repr and instead they 
are passed
+        # through runtime_defaults[name].
+        #
+        # we deliberately skip int/str since bring their string representation
+        # may involve __str__ / __repr__ that could be updated by subclasses.
+        # The missing check is generally fast enough and more controllable.
+        if default_value is None:
+            # Safe to use the default value None directly in the signature
+            arg_parts.append(f"{name}=None")
+            call_parts.append(name)
+        elif type(default_value) is bool:
+            # we deliberately not use isinstance to avoid subclasses of bool
+            # we also explicitly avoid repr for safety
+            default_value_str = "True" if default_value else "False"
+            arg_parts.append(f"{name}={default_value_str}")
+            call_parts.append(name)
+        else:
+            # For all other cases, we use the MISSING sentinel
+            arg_parts.append(f"{name}={_INTERNAL_MISSING}")
+            runtime_defaults[name] = default_value
+            # The conditional check runs
+            call_parts.append(
+                f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is 
{_INTERNAL_MISSING} else {name}'
+            )
+
+    # Handle positional arguments
+    for name in args_names:
+        if name in args_defaults_dict:
+            _add_param_with_default(name, args_defaults_dict[name])
+        else:
+            arg_parts.append(name)
+            call_parts.append(name)
+
+    # Handle keyword-only arguments
+    if kwargsonly_names:
+        arg_parts.append("*")  # Separator for keyword-only args
+        for name in kwargsonly_names:
+            if name in kwargsonly_defaults:
+                _add_param_with_default(name, kwargsonly_defaults[name])
+            else:
+                # Required keyword-only arg (no default)
+                arg_parts.append(name)
+                call_parts.append(name)
+
+    arg_list = ", ".join(arg_parts)
+    call_list = ", ".join(call_parts)
+
+    code_str = f"""
+def wrapper({arg_list}):
+    return {_INTERNAL_TARGET_FUNC}({call_list})
+"""
+    # Execute the generated code
+    exec_globals = {
+        _INTERNAL_TARGET_FUNC: target_func,
+        _INTERNAL_MISSING: MISSING,
+        _INTERNAL_DEFAULTS_DICT: runtime_defaults,
+    }
+    namespace: dict[str, Any] = {}
+    # Note: this is a limited use of exec that is safe.
+    # We ensure generated code does not contain any untrusted input.
+    # The argument names are validated and the default values are not part of 
generated code.
+    # Instead default values are set to MISSING sentinel object and explicitly 
passed as exec_globals.
+    # This is a practice adopted by `dataclasses` and `pydantic`
+    exec(code_str, exec_globals, namespace)
+    new_func = namespace["wrapper"]
+
+    # Copy metadata from prototype_func if provided
+    if prototype_func is not None:
+        functools.update_wrapper(new_func, prototype_func, updated=())
+
+    return new_func
+
+
+def make_kwargs_wrapper_from_signature(
+    target_func: Callable,
+    signature: inspect.Signature,
+    prototype_func: Callable | None = None,
+) -> Callable:
+    """Create a wrapper with kwargs support for a function that only accepts 
positional arguments.
+
+    This is a convenience function that extracts parameter information from a 
signature
+    object and calls make_kwargs_wrapper with the appropriate arguments. 
Supports both
+    required and optional keyword-only arguments.
+
+    Parameters
+    ----------
+    target_func
+        The underlying function to be called by the wrapper.
+    signature
+        An inspect.Signature object describing the desired wrapper signature.
+    prototype_func
+        Optional prototype function to copy metadata (__name__, __doc__, 
__module__,
+        __qualname__, __annotations__) from. If None, no metadata is copied.
+
+    Returns
+    -------
+        A dynamically generated wrapper function with the specified signature.
+
+    Raises
+    ------
+    ValueError
+        If the signature contains *args or **kwargs.
+
+    """
+    # Extract positional and keyword-only parameters
+    args_names = []
+    args_defaults_list = []
+    kwargsonly_names = []
+    kwargsonly_defaults = {}
+
+    # Track when we start seeing defaults for positional args
+    has_seen_positional_default = False
+
+    for param_name, param in signature.parameters.items():
+        if param.kind == inspect.Parameter.VAR_POSITIONAL:
+            raise ValueError("*args not supported in wrapper generation")
+        elif param.kind == inspect.Parameter.VAR_KEYWORD:
+            raise ValueError("**kwargs not supported in wrapper generation")
+        elif param.kind in (
+            inspect.Parameter.POSITIONAL_ONLY,
+            inspect.Parameter.POSITIONAL_OR_KEYWORD,
+        ):
+            args_names.append(param_name)
+            if param.default is not inspect.Parameter.empty:
+                has_seen_positional_default = True
+                args_defaults_list.append(param.default)
+            elif has_seen_positional_default:
+                # Required arg after optional arg (invalid in Python)
+                raise ValueError(
+                    f"Required positional parameter '{param_name}' cannot 
follow "
+                    f"parameters with defaults"
+                )
+        elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+            kwargsonly_names.append(param_name)
+            if param.default is not inspect.Parameter.empty:
+                kwargsonly_defaults[param_name] = param.default
+
+    # Convert defaults list to tuple (right-aligned to args_names)
+    args_defaults = tuple(args_defaults_list)
+
+    return make_kwargs_wrapper(
+        target_func,
+        args_names,
+        args_defaults,
+        kwargsonly_names,
+        kwargsonly_defaults,
+        prototype_func,
+    )
diff --git a/tests/python/utils/test_kwargs_wrapper.py 
b/tests/python/utils/test_kwargs_wrapper.py
new file mode 100644
index 0000000..244c50f
--- /dev/null
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -0,0 +1,321 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+from typing import Any
+
+import pytest
+from tvm_ffi.utils.kwargs_wrapper import make_kwargs_wrapper, 
make_kwargs_wrapper_from_signature
+
+
+def test_basic_wrapper() -> None:
+    """Test basic wrapper functionality with various argument combinations."""
+
+    def target(*args: Any) -> int:
+        return sum(args)
+
+    # No defaults - all required
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"])
+    assert wrapper(1, 2, 3) == 6
+    assert wrapper(a=1, b=2, c=3) == 6
+    assert wrapper(1, b=2, c=3) == 6
+
+    # Single default argument
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(10,))
+    assert wrapper(1, 2) == 13  # c=10
+    assert wrapper(1, 2, 3) == 6  # c=3 explicit
+    assert wrapper(1, 2, c=5) == 8  # c=5 via keyword
+
+    # Multiple defaults (right-aligned)
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(20, 
30))
+    assert wrapper(1) == 51  # b=20, c=30
+    assert wrapper(1, 2) == 33  # b=2, c=30
+    assert wrapper(1, 2, 3) == 6  # all explicit
+
+    # All defaults
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(1, 
2, 3))
+    assert wrapper() == 6
+    assert wrapper(10) == 15
+    assert wrapper(10, 20, 30) == 60
+
+    # Bound methods
+    class Calculator:
+        def __init__(self, base: int) -> None:
+            self.base = base
+
+        def add(self, a: int, b: int) -> int:
+            return self.base + a + b
+
+    calc = Calculator(100)
+    wrapper = make_kwargs_wrapper(calc.add, ["a", "b"], args_defaults=(5,))
+    assert wrapper(1) == 106
+
+
+def test_keyword_only_arguments() -> None:
+    """Test wrapper with keyword-only arguments."""
+
+    def target(*args: Any) -> int:
+        return sum(args)
+
+    # Optional keyword-only arguments (with defaults)
+    wrapper = make_kwargs_wrapper(
+        target,
+        ["a", "b"],
+        args_defaults=(),
+        kwargsonly_names=["c", "d"],
+        kwargsonly_defaults={"c": 100, "d": 200},
+    )
+    assert wrapper(1, 2) == 303  # c=100, d=200
+    assert wrapper(1, 2, c=10) == 213  # d=200
+    assert wrapper(1, 2, c=10, d=20) == 33
+
+    wrapper = make_kwargs_wrapper(
+        target, ["a", "b"], args_defaults=(), kwargsonly_names=["c", "d"], 
kwargsonly_defaults={}
+    )
+    assert wrapper(1, 2, c=10, d=20) == 33  # c and d are required
+
+    wrapper = make_kwargs_wrapper(
+        target,
+        ["a", "b"],
+        args_defaults=(),
+        kwargsonly_names=["c", "d"],
+        kwargsonly_defaults={"d": 100},
+    )
+    assert wrapper(1, 2, c=10) == 113  # c required, d=100
+    assert wrapper(1, 2, c=10, d=20) == 33  # both explicit
+
+    wrapper = make_kwargs_wrapper(
+        target,
+        ["a", "b", "c"],
+        args_defaults=(10,),
+        kwargsonly_names=["d", "e"],
+        kwargsonly_defaults={"d": 20, "e": 30},
+    )
+    assert wrapper(1, 2) == 63  # c=10, d=20, e=30
+    assert wrapper(1, 2, 5, d=15) == 53  # c=5 explicit, e=30
+
+
+def test_validation_errors() -> None:
+    """Test input validation and error handling."""
+    target = lambda *args: sum(args)
+
+    # Duplicate positional argument names
+    with pytest.raises(ValueError, match="Duplicate argument names found"):
+        make_kwargs_wrapper(target, ["a", "b", "a"])
+
+    # Duplicate keyword-only argument names
+    with pytest.raises(ValueError, match="Duplicate keyword-only argument 
names found"):
+        make_kwargs_wrapper(target, ["a"], kwargsonly_names=["b", "c", "b"])
+
+    # Invalid argument name types
+    with pytest.raises(TypeError, match="Argument name must be a string"):
+        make_kwargs_wrapper(target, ["a", 123])  # type: ignore[list-item]
+
+    # Invalid Python identifiers
+    with pytest.raises(ValueError, match="not a valid Python identifier"):
+        make_kwargs_wrapper(target, ["a", "b-c"])
+
+    # args_defaults not a tuple
+    with pytest.raises(TypeError, match="args_defaults must be a tuple"):
+        make_kwargs_wrapper(target, ["a", "b"], args_defaults=[10])  # type: 
ignore[arg-type]
+
+    # args_defaults too long
+    with pytest.raises(ValueError, match=r"args_defaults has .* values but 
only"):
+        make_kwargs_wrapper(target, ["a"], args_defaults=(1, 2, 3))
+
+    # Overlap between positional and keyword-only
+    with pytest.raises(ValueError, match="cannot be both positional and 
keyword-only"):
+        make_kwargs_wrapper(target, ["a", "b"], kwargsonly_names=["b"])
+
+    # kwargsonly_defaults key not in kwargsonly_names
+    with pytest.raises(ValueError, match="not in kwargsonly_names"):
+        make_kwargs_wrapper(
+            target, ["a", "b"], kwargsonly_names=["c"], 
kwargsonly_defaults={"d": 10}
+        )
+
+    # Internal name conflict
+    with pytest.raises(ValueError, match="conflict with internal names"):
+        make_kwargs_wrapper(target, ["__i_target_func", "b"])
+
+
+def test_special_default_values() -> None:
+    """Test wrapper with special default values like None and objects."""
+
+    def target(a: Any, b: Any, c: Any) -> tuple[Any, Any, Any]:
+        return (a, b, c)
+
+    # None as default
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], 
args_defaults=(None, None))
+    assert wrapper(1) == (1, None, None)
+
+    # Complex objects as defaults (verify object reference is preserved)
+    default_list = [1, 2, 3]
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], 
args_defaults=(default_list, None))
+    result = wrapper(1)
+    assert result[1] is default_list
+
+
+def test_wrapper_with_signature() -> None:
+    """Test make_kwargs_wrapper_from_signature."""
+    target = lambda *args: sum(args)
+
+    def source_func(a: Any, b: Any, c: int = 10, d: int = 20) -> None:
+        """Source function documentation."""
+        pass
+
+    sig = inspect.signature(source_func)
+    wrapper = make_kwargs_wrapper_from_signature(target, sig)
+    assert wrapper(1, 2) == 33  # 1 + 2 + 10 + 20
+    assert wrapper(1, 2, 3) == 26  # 1 + 2 + 3 + 20
+    assert wrapper(1, 2, 3, 4) == 10  # 1 + 2 + 3 + 4
+
+    # Test metadata preservation when prototype_func is provided
+    wrapper_with_metadata = make_kwargs_wrapper_from_signature(target, sig, 
source_func)
+    assert wrapper_with_metadata.__name__ == "source_func"
+    assert wrapper_with_metadata.__doc__ == "Source function documentation."
+
+    # With keyword-only arguments
+    def source_kwonly(a: Any, b: Any, *, c: int = 10, d: int = 20) -> None:
+        pass
+
+    wrapper = make_kwargs_wrapper_from_signature(target, 
inspect.signature(source_kwonly))
+    assert wrapper(1, 2) == 33
+    assert wrapper(1, 2, c=5, d=6) == 14
+
+    # With required keyword-only arguments
+    def source_required_kwonly(a: Any, b: Any, *, c: Any, d: int = 20) -> None:
+        pass
+
+    wrapper = make_kwargs_wrapper_from_signature(target, 
inspect.signature(source_required_kwonly))
+    assert wrapper(1, 2, c=10) == 33  # c required, d=20
+    assert wrapper(1, 2, c=10, d=5) == 18  # both explicit
+
+    # Reject *args and **kwargs
+    def with_varargs(a: Any, *args: Any) -> None:
+        pass
+
+    with pytest.raises(ValueError, match=r"\*args not supported"):
+        make_kwargs_wrapper_from_signature(target, 
inspect.signature(with_varargs))
+
+    def with_kwargs(a: Any, **kwargs: Any) -> None:
+        pass
+
+    with pytest.raises(ValueError, match=r"\*\*kwargs not supported"):
+        make_kwargs_wrapper_from_signature(target, 
inspect.signature(with_kwargs))
+
+
+def test_exception_propagation() -> None:
+    """Test that exceptions from the target function are properly 
propagated."""
+
+    def raising_func(a: int, b: int, c: str) -> int:
+        if a == 0:
+            raise ValueError("a cannot be zero")
+        if b < 0:
+            raise RuntimeError(f"b must be non-negative, got {b}")
+        if c != "valid":
+            raise TypeError(f"c must be 'valid', got {c!r}")
+        return a + b
+
+    # Test with positional defaults
+    wrapper = make_kwargs_wrapper(raising_func, ["a", "b", "c"], 
args_defaults=(10, "valid"))
+    assert wrapper(5) == 15
+
+    with pytest.raises(ValueError, match="a cannot be zero"):
+        wrapper(0)
+
+    with pytest.raises(RuntimeError, match="b must be non-negative"):
+        wrapper(1, -5)
+
+    # Test with keyword-only arguments
+    wrapper_kwonly = make_kwargs_wrapper(
+        raising_func,
+        ["a"],
+        kwargsonly_names=["b", "c"],
+        kwargsonly_defaults={"b": 10, "c": "valid"},
+    )
+    assert wrapper_kwonly(5) == 15
+
+    with pytest.raises(ValueError, match="a cannot be zero"):
+        wrapper_kwonly(0)
+
+    with pytest.raises(RuntimeError, match="b must be non-negative"):
+        wrapper_kwonly(5, b=-5)
+
+    with pytest.raises(TypeError, match="c must be 'valid'"):
+        wrapper_kwonly(5, c="invalid")
+
+
+def test_metadata_preservation() -> None:
+    """Test that function metadata is preserved when prototype_func is 
provided."""
+
+    def my_function(x: int, y: int = 10) -> int:
+        """Document the function."""
+        return x + y
+
+    target = lambda *args: sum(args)
+
+    wrapper = make_kwargs_wrapper(
+        target, ["x", "y"], args_defaults=(10,), prototype_func=my_function
+    )
+    assert wrapper.__name__ == "my_function"
+    assert wrapper.__doc__ == "Document the function."
+    assert wrapper.__annotations__ == my_function.__annotations__
+    assert wrapper(5) == 15
+
+
+def test_optimized_default_types() -> None:
+    """Test that None, bool, and str defaults work correctly.
+
+    This test verifies the optimization where None and bool defaults are
+    directly embedded in the generated signature, while str defaults use
+    the MISSING sentinel for safety.
+    """
+
+    def target(*args: Any) -> tuple[Any, ...]:
+        return args
+
+    # Test None default (should be optimized - directly embedded)
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], 
args_defaults=(None,))
+    assert wrapper(1, 2) == (1, 2, None)
+    assert wrapper(1, 2, 3) == (1, 2, 3)
+    assert wrapper(1, 2, c=None) == (1, 2, None)
+
+    # Test bool defaults (should be optimized - directly embedded)
+    wrapper = make_kwargs_wrapper(target, ["a", "flag", "debug"], 
args_defaults=(True, False))
+    assert wrapper(1) == (1, True, False)
+    assert wrapper(1, False) == (1, False, False)
+    assert wrapper(1, flag=False, debug=True) == (1, False, True)
+
+    # Test str default (should use MISSING sentinel for safety)
+    wrapper = make_kwargs_wrapper(target, ["a", "b", "name"], 
args_defaults=("default",))
+    assert wrapper(1, 2) == (1, 2, "default")
+    assert wrapper(1, 2, "custom") == (1, 2, "custom")
+    assert wrapper(1, 2, name="custom") == (1, 2, "custom")
+
+    # Test keyword-only with None, bool, and str
+    wrapper = make_kwargs_wrapper(
+        target,
+        ["a"],
+        kwargsonly_names=["b", "flag", "name"],
+        kwargsonly_defaults={"b": None, "flag": True, "name": "default"},
+    )
+    assert wrapper(1) == (1, None, True, "default")
+    assert wrapper(1, b=2) == (1, 2, True, "default")
+    assert wrapper(1, flag=False) == (1, None, False, "default")
+    assert wrapper(1, name="custom") == (1, None, True, "custom")
+    assert wrapper(1, b=2, flag=False, name="test") == (1, 2, False, "test")
diff --git a/tests/scripts/benchmark_kwargs_wrapper.py 
b/tests/scripts/benchmark_kwargs_wrapper.py
new file mode 100644
index 0000000..6eb0890
--- /dev/null
+++ b/tests/scripts/benchmark_kwargs_wrapper.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark API overhead of kwargs wrapper."""
+
+from __future__ import annotations
+
+import time
+from typing import Any
+
+from tvm_ffi.utils.kwargs_wrapper import make_kwargs_wrapper
+
+
+def print_speed(name: str, speed: float) -> None:
+    print(f"{name:<60} {speed} sec/call")
+
+
+def target_func(*args: Any) -> None:  # type: ignore[no-untyped-def]
+    pass
+
+
+def benchmark_kwargs_wrapper(repeat: int = 1000000) -> None:
+    """Benchmark kwargs wrapper with integer arguments."""
+    # Create test arguments
+    x = 1
+    y = 2
+    z = 3
+
+    # Create wrapper with two optional kwargs
+    wrapper = make_kwargs_wrapper(target_func, ["x", "y", "z"], 
args_defaults=(None, None))
+
+    # Benchmark 1: Direct call to target function (baseline)
+    start = time.time()
+    for _ in range(repeat):
+        target_func(x, y, z)
+    end = time.time()
+    print_speed("target_func(x, y, z)", (end - start) / repeat)
+
+    # Benchmark 2: Wrapper with all positional arguments
+    start = time.time()
+    for _ in range(repeat):
+        wrapper(x, y, z)
+    end = time.time()
+    print_speed("wrapper(x, y, z)", (end - start) / repeat)
+
+    # Benchmark 3: Wrapper with positional + kwargs
+    start = time.time()
+    for _ in range(repeat):
+        wrapper(x, y=y, z=z)
+    end = time.time()
+    print_speed("wrapper(x, y=y, z=z)", (end - start) / repeat)
+
+    # Benchmark 4: Wrapper with all kwargs
+    start = time.time()
+    for _ in range(repeat):
+        wrapper(x=x, y=y, z=z)
+    end = time.time()
+    print_speed("wrapper(x=x, y=y, z=z)", (end - start) / repeat)
+
+    # Benchmark 5: Wrapper with defaults
+    start = time.time()
+    for _ in range(repeat):
+        wrapper(x)
+    end = time.time()
+    print_speed("wrapper(x) [with defaults]", (end - start) / repeat)
+
+
+if __name__ == "__main__":
+    print("Benchmarking kwargs_wrapper overhead...")
+    print("-" * 90)
+    benchmark_kwargs_wrapper()
+    print("-" * 90)

Reply via email to