ashb commented on code in PR #53082:
URL: https://github.com/apache/airflow/pull/53082#discussion_r2204748078
##########
task-sdk/tests/task_sdk/execution_time/test_supervisor.py:
##########
@@ -147,6 +147,66 @@ def client_with_ti_start(make_ti_context):
return client
[email protected]("disable_capturing")
+class TestSupervisor:
+ @pytest.mark.parametrize(
+ "server, dry_run, error_pattern",
+ [
+ ("/execution/", False, "Invalid execution API server URL"),
+ ("", False, "Invalid execution API server URL"),
+ ("http://localhost:8080", True, "Can only specify one of"),
+ (None, True, None),
+ ("http://localhost:8080/execution/", False, None),
+ ("https://localhost:8080/execution/", False, None),
+ ],
+ )
+ def test_supervise(
+ self,
+ # mock_mask_secret,
+ patched_secrets_masker,
+ server,
+ dry_run,
+ error_pattern,
+ test_dags_dir,
+ client_with_ti_start,
+ ):
+ """
+ Test that the supervisor validates server URL and dry_run parameter
combinations correctly.
+ """
+ ti = TaskInstance(
+ id=uuid7(),
+ task_id="async",
+ dag_id="super_basic_deferred_run",
+ run_id="d",
+ try_number=1,
+ dag_version_id=uuid7(),
+ )
+
+ bundle_info = BundleInfo(name="my-bundle", version=None)
+
+ with patch.dict(os.environ, local_dag_bundle_cfg(test_dags_dir,
bundle_info.name)):
+ if error_pattern:
+ with pytest.raises(ValueError, match=error_pattern):
+ supervise(
+ ti=ti,
+ dag_rel_path="super_basic_deferred_run.py",
+ token="",
+ bundle_info=bundle_info,
+ dry_run=dry_run,
+ server=server,
+ )
+ else:
+ supervise(
+ ti=ti,
+ dag_rel_path="super_basic_deferred_run.py",
+ token="",
+ bundle_info=bundle_info,
+ dry_run=dry_run,
+ server=server,
+ client=client_with_ti_start,
+ )
Review Comment:
A better pattern other than this is to do this with the nullcontext
https://docs.pytest.org/en/stable/example/parametrize.html#parametrizing-conditional-raising
```python
from contextlib import nullcontext
import pytest
@pytest.mark.parametrize(
"example_input,expectation",
[
(3, nullcontext(2)),
(2, nullcontext(3)),
(1, nullcontext(6)),
(0, pytest.raises(ZeroDivisionError)),
],
)
def test_division(example_input, expectation):
"""Test how much I know division."""
with expectation as e:
assert (6 / example_input) == e
```
##########
task-sdk/src/airflow/sdk/execution_time/supervisor.py:
##########
@@ -1633,12 +1633,18 @@ def supervise(
:param subprocess_logs_to_stdout: Should task logs also be sent to stdout
via the main logger.
:param client: Optional preconfigured client for communication with the
server (Mostly for tests).
:return: Exit code of the process.
+ :raises ValueError: If server URL is empty or invalid.
"""
# One or the other
from airflow.sdk.execution_time.secrets_masker import reset_secrets_masker
- if not client and ((not server) ^ dry_run):
- raise ValueError(f"Can only specify one of {server=} or {dry_run=}")
+ if not client:
+ if dry_run and server:
+ raise ValueError(f"Can only specify one of {server=} or
{dry_run=}")
+ if not dry_run and (not server or not server.startswith(("http://",
"https://"))):
Review Comment:
It's a URL, not a host -- it should have a protocol, and it's no bad thing
to require users to specify it precisely/exactly.
I would suggest parsing it with urllib and then validating things (such as
schema etc) as that will catch a few more things
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]