uranusjr commented on code in PR #37498: URL: https://github.com/apache/airflow/pull/37498#discussion_r1523526971
########## airflow/ti_deps/deps/mapped_task_upstream_dep.py: ########## @@ -0,0 +1,92 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING + +from airflow.ti_deps.deps.base_ti_dep import BaseTIDep +from airflow.utils.state import State, TaskInstanceState + +if TYPE_CHECKING: + from sqlalchemy.orm import Session + + from airflow.models.taskinstance import TaskInstance + from airflow.ti_deps.dep_context import DepContext + from airflow.ti_deps.deps.base_ti_dep import TIDepStatus + + +class MappedTaskUpstreamDep(BaseTIDep): + """ + Determines if a mapped task's upstream tasks that provide XComs used by this task for task mapping are in + a state that allows a given task instance to run. + """ + + NAME = "Mapped dependencies have succeeded" + IGNORABLE = True + IS_TASK_DEP = True + + def _get_dep_statuses( + self, + ti: TaskInstance, + session: Session, + dep_context: DepContext, + ) -> Iterator[TIDepStatus]: + from airflow.models.mappedoperator import MappedOperator + + if isinstance(ti.task, MappedOperator): + mapped_dependencies = ti.task.iter_mapped_dependencies() + elif (task_group := ti.task.get_closest_mapped_task_group()) is not None: + mapped_dependencies = task_group.iter_mapped_dependencies() + else: + return + + mapped_dependency_tis = [ + ti.get_dagrun(session).get_task_instance(operator.task_id, session=session) + for operator in mapped_dependencies + ] Review Comment: I would just do ```python mapped_dependency_tis = session.scalars(select(TaskInstance).where(...)).all() ``` I believe your version would not work if a dependency is itself mapped. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org