jroachgolf84 commented on code in PR #67839:
URL: https://github.com/apache/airflow/pull/67839#discussion_r3439952528


##########
airflow-core/tests/unit/jobs/test_triggerer_job.py:
##########
@@ -561,6 +571,255 @@ def 
test_create_workload_uses_supervisor_id_without_job(jobless_supervisor, mock
     assert factory.log_path == f"/logs/ti.trigger.{jobless_supervisor.id}.log"
 
 
+def 
test_create_workload_sets_watched_assets_for_asset_only_trigger(jobless_supervisor,
 mocker):
+    """_create_workload() should populate watched_assets when 
trigger.task_instance is None and assets exist."""
+    asset1 = mocker.Mock(spec=Asset)
+    asset1.name = "my_asset"
+    asset1.uri = "s3://bucket/key"
+
+    asset2 = mocker.Mock(spec=Asset)
+    asset2.name = "other_asset"
+    asset2.uri = "gs://bucket/path"
+
+    trigger = mocker.Mock(spec=BaseEventTrigger)
+    trigger.id = 42
+    trigger.classpath = "some.path.Trigger"
+    trigger.encrypted_kwargs = "encrypted"
+    trigger.task_instance = None  # Not tied to a Task (similar to a 
BaseEventTrigger)
+    trigger.assets = [asset1, asset2]
+
+    workload = jobless_supervisor._create_workload(
+        trigger=trigger,
+        dag_bag=mocker.Mock(),
+        render_log_fname=mocker.Mock(),
+        session=mocker.Mock(),
+    )
+
+    assert workload is not None
+    assert workload.watched_assets == {"my_asset": "s3://bucket/key", 
"other_asset": "gs://bucket/path"}
+
+
+def 
test_create_workload_watched_assets_none_when_no_assets(jobless_supervisor, 
mocker):
+    """_create_workload() should set watched_assets=None when 
trigger.task_instance is None and assets is empty."""
+    trigger = mocker.Mock(spec=BaseEventTrigger)
+    trigger.id = 43
+    trigger.classpath = "some.path.Trigger"
+    trigger.encrypted_kwargs = "encrypted"
+    trigger.task_instance = None
+    trigger.assets = []  # No Assets are attached to the trigger
+
+    workload = jobless_supervisor._create_workload(
+        trigger=trigger,
+        dag_bag=mocker.Mock(),
+        render_log_fname=mocker.Mock(),
+        session=mocker.Mock(),
+    )
+
+    assert workload is not None
+    assert workload.watched_assets is None
+
+
+def test_run_trigger_workload_includes_watched_assets_field():
+    """RunTrigger workload should accept and store watched_assets."""
+    workload = RunTrigger(
+        id=1,
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        encrypted_kwargs="fake",
+        watched_assets={"asset_a": "s3://a", "asset_b": "gs://b"},
+    )
+    assert workload.watched_assets == {"asset_a": "s3://a", "asset_b": 
"gs://b"}
+
+
+def test_run_trigger_workload_watched_assets_defaults_to_none():
+    """RunTrigger workload watched_assets should default to None."""
+    workload = RunTrigger(
+        id=1,
+        classpath="airflow.triggers.testing.SuccessTrigger",
+        encrypted_kwargs="fake",
+    )
+    assert workload.watched_assets is None
+
+
[email protected]
+def make_watcher_trigger():
+    """Factory fixture: call with a list to get a BaseEventTrigger subclass 
that appends each new instance."""
+
+    def factory(injected_instances):
+        class WatcherTrigger(BaseEventTrigger):
+            def __init__(self, **kwargs):
+                super().__init__(**kwargs)
+                injected_instances.append(self)
+
+            def serialize(self):
+                return (f"{type(self).__module__}.{type(self).__qualname__}", 
{})
+
+            async def run(self):
+                yield TriggerEvent("done")
+
+        return WatcherTrigger
+
+    return factory
+
+
[email protected]
+@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath")
+async def 
test_create_triggers_injects_asset_state_store_for_base_event_trigger(
+    mock_get_classpath, session, make_watcher_trigger
+):
+    """asset_state_store is populated on BaseEventTrigger instances when 
watched_assets is set."""
+    injected_instances = []
+    mock_get_classpath.return_value = make_watcher_trigger(injected_instances)
+
+    runner = TriggerRunner()
+    runner.to_create.append(
+        workloads.RunTrigger.model_construct(
+            id=10,
+            ti=None,
+            classpath="fake.WatcherTrigger",
+            encrypted_kwargs="{}",
+            watched_assets={"my_asset": "s3://bucket/key"},
+        )
+    )
+
+    await runner.create_triggers()
+
+    # This is only testing that an exception was NOT thrown when creating the 
Trigger
+    assert 10 in runner.triggers
+
+    assert len(injected_instances) == 1
+    assert injected_instances[0].asset_state_store is not None
+    assert isinstance(injected_instances[0].asset_state_store, 
AssetStateStoreAccessors)
+
+    runner.triggers[10]["task"].cancel()
+    await runner.cleanup_finished_triggers()
+
+
[email protected]
+@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath")
+async def test_create_triggers_asset_state_store_none_when_no_watched_assets(
+    mock_get_classpath, session, make_watcher_trigger
+):
+    """asset_state_store stays None when watched_assets is not set on the 
workload."""
+    injected_instances = []
+    mock_get_classpath.return_value = make_watcher_trigger(injected_instances)
+
+    runner = TriggerRunner()
+    runner.to_create.append(
+        workloads.RunTrigger.model_construct(
+            id=11,
+            ti=None,
+            classpath="fake.WatcherTrigger",
+            encrypted_kwargs="{}",
+            watched_assets=None,
+        )
+    )
+
+    await runner.create_triggers()
+
+    assert len(injected_instances) == 1
+    assert injected_instances[0].asset_state_store is None
+
+    runner.triggers[11]["task"].cancel()
+    await runner.cleanup_finished_triggers()
+
+
[email protected]
+@patch("airflow.jobs.triggerer_job_runner.TriggerRunner.get_trigger_by_classpath")
+async def 
test_create_triggers_skips_asset_state_store_for_non_event_trigger(mock_get_classpath,
 session):
+    """asset_state_store injection is skipped for plain BaseTrigger 
(non-BaseEventTrigger) instances."""
+    mock_get_classpath.return_value = SuccessTrigger
+
+    runner = TriggerRunner()
+    runner.to_create.append(
+        workloads.RunTrigger.model_construct(
+            id=12, ti=None, 
classpath="airflow.triggers.testing.SuccessTrigger", encrypted_kwargs="{}"
+        )
+    )
+
+    await runner.create_triggers()
+
+    assert 12 in runner.triggers
+    assert not hasattr(runner.triggers[12]["task"], "asset_state_store")

Review Comment:
   Resolved in next commit.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to