amoghrajesh commented on code in PR #45924:
URL: https://github.com/apache/airflow/pull/45924#discussion_r1925216358


##########
airflow/api_fastapi/execution_api/datamodels/taskinstance.py:
##########
@@ -54,12 +64,52 @@ class TIEnterRunningPayload(BaseModel):
 class TITerminalStatePayload(BaseModel):
     """Schema for updating TaskInstance to a terminal state (e.g., SUCCESS or 
FAILED)."""
 
-    state: TerminalTIState
+    state: Literal[
+        TerminalTIState.FAILED,
+        TerminalTIState.SKIPPED,
+        TerminalTIState.REMOVED,
+        TerminalTIState.FAIL_WITHOUT_RETRY,
+    ]
 
     end_date: UtcDateTime
     """When the task completed executing"""
 
 
+class TISuccessStatePayload(BaseModel):
+    """Schema for updating TaskInstance to success state."""
+
+    state: Annotated[
+        Literal[TerminalTIState.SUCCESS],
+        # Specify a default in the schema, but not in code, so Pydantic marks 
it as required.
+        WithJsonSchema(
+            {
+                "type": "string",
+                "enum": [TerminalTIState.SUCCESS],
+                "default": TerminalTIState.SUCCESS,
+            }
+        ),
+    ]
+
+    end_date: UtcDateTime
+    """When the task completed executing"""
+
+    task_outlets: Annotated[list[AssetNameAndUri], Field(default_factory=list)]
+    outlet_events: Annotated[list[Any], Field(default_factory=list)]
+    asset_type: str | None = None
+
+    @root_validator(pre=True)
+    def parse_json_fields(cls, values):
+        import json
+
+        if "task_outlets" in values and isinstance(values["task_outlets"], 
str):
+            values["task_outlets"] = json.loads(values["task_outlets"])
+
+        if "outlet_events" in values and isinstance(values["outlet_events"], 
str):
+            values["outlet_events"] = json.loads(values["outlet_events"])
+
+        return values

Review Comment:
   This is done so that the client and server can communicate in json string 
format.



##########
tests/api_fastapi/execution_api/routes/test_task_instances.py:
##########
@@ -322,6 +415,7 @@ def test_ti_update_state_database_error(self, client, 
session, create_task_insta
             "airflow.api_fastapi.common.db.common.Session.execute",
             side_effect=[
                 mock.Mock(one=lambda: ("running", 1, 0)),  # First call 
returns "queued"
+                mock.Mock(one=lambda: ("running", 1, 0)),  # Second call 
returns "queued"

Review Comment:
   success state adds another query to "ti" table. Hence needed.



##########
airflow/api_fastapi/execution_api/routes/task_instances.py:
##########
@@ -440,3 +462,103 @@ def _is_eligible_to_retry(state: str, try_number: int, 
max_tries: int) -> bool:
     # max_tries is initialised with the retries defined at task level, we do 
not need to explicitly ask for
     # retries from the task SDK now, we can handle using max_tries
     return max_tries != 0 and try_number <= max_tries
+
+
+def register_asset_changes(task_instance, task_outlets, outlet_events, 
asset_type, session):
+    # One task only triggers one asset event for each asset with the same 
extra.
+    # This tuple[asset uri, extra] to sets alias names mapping is used to find 
whether
+    # there're assets with same uri but different extra that we need to emit 
more than one asset events.
+    asset_alias_names: dict[tuple[AssetUniqueKey, frozenset], set[str]] = 
defaultdict(set)
+    asset_name_refs: set[str] = set()
+    asset_uri_refs: set[str] = set()
+
+    for obj in task_outlets:
+        # Lineage can have other types of objects besides assets
+        if asset_type == "Asset":
+            asset_manager.register_asset_change(
+                task_instance=task_instance,
+                asset=Asset(name=obj.name, uri=obj.uri),
+                extra=outlet_events,
+                session=session,
+            )
+        elif asset_type == "AssetNameRef":
+            asset_name_refs.add(obj.name)
+        elif asset_type == "AssetUriRef":
+            asset_uri_refs.add(obj.uri)
+
+    if asset_type == "AssetAlias":
+        # deserialize to the expected type
+        outlet_events = list(
+            map(
+                lambda event: {**event, "dest_asset_key": 
AssetUniqueKey(**event["dest_asset_key"])},
+                outlet_events,
+            )
+        )

Review Comment:
   Converting it to `AssetUniqueKey` format so that the below access can remain 
simple and in lines with legacy code. Otherwise Line 502 will complain that we 
cannot use mutable key for dictionary



##########
task_sdk/src/airflow/sdk/execution_time/task_runner.py:
##########
@@ -479,12 +481,40 @@ def run(ti: RuntimeTaskInstance, log: Logger):
 
         _push_xcom_if_needed(result, ti)
 
+        task_outlets = []
+        outlet_events = []
+        events = context["outlet_events"]
+        asset_type = ""
+
+        for obj in ti.task.outlets or []:
+            # Lineage can have other types of objects besides assets
+            if isinstance(obj, Asset):
+                task_outlets.append(AssetNameAndUri(name=obj.name, 
uri=obj.uri))
+                outlet_events.append(attrs.asdict(events[obj]))  # type: ignore
+            elif isinstance(obj, AssetNameRef):
+                task_outlets.append(AssetNameAndUri(name=obj.name))
+                # send all as we do not know how to filter here yet
+                outlet_events.append(attrs.asdict(events))  # type: ignore
+            elif isinstance(obj, AssetUriRef):
+                task_outlets.append(AssetNameAndUri(uri=obj.uri))
+                # send all as we do not know how to filter here yet
+                outlet_events.append(attrs.asdict(events))  # type: ignore

Review Comment:
   We can probably use the asset client API to get AssetModel here and send it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to