amoghrajesh commented on code in PR #45924: URL: https://github.com/apache/airflow/pull/45924#discussion_r1925216358
########## airflow/api_fastapi/execution_api/datamodels/taskinstance.py: ########## @@ -54,12 +64,52 @@ class TIEnterRunningPayload(BaseModel): class TITerminalStatePayload(BaseModel): """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or FAILED).""" - state: TerminalTIState + state: Literal[ + TerminalTIState.FAILED, + TerminalTIState.SKIPPED, + TerminalTIState.REMOVED, + TerminalTIState.FAIL_WITHOUT_RETRY, + ] end_date: UtcDateTime """When the task completed executing""" +class TISuccessStatePayload(BaseModel): + """Schema for updating TaskInstance to success state.""" + + state: Annotated[ + Literal[TerminalTIState.SUCCESS], + # Specify a default in the schema, but not in code, so Pydantic marks it as required. + WithJsonSchema( + { + "type": "string", + "enum": [TerminalTIState.SUCCESS], + "default": TerminalTIState.SUCCESS, + } + ), + ] + + end_date: UtcDateTime + """When the task completed executing""" + + task_outlets: Annotated[list[AssetNameAndUri], Field(default_factory=list)] + outlet_events: Annotated[list[Any], Field(default_factory=list)] + asset_type: str | None = None + + @root_validator(pre=True) + def parse_json_fields(cls, values): + import json + + if "task_outlets" in values and isinstance(values["task_outlets"], str): + values["task_outlets"] = json.loads(values["task_outlets"]) + + if "outlet_events" in values and isinstance(values["outlet_events"], str): + values["outlet_events"] = json.loads(values["outlet_events"]) + + return values Review Comment: This is done so that the client and server can communicate in json string format. ########## tests/api_fastapi/execution_api/routes/test_task_instances.py: ########## @@ -322,6 +415,7 @@ def test_ti_update_state_database_error(self, client, session, create_task_insta "airflow.api_fastapi.common.db.common.Session.execute", side_effect=[ mock.Mock(one=lambda: ("running", 1, 0)), # First call returns "queued" + mock.Mock(one=lambda: ("running", 1, 0)), # Second call returns "queued" Review Comment: success state adds another query to "ti" table. Hence needed. ########## airflow/api_fastapi/execution_api/routes/task_instances.py: ########## @@ -440,3 +462,103 @@ def _is_eligible_to_retry(state: str, try_number: int, max_tries: int) -> bool: # max_tries is initialised with the retries defined at task level, we do not need to explicitly ask for # retries from the task SDK now, we can handle using max_tries return max_tries != 0 and try_number <= max_tries + + +def register_asset_changes(task_instance, task_outlets, outlet_events, asset_type, session): + # One task only triggers one asset event for each asset with the same extra. + # This tuple[asset uri, extra] to sets alias names mapping is used to find whether + # there're assets with same uri but different extra that we need to emit more than one asset events. + asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = defaultdict(set) + asset_name_refs: set[str] = set() + asset_uri_refs: set[str] = set() + + for obj in task_outlets: + # Lineage can have other types of objects besides assets + if asset_type == "Asset": + asset_manager.register_asset_change( + task_instance=task_instance, + asset=Asset(name=obj.name, uri=obj.uri), + extra=outlet_events, + session=session, + ) + elif asset_type == "AssetNameRef": + asset_name_refs.add(obj.name) + elif asset_type == "AssetUriRef": + asset_uri_refs.add(obj.uri) + + if asset_type == "AssetAlias": + # deserialize to the expected type + outlet_events = list( + map( + lambda event: {**event, "dest_asset_key": AssetUniqueKey(**event["dest_asset_key"])}, + outlet_events, + ) + ) Review Comment: Converting it to `AssetUniqueKey` format so that the below access can remain simple and in lines with legacy code. Otherwise Line 502 will complain that we cannot use mutable key for dictionary ########## task_sdk/src/airflow/sdk/execution_time/task_runner.py: ########## @@ -479,12 +481,40 @@ def run(ti: RuntimeTaskInstance, log: Logger): _push_xcom_if_needed(result, ti) + task_outlets = [] + outlet_events = [] + events = context["outlet_events"] + asset_type = "" + + for obj in ti.task.outlets or []: + # Lineage can have other types of objects besides assets + if isinstance(obj, Asset): + task_outlets.append(AssetNameAndUri(name=obj.name, uri=obj.uri)) + outlet_events.append(attrs.asdict(events[obj])) # type: ignore + elif isinstance(obj, AssetNameRef): + task_outlets.append(AssetNameAndUri(name=obj.name)) + # send all as we do not know how to filter here yet + outlet_events.append(attrs.asdict(events)) # type: ignore + elif isinstance(obj, AssetUriRef): + task_outlets.append(AssetNameAndUri(uri=obj.uri)) + # send all as we do not know how to filter here yet + outlet_events.append(attrs.asdict(events)) # type: ignore Review Comment: We can probably use the asset client API to get AssetModel here and send it -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org