This is an automated email from the ASF dual-hosted git repository. wenjin272 pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/flink-agents.git
commit 54422198d05a0664e763a403561d14a2e7fae8a3 Author: WenjinXie <[email protected]> AuthorDate: Mon Apr 20 15:43:20 2026 +0800 [plan][python] Add bash tool in python. Co-authored-by: Claude Opus 4.7 (1M context) <[email protected]> --- python/flink_agents/plan/tools/bash/__init__.py | 16 ++ python/flink_agents/plan/tools/bash/bash_tool.py | 144 +++++++++++ .../flink_agents/plan/tools/bash/bash_validator.py | 184 ++++++++++++++ .../flink_agents/plan/tools/bash/tests/__init__.py | 16 ++ .../plan/tools/bash/tests/test_bash_tool.py | 282 +++++++++++++++++++++ python/pyproject.toml | 2 + 6 files changed, 644 insertions(+) diff --git a/python/flink_agents/plan/tools/bash/__init__.py b/python/flink_agents/plan/tools/bash/__init__.py new file mode 100644 index 00000000..95362671 --- /dev/null +++ b/python/flink_agents/plan/tools/bash/__init__.py @@ -0,0 +1,16 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. +################################################################################ diff --git a/python/flink_agents/plan/tools/bash/bash_tool.py b/python/flink_agents/plan/tools/bash/bash_tool.py new file mode 100644 index 00000000..d579c5ac --- /dev/null +++ b/python/flink_agents/plan/tools/bash/bash_tool.py @@ -0,0 +1,144 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. +################################################################################ +"""Standalone bash execution tool. + +``BashTool`` is a general-purpose shell tool — it has no knowledge of skills. +The framework (e.g. ``chat_model_action``) injects ``allowed_commands`` and +``allowed_script_dirs`` at call time; the model only sees ``command``, +``timeout`` and ``cwd``. +""" + +from __future__ import annotations + +import logging +import subprocess +from typing import Any, List + +from pydantic import BaseModel, Field +from typing_extensions import override + +from flink_agents.api.tools.tool import Tool, ToolMetadata, ToolType +from flink_agents.plan.tools.bash.bash_validator import ( + is_under_allowed_dirs, + validate_command, +) + +logger = logging.getLogger(__name__) + + +class BashArgs(BaseModel): + """Arguments for BashTool that are visible to the LLM.""" + + command: str = Field( + ..., + description="The shell command to execute.", + ) + timeout: int = Field( + default=60, + description="Timeout in seconds. Defaults to 60.", + ) + cwd: str | None = Field( + default=None, + description=( + "The working directory to run the command in. Defaults to the " + "current directory. Use this instead of `cd` commands." + ), + ) + + +class BashTool(Tool): + """Standalone bash execution tool. + + Safety: + - The first token of each sub-command must be in ``allowed_commands``, or + resolve to a file under one of ``allowed_script_dirs``. + - ``allowed_commands`` and ``allowed_script_dirs`` are injected at call + time by the framework (not visible to the LLM through ``args_schema``). + """ + + metadata: ToolMetadata = Field(exclude=True) + + def __init__(self, **kwargs: Any) -> None: + """Initialize the tool.""" + super().__init__( + metadata=ToolMetadata( + name="bash", + description=( + "Execute a shell command. Only commands on the allowed " + "list or scripts under the allowed directories may run." + ), + args_schema=BashArgs, + ), + **kwargs, + ) + + @classmethod + @override + def tool_type(cls) -> ToolType: + """Return tool type of class.""" + return ToolType.FUNCTION + + @override + def call(self, *args: Any, **kwargs: Any) -> str: + """Execute the command after validation. + + Accepts ``command``, ``timeout`` and ``cwd`` from the LLM, plus + framework-injected ``allowed_commands`` and ``allowed_script_dirs``. + """ + allowed_commands: List[str] = kwargs.pop("allowed_commands", None) or [] + allowed_script_dirs: List[str] = kwargs.pop("allowed_script_dirs", None) or [] + + if args: + parsed_args = BashArgs(command=args[0], **kwargs) + else: + parsed_args = BashArgs(**kwargs) + + command = parsed_args.command + timeout = parsed_args.timeout + cwd = parsed_args.cwd + + if cwd is not None and not is_under_allowed_dirs(cwd, allowed_script_dirs): + return ( + f"Command rejected: cwd '{cwd}' is not under any allowed script dir. " + f"Allowed script dirs: {sorted(allowed_script_dirs)}." + ) + + error = validate_command(command, allowed_commands, allowed_script_dirs, cwd) + if error is not None: + return f"Command rejected: {error}" + + logger.debug( + f"Executing bash command: {command} (timeout={timeout}s, cwd={cwd})" + ) + try: + result = subprocess.run( + command, + shell=True, + capture_output=True, + text=True, + timeout=timeout, + cwd=cwd, + executable="/bin/bash", + ) + if result.returncode == 0: + stdout = result.stdout.strip() + return stdout if stdout else "Success" + return f"Error (exit code {result.returncode}): {result.stderr.strip()}" + except subprocess.TimeoutExpired: + return f"Error: Command timed out after {timeout} seconds" + except Exception as e: + return f"Error: {e!s}" diff --git a/python/flink_agents/plan/tools/bash/bash_validator.py b/python/flink_agents/plan/tools/bash/bash_validator.py new file mode 100644 index 00000000..c53f41b8 --- /dev/null +++ b/python/flink_agents/plan/tools/bash/bash_validator.py @@ -0,0 +1,184 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. +################################################################################ +"""AST-based bash command validation using tree-sitter-bash. + +Parses the command with tree-sitter-bash and walks the AST. Any named node +whose type is not in the allowed set (e.g. ``command_substitution``, +``process_substitution``, ``subshell``, ``for_statement``) causes the whole +command to be rejected. Every ``command`` node's name is checked against the +``allowed_commands`` allowlist or resolved under ``allowed_script_dirs``. + +This lets the tool accept natural shell constructs like pipes, ``&&`` / ``||`` +chains and simple redirections while blocking common injection vectors (``$()``, +backticks, heredoc bodies containing substitutions, control flow, etc.). +""" + +from __future__ import annotations + +from functools import lru_cache +from pathlib import Path +from typing import TYPE_CHECKING, List + +if TYPE_CHECKING: + from tree_sitter import Node, Parser + + +# Named AST node types we accept. Anything named but missing is treated as a +# potentially dangerous shell construct and rejected. Unnamed nodes (literal +# punctuation like ``|``, ``&&``, ``(``, ``$(``) are always allowed — they're +# just syntax tokens, not semantic structures. +_ALLOWED_NAMED = frozenset( + { + "program", + "command", + "command_name", + # `export VAR=...`, `readonly`, `declare`, `local`, `typeset` + "declaration_command", + "pipeline", + "list", + "redirected_statement", + "file_redirect", + "file_descriptor", + "variable_assignment", + "variable_name", + "special_variable_name", # $@ $? $* $# + "word", + "string", + "string_content", + "raw_string", + "ansi_c_string", + "translated_string", + "concatenation", + "number", + "simple_expansion", # $VAR + "expansion", # ${VAR} + "arithmetic_expansion", # $((...)) + "binary_expression", + "unary_expression", + "parenthesized_expression", + "array", + } +) + + +@lru_cache(maxsize=1) +def _get_parser() -> Parser: + """Return a cached tree-sitter parser configured with the bash grammar.""" + import tree_sitter_bash + from tree_sitter import Language, Parser + + return Parser(Language(tree_sitter_bash.language())) + + +def validate_command( + command: str, + allowed_commands: List[str], + allowed_script_dirs: List[str], + cwd: str | None = None, +) -> str | None: + """Validate a bash command. + + Returns ``None`` if the command is allowed, or an error string otherwise. + """ + if not command.strip(): + return "Empty command." + + try: + tree = _get_parser().parse(command.encode("utf-8")) + except Exception as exc: + return f"Failed to parse command: {exc}" + + root = tree.root_node + if root.has_error: + return "Command has syntax errors." + if not root.children: + return "Empty command." + + return _walk(root, allowed_commands, allowed_script_dirs, cwd) + + +def _walk( + node: Node, + allowed_commands: List[str], + allowed_script_dirs: List[str], + cwd: str | None, +) -> str | None: + if node.is_named and node.type not in _ALLOWED_NAMED: + snippet = node.text.decode("utf-8", errors="replace")[:80] + return f"Disallowed shell construct '{node.type}' in: {snippet!r}" + if node.type == "command": + err = _validate_command_node(node, allowed_commands, allowed_script_dirs, cwd) + if err is not None: + return err + for child in node.children: + err = _walk(child, allowed_commands, allowed_script_dirs, cwd) + if err is not None: + return err + return None + + +def _validate_command_node( + node: Node, + allowed_commands: List[str], + allowed_script_dirs: List[str], + cwd: str | None, +) -> str | None: + name_node = node.child_by_field_name("name") + if name_node is None: + # Commands without a resolvable name (edge case, e.g. bare + # variable-assignment parsed as `command`) — nothing to validate. + return None + executable = name_node.text.decode("utf-8", errors="replace") + if executable in allowed_commands: + return None + if is_under_allowed_dirs(executable, allowed_script_dirs, cwd): + return None + return ( + f"Command '{executable}' is not allowed. " + f"Allowed commands: {sorted(allowed_commands)}. " + f"Allowed script dirs: {sorted(allowed_script_dirs)}." + ) + + +def is_under_allowed_dirs( + path_str: str, + allowed_dirs: List[str], + cwd: str | None = None, +) -> bool: + """Return True if ``path_str`` resolves to a path under any allowed dir. + + When ``cwd`` is given, relative ``path_str`` is resolved against ``cwd``. + """ + try: + base = Path(path_str) + if not base.is_absolute() and cwd is not None: + base = Path(cwd) / base + path = base.resolve() + except (OSError, ValueError): + return False + for allowed in allowed_dirs: + try: + allowed_root = Path(allowed).resolve() + except (OSError, ValueError): + continue + try: + path.relative_to(allowed_root) + except ValueError: + continue + else: + return True + return False diff --git a/python/flink_agents/plan/tools/bash/tests/__init__.py b/python/flink_agents/plan/tools/bash/tests/__init__.py new file mode 100644 index 00000000..95362671 --- /dev/null +++ b/python/flink_agents/plan/tools/bash/tests/__init__.py @@ -0,0 +1,16 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. +################################################################################ diff --git a/python/flink_agents/plan/tools/bash/tests/test_bash_tool.py b/python/flink_agents/plan/tools/bash/tests/test_bash_tool.py new file mode 100644 index 00000000..e90f67c9 --- /dev/null +++ b/python/flink_agents/plan/tools/bash/tests/test_bash_tool.py @@ -0,0 +1,282 @@ +################################################################################ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# limitations under the License. +################################################################################ +"""Tests for BashTool and its AST-based command validation.""" + +from pathlib import Path + +import pytest + +from flink_agents.plan.tools.bash.bash_tool import BashTool +from flink_agents.plan.tools.bash.bash_validator import validate_command + + [email protected](scope="module") +def script_dir(tmp_path_factory: pytest.TempPathFactory) -> Path: + """An allowed-script directory containing a small runnable script.""" + root = tmp_path_factory.mktemp("allowed_scripts") + scripts = root / "scripts" + scripts.mkdir() + script = scripts / "hello.py" + script.write_text("print('ok')\n") + return root.resolve() + + +class TestValidateCommand: + """Tests for the tree-sitter-backed `validate_command` helper.""" + + # -- allowlist basics -------------------------------------------------- + + def test_allow_whitelisted_command(self) -> None: + assert validate_command("gh issue list", ["gh", "git"], []) is None + + def test_reject_command_not_in_whitelist(self) -> None: + error = validate_command("rm -rf /", ["gh", "git"], []) + assert error is not None + assert "not allowed" in error + + def test_empty_allowed_commands_rejects(self) -> None: + assert validate_command("gh issue list", [], []) is not None + + def test_reject_empty_command(self) -> None: + assert validate_command("", ["gh"], []) is not None + assert validate_command(" ", ["gh"], []) is not None + + # -- script dirs ------------------------------------------------------- + + def test_allow_script_under_allowed_dir(self, script_dir: Path) -> None: + script = script_dir / "scripts" / "hello.py" + assert validate_command(f"{script} --flag", [], [str(script_dir)]) is None + + def test_reject_script_outside_allowed_dirs(self, script_dir: Path) -> None: + assert validate_command("/tmp/unknown.py", [], [str(script_dir)]) is not None + + # -- pipes, &&, ||, ; — each sub-command validated -------------------- + + def test_allow_pipe_all_whitelisted(self) -> None: + assert validate_command("gh issue list | git log", ["gh", "git"], []) is None + + def test_reject_pipe_one_not_whitelisted(self) -> None: + error = validate_command("gh issue list | grep bug", ["gh", "git"], []) + assert error is not None + assert "grep" in error + + def test_reject_and_one_not_whitelisted(self) -> None: + error = validate_command("gh issue list && rm -rf /", ["gh", "git"], []) + assert error is not None + assert "rm" in error + + def test_reject_or_one_not_whitelisted(self) -> None: + error = validate_command("gh issue list || rm -rf /", ["gh", "git"], []) + assert error is not None + assert "rm" in error + + def test_reject_semicolon_one_not_whitelisted(self) -> None: + assert ( + validate_command("gh issue list; rm -rf /", ["gh", "git"], []) is not None + ) + + def test_reject_newline_one_not_whitelisted(self) -> None: + assert ( + validate_command("gh issue list\nrm -rf /", ["gh", "git"], []) is not None + ) + + # -- quoting and escaping --------------------------------------------- + + def test_allow_operators_inside_single_quotes(self) -> None: + assert ( + validate_command("gh issue list --label 'bug|feature'", ["gh"], []) is None + ) + + def test_allow_operators_inside_double_quotes(self) -> None: + assert validate_command('gh issue list --label "a&&b"', ["gh"], []) is None + + def test_allow_escaped_operator(self) -> None: + assert validate_command("gh issue list --query a\\|b", ["gh"], []) is None + + # -- env var prefix inside a command ---------------------------------- + + def test_allow_env_prefix(self) -> None: + assert validate_command("FOO=bar echo hi", ["echo"], []) is None + + def test_reject_env_prefix_command_not_whitelisted(self) -> None: + assert validate_command("FOO=bar rm -rf /", ["echo"], []) is not None + + # -- injection vectors: MUST be rejected ------------------------------ + + def test_reject_dollar_paren_substitution(self) -> None: + error = validate_command("echo $(rm /)", ["echo"], []) + assert error is not None + assert "command_substitution" in error + + def test_reject_backtick_substitution(self) -> None: + error = validate_command("echo `rm /`", ["echo"], []) + assert error is not None + assert "command_substitution" in error + + def test_reject_substitution_in_double_quotes(self) -> None: + error = validate_command('echo "$(rm /)"', ["echo"], []) + assert error is not None + assert "command_substitution" in error + + def test_reject_substitution_via_default_expansion(self) -> None: + error = validate_command("echo ${FOO:-$(rm /)}", ["echo"], []) + assert error is not None + assert "command_substitution" in error + + def test_reject_process_substitution(self) -> None: + error = validate_command("cat <(true)", ["cat"], []) + assert error is not None + assert "process_substitution" in error + + def test_reject_subshell(self) -> None: + error = validate_command("(rm /)", ["rm"], []) + assert error is not None + assert "subshell" in error + + def test_reject_for_loop(self) -> None: + error = validate_command("for f in *; do echo $f; done", ["echo"], []) + assert error is not None + assert "for_statement" in error + + def test_reject_while_loop(self) -> None: + error = validate_command("while true; do echo x; done", ["echo"], []) + assert error is not None + + def test_reject_if_statement(self) -> None: + error = validate_command("if true; then echo hi; fi", ["echo"], []) + assert error is not None + + def test_reject_function_definition(self) -> None: + error = validate_command("f() { echo hi; }", ["echo"], []) + assert error is not None + + def test_reject_heredoc(self) -> None: + error = validate_command("cat <<EOF\n$(rm /)\nEOF", ["cat"], []) + assert error is not None + + # -- still allow common safe features --------------------------------- + + def test_allow_simple_expansion(self) -> None: + assert validate_command("echo $USER", ["echo"], []) is None + + def test_allow_brace_expansion(self) -> None: + assert validate_command("echo {a,b,c}", ["echo"], []) is None + + def test_allow_redirect_to_file(self) -> None: + assert validate_command("echo hi > /tmp/out", ["echo"], []) is None + + def test_allow_stderr_redirect(self) -> None: + assert validate_command("echo hi 2>&1", ["echo"], []) is None + + +class TestBashTool: + """Integration tests for BashTool.call.""" + + @pytest.fixture + def tool(self) -> BashTool: + return BashTool() + + def test_execute_whitelisted_command(self, tool: BashTool) -> None: + result = tool.call( + command="echo hello", + timeout=10, + allowed_commands=["echo"], + allowed_script_dirs=[], + ) + assert result == "hello" + + def test_execute_piped_whitelisted_commands(self, tool: BashTool) -> None: + result = tool.call( + command='echo "(2 ^ 3)" | bc', + timeout=10, + allowed_commands=["echo", "bc"], + allowed_script_dirs=[], + ) + assert result == "8" + + def test_reject_non_whitelisted_command(self, tool: BashTool) -> None: + result = tool.call( + command="rm -rf /", + timeout=10, + allowed_commands=["echo"], + allowed_script_dirs=[], + ) + assert "Command rejected" in result + + def test_reject_piped_non_whitelisted(self, tool: BashTool) -> None: + result = tool.call( + command="echo hello | cat", + timeout=10, + allowed_commands=["echo"], + allowed_script_dirs=[], + ) + assert "Command rejected" in result + + def test_reject_command_substitution(self, tool: BashTool) -> None: + """$(...) must be rejected even when outer command is allowed.""" + result = tool.call( + command="echo $(touch /tmp/injected)", + timeout=10, + allowed_commands=["echo"], + allowed_script_dirs=[], + ) + assert "Command rejected" in result + + def test_no_allowed_commands_rejects_everything(self, tool: BashTool) -> None: + result = tool.call(command="echo hello", timeout=10) + assert "Command rejected" in result + + # -- cwd parameter ----------------------------------------------------- + + def test_cwd_enables_relative_script_path( + self, tool: BashTool, script_dir: Path + ) -> None: + """With cwd set to an allowed dir, a relative script path resolves under it.""" + result = tool.call( + command="python3 scripts/hello.py", + timeout=10, + cwd=str(script_dir), + allowed_commands=[], + allowed_script_dirs=[str(script_dir)], + ) + assert "Command rejected" not in result + assert "ok" in result + + def test_cwd_outside_allowed_dirs_rejected( + self, tool: BashTool, script_dir: Path + ) -> None: + result = tool.call( + command="echo hello", + timeout=10, + cwd="/tmp", + allowed_commands=["echo"], + allowed_script_dirs=[str(script_dir)], + ) + assert "Command rejected" in result + assert "cwd" in result + + def test_relative_path_without_cwd_rejected( + self, tool: BashTool, script_dir: Path + ) -> None: + """Bare relative script path (no cwd) is rejected.""" + result = tool.call( + command="python3 scripts/hello.py", + timeout=10, + allowed_commands=[], + allowed_script_dirs=[str(script_dir)], + ) + assert "Command rejected" in result diff --git a/python/pyproject.toml b/python/pyproject.toml index 1c9e9e98..b9f839df 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -49,6 +49,8 @@ dependencies = [ "mcp>=1.8.0", "setuptools>=75.3,<82", "find_libpython", + "tree-sitter>=0.24,<0.26", + "tree-sitter-bash>=0.23,<0.26", #TODO: Seperate integration dependencies from project "ollama==0.6.1", "dashscope~=1.24.2",
