This is an automated email from the ASF dual-hosted git repository. wenjin272 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit eac1221f9a9d8d8298dd6dbe577bfd8425a0031c Author: WenjinXie <[email protected]> AuthorDate: Thu May 14 14:43:22 2026 +0800 [api][plan][runtime] Cross-language Function/FunctionTool Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> --- python/flink_agents/api/agents/agent.py | 15 ++- python/flink_agents/api/function.py | 109 ++++++++++++++++++ python/flink_agents/api/tools/function_tool.py | 33 ++++++ python/flink_agents/api/tools/tool.py | 39 +++---- .../anthropic/tests/test_anthropic_chat_model.py | 4 +- .../azure/tests/test_azure_openai_chat_model.py | 5 +- .../openai/tests/test_openai_chat_model.py | 5 +- .../chat_models/tests/test_ollama_chat_model.py | 5 +- .../chat_models/tests/test_tongyi_chat_model.py | 5 +- .../integrations/mcp/tests/test_mcp.py | 27 ++++- python/flink_agents/plan/agent_plan.py | 44 ++++++- python/flink_agents/plan/function.py | 42 ++++++- python/flink_agents/plan/tests/test_function.py | 35 +----- .../plan/tests/tools/resources/function_tool.json | 12 +- .../plan/tests/tools/test_function_tool.py | 127 ++++++++++++++++++++- python/flink_agents/plan/tools/bash/bash_tool.py | 2 - python/flink_agents/plan/tools/function_tool.py | 81 +++++++++---- .../runtime/java/java_resource_wrapper.py | 14 ++- python/flink_agents/runtime/resource_cache.py | 4 + python/flink_agents/runtime/skill/skill_tools.py | 2 - .../runtime/operator/ActionExecutionOperator.java | 3 +- .../runtime/operator/PythonBridgeManager.java | 9 +- .../runtime/python/utils/JavaResourceAdapter.java | 116 ++++++++++++++++++- .../runtime/operator/PythonBridgeManagerTest.java | 3 +- 24 files changed, 613 insertions(+), 128 deletions(-) diff --git a/python/flink_agents/api/agents/agent.py b/python/flink_agents/api/agents/agent.py index 01e6f8b2..3a6aed85 100644 --- a/python/flink_agents/api/agents/agent.py +++ b/python/flink_agents/api/agents/agent.py @@ -18,6 +18,7 @@ from abc import ABC from typing import Any, Callable, Dict, List, Tuple +from flink_agents.api.function import Function, PythonFunction from flink_agents.api.resource import ( ResourceDescriptor, ResourceType, @@ -85,7 +86,7 @@ class Agent(ABC): """ _actions: Dict[ - str, Tuple[List[str], Callable, Dict[str, Any]] + str, Tuple[List[str], Function, Dict[str, Any] | None] ] _resources: Dict[ResourceType, Dict[str, Any]] @@ -99,7 +100,7 @@ class Agent(ABC): @property def actions( self, - ) -> Dict[str, Tuple[List[str], Callable, Dict[str, Any]]]: + ) -> Dict[str, Tuple[List[str], Function, Dict[str, Any] | None]]: """Get added actions.""" return self._actions @@ -112,7 +113,7 @@ class Agent(ABC): self, name: str, events: List[str], - func: Callable, + func: Callable | Function, **config: Any, ) -> "Agent": """Add action to agent. @@ -123,8 +124,10 @@ class Agent(ABC): The name of the action, should be unique in the same Agent. events : list[str] Type-identifier strings listened by this action. - func : Callable - The function to be executed when receive listened events. + func : Callable | Function + Either a raw Python callable (it will be wrapped as a + ``PythonFunction``) or a pre-built flink-agents ``Function`` + (e.g. from the YAML loader). **config : Any Key named arguments can be used by this action in runtime. @@ -136,6 +139,8 @@ class Agent(ABC): if name in self._actions: msg = f"Action {name} already defined" raise ValueError(msg) + if not isinstance(func, Function): + func = PythonFunction.from_callable(func) self._actions[name] = (events, func, config if config else None) return self diff --git a/python/flink_agents/api/function.py b/python/flink_agents/api/function.py new file mode 100644 index 00000000..b5597664 --- /dev/null +++ b/python/flink_agents/api/function.py @@ -0,0 +1,109 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +"""Data-only descriptors for user-defined functions. + +These models carry the information needed to *identify* a Python or Java +function: ``module`` and ``qualname`` for Python; declaring class, +method name, and parameter types for Java. +""" + +import importlib +import inspect +from abc import ABC +from typing import Any, Callable, List + +from pydantic import BaseModel, model_serializer + + +class Function(BaseModel, ABC): + """Marker base class for function descriptors. Pure data — has no + ``__call__`` and no executable behavior. + """ + + +class PythonFunction(Function): + """Descriptor for a Python callable: module + qualified name. + + Attributes: + ---------- + module : str + Name of the Python module where the function is defined. + qualname : str + Qualified name of the function (e.g. ``ClassName.method`` for + class methods). + """ + + module: str + qualname: str + + @model_serializer + def __serialize(self) -> dict[str, Any]: + return { + "func_type": self.__class__.__qualname__, + "module": self.module, + "qualname": self.qualname, + } + + @staticmethod + def from_callable(func: Callable) -> "PythonFunction": + """Build a ``PythonFunction`` descriptor from a Python callable.""" + return PythonFunction( + module=inspect.getmodule(func).__name__, + qualname=func.__qualname__, + ) + + def as_callable(self) -> Callable: + """Resolve this descriptor to the underlying Python callable. + + Imports the target module and looks up ``qualname``. Pure Python + reflection — no execution, no JVM. ``ClassName.method`` is split + and resolved through the class attribute. + """ + module = importlib.import_module(self.module) + if "." in self.qualname: + classname, methodname = self.qualname.rsplit(".", 1) + clazz = getattr(module, classname) + return getattr(clazz, methodname) + return getattr(module, self.qualname) + + +class JavaFunction(Function): + """Descriptor for a Java method: class FQN + method name + parameter types. + + Attributes: + ---------- + qualname : str + Fully-qualified name of the declaring Java class. + method_name : str + The Java method name. + parameter_types : List[str] + The Java parameter types, in declaration order. + """ + + qualname: str + method_name: str + parameter_types: List[str] + + @model_serializer + def __serialize(self) -> dict[str, Any]: + return { + "func_type": self.__class__.__qualname__, + "qualname": self.qualname, + "method_name": self.method_name, + "parameter_types": self.parameter_types, + } diff --git a/python/flink_agents/api/tools/function_tool.py b/python/flink_agents/api/tools/function_tool.py new file mode 100644 index 00000000..48ecfb25 --- /dev/null +++ b/python/flink_agents/api/tools/function_tool.py @@ -0,0 +1,33 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +################################################################################# +from typing_extensions import override + +from flink_agents.api.function import JavaFunction, PythonFunction +from flink_agents.api.resource import ResourceType, SerializableResource + + +class FunctionTool(SerializableResource): + """Declarative function tool: carries a function descriptor.""" + + func: PythonFunction | JavaFunction + + @classmethod + @override + def resource_type(cls) -> ResourceType: + """Return resource type of class.""" + return ResourceType.TOOL diff --git a/python/flink_agents/api/tools/tool.py b/python/flink_agents/api/tools/tool.py index a2ee041a..1dd6ac13 100644 --- a/python/flink_agents/api/tools/tool.py +++ b/python/flink_agents/api/tools/tool.py @@ -18,14 +18,17 @@ import typing from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Type +from typing import TYPE_CHECKING, Any, Type -from pydantic import BaseModel, Field, field_serializer, model_validator +from pydantic import BaseModel, field_serializer, model_validator from typing_extensions import override from flink_agents.api.resource import ResourceType, SerializableResource from flink_agents.api.tools.utils import create_model_from_schema +if TYPE_CHECKING: + from flink_agents.api.tools.function_tool import FunctionTool + class ToolType(Enum): """Tool type enum. @@ -99,34 +102,18 @@ class ToolMetadata(BaseModel): return parameters -class FunctionTool(SerializableResource): - """Tool container keeps a callable, mainly used to represent - a function which will be converted to BaseTool after compile. - """ - - func: typing.Callable = Field(exclude=True) - - @classmethod - def resource_type(cls) -> ResourceType: - """Get the resource type.""" - return ResourceType.TOOL - - class Tool(SerializableResource, ABC): - """Base abstract class of all kinds of tools. + """Base abstract class of all kinds of tools.""" - Attributes: - ---------- - metadata : ToolMetadata - The metadata of the tools, includes name, description and arguments schema. - """ - - metadata: ToolMetadata + metadata: ToolMetadata | None = None @staticmethod - def from_callable(func: typing.Callable) -> FunctionTool: - """Create a function tool from a callable.""" - return FunctionTool(func=func) + def from_callable(func: typing.Callable) -> "FunctionTool": + """Wrap a Python callable as a declarative ``FunctionTool``.""" + from flink_agents.api.function import PythonFunction + from flink_agents.api.tools.function_tool import FunctionTool + + return FunctionTool(func=PythonFunction.from_callable(func)) @property def name(self) -> str: diff --git a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py index 247759a2..9741e054 100644 --- a/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py +++ b/python/flink_agents/integrations/chat_models/anthropic/tests/test_anthropic_chat_model.py @@ -23,12 +23,12 @@ import pytest from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext +from flink_agents.api.tools.tool import Tool from flink_agents.integrations.chat_models.anthropic.anthropic_chat_model import ( DEFAULT_ANTHROPIC_MODEL, AnthropicChatModelConnection, AnthropicChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable test_model = os.environ.get("TEST_MODEL") api_key = os.environ.get("TEST_API_KEY") @@ -84,7 +84,7 @@ def test_anthropic_chat_with_tools() -> None: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return Tool.from_callable(func=add) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py index 30753d75..79d2fb5c 100644 --- a/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/azure/tests/test_azure_openai_chat_model.py @@ -27,7 +27,8 @@ from flink_agents.integrations.chat_models.azure.azure_openai_chat_model import AzureOpenAIChatModelConnection, AzureOpenAIChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_deployment = os.environ.get("TEST_AZURE_DEPLOYMENT") api_key = os.environ.get("AZURE_OPENAI_API_KEY") @@ -95,7 +96,7 @@ def test_azure_openai_chat_with_tools() -> None: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py index 2280bf44..7ccb6c22 100644 --- a/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py +++ b/python/flink_agents/integrations/chat_models/openai/tests/test_openai_chat_model.py @@ -28,7 +28,8 @@ from flink_agents.integrations.chat_models.openai.openai_chat_model import ( OpenAIChatModelConnection, OpenAIChatModelSetup, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("TEST_MODEL") api_key = os.environ.get("TEST_API_KEY") @@ -86,7 +87,7 @@ def test_openai_chat_with_tools() -> None: if type == ResourceType.CHAT_MODEL_CONNECTION: return connection else: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) mock_ctx = MagicMock(spec=ResourceContext) mock_ctx.get_resource = get_resource diff --git a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py index 5503185b..6a2a4711 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_ollama_chat_model.py @@ -31,7 +31,8 @@ from flink_agents.integrations.chat_models.ollama_chat_model import ( OllamaChatModelConnection, OllamaChatModelSetup, ) -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("OLLAMA_CHAT_MODEL", "qwen3:1.7b") current_dir = Path(__file__).parent @@ -90,7 +91,7 @@ def add(a: int, b: int) -> int: def get_tool(name: str, type: ResourceType) -> FunctionTool: - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) @pytest.mark.skipif( diff --git a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py index fc80ac6a..c33a792c 100644 --- a/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py +++ b/python/flink_agents/integrations/chat_models/tests/test_tongyi_chat_model.py @@ -29,7 +29,8 @@ from flink_agents.integrations.chat_models.tongyi_chat_model import ( TongyiChatModelConnection, TongyiChatModelSetup, ) -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool test_model = os.environ.get("TONGYI_CHAT_MODEL", "qwen-plus") api_key_available = "DASHSCOPE_API_KEY" in os.environ @@ -68,7 +69,7 @@ def add(a: int, b: int) -> int: def get_tool(name: str, type: ResourceType) -> FunctionTool: """Helper function to create a tool for testing.""" - return from_callable(func=add) + return FunctionTool(func=PythonFunction.from_callable(add)) @pytest.mark.skipif(not api_key_available, reason="DashScope API key is not set") diff --git a/python/flink_agents/integrations/mcp/tests/test_mcp.py b/python/flink_agents/integrations/mcp/tests/test_mcp.py index 7a6e1487..1bb81ce8 100644 --- a/python/flink_agents/integrations/mcp/tests/test_mcp.py +++ b/python/flink_agents/integrations/mcp/tests/test_mcp.py @@ -26,7 +26,8 @@ from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAu from pydantic import AnyUrl from flink_agents.api.chat_message import ChatMessage, MessageRole -from flink_agents.integrations.mcp.mcp import MCPServer +from flink_agents.api.tools.tool import ToolMetadata +from flink_agents.integrations.mcp.mcp import MCPServer, MCPTool def run_server() -> None: @@ -124,3 +125,27 @@ def test_serialize_mcp_server() -> None: deserialized.auth.context.client_metadata == mcp_server.auth.context.client_metadata ) + + +def test_mcp_tool_roundtrip_preserves_metadata() -> None: + metadata = ToolMetadata( + name="add", + description="Add two integers.", + args_schema={ + "type": "object", + "properties": { + "a": {"type": "integer"}, + "b": {"type": "integer"}, + }, + "required": ["a", "b"], + }, + ) + tool = MCPTool(metadata=metadata, mcp_server=MCPServer(endpoint="http://x")) + + dumped = tool.model_dump() + assert "metadata" in dumped, "serialized form must expose `metadata` key" + assert "metadata_" not in dumped + + restored = MCPTool.model_validate(dumped) + assert restored.metadata == metadata + assert restored.name == "add" diff --git a/python/flink_agents/plan/agent_plan.py b/python/flink_agents/plan/agent_plan.py index f520ee63..f38c0fc7 100644 --- a/python/flink_agents/plan/agent_plan.py +++ b/python/flink_agents/plan/agent_plan.py @@ -20,6 +20,9 @@ from typing import TYPE_CHECKING, Any, Dict, List, cast from pydantic import BaseModel, field_serializer, model_validator from flink_agents.api.agents.agent import Agent +from flink_agents.api.function import Function as ApiFunction +from flink_agents.api.function import JavaFunction as ApiJavaFunction +from flink_agents.api.function import PythonFunction as ApiPythonFunction from flink_agents.api.resource import ( ResourceDescriptor, ResourceType, @@ -30,12 +33,14 @@ from flink_agents.api.skills import ( LOAD_SKILL_TOOL, Skills, ) +from flink_agents.api.tools.function_tool import FunctionTool as ApiFunctionTool +from flink_agents.api.tools.tool import Tool from flink_agents.plan.actions.action import Action from flink_agents.plan.actions.chat_model_action import CHAT_MODEL_ACTION from flink_agents.plan.actions.context_retrieval_action import CONTEXT_RETRIEVAL_ACTION from flink_agents.plan.actions.tool_call_action import TOOL_CALL_ACTION from flink_agents.plan.configuration import AgentConfiguration -from flink_agents.plan.function import PythonFunction +from flink_agents.plan.function import JavaFunction, PythonFunction from flink_agents.plan.resource_provider import ( JavaResourceProvider, JavaSerializableResourceProvider, @@ -43,7 +48,7 @@ from flink_agents.plan.resource_provider import ( PythonSerializableResourceProvider, ResourceProvider, ) -from flink_agents.plan.tools.function_tool import from_callable +from flink_agents.plan.tools.function_tool import FunctionTool if TYPE_CHECKING: from flink_agents.api.resource import ( @@ -262,7 +267,7 @@ def _get_actions(agent: Agent) -> List[Action]: actions.append( Action( name=name, - exec=PythonFunction.from_callable(action_tuple[1]), + exec=_to_plan_function(action_tuple[1]), listen_event_types=[ _resolve_event_type(et) for et in action_tuple[0] @@ -273,6 +278,27 @@ def _get_actions(agent: Agent) -> List[Action]: return actions +def _to_plan_function(func: ApiFunction) -> PythonFunction | JavaFunction: + """Promote an api Function descriptor to its executable plan counterpart. + + Agent stores api-layer descriptors (pure data). Action.exec needs the + plan-layer executable variants for ``check_signature`` and + ``__call__``, so we rebuild here. + """ + if isinstance(func, ApiPythonFunction): + return PythonFunction(module=func.module, qualname=func.qualname) + if isinstance(func, ApiJavaFunction): + return JavaFunction( + qualname=func.qualname, + method_name=func.method_name, + parameter_types=list(func.parameter_types), + ) + msg = f"Unsupported function descriptor: {type(func).__name__}" + raise TypeError(msg) + + + + def _get_resource_providers( agent: Agent, config: AgentConfiguration ) -> List[ResourceProvider]: @@ -307,10 +333,11 @@ def _get_resource_providers( if callable(value): # TODO: support other tool type. - tool = from_callable(func=value) + tool = Tool.from_callable(func=value) resource_providers.append( PythonSerializableResourceProvider.from_resource( - name=name, resource=tool + name=name, + resource=FunctionTool(func=_to_plan_function(tool.func)), ) ) elif hasattr(value, "_is_prompt"): @@ -342,7 +369,12 @@ def _get_resource_providers( for name, tool in agent.resources[ResourceType.TOOL].items(): resource_providers.append( PythonSerializableResourceProvider.from_resource( - name=name, resource=from_callable(tool.func) + name=name, + resource=( + FunctionTool(func=_to_plan_function(tool.func)) + if isinstance(tool, ApiFunctionTool) + else tool + ), ) ) diff --git a/python/flink_agents/plan/function.py b/python/flink_agents/plan/function.py index d8c10a4b..55086414 100644 --- a/python/flink_agents/plan/function.py +++ b/python/flink_agents/plan/function.py @@ -22,7 +22,7 @@ import logging from abc import ABC, abstractmethod from typing import Any, Callable, Dict, List, Tuple, get_type_hints -from pydantic import BaseModel, model_serializer +from pydantic import BaseModel, PrivateAttr, model_serializer from flink_agents.plan.utils import check_type_match @@ -102,7 +102,7 @@ def _is_function_cacheable(func: Callable) -> bool: class Function(BaseModel, ABC): - """Base interface for user defined functions, includes python and java.""" + """Base interface for user-defined functions.""" @abstractmethod def check_signature(self, *args: Tuple[Any, ...]) -> None: @@ -216,6 +216,10 @@ class PythonFunction(Function): """ return self.__get_func()(*args, **kwargs) + def as_callable(self) -> Callable: + """Return the underlying Python callable, importing the module if needed.""" + return self.__get_func() + def __get_func(self) -> Callable: if self.__func is None: module = importlib.import_module(self.module) @@ -237,14 +241,20 @@ class PythonFunction(Function): return self.__is_cacheable -# TODO: Implement JavaFunction. class JavaFunction(Function): - """Descriptor for a java callable function.""" + """Descriptor for a Java callable function. + + Invocation goes through the JVM resource adapter, injected by the + runtime via :meth:`set_java_resource_adapter`; until then + ``__call__`` raises ``RuntimeError``. + """ qualname: str method_name: str parameter_types: List[str] + _j_resource_adapter: Any = PrivateAttr(default=None) + @model_serializer def __custom_serializer(self) -> dict[str, Any]: data = { @@ -255,8 +265,30 @@ class JavaFunction(Function): } return data + def set_java_resource_adapter(self, adapter: Any) -> None: + """Inject the JVM adapter used to invoke this Java method.""" + self._j_resource_adapter = adapter + def __call__(self, *args: Tuple[Any, ...], **kwargs: Dict[str, Any]) -> Any: - """Execute the stored function with provided arguments.""" + """Invoke the Java method via the JVM resource adapter. + + LLM tool calls always arrive as keyword arguments — positional + ``*args`` are ignored because the Java side reorders parameters + by name via reflection. + """ + if self._j_resource_adapter is None: + msg = ( + "JavaFunction requires the JVM resource adapter; not set " + "on this descriptor. The runtime should inject it via " + "set_java_resource_adapter before invocation." + ) + raise RuntimeError(msg) + return self._j_resource_adapter.invokeJavaTool( + self.qualname, + self.method_name, + self.parameter_types, + kwargs, + ) def check_signature(self, *args: Tuple[Any, ...]) -> None: """Check function signature is legal or not.""" diff --git a/python/flink_agents/plan/tests/test_function.py b/python/flink_agents/plan/tests/test_function.py index a42ac01e..6ce4c853 100644 --- a/python/flink_agents/plan/tests/test_function.py +++ b/python/flink_agents/plan/tests/test_function.py @@ -34,7 +34,7 @@ from flink_agents.plan.function import ( ) if TYPE_CHECKING: - from flink_agents.plan.function import Function + from flink_agents.api.function import Function def check_class(input_event: InputEvent, output_event: OutputEvent) -> None: @@ -256,10 +256,6 @@ def test_cache_performance_benefit() -> None: """Test that caching provides performance benefits.""" clear_python_function_cache() - # This test verifies that the same PythonFunction instance is reused - # We can't easily test performance directly, but we can verify that - # the cache key mechanism works correctly - # First call creates cache entry call_python_function( "flink_agents.plan.tests.test_function", "function_for_caching", (1,) @@ -285,7 +281,6 @@ def test_selective_caching_pure_functions() -> None: """Test that pure functions are cached.""" clear_python_function_cache() - # Pure functions should be cached call_python_function( "flink_agents.plan.tests.test_function", "simple_pure_function", (1, 2) ) @@ -293,7 +288,6 @@ def test_selective_caching_pure_functions() -> None: "flink_agents.plan.tests.test_function", "function_for_caching", (5,) ) - # Both should be in cache assert get_python_function_cache_size() == 2 cache_keys = get_python_function_cache_keys() assert ( @@ -310,17 +304,13 @@ def test_selective_caching_generator_functions() -> None: """Test that generator functions are not cached.""" clear_python_function_cache() - # Generator function should not be cached result = call_python_function( "flink_agents.plan.tests.test_function", "generator_function", (3,) ) - # Result is now a generator directly (no wrapper) assert isinstance(result, Generator) - # Convert generator to list for testing result_list = list(result) assert result_list == [0, 1, 2] - # Should not be cached assert get_python_function_cache_size() == 0 @@ -328,7 +318,6 @@ def test_selective_caching_mutable_defaults() -> None: """Test that functions with mutable defaults are not cached.""" clear_python_function_cache() - # Function with mutable default should not be cached result1 = call_python_function( "flink_agents.plan.tests.test_function", "function_with_mutable_default", () ) @@ -336,57 +325,35 @@ def test_selective_caching_mutable_defaults() -> None: "flink_agents.plan.tests.test_function", "function_with_mutable_default", () ) - # Should not be cached (each call creates a new function instance) assert get_python_function_cache_size() == 0 - # Results should be different if function is correctly not cached - # (mutable default behavior depends on not caching) assert isinstance(result1, list) assert isinstance(result2, list) def test_is_function_cacheable() -> None: """Test the _is_function_cacheable function directly.""" - # Pure functions should be cacheable assert _is_function_cacheable(simple_pure_function) is True assert _is_function_cacheable(function_for_caching) is True - - # Generator functions should not be cacheable assert _is_function_cacheable(generator_function) is False - - # Functions with mutable defaults should not be cacheable assert _is_function_cacheable(function_with_mutable_default) is False - - # Closures should not be cacheable closure_func = make_closure(5) assert _is_function_cacheable(closure_func) is False - - # None should not be cacheable assert _is_function_cacheable(None) is False def test_python_function_cacheability_optimization() -> None: """Test that PythonFunction caches the cacheability check result.""" - # Test cacheable function cacheable_func = PythonFunction.from_callable(simple_pure_function) - # First call should compute and cache the result assert cacheable_func.is_cacheable() is True - - # Second call should use cached result (we can't directly test this, - # but we can verify it returns the same result) assert cacheable_func.is_cacheable() is True - # Test non-cacheable function non_cacheable_func = PythonFunction.from_callable(generator_function) - # First call should compute and cache the result assert non_cacheable_func.is_cacheable() is False - - # Second call should use cached result assert non_cacheable_func.is_cacheable() is False - # Test that the cacheability check is consistent with direct _is_function_cacheable assert cacheable_func.is_cacheable() == _is_function_cacheable(simple_pure_function) assert non_cacheable_func.is_cacheable() == _is_function_cacheable( generator_function diff --git a/python/flink_agents/plan/tests/tools/resources/function_tool.json b/python/flink_agents/plan/tests/tools/resources/function_tool.json index 50dd0e98..a218c83f 100644 --- a/python/flink_agents/plan/tests/tools/resources/function_tool.json +++ b/python/flink_agents/plan/tests/tools/resources/function_tool.json @@ -1,4 +1,9 @@ { + "func": { + "func_type": "PythonFunction", + "module": "flink_agents.plan.tests.tools.test_function_tool", + "qualname": "foo" + }, "metadata": { "name": "foo", "description": "Function for testing ToolMetadata.\n", @@ -22,10 +27,5 @@ "title": "foo", "type": "object" } - }, - "func": { - "module": "flink_agents.plan.tests.tools.test_function_tool", - "qualname": "foo", - "func_type": "PythonFunction" } -} \ No newline at end of file +} diff --git a/python/flink_agents/plan/tests/tools/test_function_tool.py b/python/flink_agents/plan/tests/tools/test_function_tool.py index d4a3f512..18fec91e 100644 --- a/python/flink_agents/plan/tests/tools/test_function_tool.py +++ b/python/flink_agents/plan/tests/tools/test_function_tool.py @@ -17,10 +17,12 @@ ################################################################################# import json from pathlib import Path +from unittest.mock import MagicMock import pytest -from flink_agents.plan.tools.function_tool import FunctionTool, from_callable +from flink_agents.plan.function import JavaFunction, PythonFunction +from flink_agents.plan.tools.function_tool import FunctionTool current_dir = Path(__file__).parent @@ -45,7 +47,7 @@ def foo(bar: int, baz: str) -> str: @pytest.fixture(scope="module") def func_tool() -> FunctionTool: - return from_callable(foo) + return FunctionTool(func=PythonFunction.from_callable(foo)) def test_serialize_function_tool(func_tool: FunctionTool) -> None: @@ -61,4 +63,123 @@ def test_deserialize_function_tool(func_tool: FunctionTool) -> None: with Path(f"{current_dir}/resources/function_tool.json").open() as f: json_value = f.read() actual_func_tool = FunctionTool.model_validate_json(json_value) - assert actual_func_tool == func_tool + # ``PythonFunction`` carries a private ``__func`` cache that is only + # populated once the callable has been resolved (e.g. via the eager + # metadata derivation in the fixture). The deserialized instance hasn't + # resolved the callable yet, so a full BaseModel ``==`` would differ on + # the cache. Compare the public, serialized state instead. + assert actual_func_tool.metadata == func_tool.metadata + assert actual_func_tool.func.module == func_tool.func.module + assert actual_func_tool.func.qualname == func_tool.func.qualname + + +def test_python_function_tool_metadata_filled_eagerly() -> None: + # ``PythonFunction`` metadata can be derived without external context, + # so ``FunctionTool`` fills it during model validation. The field is + # therefore already populated immediately after construction. + tool = FunctionTool(func=PythonFunction.from_callable(foo)) + assert tool.metadata is not None + assert tool.metadata.name == "foo" + + +# ---- Java function tool path ------------------------------------------------- + + +def _java_func() -> JavaFunction: + # Fresh instance per test — the adapter now lives on JavaFunction, so + # sharing one would leak state between tests. + return JavaFunction( + qualname="com.example.Tools", + method_name="add", + parameter_types=["int", "int"], + ) + +_FAKE_JAVA_SCHEMA = json.dumps( + { + "type": "object", + "properties": { + "a": {"type": "integer", "description": "First operand."}, + "b": {"type": "integer", "description": "Second operand."}, + }, + "required": ["a", "b"], + "title": "add", + } +) + + +def _fake_adapter() -> MagicMock: + """Build a mock ``_j_resource_adapter`` that mirrors the Java + ``JavaResourceAdapter`` surface used by ``plan.FunctionTool``. + + ``getJavaToolMetadata`` returns a flat ``Map<String, String>`` (see + ``JavaResourceAdapter.getJavaToolMetadata`` Java side for why), + so mock it as a plain Python dict. + """ + adapter = MagicMock() + adapter.getJavaToolMetadata.return_value = { + "name": "add", + "description": "Add two ints.", + "inputSchema": _FAKE_JAVA_SCHEMA, + } + adapter.invokeJavaTool.return_value = 1065 + return adapter + + +def test_java_function_tool_constructs_without_adapter() -> None: + # Plan compile time: no JVM adapter yet. Construction (and its + # SerializableResource self-validation) must not call into the adapter. + # ``metadata`` stays ``None`` until the adapter is injected. + tool = FunctionTool(func=_java_func()) + assert tool.func._j_resource_adapter is None + assert isinstance(tool.func, JavaFunction) + assert tool.metadata is None + + +def test_java_function_tool_metadata_filled_on_adapter_injection() -> None: + tool = FunctionTool(func=_java_func()) + adapter = _fake_adapter() + + tool.set_java_resource_adapter(adapter) + + # Adapter is consulted exactly once at injection time and the result is + # stored in the regular ``metadata`` field. Subsequent accesses just read + # the field, so the adapter is not hit again. + adapter.getJavaToolMetadata.assert_called_once_with( + "com.example.Tools", "add", ["int", "int"] + ) + assert tool.metadata is not None + assert tool.metadata.name == "add" + assert tool.metadata.description == "Add two ints." + assert set(tool.metadata.args_schema.model_fields) == {"a", "b"} + _ = tool.metadata + adapter.getJavaToolMetadata.assert_called_once() + + +def test_java_function_tool_metadata_is_none_without_adapter() -> None: + # Before the runtime injects the adapter, the metadata is intentionally + # absent — this is the only legal window where ``Tool.metadata`` is + # ``None``. Accessing the field must not raise. + tool = FunctionTool(func=_java_func()) + assert tool.metadata is None + + +def test_java_function_tool_call_dispatches_through_adapter() -> None: + tool = FunctionTool(func=_java_func()) + adapter = _fake_adapter() + tool.set_java_resource_adapter(adapter) + + result = tool.call(a=377, b=688) + + assert result == 1065 + adapter.invokeJavaTool.assert_called_once_with( + "com.example.Tools", + "add", + ["int", "int"], + {"a": 377, "b": 688}, + ) + + +def test_java_function_tool_call_without_adapter_raises() -> None: + tool = FunctionTool(func=_java_func()) + with pytest.raises(RuntimeError, match="JVM resource adapter"): + tool.call(a=1, b=2) diff --git a/python/flink_agents/plan/tools/bash/bash_tool.py b/python/flink_agents/plan/tools/bash/bash_tool.py index d579c5ac..54797e97 100644 --- a/python/flink_agents/plan/tools/bash/bash_tool.py +++ b/python/flink_agents/plan/tools/bash/bash_tool.py @@ -70,8 +70,6 @@ class BashTool(Tool): time by the framework (not visible to the LLM through ``args_schema``). """ - metadata: ToolMetadata = Field(exclude=True) - def __init__(self, **kwargs: Any) -> None: """Initialize the tool.""" super().__init__( diff --git a/python/flink_agents/plan/tools/function_tool.py b/python/flink_agents/plan/tools/function_tool.py index f686bf22..85be3df1 100644 --- a/python/flink_agents/plan/tools/function_tool.py +++ b/python/flink_agents/plan/tools/function_tool.py @@ -15,50 +15,89 @@ # See the License for the specific language governing permissions and # limitations under the License. ################################################################################# -from typing import Any, Callable +from typing import Any from docstring_parser import parse +from pydantic import model_validator from typing_extensions import override from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType -from flink_agents.api.tools.utils import create_schema_from_function +from flink_agents.api.tools.utils import ( + create_model_from_java_tool_schema_str, + create_schema_from_function, +) from flink_agents.plan.function import JavaFunction, PythonFunction class FunctionTool(Tool): - """Tool that takes in a function. + """Executable function tool. - Attributes: - ---------- - func : Function - User defined function. + ``metadata`` is filled eagerly as soon as the value is derivable — + during model validation for ``PythonFunction`` (from the callable's + docstring/signature), and inside :meth:`set_java_resource_adapter` + once the runtime injects the JVM bridge for ``JavaFunction``. Until + that injection the field stays ``None``. """ func: PythonFunction | JavaFunction + @model_validator(mode="after") + def _eager_derive_python_metadata(self) -> "FunctionTool": + if self.metadata is None and isinstance(self.func, PythonFunction): + self.metadata = _python_metadata(self.func) + return self + + def set_java_resource_adapter(self, adapter: Any) -> None: + """Inject the JVM resource adapter and derive ``metadata``. Called + by the runtime resource cache when the tool is first materialised; + no-op when ``func`` is not a ``JavaFunction``. + """ + if not isinstance(self.func, JavaFunction): + return + self.func.set_java_resource_adapter(adapter) + if self.metadata is None: + self.metadata = _java_metadata(self.func) + @classmethod @override def tool_type(cls) -> ToolType: """Get the tool type.""" return ToolType.FUNCTION + @override def call(self, *args: Any, **kwargs: Any) -> Any: - """Call the function tool.""" + """Invoke the underlying function.""" return self.func(*args, **kwargs) -def from_callable(func: Callable) -> FunctionTool: - """Create FunctionTool from a user defined function. - - Parameters - ---------- - func : Callable - The function to analyze. - """ - description = parse(func.__doc__).description - metadata = ToolMetadata( - name=func.__name__, +def _python_metadata(func: PythonFunction) -> ToolMetadata: + callable_ = func.as_callable() + description = parse(callable_.__doc__).description or "" + return ToolMetadata( + name=callable_.__name__, description=description, - args_schema=create_schema_from_function(func.__name__, func=func), + args_schema=create_schema_from_function(callable_.__name__, func=callable_), + ) + + +def _java_metadata(func: JavaFunction) -> ToolMetadata: + adapter = func._j_resource_adapter + if adapter is None: + msg = ( + "Java function tool metadata requires the JVM resource adapter; " + "not set on the underlying JavaFunction. The runtime should " + "inject it via FunctionTool.set_java_resource_adapter before " + "metadata access." + ) + raise RuntimeError(msg) + flat = adapter.getJavaToolMetadata( + func.qualname, func.method_name, func.parameter_types + ) + name = flat["name"] + return ToolMetadata( + name=name, + description=flat["description"], + args_schema=create_model_from_java_tool_schema_str( + name, flat["inputSchema"] + ), ) - return FunctionTool(func=PythonFunction.from_callable(func), metadata=metadata) diff --git a/python/flink_agents/runtime/java/java_resource_wrapper.py b/python/flink_agents/runtime/java/java_resource_wrapper.py index 83305731..886e4c84 100644 --- a/python/flink_agents/runtime/java/java_resource_wrapper.py +++ b/python/flink_agents/runtime/java/java_resource_wrapper.py @@ -17,19 +17,29 @@ ################################################################################# from typing import Any, List -from pydantic import Field +from pydantic import ConfigDict, Field from typing_extensions import override from flink_agents.api.chat_message import ChatMessage, MessageRole from flink_agents.api.prompts.prompt import Prompt from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext -from flink_agents.api.tools.tool import Tool, ToolType +from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType class JavaTool(Tool): """Java Tool that carries tool metadata and can be recognized by PythonChatModel.""" + model_config = ConfigDict(populate_by_name=True) + + metadata_: ToolMetadata = Field(exclude=True, alias="metadata") + + @property + @override + def metadata(self) -> ToolMetadata: + """Return the tool metadata.""" + return self.metadata_ + @classmethod @override def tool_type(cls) -> ToolType: diff --git a/python/flink_agents/runtime/resource_cache.py b/python/flink_agents/runtime/resource_cache.py index 0e7ded88..f66cc758 100644 --- a/python/flink_agents/runtime/resource_cache.py +++ b/python/flink_agents/runtime/resource_cache.py @@ -20,7 +20,9 @@ from typing import Any, Dict from flink_agents.api.resource import Resource, ResourceType from flink_agents.api.resource_context import ResourceContext from flink_agents.plan.configuration import AgentConfiguration +from flink_agents.plan.function import JavaFunction from flink_agents.plan.resource_provider import JavaResourceProvider, ResourceProvider +from flink_agents.plan.tools.function_tool import FunctionTool class ResourceCache: @@ -85,6 +87,8 @@ class ResourceCache: resource = resource_provider.provide( resource_context=self._resource_context, config=self._config ) + if isinstance(resource, FunctionTool) and isinstance(resource.func, JavaFunction): + resource.set_java_resource_adapter(self._j_resource_adapter) resource.open() self._cache.setdefault(type, {})[name] = resource return resource diff --git a/python/flink_agents/runtime/skill/skill_tools.py b/python/flink_agents/runtime/skill/skill_tools.py index b5cb8ede..3f5da277 100644 --- a/python/flink_agents/runtime/skill/skill_tools.py +++ b/python/flink_agents/runtime/skill/skill_tools.py @@ -53,8 +53,6 @@ class LoadSkillTool(Tool): (not the public ResourceContext interface). """ - metadata: ToolMetadata = Field(exclude=True) - def __init__(self, **kwargs: Any) -> None: """Initialize the load skill tool.""" super().__init__( diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java index 279d4652..899f463d 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/ActionExecutionOperator.java @@ -166,7 +166,8 @@ public class ActionExecutionOperator<IN, OUT> extends AbstractStreamOperator<OUT getRuntimeContext().getJobInfo().getJobId(), metricGroup, this::checkMailboxThread, - jobIdentifier); + jobIdentifier, + getRuntimeContext().getUserCodeClassLoader()); // Capture the wired Mem0 long-term memory, if any, so it can be plumbed into the Java // runner context created by ActionTaskContextManager. diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java index 1f6b190d..73eb6a50 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/operator/PythonBridgeManager.java @@ -111,6 +111,9 @@ class PythonBridgeManager implements AutoCloseable { * @param metricGroup the agent metric group, exposed to Python via the runner context. * @param mailboxThreadChecker hook used by the runner context to assert mailbox-thread access. * @param jobIdentifier the job identifier used to scope Python state. + * @param userCodeClassLoader the operator's user-code class loader, propagated to {@link + * JavaResourceAdapter} so reflective Java tool resolution sees user jars added via {@code + * env.add_jars(...)}. */ void open( AgentPlan agentPlan, @@ -121,7 +124,8 @@ class PythonBridgeManager implements AutoCloseable { JobID jobId, FlinkAgentsMetricGroupImpl metricGroup, Runnable mailboxThreadChecker, - String jobIdentifier) + String jobIdentifier, + ClassLoader userCodeClassLoader) throws Exception { boolean containPythonAction = agentPlan.getActions().values().stream() @@ -169,7 +173,8 @@ class PythonBridgeManager implements AutoCloseable { throw new RuntimeException(e); } }), - pythonInterpreter); + pythonInterpreter, + userCodeClassLoader); if (containPythonResource || mem0Configured) { initPythonResourceAdapter(agentPlan, resourceCache); } diff --git a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java index f17d7ce7..3c00d96d 100644 --- a/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java +++ b/runtime/src/main/java/org/apache/flink/agents/runtime/python/utils/JavaResourceAdapter.java @@ -22,9 +22,16 @@ import org.apache.flink.agents.api.chat.messages.MessageRole; import org.apache.flink.agents.api.resource.Resource; import org.apache.flink.agents.api.resource.ResourceContext; import org.apache.flink.agents.api.resource.ResourceType; +import org.apache.flink.agents.api.tools.ToolMetadata; +import org.apache.flink.agents.api.tools.ToolParameters; +import org.apache.flink.agents.api.tools.ToolResponse; import org.apache.flink.agents.api.vectorstores.Document; +import org.apache.flink.agents.plan.tools.FunctionTool; +import org.apache.flink.agents.plan.tools.ToolMetadataFactory; import pemja.core.PythonInterpreter; +import java.lang.reflect.Method; +import java.util.HashMap; import java.util.List; import java.util.Map; @@ -34,9 +41,21 @@ public class JavaResourceAdapter { private final transient PythonInterpreter interpreter; - public JavaResourceAdapter(ResourceContext resourceContext, PythonInterpreter interpreter) { + /** + * Class loader used to resolve Java tool methods declared by name. Captured at construction + * (the operator passes its {@code RuntimeContext.getUserCodeClassLoader()}) because pemja + * worker threads inherit the JVM system loader as their context loader and would not see + * user-supplied jars added via {@code env.add_jars(...)}. + */ + private final transient ClassLoader userCodeClassLoader; + + public JavaResourceAdapter( + ResourceContext resourceContext, + PythonInterpreter interpreter, + ClassLoader userCodeClassLoader) { this.resourceContext = resourceContext; this.interpreter = interpreter; + this.userCodeClassLoader = userCodeClassLoader; } /** @@ -105,4 +124,99 @@ public class JavaResourceAdapter { Float score) { return new Document(content, metadata, id, embedding, score); } + + /** + * Resolve the metadata for a Java static tool method declared by fully-qualified class name, + * method name and parameter type names. + * + * <p>Invoked from the Python side via the {@code _j_resource_adapter} bridge when a {@code + * plan.FunctionTool} backed by a {@code JavaFunction} first materialises its metadata. + * Delegates to {@link ToolMetadataFactory#fromStaticMethod(Method)} once the {@code Method} is + * resolved, then flattens the resulting {@link ToolMetadata} into a {@code Map<String, String>} + * before returning. + * + * <p>The flattening is required because pemja can crash with a SIGSEGV inside {@code + * JcpPyJObject_New} when Java returns an arbitrary Java object to a Python call that originated + * on a non-main interpreter thread (e.g. a Flink mailbox worker that resolves a tool's + * metadata). Returning only String fields — which pemja maps natively to {@code str} — + * sidesteps the reverse Java→Python object wrap entirely. The Python side rebuilds {@link + * ToolMetadata} from the flat map. + */ + public Map<String, String> getJavaToolMetadata( + String className, String methodName, List<String> parameterTypes) throws Exception { + Method method = resolveMethod(className, methodName, parameterTypes); + ToolMetadata metadata = ToolMetadataFactory.fromStaticMethod(method); + Map<String, String> result = new HashMap<>(); + result.put("name", metadata.getName()); + result.put("description", metadata.getDescription()); + result.put("inputSchema", metadata.getInputSchema()); + return result; + } + + /** + * Invoke a Java static tool method with keyword arguments coming from a Python tool call. + * + * <p>Delegates to {@link FunctionTool#call(ToolParameters)} so the Python-driven tool-call path + * shares every detail of argument resolution with the Java agent path — {@link + * org.apache.flink.agents.api.annotation.ToolParam} name override, {@link ToolParameters} + * numeric coercion (covers the LLM-emitted JSON Number → Java box type mismatch that reflective + * {@code Method.invoke} otherwise rejects), required-parameter checking, and {@link + * ToolResponse} success / error semantics. The success result is unwrapped for the Python + * caller; an unsuccessful response is re-thrown as a {@link RuntimeException}. + */ + public Object invokeJavaTool( + String className, + String methodName, + List<String> parameterTypes, + Map<String, Object> arguments) + throws Exception { + Method method = resolveMethod(className, methodName, parameterTypes); + FunctionTool tool = FunctionTool.fromStaticMethod(method); + ToolResponse response = + tool.call(new ToolParameters(arguments == null ? new HashMap<>() : arguments)); + if (!response.isSuccess()) { + throw new RuntimeException(response.getError()); + } + return response.getResult(); + } + + private Method resolveMethod(String className, String methodName, List<String> parameterTypes) + throws ClassNotFoundException, NoSuchMethodException { + ClassLoader classLoader = + userCodeClassLoader != null + ? userCodeClassLoader + : Thread.currentThread().getContextClassLoader(); + Class<?> clazz = Class.forName(className, true, classLoader); + Class<?>[] paramClasses = new Class<?>[parameterTypes.size()]; + for (int i = 0; i < parameterTypes.size(); i++) { + paramClasses[i] = resolveType(parameterTypes.get(i), classLoader); + } + return clazz.getMethod(methodName, paramClasses); + } + + private static Class<?> resolveType(String typeName, ClassLoader classLoader) + throws ClassNotFoundException { + switch (typeName) { + case "boolean": + return boolean.class; + case "byte": + return byte.class; + case "short": + return short.class; + case "int": + return int.class; + case "long": + return long.class; + case "float": + return float.class; + case "double": + return double.class; + case "char": + return char.class; + case "void": + return void.class; + default: + return Class.forName(typeName, true, classLoader); + } + } } diff --git a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java index f7226826..9ee68563 100644 --- a/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java +++ b/runtime/src/test/java/org/apache/flink/agents/runtime/operator/PythonBridgeManagerTest.java @@ -50,7 +50,8 @@ class PythonBridgeManagerTest { /* jobId */ new JobID(), /* metricGroup */ null, /* mailboxThreadChecker */ () -> {}, - /* jobIdentifier */ "job-1"); + /* jobIdentifier */ "job-1", + /* userCodeClassLoader */ Thread.currentThread().getContextClassLoader()); // No-op contract: nothing initialized, no Pemja interpreter created. assertThat(bridge.isInitialized()).isFalse();
