This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm-ffi.git
The following commit(s) were added to refs/heads/main by this push:
new 3115b23 [FEAT] kwargs wrapping utility (#309)
3115b23 is described below
commit 3115b237d43fa2c7a24157ec88e1a9f9ec403900
Author: Tianqi Chen <[email protected]>
AuthorDate: Thu Dec 4 00:05:25 2025 -0500
[FEAT] kwargs wrapping utility (#309)
This PR introduces a kwargs_wrapper that can wrap an existing positional
only function to a function with optional kwargs signatures through code
generation. We carefully limit the code generation scope to not impact
by untrusted input so generated code is safe.
---
python/tvm_ffi/utils/__init__.py | 1 +
python/tvm_ffi/utils/kwargs_wrapper.py | 360 ++++++++++++++++++++++++++++++
tests/python/utils/test_kwargs_wrapper.py | 321 ++++++++++++++++++++++++++
tests/scripts/benchmark_kwargs_wrapper.py | 85 +++++++
4 files changed, 767 insertions(+)
diff --git a/python/tvm_ffi/utils/__init__.py b/python/tvm_ffi/utils/__init__.py
index 896001e..f142067 100644
--- a/python/tvm_ffi/utils/__init__.py
+++ b/python/tvm_ffi/utils/__init__.py
@@ -16,4 +16,5 @@
# under the License.
"""Utilities used by the tvm_ffi Python package."""
+from . import kwargs_wrapper
from .lockfile import FileLock
diff --git a/python/tvm_ffi/utils/kwargs_wrapper.py
b/python/tvm_ffi/utils/kwargs_wrapper.py
new file mode 100644
index 0000000..5f6395b
--- /dev/null
+++ b/python/tvm_ffi/utils/kwargs_wrapper.py
@@ -0,0 +1,360 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utilities for creating high-performance keyword argument wrapper functions.
+
+This module provides tools for wrapping positional-only callables with
+keyword argument support using code generation techniques.
+"""
+
+from __future__ import annotations
+
+import functools
+import inspect
+from typing import Any, Callable
+
+# Sentinel object for missing arguments
+MISSING = object()
+
+
+def _validate_argument_names(names: list[str], arg_type: str) -> None:
+ """Validate that argument names are valid Python identifiers and unique.
+
+ Parameters
+ ----------
+ names
+ List of argument names to validate.
+ arg_type
+ Description of the argument type (e.g., "Argument", "Keyword-only
argument").
+
+ """
+ # Check for duplicate names
+ if len(set(names)) != len(names):
+ raise ValueError(f"Duplicate {arg_type.lower()} names found in:
{names}")
+
+ # Validate each name is a valid identifier
+ for name in names:
+ if not isinstance(name, str):
+ raise TypeError(
+ f"{arg_type} name must be a string, got {type(name).__name__}:
{name!r}"
+ )
+ if not name.isidentifier():
+ raise ValueError(
+ f"Invalid {arg_type.lower()} name: {name!r} is not a valid
Python identifier"
+ )
+
+
+def _validate_wrapper_args(
+ args_names: list[str],
+ args_defaults: tuple,
+ kwargsonly_names: list[str],
+ kwargsonly_defaults: dict[str, Any],
+ reserved_names: set[str],
+) -> None:
+ """Validate all input arguments for make_kwargs_wrapper.
+
+ Parameters
+ ----------
+ args_names
+ List of positional argument names.
+ args_defaults
+ Tuple of default values for positional arguments.
+ kwargsonly_names
+ List of keyword-only argument names.
+ kwargsonly_defaults
+ Dictionary of default values for keyword-only arguments.
+ reserved_names
+ Set of reserved internal names that cannot be used as argument names.
+
+ """
+ # Validate args_names are valid Python identifiers and unique
+ _validate_argument_names(args_names, "Argument")
+
+ # Validate args_defaults is a tuple
+ if not isinstance(args_defaults, tuple):
+ raise TypeError(f"args_defaults must be a tuple, got
{type(args_defaults).__name__}")
+
+ # Validate args_defaults length doesn't exceed args_names length
+ if len(args_defaults) > len(args_names):
+ raise ValueError(
+ f"args_defaults has {len(args_defaults)} values but only "
+ f"{len(args_names)} positional arguments"
+ )
+
+ # Validate kwargsonly_names are valid identifiers and unique
+ _validate_argument_names(kwargsonly_names, "Keyword-only argument")
+
+ # Validate kwargsonly_defaults keys are in kwargsonly_names
+ kwargsonly_names_set = set(kwargsonly_names)
+ for key in kwargsonly_defaults:
+ if key not in kwargsonly_names_set:
+ raise ValueError(
+ f"Default provided for '{key}' which is not in
kwargsonly_names: {kwargsonly_names}"
+ )
+
+ # Validate no overlap between positional and keyword-only arguments
+ args_names_set = set(args_names)
+ overlap = args_names_set & kwargsonly_names_set
+ if overlap:
+ raise ValueError(f"Arguments cannot be both positional and
keyword-only: {overlap}")
+
+ # Validate no conflict between user argument names and internal names
+ all_user_names = args_names_set | kwargsonly_names_set
+ conflicts = all_user_names & reserved_names
+ if conflicts:
+ raise ValueError(
+ f"Argument names conflict with internal names: {conflicts}. "
+ f'Please avoid using names starting with "__i_"'
+ )
+
+
+def make_kwargs_wrapper(
+ target_func: Callable,
+ args_names: list[str],
+ args_defaults: tuple = (),
+ kwargsonly_names: list[str] | None = None,
+ kwargsonly_defaults: dict[str, Any] | None = None,
+ prototype_func: Callable | None = None,
+) -> Callable:
+ """Create a wrapper with kwargs support for a function that only accepts
positional arguments.
+
+ This function dynamically generates a wrapper using code generation to
minimize overhead.
+
+ Parameters
+ ----------
+ target_func
+ The underlying function to be called by the wrapper. This function
must only
+ accept positional arguments.
+ args_names
+ A list of ALL positional argument names in order. These define the
positional
+ parameters that the wrapper will accept. Must not overlap with
kwargsonly_names.
+ args_defaults
+ A tuple of default values for positional arguments, right-aligned to
args_names
+ (matching Python's __defaults__ behavior). The length of this tuple
determines
+ how many trailing arguments have defaults.
+ Example: (10, 20) with args_names=['a', 'b', 'c', 'd'] means c=10,
d=20.
+ Empty tuple () means no defaults.
+ kwargsonly_names
+ A list of keyword-only argument names. These arguments can only be
passed by name,
+ not positionally, and appear after a '*' separator in the signature.
Can include both
+ required and optional keyword-only arguments. Must not overlap with
args_names.
+ Example: ['debug', 'timeout'] creates wrapper(..., *, debug, timeout).
+ kwargsonly_defaults
+ Optional dictionary of default values for keyword-only arguments
(matching Python's
+ __kwdefaults__ behavior). Keys must be a subset of kwargsonly_names.
Keyword-only
+ arguments not in this dict are required.
+ Example: {'timeout': 30} with kwargsonly_names=['debug', 'timeout']
means 'debug'
+ is required and 'timeout' is optional.
+ prototype_func
+ Optional prototype function to copy metadata (__name__, __doc__,
__module__,
+ __qualname__, __annotations__) from. If None, no metadata is copied.
+
+ Returns
+ -------
+ A dynamically generated wrapper function with the specified signature
+
+ Notes
+ -----
+ The generated wrapper will directly embed default values for None and bool
types
+ and use a MISSING sentinel object to distinguish between explicitly
+ passed arguments and those that should use default values for other types
to ensure
+ the generated code does not contain unexpected str repr.
+
+ """
+ # Normalize inputs
+ if kwargsonly_names is None:
+ kwargsonly_names = []
+ if kwargsonly_defaults is None:
+ kwargsonly_defaults = {}
+
+ # Internal variable names used in generated code to avoid user argument
conflicts
+ _INTERNAL_TARGET_FUNC = "__i_target_func"
+ _INTERNAL_MISSING = "__i_MISSING"
+ _INTERNAL_DEFAULTS_DICT = "__i_args_defaults"
+ _INTERNAL_NAMES = {_INTERNAL_TARGET_FUNC, _INTERNAL_MISSING,
_INTERNAL_DEFAULTS_DICT}
+
+ # Validate all input arguments
+ _validate_wrapper_args(
+ args_names, args_defaults, kwargsonly_names, kwargsonly_defaults,
_INTERNAL_NAMES
+ )
+
+ # Build positional defaults dictionary (right-aligned)
+ # Example: args_names=["a","b","c","d"], args_defaults=(10,20) -> {"c":10,
"d":20}
+ args_defaults_dict = (
+ dict(zip(args_names[-len(args_defaults) :], args_defaults)) if
args_defaults else {}
+ )
+
+ # Build wrapper signature and call arguments
+ # Note: this code must be in this function so all code generation and exec
are self-contained
+ # We construct runtime_defaults dict for only non-safe defaults that need
MISSING sentinel
+ arg_parts = []
+ call_parts = []
+ runtime_defaults = {}
+
+ def _add_param_with_default(name: str, default_value: Any) -> None:
+ """Add a parameter with a default value to arg_parts and call_parts."""
+ # Rationale: we directly embed default values for None and bool
+ # since they are common case and safe to be directly included in
generated code.
+ #
+ # For other cases (including int/str), we use the MISSING sentinel to
ensure
+ # generated code do not contain unexpected str repr and instead they
are passed
+ # through runtime_defaults[name].
+ #
+ # we deliberately skip int/str since bring their string representation
+ # may involve __str__ / __repr__ that could be updated by subclasses.
+ # The missing check is generally fast enough and more controllable.
+ if default_value is None:
+ # Safe to use the default value None directly in the signature
+ arg_parts.append(f"{name}=None")
+ call_parts.append(name)
+ elif type(default_value) is bool:
+ # we deliberately not use isinstance to avoid subclasses of bool
+ # we also explicitly avoid repr for safety
+ default_value_str = "True" if default_value else "False"
+ arg_parts.append(f"{name}={default_value_str}")
+ call_parts.append(name)
+ else:
+ # For all other cases, we use the MISSING sentinel
+ arg_parts.append(f"{name}={_INTERNAL_MISSING}")
+ runtime_defaults[name] = default_value
+ # The conditional check runs
+ call_parts.append(
+ f'{_INTERNAL_DEFAULTS_DICT}["{name}"] if {name} is
{_INTERNAL_MISSING} else {name}'
+ )
+
+ # Handle positional arguments
+ for name in args_names:
+ if name in args_defaults_dict:
+ _add_param_with_default(name, args_defaults_dict[name])
+ else:
+ arg_parts.append(name)
+ call_parts.append(name)
+
+ # Handle keyword-only arguments
+ if kwargsonly_names:
+ arg_parts.append("*") # Separator for keyword-only args
+ for name in kwargsonly_names:
+ if name in kwargsonly_defaults:
+ _add_param_with_default(name, kwargsonly_defaults[name])
+ else:
+ # Required keyword-only arg (no default)
+ arg_parts.append(name)
+ call_parts.append(name)
+
+ arg_list = ", ".join(arg_parts)
+ call_list = ", ".join(call_parts)
+
+ code_str = f"""
+def wrapper({arg_list}):
+ return {_INTERNAL_TARGET_FUNC}({call_list})
+"""
+ # Execute the generated code
+ exec_globals = {
+ _INTERNAL_TARGET_FUNC: target_func,
+ _INTERNAL_MISSING: MISSING,
+ _INTERNAL_DEFAULTS_DICT: runtime_defaults,
+ }
+ namespace: dict[str, Any] = {}
+ # Note: this is a limited use of exec that is safe.
+ # We ensure generated code does not contain any untrusted input.
+ # The argument names are validated and the default values are not part of
generated code.
+ # Instead default values are set to MISSING sentinel object and explicitly
passed as exec_globals.
+ # This is a practice adopted by `dataclasses` and `pydantic`
+ exec(code_str, exec_globals, namespace)
+ new_func = namespace["wrapper"]
+
+ # Copy metadata from prototype_func if provided
+ if prototype_func is not None:
+ functools.update_wrapper(new_func, prototype_func, updated=())
+
+ return new_func
+
+
+def make_kwargs_wrapper_from_signature(
+ target_func: Callable,
+ signature: inspect.Signature,
+ prototype_func: Callable | None = None,
+) -> Callable:
+ """Create a wrapper with kwargs support for a function that only accepts
positional arguments.
+
+ This is a convenience function that extracts parameter information from a
signature
+ object and calls make_kwargs_wrapper with the appropriate arguments.
Supports both
+ required and optional keyword-only arguments.
+
+ Parameters
+ ----------
+ target_func
+ The underlying function to be called by the wrapper.
+ signature
+ An inspect.Signature object describing the desired wrapper signature.
+ prototype_func
+ Optional prototype function to copy metadata (__name__, __doc__,
__module__,
+ __qualname__, __annotations__) from. If None, no metadata is copied.
+
+ Returns
+ -------
+ A dynamically generated wrapper function with the specified signature.
+
+ Raises
+ ------
+ ValueError
+ If the signature contains *args or **kwargs.
+
+ """
+ # Extract positional and keyword-only parameters
+ args_names = []
+ args_defaults_list = []
+ kwargsonly_names = []
+ kwargsonly_defaults = {}
+
+ # Track when we start seeing defaults for positional args
+ has_seen_positional_default = False
+
+ for param_name, param in signature.parameters.items():
+ if param.kind == inspect.Parameter.VAR_POSITIONAL:
+ raise ValueError("*args not supported in wrapper generation")
+ elif param.kind == inspect.Parameter.VAR_KEYWORD:
+ raise ValueError("**kwargs not supported in wrapper generation")
+ elif param.kind in (
+ inspect.Parameter.POSITIONAL_ONLY,
+ inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ ):
+ args_names.append(param_name)
+ if param.default is not inspect.Parameter.empty:
+ has_seen_positional_default = True
+ args_defaults_list.append(param.default)
+ elif has_seen_positional_default:
+ # Required arg after optional arg (invalid in Python)
+ raise ValueError(
+ f"Required positional parameter '{param_name}' cannot
follow "
+ f"parameters with defaults"
+ )
+ elif param.kind == inspect.Parameter.KEYWORD_ONLY:
+ kwargsonly_names.append(param_name)
+ if param.default is not inspect.Parameter.empty:
+ kwargsonly_defaults[param_name] = param.default
+
+ # Convert defaults list to tuple (right-aligned to args_names)
+ args_defaults = tuple(args_defaults_list)
+
+ return make_kwargs_wrapper(
+ target_func,
+ args_names,
+ args_defaults,
+ kwargsonly_names,
+ kwargsonly_defaults,
+ prototype_func,
+ )
diff --git a/tests/python/utils/test_kwargs_wrapper.py
b/tests/python/utils/test_kwargs_wrapper.py
new file mode 100644
index 0000000..244c50f
--- /dev/null
+++ b/tests/python/utils/test_kwargs_wrapper.py
@@ -0,0 +1,321 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import inspect
+from typing import Any
+
+import pytest
+from tvm_ffi.utils.kwargs_wrapper import make_kwargs_wrapper,
make_kwargs_wrapper_from_signature
+
+
+def test_basic_wrapper() -> None:
+ """Test basic wrapper functionality with various argument combinations."""
+
+ def target(*args: Any) -> int:
+ return sum(args)
+
+ # No defaults - all required
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"])
+ assert wrapper(1, 2, 3) == 6
+ assert wrapper(a=1, b=2, c=3) == 6
+ assert wrapper(1, b=2, c=3) == 6
+
+ # Single default argument
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(10,))
+ assert wrapper(1, 2) == 13 # c=10
+ assert wrapper(1, 2, 3) == 6 # c=3 explicit
+ assert wrapper(1, 2, c=5) == 8 # c=5 via keyword
+
+ # Multiple defaults (right-aligned)
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(20,
30))
+ assert wrapper(1) == 51 # b=20, c=30
+ assert wrapper(1, 2) == 33 # b=2, c=30
+ assert wrapper(1, 2, 3) == 6 # all explicit
+
+ # All defaults
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"], args_defaults=(1,
2, 3))
+ assert wrapper() == 6
+ assert wrapper(10) == 15
+ assert wrapper(10, 20, 30) == 60
+
+ # Bound methods
+ class Calculator:
+ def __init__(self, base: int) -> None:
+ self.base = base
+
+ def add(self, a: int, b: int) -> int:
+ return self.base + a + b
+
+ calc = Calculator(100)
+ wrapper = make_kwargs_wrapper(calc.add, ["a", "b"], args_defaults=(5,))
+ assert wrapper(1) == 106
+
+
+def test_keyword_only_arguments() -> None:
+ """Test wrapper with keyword-only arguments."""
+
+ def target(*args: Any) -> int:
+ return sum(args)
+
+ # Optional keyword-only arguments (with defaults)
+ wrapper = make_kwargs_wrapper(
+ target,
+ ["a", "b"],
+ args_defaults=(),
+ kwargsonly_names=["c", "d"],
+ kwargsonly_defaults={"c": 100, "d": 200},
+ )
+ assert wrapper(1, 2) == 303 # c=100, d=200
+ assert wrapper(1, 2, c=10) == 213 # d=200
+ assert wrapper(1, 2, c=10, d=20) == 33
+
+ wrapper = make_kwargs_wrapper(
+ target, ["a", "b"], args_defaults=(), kwargsonly_names=["c", "d"],
kwargsonly_defaults={}
+ )
+ assert wrapper(1, 2, c=10, d=20) == 33 # c and d are required
+
+ wrapper = make_kwargs_wrapper(
+ target,
+ ["a", "b"],
+ args_defaults=(),
+ kwargsonly_names=["c", "d"],
+ kwargsonly_defaults={"d": 100},
+ )
+ assert wrapper(1, 2, c=10) == 113 # c required, d=100
+ assert wrapper(1, 2, c=10, d=20) == 33 # both explicit
+
+ wrapper = make_kwargs_wrapper(
+ target,
+ ["a", "b", "c"],
+ args_defaults=(10,),
+ kwargsonly_names=["d", "e"],
+ kwargsonly_defaults={"d": 20, "e": 30},
+ )
+ assert wrapper(1, 2) == 63 # c=10, d=20, e=30
+ assert wrapper(1, 2, 5, d=15) == 53 # c=5 explicit, e=30
+
+
+def test_validation_errors() -> None:
+ """Test input validation and error handling."""
+ target = lambda *args: sum(args)
+
+ # Duplicate positional argument names
+ with pytest.raises(ValueError, match="Duplicate argument names found"):
+ make_kwargs_wrapper(target, ["a", "b", "a"])
+
+ # Duplicate keyword-only argument names
+ with pytest.raises(ValueError, match="Duplicate keyword-only argument
names found"):
+ make_kwargs_wrapper(target, ["a"], kwargsonly_names=["b", "c", "b"])
+
+ # Invalid argument name types
+ with pytest.raises(TypeError, match="Argument name must be a string"):
+ make_kwargs_wrapper(target, ["a", 123]) # type: ignore[list-item]
+
+ # Invalid Python identifiers
+ with pytest.raises(ValueError, match="not a valid Python identifier"):
+ make_kwargs_wrapper(target, ["a", "b-c"])
+
+ # args_defaults not a tuple
+ with pytest.raises(TypeError, match="args_defaults must be a tuple"):
+ make_kwargs_wrapper(target, ["a", "b"], args_defaults=[10]) # type:
ignore[arg-type]
+
+ # args_defaults too long
+ with pytest.raises(ValueError, match=r"args_defaults has .* values but
only"):
+ make_kwargs_wrapper(target, ["a"], args_defaults=(1, 2, 3))
+
+ # Overlap between positional and keyword-only
+ with pytest.raises(ValueError, match="cannot be both positional and
keyword-only"):
+ make_kwargs_wrapper(target, ["a", "b"], kwargsonly_names=["b"])
+
+ # kwargsonly_defaults key not in kwargsonly_names
+ with pytest.raises(ValueError, match="not in kwargsonly_names"):
+ make_kwargs_wrapper(
+ target, ["a", "b"], kwargsonly_names=["c"],
kwargsonly_defaults={"d": 10}
+ )
+
+ # Internal name conflict
+ with pytest.raises(ValueError, match="conflict with internal names"):
+ make_kwargs_wrapper(target, ["__i_target_func", "b"])
+
+
+def test_special_default_values() -> None:
+ """Test wrapper with special default values like None and objects."""
+
+ def target(a: Any, b: Any, c: Any) -> tuple[Any, Any, Any]:
+ return (a, b, c)
+
+ # None as default
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"],
args_defaults=(None, None))
+ assert wrapper(1) == (1, None, None)
+
+ # Complex objects as defaults (verify object reference is preserved)
+ default_list = [1, 2, 3]
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"],
args_defaults=(default_list, None))
+ result = wrapper(1)
+ assert result[1] is default_list
+
+
+def test_wrapper_with_signature() -> None:
+ """Test make_kwargs_wrapper_from_signature."""
+ target = lambda *args: sum(args)
+
+ def source_func(a: Any, b: Any, c: int = 10, d: int = 20) -> None:
+ """Source function documentation."""
+ pass
+
+ sig = inspect.signature(source_func)
+ wrapper = make_kwargs_wrapper_from_signature(target, sig)
+ assert wrapper(1, 2) == 33 # 1 + 2 + 10 + 20
+ assert wrapper(1, 2, 3) == 26 # 1 + 2 + 3 + 20
+ assert wrapper(1, 2, 3, 4) == 10 # 1 + 2 + 3 + 4
+
+ # Test metadata preservation when prototype_func is provided
+ wrapper_with_metadata = make_kwargs_wrapper_from_signature(target, sig,
source_func)
+ assert wrapper_with_metadata.__name__ == "source_func"
+ assert wrapper_with_metadata.__doc__ == "Source function documentation."
+
+ # With keyword-only arguments
+ def source_kwonly(a: Any, b: Any, *, c: int = 10, d: int = 20) -> None:
+ pass
+
+ wrapper = make_kwargs_wrapper_from_signature(target,
inspect.signature(source_kwonly))
+ assert wrapper(1, 2) == 33
+ assert wrapper(1, 2, c=5, d=6) == 14
+
+ # With required keyword-only arguments
+ def source_required_kwonly(a: Any, b: Any, *, c: Any, d: int = 20) -> None:
+ pass
+
+ wrapper = make_kwargs_wrapper_from_signature(target,
inspect.signature(source_required_kwonly))
+ assert wrapper(1, 2, c=10) == 33 # c required, d=20
+ assert wrapper(1, 2, c=10, d=5) == 18 # both explicit
+
+ # Reject *args and **kwargs
+ def with_varargs(a: Any, *args: Any) -> None:
+ pass
+
+ with pytest.raises(ValueError, match=r"\*args not supported"):
+ make_kwargs_wrapper_from_signature(target,
inspect.signature(with_varargs))
+
+ def with_kwargs(a: Any, **kwargs: Any) -> None:
+ pass
+
+ with pytest.raises(ValueError, match=r"\*\*kwargs not supported"):
+ make_kwargs_wrapper_from_signature(target,
inspect.signature(with_kwargs))
+
+
+def test_exception_propagation() -> None:
+ """Test that exceptions from the target function are properly
propagated."""
+
+ def raising_func(a: int, b: int, c: str) -> int:
+ if a == 0:
+ raise ValueError("a cannot be zero")
+ if b < 0:
+ raise RuntimeError(f"b must be non-negative, got {b}")
+ if c != "valid":
+ raise TypeError(f"c must be 'valid', got {c!r}")
+ return a + b
+
+ # Test with positional defaults
+ wrapper = make_kwargs_wrapper(raising_func, ["a", "b", "c"],
args_defaults=(10, "valid"))
+ assert wrapper(5) == 15
+
+ with pytest.raises(ValueError, match="a cannot be zero"):
+ wrapper(0)
+
+ with pytest.raises(RuntimeError, match="b must be non-negative"):
+ wrapper(1, -5)
+
+ # Test with keyword-only arguments
+ wrapper_kwonly = make_kwargs_wrapper(
+ raising_func,
+ ["a"],
+ kwargsonly_names=["b", "c"],
+ kwargsonly_defaults={"b": 10, "c": "valid"},
+ )
+ assert wrapper_kwonly(5) == 15
+
+ with pytest.raises(ValueError, match="a cannot be zero"):
+ wrapper_kwonly(0)
+
+ with pytest.raises(RuntimeError, match="b must be non-negative"):
+ wrapper_kwonly(5, b=-5)
+
+ with pytest.raises(TypeError, match="c must be 'valid'"):
+ wrapper_kwonly(5, c="invalid")
+
+
+def test_metadata_preservation() -> None:
+ """Test that function metadata is preserved when prototype_func is
provided."""
+
+ def my_function(x: int, y: int = 10) -> int:
+ """Document the function."""
+ return x + y
+
+ target = lambda *args: sum(args)
+
+ wrapper = make_kwargs_wrapper(
+ target, ["x", "y"], args_defaults=(10,), prototype_func=my_function
+ )
+ assert wrapper.__name__ == "my_function"
+ assert wrapper.__doc__ == "Document the function."
+ assert wrapper.__annotations__ == my_function.__annotations__
+ assert wrapper(5) == 15
+
+
+def test_optimized_default_types() -> None:
+ """Test that None, bool, and str defaults work correctly.
+
+ This test verifies the optimization where None and bool defaults are
+ directly embedded in the generated signature, while str defaults use
+ the MISSING sentinel for safety.
+ """
+
+ def target(*args: Any) -> tuple[Any, ...]:
+ return args
+
+ # Test None default (should be optimized - directly embedded)
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "c"],
args_defaults=(None,))
+ assert wrapper(1, 2) == (1, 2, None)
+ assert wrapper(1, 2, 3) == (1, 2, 3)
+ assert wrapper(1, 2, c=None) == (1, 2, None)
+
+ # Test bool defaults (should be optimized - directly embedded)
+ wrapper = make_kwargs_wrapper(target, ["a", "flag", "debug"],
args_defaults=(True, False))
+ assert wrapper(1) == (1, True, False)
+ assert wrapper(1, False) == (1, False, False)
+ assert wrapper(1, flag=False, debug=True) == (1, False, True)
+
+ # Test str default (should use MISSING sentinel for safety)
+ wrapper = make_kwargs_wrapper(target, ["a", "b", "name"],
args_defaults=("default",))
+ assert wrapper(1, 2) == (1, 2, "default")
+ assert wrapper(1, 2, "custom") == (1, 2, "custom")
+ assert wrapper(1, 2, name="custom") == (1, 2, "custom")
+
+ # Test keyword-only with None, bool, and str
+ wrapper = make_kwargs_wrapper(
+ target,
+ ["a"],
+ kwargsonly_names=["b", "flag", "name"],
+ kwargsonly_defaults={"b": None, "flag": True, "name": "default"},
+ )
+ assert wrapper(1) == (1, None, True, "default")
+ assert wrapper(1, b=2) == (1, 2, True, "default")
+ assert wrapper(1, flag=False) == (1, None, False, "default")
+ assert wrapper(1, name="custom") == (1, None, True, "custom")
+ assert wrapper(1, b=2, flag=False, name="test") == (1, 2, False, "test")
diff --git a/tests/scripts/benchmark_kwargs_wrapper.py
b/tests/scripts/benchmark_kwargs_wrapper.py
new file mode 100644
index 0000000..6eb0890
--- /dev/null
+++ b/tests/scripts/benchmark_kwargs_wrapper.py
@@ -0,0 +1,85 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Benchmark API overhead of kwargs wrapper."""
+
+from __future__ import annotations
+
+import time
+from typing import Any
+
+from tvm_ffi.utils.kwargs_wrapper import make_kwargs_wrapper
+
+
+def print_speed(name: str, speed: float) -> None:
+ print(f"{name:<60} {speed} sec/call")
+
+
+def target_func(*args: Any) -> None: # type: ignore[no-untyped-def]
+ pass
+
+
+def benchmark_kwargs_wrapper(repeat: int = 1000000) -> None:
+ """Benchmark kwargs wrapper with integer arguments."""
+ # Create test arguments
+ x = 1
+ y = 2
+ z = 3
+
+ # Create wrapper with two optional kwargs
+ wrapper = make_kwargs_wrapper(target_func, ["x", "y", "z"],
args_defaults=(None, None))
+
+ # Benchmark 1: Direct call to target function (baseline)
+ start = time.time()
+ for _ in range(repeat):
+ target_func(x, y, z)
+ end = time.time()
+ print_speed("target_func(x, y, z)", (end - start) / repeat)
+
+ # Benchmark 2: Wrapper with all positional arguments
+ start = time.time()
+ for _ in range(repeat):
+ wrapper(x, y, z)
+ end = time.time()
+ print_speed("wrapper(x, y, z)", (end - start) / repeat)
+
+ # Benchmark 3: Wrapper with positional + kwargs
+ start = time.time()
+ for _ in range(repeat):
+ wrapper(x, y=y, z=z)
+ end = time.time()
+ print_speed("wrapper(x, y=y, z=z)", (end - start) / repeat)
+
+ # Benchmark 4: Wrapper with all kwargs
+ start = time.time()
+ for _ in range(repeat):
+ wrapper(x=x, y=y, z=z)
+ end = time.time()
+ print_speed("wrapper(x=x, y=y, z=z)", (end - start) / repeat)
+
+ # Benchmark 5: Wrapper with defaults
+ start = time.time()
+ for _ in range(repeat):
+ wrapper(x)
+ end = time.time()
+ print_speed("wrapper(x) [with defaults]", (end - start) / repeat)
+
+
+if __name__ == "__main__":
+ print("Benchmarking kwargs_wrapper overhead...")
+ print("-" * 90)
+ benchmark_kwargs_wrapper()
+ print("-" * 90)