uranusjr commented on code in PR #39426:
URL: https://github.com/apache/airflow/pull/39426#discussion_r1592596293


##########
airflow/utils/db.py:
##########
@@ -1888,12 +1914,119 @@ def get_query_count(query_stmt: Select, *, session: 
Session) -> int:
     return session.scalar(count_stmt)
 
 
+def get_query_exists(query_stmt: Select, *, session: Session) -> bool:
+    """Check whether there is at least one row matching a query.
+
+    A SELECT 1 FROM is issued against the subquery built from the given
+    statement. The ORDER BY clause is stripped from the statement since it's
+    unnecessary, and can impact query planning and degrade performance.
+
+    :meta private:
+    """
+    count_stmt = 
select(literal(True)).select_from(query_stmt.order_by(None).subquery())
+    return session.scalar(count_stmt)
+
+
 def exists_query(*where: ClauseElement, session: Session) -> bool:
-    """Check whether there is at least one row matching given clause.
+    """Check whether there is at least one row matching given clauses.
 
     This does a SELECT 1 WHERE ... LIMIT 1 and check the result.
 
     :meta private:
     """
     stmt = select(literal(True)).where(*where).limit(1)
     return session.scalar(stmt) is not None
+
+
+@attrs.define(slots=True)
+class LazySelectSequence(Sequence[T]):
+    """List-like interface to lazily access a database model query.
+
+    The intended use case is inside a task execution context, where we manage 
an
+    active SQLAlchemy session in the background.
+
+    :meta private:
+    """
+
+    _select_asc: ClauseElement
+    _select_desc: ClauseElement
+    _process_row: Callable[[Any], Any] = attrs.field(kw_only=True, 
default=operator.itemgetter(0))
+    _session: Session = attrs.field(kw_only=True, 
factory=get_current_task_instance_session)
+    _len: int | None = attrs.field(init=False, default=None)
+
+    @classmethod
+    def from_select(
+        cls,
+        select: Select,
+        *,
+        order_by: Sequence[ClauseElement],
+        process_row: Callable[[Any], Any] = operator.itemgetter(0),
+        session: Session | None = None,
+    ) -> Self:
+        s1 = select
+        for col in order_by:
+            s1 = s1.order_by(col.asc())
+        s2 = select
+        for col in order_by:
+            s2 = s2.order_by(col.desc())
+        return cls(s1, s2, process_row=process_row, session=session or 
get_current_task_instance_session())
+
+    def __repr__(self) -> str:
+        return f"LazySelectSequence([{len(self)} items])"
+
+    def __str__(self) -> str:
+        return str(list(self))
+
+    def __getstate__(self) -> Any:
+        # We don't want to go to the trouble of serializing SQLAlchemy objects.
+        # Converting the statement into a SQL string is the best we can get.
+        # The literal_binds compile argument inlines all the values into the 
SQL
+        # string to simplify cross-process commuinication as much as possible.
+        # Theoratically we can do the same for count(), but I think it should 
be
+        # performant enough to calculate only that eagerly.
+        s1 = str(self._select_asc.compile(self._session.get_bind(), 
compile_kwargs={"literal_binds": True}))
+        s2 = str(self._select_desc.compile(self._session.get_bind(), 
compile_kwargs={"literal_binds": True}))
+        return (s1, s2, self._process_row, len(self))
+
+    def __setstate__(self, state: Any) -> None:
+        s1, s2, self._process_row, self._len = state
+        self._select_asc = text(s1)
+        self._select_desc = text(s2)
+        self._session = get_current_task_instance_session()
+
+    def __bool__(self) -> bool:
+        return get_query_exists(self._select_asc, session=self._session)
+
+    def __eq__(self, other: Any) -> bool:
+        if not isinstance(other, collections.abc.Sequence):
+            return NotImplemented
+        z = itertools.zip_longest(iter(self), iter(other), fillvalue=object())
+        return all(x == y for x, y in z)
+
+    def __reversed__(self) -> Iterator[T]:
+        return iter(self._process_row(r) for r in 
self._session.execute(self._select_desc))
+
+    def __iter__(self) -> Iterator[T]:
+        return iter(self._process_row(r) for r in 
self._session.execute(self._select_asc))
+
+    def __len__(self) -> int:
+        if (le := self._len) is None:
+            le = self._len = get_query_count(self._select_asc, 
session=self._session)
+        return le
+

Review Comment:
   Not required



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to