This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 2746dcb1a24 Fix mypy errors for task_instance access in provider
triggers (#68685)
2746dcb1a24 is described below
commit 2746dcb1a24287fc052d9c078b4a34bdd7a609ae
Author: Vincent <[email protected]>
AuthorDate: Wed Jun 17 17:23:11 2026 -0400
Fix mypy errors for task_instance access in provider triggers (#68685)
---
.../airflow/providers/amazon/aws/triggers/emr.py | 19 ++--
.../providers/cncf/kubernetes/triggers/pod.py | 19 ++--
.../providers/google/cloud/triggers/bigquery.py | 40 ++++---
.../providers/google/cloud/triggers/dataproc.py | 120 ++++++++++++---------
4 files changed, 114 insertions(+), 84 deletions(-)
diff --git a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
index c98641b0c6b..48ddb4b0c31 100644
--- a/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
+++ b/providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py
@@ -423,19 +423,22 @@ class EmrServerlessStartJobTrigger(AwsBaseWaiterTrigger):
"""Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
from sqlalchemy import select
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
query = select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
task_instance = session.scalars(query).one_or_none()
if task_instance is None:
raise ValueError(
- f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
- f"task_id: {self.task_instance.task_id}, "
- f"run_id: {self.task_instance.run_id} and "
- f"map_index: {self.task_instance.map_index} is not found"
+ f"TaskInstance with dag_id: {ti.dag_id}, "
+ f"task_id: {ti.task_id}, "
+ f"run_id: {ti.run_id} and "
+ f"map_index: {ti.map_index} is not found"
)
return task_instance
diff --git
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
index 0599b0519a6..4ee0b90c51f 100644
---
a/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
+++
b/providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py
@@ -440,21 +440,24 @@ class KubernetesPodTrigger(BaseTrigger):
@provide_session
def get_task_instance(self, *, session: Session) -> TaskInstance:
"""Get the task instance for this trigger from the database
(Airflow 2.x only)."""
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_instance = session.scalar(
select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
)
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_instance
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index 659cc9f8bf3..e11059dfbb8 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -120,42 +120,48 @@ class BigQueryInsertJobTrigger(BaseTrigger):
@provide_session
def get_task_instance(self, *, session: Session) -> TaskInstance:
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_instance = session.scalar(
select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
)
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_instance
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
+ dag_id=ti.dag_id,
+ task_ids=[ti.task_id],
+ run_ids=[ti.run_id],
+ map_index=ti.map_index,
)
try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ task_state = task_states_response[ti.run_id][ti.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_state
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
index 5a218c21564..17eba607d58 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/dataproc.py
@@ -141,42 +141,48 @@ class DataprocSubmitTrigger(DataprocBaseTrigger):
:param session: Sqlalchemy session
"""
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_instance = session.scalar(
select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
)
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_instance
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
+ dag_id=ti.dag_id,
+ task_ids=[ti.task_id],
+ run_ids=[ti.run_id],
+ map_index=ti.map_index,
)
try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ task_state = task_states_response[ti.run_id][ti.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_state
@@ -293,42 +299,48 @@ class DataprocSubmitJobDirectTrigger(DataprocBaseTrigger):
:param session: Sqlalchemy session
"""
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_instance = session.scalar(
select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
)
if task_instance is None:
raise RuntimeError(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_instance
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
+ dag_id=ti.dag_id,
+ task_ids=[ti.task_id],
+ run_ids=[ti.run_id],
+ map_index=ti.map_index,
)
try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ task_state = task_states_response[ti.run_id][ti.task_id]
except Exception:
raise RuntimeError(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_state
@@ -432,42 +444,48 @@ class DataprocClusterTrigger(DataprocBaseTrigger):
@provide_session
def get_task_instance(self, *, session: Session) -> TaskInstance:
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_instance = session.scalar(
select(TaskInstance).where(
- TaskInstance.dag_id == self.task_instance.dag_id,
- TaskInstance.task_id == self.task_instance.task_id,
- TaskInstance.run_id == self.task_instance.run_id,
- TaskInstance.map_index == self.task_instance.map_index,
+ TaskInstance.dag_id == ti.dag_id,
+ TaskInstance.task_id == ti.task_id,
+ TaskInstance.run_id == ti.run_id,
+ TaskInstance.map_index == ti.map_index,
)
)
if task_instance is None:
raise AirflowException(
"TaskInstance with dag_id: %s,task_id: %s, run_id: %s and
map_index: %s is not found.",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_instance
async def get_task_state(self):
from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+ ti = self.task_instance
+ if ti is None:
+ raise RuntimeError("task_instance is not set on the trigger")
task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
- dag_id=self.task_instance.dag_id,
- task_ids=[self.task_instance.task_id],
- run_ids=[self.task_instance.run_id],
- map_index=self.task_instance.map_index,
+ dag_id=ti.dag_id,
+ task_ids=[ti.task_id],
+ run_ids=[ti.run_id],
+ map_index=ti.map_index,
)
try:
- task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ task_state = task_states_response[ti.run_id][ti.task_id]
except Exception:
raise AirflowException(
"TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
- self.task_instance.dag_id,
- self.task_instance.task_id,
- self.task_instance.run_id,
- self.task_instance.map_index,
+ ti.dag_id,
+ ti.task_id,
+ ti.run_id,
+ ti.map_index,
)
return task_state