This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new c6bdb9171ad SFTP: add async retrieve_file / store_file / mkdir to
SFTPHookAsync and introduce SFTPClientPool (#64465)
c6bdb9171ad is described below
commit c6bdb9171ad6208270938d2f9ff484551f2f1408
Author: David Blain <[email protected]>
AuthorDate: Sun May 10 19:05:29 2026 +0200
SFTP: add async retrieve_file / store_file / mkdir to SFTPHookAsync and
introduce SFTPClientPool (#64465)
* refactor: Added additional async methods to SFTPHookAsync like the
SFTPHook counterpart and added SFTPClientPool
* refactor: Also specify version for types-aiofiles
* refactor: Updated aiofiles dependencies
* refactor: Downgraded aiofiles dependencies
* refactor: Make sure SFTPHookAsync is lazily imported
* fix: Fixed asyncssh.misc.ConnectionLost: Connection lost in SFTPClientPool
* refactor: Lazy construct pool form current running event loop
* refactor: Don't close pool when exiting it, only close it explicitly
otherwise the singleton has no value
* refactor: Replaced pool size configuration with os.cpu_count instead of
core.parallelism configuration parameter which is actually used by scheduler
* refactor: Made SFTP acquire/create resilient
* refactor: Keep original pool size if pool size changes over time for same
conn_id
* refactor: Preserve non-recursive SFTP entries
* refactor: Added new features in changelog
* refactor: Moved types-aiofiles dependency to dev
* refactor: Drop unreachable return statement
* refactor: Splitted async list_directory with recursion option into 2
methods, one list_directory without recursion like the sync variant and added
an async walktree method which respects same contract as the sync version and
which allows us to do recursion
* refactor: Added symlink-loop guard in walktree method
* refactor: Make sure makedirs is always called first independently of
which branch is chosen in store_file method
* refactor: Make sure retrieve_files always returns bytes
* refactor: Improve exception handling in get_sftp_client of SFTPClientPool
* refactor: Clarify async SFTP upload input types for store_file
* refactor: Test blocking of acquire method on SFTPClientPool when all
connections are already in use
* refactor: Improved mocking of aiofiles in test_retrieve_file_to_path
* refactor: Removed new line
* refactor: Added recursive option to list_directory
* refactor: Remove new line in TestSFTPClientPool
* refactor: Updated uv.lock
* refactor: Reformatted SFTPClientPool
* refactor: Reformatted docstring walktree method of SFTPHookAsync
* refactor: Removed types-aiofiles from dev dependencies in sftp provider
* refactor: Reformatted retrieve_file method of SFTPHookAsync
* refactor: Reformatted TestSFTPHookAsync
* refactor: Fixed mypy remark concerning BytesIO cast in store_file
* refactor: Fixed tests
* refactor: Removed newline
---
providers/sftp/README.rst | 1 +
providers/sftp/docs/changelog.rst | 8 +
providers/sftp/docs/index.rst | 1 +
providers/sftp/pyproject.toml | 1 +
.../sftp/src/airflow/providers/sftp/hooks/sftp.py | 245 ++++++++++--
.../airflow/providers/sftp/pools/__init__.py} | 3 -
.../sftp/src/airflow/providers/sftp/pools/sftp.py | 282 ++++++++++++++
providers/sftp/tests/conftest.py | 32 ++
providers/sftp/tests/unit/sftp/hooks/test_sftp.py | 420 ++++++++++++++++++---
.../{conftest.py => unit/sftp/pools/__init__.py} | 3 -
providers/sftp/tests/unit/sftp/pools/test_sftp.py | 372 ++++++++++++++++++
uv.lock | 2 +
12 files changed, 1296 insertions(+), 74 deletions(-)
diff --git a/providers/sftp/README.rst b/providers/sftp/README.rst
index dff6ef0b5e9..529cb75a5ea 100644
--- a/providers/sftp/README.rst
+++ b/providers/sftp/README.rst
@@ -53,6 +53,7 @@ Requirements
==========================================
======================================
PIP package Version required
==========================================
======================================
+``aiofiles`` ``>=23.2.0``
``apache-airflow`` ``>=2.11.0``
``apache-airflow-providers-ssh`` ``>=4.0.0``
``apache-airflow-providers-common-compat`` ``>=1.12.0``
diff --git a/providers/sftp/docs/changelog.rst
b/providers/sftp/docs/changelog.rst
index a63b9cc5d60..89f3209a2c1 100644
--- a/providers/sftp/docs/changelog.rst
+++ b/providers/sftp/docs/changelog.rst
@@ -27,6 +27,14 @@
Changelog
---------
+5.7.5
+.....
+
+Features
+~~~~~~~~
+
+* ``SFTP: add async retrieve_file / store_file / mkdir to SFTPHookAsync and
introduce SFTPClientPool``
+
5.7.4
.....
diff --git a/providers/sftp/docs/index.rst b/providers/sftp/docs/index.rst
index 4dd775a5cdf..30c69edf0d1 100644
--- a/providers/sftp/docs/index.rst
+++ b/providers/sftp/docs/index.rst
@@ -100,6 +100,7 @@ The minimum Apache Airflow version supported by this
provider distribution is ``
==========================================
======================================
PIP package Version required
==========================================
======================================
+``aiofiles`` ``>=23.2.0``
``apache-airflow`` ``>=2.11.0``
``apache-airflow-providers-ssh`` ``>=4.0.0``
``apache-airflow-providers-common-compat`` ``>=1.12.0``
diff --git a/providers/sftp/pyproject.toml b/providers/sftp/pyproject.toml
index 4d0e4c91360..4cb85c5ee63 100644
--- a/providers/sftp/pyproject.toml
+++ b/providers/sftp/pyproject.toml
@@ -59,6 +59,7 @@ requires-python = ">=3.10"
# Make sure to run ``prek update-providers-dependencies --all-files``
# After you modify the dependencies, and rebuild your Breeze CI image with
``breeze ci-image build``
dependencies = [
+ "aiofiles>=23.2.0",
"apache-airflow>=2.11.0",
"apache-airflow-providers-ssh>=4.0.0",
"apache-airflow-providers-common-compat>=1.12.0",
diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
index 80ceb729082..2f31c66d3b2 100644
--- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
@@ -23,15 +23,17 @@ import concurrent.futures
import datetime
import functools
import os
+import posixpath
import stat
import warnings
from collections.abc import Callable, Generator, Sequence
-from contextlib import contextmanager
+from contextlib import contextmanager, suppress
from fnmatch import fnmatch
from io import BytesIO
-from pathlib import Path
+from pathlib import Path, PurePosixPath
from typing import IO, TYPE_CHECKING, Any, cast
+import aiofiles
import asyncssh
from paramiko.config import SSH_PORT
@@ -46,6 +48,8 @@ if TYPE_CHECKING:
from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_client import SFTPClient
+CHUNK_SIZE = 64 * 1024 # 64KB
+
def handle_connection_management(func: Callable) -> Callable:
@functools.wraps(func)
@@ -197,13 +201,39 @@ class SFTPHook(SSHHook):
}
@handle_connection_management
- def list_directory(self, path: str) -> list[str]:
+ def list_directory(self, path: str, recursive: bool = False) -> list[str]
| None:
"""
List files in a directory on the remote system.
+ Lists one-level entry names under the given directory path.
+
+ If ``recursive=True``, returns files recursively as paths relative to
``path``.
+
:param path: full path to the remote directory to list
+ :param recursive: Whether to recursively list descendants.
+ :return: List of entry names found under the directory, or None if the
directory does not exist.
"""
- return sorted(self.conn.listdir(path)) # type: ignore[union-attr]
+ if recursive:
+ files: list[str] = []
+
+ def append_relative(item: str) -> None:
+ files.append(os.path.relpath(item, path))
+
+ try:
+ self.walktree(
+ path=path,
+ fcallback=append_relative,
+ dcallback=lambda _: None,
+ ucallback=lambda _: None,
+ )
+ except OSError:
+ return None
+ return sorted(files)
+
+ try:
+ return sorted(self.conn.listdir(path)) # type: ignore[union-attr]
+ except OSError:
+ return None
@handle_connection_management
def list_directory_with_attr(self, path: str) -> list[SFTPAttributes]:
@@ -789,24 +819,193 @@ class SFTPHookAsync(BaseHook):
ssh_client_conn = await asyncssh.connect(**conn_config)
return ssh_client_conn
- async def list_directory(self, path: str = "") -> list[str] | None: #
type: ignore[return]
- """Return a list of files on the SFTP server at the provided path."""
+ async def retrieve_file(
+ self,
+ remote_full_path: str,
+ local_full_path: str | os.PathLike[str] | IO[bytes],
+ chunk_size: int = CHUNK_SIZE,
+ ) -> None:
+ """
+ Transfer the remote file to a local location asynchronously.
+
+ If local_full_path is a string or PathLike path, the file will be put
at that location.
+ If it is a BytesIO or other binary file-like object, the file will be
streamed into it.
+
+ :param remote_full_path: Full path to the remote file.
+ :param local_full_path: Full path to the local file or a binary
file-like buffer.
+ :param chunk_size: Size of chunks to read at a time (default: 64KB).
+ """
+ async with await self._get_conn() as ssh_conn:
+ async with ssh_conn.start_sftp_client() as sftp:
+ async with sftp.open(remote_full_path, "rb") as remote_file:
+ if isinstance(local_full_path, (str, os.PathLike)):
+ async with aiofiles.open(local_full_path, "wb") as f:
+ while True:
+ chunk = await remote_file.read(chunk_size)
+ if not chunk:
+ break
+ await f.write(cast("bytes", chunk))
+ else:
+ while True:
+ chunk = await remote_file.read(chunk_size)
+ if not chunk:
+ break
+ local_full_path.write(cast("bytes", chunk))
+ if hasattr(local_full_path, "seek"):
+ local_full_path.seek(0)
+
+ async def store_file(
+ self, remote_full_path: str, local_full_path: str | os.PathLike[str] |
IO[bytes]
+ ) -> None:
+ """
+ Transfer a local file to the remote location.
+
+ If ``local_full_path`` is a path, the file will be read from that
location.
+ If it is a binary file-like object, the content will be uploaded from
the stream.
+ Raw ``bytes`` values are not accepted directly; wrap bytes in
``BytesIO``.
+
+ Parent directories for ``remote_full_path`` are created when missing.
+
+ :param remote_full_path: full path to the remote file
+ :param local_full_path: full path to the local file or a binary
file-like buffer
+ """
+ if isinstance(local_full_path, bytes):
+ raise TypeError("Unsupported type for local_full_path: bytes. Wrap
raw bytes in BytesIO.")
+
async with await self._get_conn() as ssh_conn:
- sftp_client = await ssh_conn.start_sftp_client()
+ async with ssh_conn.start_sftp_client() as sftp:
+ with suppress(asyncssh.SFTPFailure):
+ remote_path = PurePosixPath(remote_full_path)
+ await sftp.makedirs(str(remote_path.parent))
+
+ if isinstance(local_full_path, (str, os.PathLike)):
+ await sftp.put(str(local_full_path), remote_full_path)
+ elif hasattr(local_full_path, "read"):
+ async with sftp.open(remote_full_path, "wb") as f:
+ stream = local_full_path
+ if hasattr(stream, "seek"):
+ stream.seek(0)
+ data = stream.read()
+ await f.write(data)
+ else:
+ raise TypeError(
+ f"Unsupported type for local_full_path:
{type(local_full_path)}. "
+ "Expected a binary file-like object or a path-like
object."
+ )
+
+ async def mkdir(self, path: str) -> None:
+ """
+ Create a directory on the remote system asynchronously.
+
+ The default permissions are determined by the server. Parent
directories are created as needed.
+
+ :param path: Full path to the remote directory to create.
+ """
+ async with await self._get_conn() as ssh_conn:
+ async with ssh_conn.start_sftp_client() as sftp:
+ await sftp.makedirs(path)
+
+ async def list_directory(self, path: str = "", recursive: bool = False) ->
list[str] | None:
+ """
+ List files in a directory on the remote system asynchronously.
+
+ Lists one-level entry names under the given directory path.
+
+ If ``recursive=True``, returns files recursively as paths relative to
``path``.
+
+ :param path: Full path to the remote directory to list.
+ :param recursive: Whether to recursively list descendants.
+ :return: List of entry names found under the directory, or None if the
directory does not exist.
+ """
+ if recursive:
+ files: list[str] = []
+
+ def append_relative(item: str) -> None:
+ files.append(posixpath.relpath(item, path))
+
try:
- files = await sftp_client.listdir(path)
- return sorted(files)
+ await self.walktree(
+ path=path,
+ fcallback=append_relative,
+ dcallback=lambda _: None,
+ ucallback=lambda _: None,
+ )
except asyncssh.SFTPNoSuchFile:
return None
+ return sorted(files)
+
+ async with await self._get_conn() as ssh_conn:
+ async with ssh_conn.start_sftp_client() as sftp:
+ try:
+ entries = await sftp.readdir(path)
+ except asyncssh.SFTPNoSuchFile:
+ return None
+ return sorted(os.fsdecode(entry.filename) for entry in entries)
+
+ return None
+
+ async def walktree(
+ self,
+ path: str,
+ fcallback: Callable[[str], Any | None],
+ dcallback: Callable[[str], Any | None],
+ ucallback: Callable[[str], Any | None],
+ recurse: bool = True,
+ ) -> None:
+ """
+ Recursively descend, depth first, the directory tree at ``path``.
+
+ This mirrors :meth:`SFTPHook.walktree` contract and calls callback
functions for
+ regular files, directories, and unknown file types.
+ """
+ async with await self._get_conn() as ssh_conn:
+ async with ssh_conn.start_sftp_client() as sftp:
+ visited_dirs: set[str] = set()
+
+ async def _canonical_dir(dir_path: str) -> str:
+ with suppress(asyncssh.SFTPError):
+ return os.fsdecode(await sftp.realpath(dir_path))
+ return posixpath.normpath(dir_path)
+
+ async def _walk(dir_path: str) -> None:
+ canonical_dir = await _canonical_dir(dir_path)
+ if canonical_dir in visited_dirs:
+ return
+ visited_dirs.add(canonical_dir)
+
+ try:
+ entries = await sftp.readdir(dir_path)
+ except asyncssh.SFTPNoSuchFile:
+ # Directory may disappear mid-walk on busy drops; skip
and continue.
+ return
+
+ for entry in sorted(entries, key=lambda file:
os.fsdecode(file.filename)):
+ filename = os.fsdecode(entry.filename)
+ if filename in {".", ".."}:
+ continue
+
+ pathname = posixpath.join(dir_path, filename)
+ permissions = entry.attrs.permissions
+
+ if permissions is not None and
stat.S_ISDIR(permissions):
+ dcallback(pathname)
+ if recurse:
+ await _walk(pathname)
+ elif permissions is not None and
stat.S_ISREG(permissions):
+ fcallback(pathname)
+ else:
+ ucallback(pathname)
+
+ await _walk(path)
async def read_directory(self, path: str = "") ->
Sequence[asyncssh.sftp.SFTPName] | None: # type: ignore[return]
"""Return a list of files along with their attributes on the SFTP
server at the provided path."""
async with await self._get_conn() as ssh_conn:
- sftp_client = await ssh_conn.start_sftp_client()
- try:
- return await sftp_client.readdir(path)
- except asyncssh.SFTPNoSuchFile:
- return None
+ async with ssh_conn.start_sftp_client() as sftp:
+ try:
+ return await sftp.readdir(path)
+ except asyncssh.SFTPNoSuchFile:
+ return None
async def get_files_and_attrs_by_pattern(
self, path: str = "", fnmatch_pattern: str = ""
@@ -832,12 +1031,12 @@ class SFTPHookAsync(BaseHook):
:param path: full path to the remote file
"""
async with await self._get_conn() as ssh_conn:
- try:
- sftp_client = await ssh_conn.start_sftp_client()
- ftp_mdtm = await sftp_client.stat(path)
- modified_time = ftp_mdtm.mtime
- mod_time =
datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") #
type: ignore[arg-type]
- self.log.info("Found File %s last modified: %s", str(path),
str(mod_time))
- return mod_time
- except asyncssh.SFTPNoSuchFile:
- raise AirflowException("No files matching")
+ async with ssh_conn.start_sftp_client() as sftp:
+ try:
+ ftp_mdtm = await sftp.stat(path)
+ modified_time = ftp_mdtm.mtime
+ mod_time =
datetime.datetime.fromtimestamp(modified_time).strftime("%Y%m%d%H%M%S") #
type: ignore[arg-type]
+ self.log.info("Found File %s last modified: %s",
str(path), str(mod_time))
+ return mod_time
+ except asyncssh.SFTPNoSuchFile:
+ raise AirflowException("No files matching")
diff --git a/providers/sftp/tests/conftest.py
b/providers/sftp/src/airflow/providers/sftp/pools/__init__.py
similarity index 90%
copy from providers/sftp/tests/conftest.py
copy to providers/sftp/src/airflow/providers/sftp/pools/__init__.py
index f56ccce0a3f..13a83393a91 100644
--- a/providers/sftp/tests/conftest.py
+++ b/providers/sftp/src/airflow/providers/sftp/pools/__init__.py
@@ -14,6 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-pytest_plugins = "tests_common.pytest_plugin"
diff --git a/providers/sftp/src/airflow/providers/sftp/pools/sftp.py
b/providers/sftp/src/airflow/providers/sftp/pools/sftp.py
new file mode 100644
index 00000000000..1115b4e19b4
--- /dev/null
+++ b/providers/sftp/src/airflow/providers/sftp/pools/sftp.py
@@ -0,0 +1,282 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import os
+from contextlib import asynccontextmanager, suppress
+from dataclasses import dataclass, field
+from threading import Lock
+from typing import TYPE_CHECKING
+from weakref import WeakKeyDictionary
+
+from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+from airflow.utils.log.logging_mixin import LoggingMixin
+
+if TYPE_CHECKING:
+ import asyncssh
+
+
+@dataclass
+class _LoopState:
+ """Per-event-loop state for SFTP client pool."""
+
+ idle: asyncio.LifoQueue = field(default_factory=asyncio.LifoQueue)
+ in_use: set[tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]] =
field(default_factory=set)
+ semaphore: asyncio.Semaphore | None = None
+ init_lock: asyncio.Lock | None = None
+ initialized: bool = False
+ closed: bool = False
+
+
+class SFTPClientPool(LoggingMixin):
+ """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and
SFTP clients alive until exit, and limits concurrent usage to pool_size."""
+
+ _instances: dict[str, SFTPClientPool] = {}
+ _lock = Lock()
+ _create_connection_max_retries = 2
+ _create_connection_retry_base_delay = 0.2
+ _create_connection_retry_max_delay = 1.0
+
+ @staticmethod
+ def _resolve_pool_size(pool_size: int | None) -> int:
+ resolved_pool_size = (os.cpu_count() or 1) if pool_size is None else
pool_size
+ if resolved_pool_size < 1:
+ raise ValueError(f"pool_size must be greater than or equal to 1,
got {resolved_pool_size}.")
+ return resolved_pool_size
+
+ def __new__(cls, sftp_conn_id: str, pool_size: int | None = None):
+ with cls._lock:
+ if sftp_conn_id not in cls._instances:
+ instance = super().__new__(cls)
+ instance._pre_init(sftp_conn_id, pool_size)
+ cls._instances[sftp_conn_id] = instance
+ else:
+ instance = cls._instances[sftp_conn_id]
+ if pool_size is not None and pool_size != instance.pool_size:
+ instance.log.debug(
+ "SFTPClientPool for sftp_conn_id '%s' is already
initialized with "
+ "pool_size=%d; ignoring requested pool_size=%d and
reusing the "
+ "existing singleton.",
+ sftp_conn_id,
+ instance.pool_size,
+ pool_size,
+ )
+ return cls._instances[sftp_conn_id]
+
+ def __init__(self, sftp_conn_id: str, pool_size: int | None = None):
+ # Prevent parent __init__ argument errors
+ pass
+
+ def _pre_init(self, sftp_conn_id: str, pool_size: int | None):
+ """Initialize the singleton synchronously, deferring asyncio
primitives to the active event loop."""
+ LoggingMixin.__init__(self)
+ self.sftp_conn_id = sftp_conn_id
+ self.pool_size = self._resolve_pool_size(pool_size)
+ self._loop_states: WeakKeyDictionary[asyncio.AbstractEventLoop,
_LoopState] = WeakKeyDictionary()
+ self._loop_states_lock = Lock()
+ self.log.info("SFTPClientPool with size %d initialised...",
self.pool_size)
+
+ def _get_loop_state(self) -> _LoopState:
+ """Get or create the state container for the current event loop."""
+ running_loop = asyncio.get_running_loop()
+ with self._loop_states_lock:
+ state = self._loop_states.get(running_loop)
+ if state is None:
+ state = _LoopState(
+ semaphore=asyncio.Semaphore(self.pool_size),
+ init_lock=asyncio.Lock(),
+ )
+ self._loop_states[running_loop] = state
+ return state
+
+ async def _ensure_initialized(self):
+ """Ensure pool primitives exist for the current loop and the pool is
open."""
+ state = self._get_loop_state()
+ if state.init_lock is None:
+ raise RuntimeError("SFTPClientPool init lock is not initialized")
+
+ if state.initialized and not state.closed:
+ return
+
+ async with state.init_lock:
+ if not state.initialized or state.closed:
+ self.log.info(
+ "Initializing / resetting SFTPClientPool for '%s' with
size %d",
+ self.sftp_conn_id,
+ self.pool_size,
+ )
+ state.idle = asyncio.LifoQueue()
+ state.in_use.clear()
+ state.closed = False
+ state.initialized = True
+
+ async def _create_connection(
+ self,
+ ) -> tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]:
+ ssh_conn = await
SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)._get_conn()
+ sftp = await ssh_conn.start_sftp_client()
+ self.log.info("Created new SFTP connection for sftp_conn_id '%s'",
self.sftp_conn_id)
+ return ssh_conn, sftp
+
+ async def _create_connection_with_retry(
+ self,
+ ) -> tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]:
+ max_attempts = self._create_connection_max_retries + 1
+ for attempt in range(1, max_attempts + 1):
+ try:
+ return await self._create_connection()
+ except Exception as exc:
+ if attempt >= max_attempts:
+ self.log.warning(
+ "Failed creating SFTP connection for '%s' after %d
attempts: %s",
+ self.sftp_conn_id,
+ max_attempts,
+ exc,
+ )
+ raise
+
+ delay = min(
+ self._create_connection_retry_base_delay * (2 ** (attempt
- 1)),
+ self._create_connection_retry_max_delay,
+ )
+ self.log.warning(
+ "Failed creating SFTP connection for '%s' (attempt %d/%d):
%s. Retrying in %.2fs",
+ self.sftp_conn_id,
+ attempt,
+ max_attempts,
+ exc,
+ delay,
+ )
+ await asyncio.sleep(delay)
+
+ # Unreachable, but keeps type checkers happy.
+ raise RuntimeError("Unable to create SFTP connection")
+
+ async def acquire(self):
+ await self._ensure_initialized()
+ state = self._get_loop_state()
+
+ if state.closed:
+ raise RuntimeError("Cannot acquire from a closed SFTPClientPool")
+
+ if state.semaphore is None:
+ raise RuntimeError("SFTPClientPool is not initialized")
+
+ self.log.debug("Acquiring SFTP connection for '%s'", self.sftp_conn_id)
+
+ await state.semaphore.acquire()
+
+ try:
+ try:
+ pair = state.idle.get_nowait()
+ except asyncio.QueueEmpty:
+ pair = await self._create_connection_with_retry()
+
+ state.in_use.add(pair)
+ return pair
+ except Exception:
+ state.semaphore.release()
+ raise
+
+ def _close_connection_pair(self, pair) -> None:
+ ssh, sftp = pair
+ with suppress(Exception):
+ sftp.exit()
+ with suppress(Exception):
+ ssh.close()
+
+ async def _release_pair(self, pair, state: _LoopState, *, faulty: bool) ->
None:
+ if pair not in state.in_use:
+ self.log.warning("Attempted to release unknown or already released
connection")
+ return
+
+ if state.semaphore is None:
+ raise RuntimeError("SFTPClientPool is not initialized")
+
+ state.in_use.discard(pair)
+
+ if faulty or state.closed:
+ self._close_connection_pair(pair)
+ else:
+ await state.idle.put(pair)
+
+ self.log.debug("Releasing SFTP connection for '%s'", self.sftp_conn_id)
+ state.semaphore.release()
+
+ async def release(self, pair):
+ state = self._get_loop_state()
+ await self._release_pair(pair, state, faulty=False)
+
+ @asynccontextmanager
+ async def get_sftp_client(self):
+ await self._ensure_initialized()
+ state = self._get_loop_state()
+ pair = None
+ try:
+ pair = await self.acquire()
+ ssh, sftp = pair
+ yield sftp
+ except asyncio.CancelledError:
+ if pair:
+ await self._release_pair(pair, state, faulty=True)
+ raise
+ except Exception as e:
+ self.log.warning("Dropping faulty connection for '%s': %s",
self.sftp_conn_id, e)
+ if pair:
+ await self._release_pair(pair, state, faulty=True)
+ raise
+ else:
+ await self._release_pair(pair, state, faulty=False)
+
+ async def close(self):
+ """Gracefully shutdown all connections in the pool for the current
event loop."""
+ await self._ensure_initialized()
+ state = self._get_loop_state()
+ if state.init_lock is None:
+ raise RuntimeError("SFTPClientPool is not initialized")
+
+ async with state.init_lock:
+ if state.closed:
+ return
+
+ state.closed = True
+
+ self.log.info("Closing all SFTP connections for '%s'",
self.sftp_conn_id)
+
+ while not state.idle.empty():
+ pair = await state.idle.get()
+ self._close_connection_pair(pair)
+
+ active_in_use = len(state.in_use)
+ for pair in list(state.in_use):
+ self._close_connection_pair(pair)
+ state.in_use.discard(pair)
+
+ if active_in_use:
+ self.log.warning("Pool closed with %d active connections",
active_in_use)
+
+ async def __aenter__(self):
+ await self._ensure_initialized()
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ # Intentionally a no-op: this pool is a process-wide singleton, so
+ # exiting a single `async with` block must not close it for all other
+ # concurrent users. Call `close()` explicitly when you truly want to
+ # shut down all connections for the current event loop.
+ pass
diff --git a/providers/sftp/tests/conftest.py b/providers/sftp/tests/conftest.py
index f56ccce0a3f..352f17c22a3 100644
--- a/providers/sftp/tests/conftest.py
+++ b/providers/sftp/tests/conftest.py
@@ -16,4 +16,36 @@
# under the License.
from __future__ import annotations
+from collections.abc import Generator
+from typing import TYPE_CHECKING, Any
+from unittest.mock import AsyncMock, patch
+
+import pytest
+from asyncssh import SFTPClient, SSHClientConnection
+
pytest_plugins = "tests_common.pytest_plugin"
+
+if TYPE_CHECKING:
+ from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+
+
[email protected]
+def sftp_hook_mocked() -> Generator[tuple[SFTPHookAsync, SFTPClient], Any,
None]:
+ """
+ Fixture that mocks SFTPHookAsync._get_conn with SSH + SFTP async mocks.
+ Returns a tuple (hook, sftp_client_mock) so tests can easily set readdir.
+ """
+ from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+
+ sftp_client_mock = AsyncMock(spec=SFTPClient)
+ sftp_client_mock.readdir.return_value = []
+
+ client_connection_mock = AsyncMock(spec=SSHClientConnection)
+ sftp_cm_mock = client_connection_mock.start_sftp_client.return_value
+ sftp_cm_mock.__aenter__ = AsyncMock(return_value=sftp_client_mock)
+ sftp_cm_mock.__aexit__ = AsyncMock(return_value=None)
+
+ with patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") as
mock_get_conn:
+ mock_get_conn.return_value.__aenter__.return_value =
client_connection_mock
+
+ yield SFTPHookAsync(), sftp_cm_mock
diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
index 835bfcc0221..902a191041b 100644
--- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
+++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
@@ -21,8 +21,9 @@ import datetime
import json
import os
import shutil
+import stat
from io import BytesIO, StringIO
-from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
+from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
import paramiko
import pytest
@@ -199,13 +200,41 @@ class TestSFTPHook:
def test_list_directory(self):
output = self.hook.list_directory(path=os.path.join(self.temp_dir,
TMP_DIR_FOR_TESTS))
+ assert output is not None
assert output == [SUB_DIR, FIFO_FOR_TESTS]
- def test_list_directory_with_attr(self):
- output =
self.hook.list_directory_with_attr(path=os.path.join(self.temp_dir,
TMP_DIR_FOR_TESTS))
- file_names = [f.filename for f in output]
- assert all(isinstance(f, paramiko.SFTPAttributes) for f in output)
- assert sorted(file_names) == [SUB_DIR, FIFO_FOR_TESTS]
+ def test_list_directory_non_recursive_default(self):
+ """Assert that default list_directory returns one-level entry names
only."""
+ output = self.hook.list_directory(
+ path=os.path.join(self.temp_dir, TMP_DIR_FOR_TESTS),
recursive=False
+ )
+ assert output is not None
+ assert output == [SUB_DIR, FIFO_FOR_TESTS]
+
+ def test_list_directory_recursive_returns_relative_file_entries(self):
+ """Assert that recursive list_directory returns file-only relative
paths."""
+ # The temp directory structure is:
+ # - TMP_DIR_FOR_TESTS/
+ # - SUB_DIR/
+ # - TMP_FILE_FOR_TESTS (file)
+ # - FIFO_FOR_TESTS (fifo)
+ output = self.hook.list_directory(path=os.path.join(self.temp_dir,
TMP_DIR_FOR_TESTS), recursive=True)
+ assert output is not None
+ # Should only return files, not directories or fifos, with relative
paths
+ expected_files = [os.path.join(SUB_DIR, TMP_FILE_FOR_TESTS)]
+ assert sorted(output) == sorted(expected_files)
+
+ def test_list_directory_path_does_not_exist(self):
+ """Assert that list_directory returns None when path does not exist."""
+ non_existent_path = os.path.join(self.temp_dir, "non_existent_dir")
+ output = self.hook.list_directory(path=non_existent_path)
+ assert output is None
+
+ def test_list_directory_recursive_path_does_not_exist(self):
+ """Assert that recursive list_directory returns None when path does
not exist."""
+ non_existent_path = os.path.join(self.temp_dir, "non_existent_dir")
+ output = self.hook.list_directory(path=non_existent_path,
recursive=True)
+ assert output is None
def test_mkdir(self):
new_dir_name = "mk_dir"
@@ -632,7 +661,7 @@ class TestSFTPHook:
host_proxy_cmd=host_proxy_cmd,
)
- with hook.get_managed_conn():
+ with hook.get_managed_conn() as _:
mock_proxy_command.assert_called_once_with(host_proxy_cmd)
mock_ssh_client.return_value.connect.assert_called_once_with(
hostname="example.com",
@@ -947,86 +976,387 @@ class TestSFTPHookAsync:
),
]
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
- async def test_list_directory_path_does_not_exist(self,
mock_hook_get_conn):
+ async def test_list_directory_path_does_not_exist(self, sftp_hook_mocked):
"""
- Assert that AirflowException is raised when path does not exist on
SFTP server
+ Assert that None is returned when path does not exist on SFTP server
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
+ hook, sftp_cm_mock = sftp_hook_mocked
- hook = SFTPHookAsync()
-
- expected_files = None
files = await hook.list_directory(path="/path/does_not/exist/")
- assert files == expected_files
- mock_hook_get_conn.return_value.__aexit__.assert_called()
+ assert not files
+ sftp_cm_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ async def test_list_directory_recursive_path_does_not_exist(self,
sftp_hook_mocked):
+ """Assert that recursive list_directory returns None when path does
not exist."""
+ hook, sftp_cm_mock = sftp_hook_mocked
+
+ files = await hook.list_directory(path="/path/does_not/exist/",
recursive=True)
+ assert not files
+ sftp_cm_mock.__aexit__.assert_awaited()
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
- async def test_read_directory_path_does_not_exist(self,
mock_hook_get_conn):
+ async def test_read_directory_path_does_not_exist(self, sftp_hook_mocked):
"""
Assert that AirflowException is raised when path does not exist on
SFTP server
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
- hook = SFTPHookAsync()
+ hook, sftp_client_mock = sftp_hook_mocked
- expected_files = None
files = await hook.read_directory(path="/path/does_not/exist/")
- assert files == expected_files
- mock_hook_get_conn.return_value.__aexit__.assert_called()
+ assert not files
+ sftp_client_mock.__aexit__.assert_awaited()
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
- async def test_list_directory_path_has_files(self, mock_hook_get_conn):
+ async def test_list_directory_path_has_files(self, sftp_hook_mocked):
"""
Assert that file list is returned when path exists on SFTP server
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
- hook = SFTPHookAsync()
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client_mock.__aenter__.return_value.readdir.return_value = [
+ Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)),
+ Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)),
+ Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)),
+ ]
- expected_files = ["..", ".", "file"]
files = await hook.list_directory(path="/path/exists/")
- assert sorted(files) == sorted(expected_files)
- mock_hook_get_conn.return_value.__aexit__.assert_called()
+ assert files is not None
+ assert sorted(files) == sorted(["..", ".", "file"])
+ sftp_client_mock.__aexit__.assert_awaited()
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
@pytest.mark.asyncio
- async def test_get_file_by_pattern_with_match(self, mock_hook_get_conn):
+ async def test_get_file_by_pattern_with_match(self, sftp_hook_mocked):
"""
Assert that filename is returned when file pattern is matched on SFTP
server
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
- hook = SFTPHookAsync()
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client_mock.__aenter__.return_value.readdir.return_value = [
+ Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)),
+ Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)),
+ Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)),
+ ]
files = await
hook.get_files_and_attrs_by_pattern(path="/path/exists/",
fnmatch_pattern="file")
assert len(files) == 1
assert files[0].filename == "file"
- mock_hook_get_conn.return_value.__aexit__.assert_called()
+ sftp_client_mock.__aexit__.assert_awaited()
@pytest.mark.asyncio
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
- async def test_get_mod_time(self, mock_hook_get_conn):
+ async def test_get_mod_time(self, sftp_hook_mocked):
"""
Assert that file attribute and return the modified time of the file
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
- hook = SFTPHookAsync()
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ mtime = 1667302566 # This is a valid Unix timestamp
+ expected =
datetime.datetime.fromtimestamp(mtime).strftime("%Y%m%d%H%M%S")
+ sftp_client_mock.__aenter__.return_value.stat.return_value =
Mock(spec=SFTPAttrs, mtime=mtime)
mod_time = await hook.get_mod_time("/path/exists/file")
- expected_value =
datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S")
- assert mod_time == expected_value
+ assert mod_time == expected
@pytest.mark.asyncio
- @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn")
- async def test_get_mod_time_exception(self, mock_hook_get_conn):
+ async def test_get_mod_time_exception(self, sftp_hook_mocked):
"""
Assert that get_mod_time raise exception when file does not exist
"""
- mock_hook_get_conn.return_value.__aenter__.return_value =
MockSSHClient()
- hook = SFTPHookAsync()
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client_mock.__aenter__.return_value.stat.side_effect =
SFTPNoSuchFile(
+ reason="File does not exist"
+ )
with pytest.raises(AirflowException) as exc:
await hook.get_mod_time("/path/does_not/exist/")
assert str(exc.value) == "No files matching"
+
+ @pytest.mark.asyncio
+ async def test_mkdir_creates_directory(self, sftp_hook_mocked):
+ """
+ Assert that mkdir calls makedirs on the SFTP client
+ """
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+ sftp_client.makedirs = AsyncMock()
+
+ await hook.mkdir("/remote/newdir")
+ sftp_client.makedirs.assert_awaited_once_with("/remote/newdir")
+ sftp_client_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ @patch("aiofiles.open")
+ async def test_retrieve_file_to_path(self, mock_aiofiles_open,
sftp_hook_mocked):
+ """
+ Assert that retrieve_file writes to a local file using aiofiles
+ """
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+ mock_remote_file = AsyncMock()
+ mock_remote_file.read = AsyncMock(side_effect=[b"abc", b"",
StopAsyncIteration])
+ sftp_client.open.return_value.__aenter__.return_value =
mock_remote_file
+
+ mock_file = AsyncMock()
+ aiofiles_cm = AsyncMock()
+ aiofiles_cm.__aenter__.return_value = mock_file
+ aiofiles_cm.__aexit__.return_value = None
+ mock_aiofiles_open.return_value = aiofiles_cm
+
+ await hook.retrieve_file("/remote/file", "/local/file")
+ sftp_client.open.assert_called_once_with("/remote/file", "rb")
+ mock_file.write.assert_awaited()
+ sftp_client_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ async def test_retrieve_file_to_bytesio(self, sftp_hook_mocked):
+ """
+ Assert that retrieve_file writes to a BytesIO buffer
+ """
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+ mock_remote_file = AsyncMock()
+ mock_remote_file.read = AsyncMock(side_effect=[b"abc", b""])
+ sftp_client.open.return_value.__aenter__.return_value =
mock_remote_file
+ buf = BytesIO()
+
+ await hook.retrieve_file("/remote/file", buf)
+ assert buf.getvalue() == b"abc"
+ sftp_client.open.assert_called_once_with("/remote/file", "rb")
+ sftp_client_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ async def test_store_file_bytesio_creates_parent_directories(self,
sftp_hook_mocked):
+ """Assert that BytesIO uploads create parent dirs before writing
file."""
+ hook, sftp_client_mock = sftp_hook_mocked
+ sftp_client = sftp_client_mock.__aenter__.return_value
+ sftp_client.makedirs = AsyncMock()
+
+ await hook.store_file("/remote/new/dir/file.txt", BytesIO(b"abc"))
+
+ sftp_client.makedirs.assert_awaited_once_with("/remote/new/dir")
+ sftp_client.open.assert_called_once_with("/remote/new/dir/file.txt",
"wb")
+
+ @pytest.mark.asyncio
+ async def test_store_file_path_creates_parent_directories(self,
sftp_hook_mocked):
+ """Assert that local-path uploads create parent dirs before put()."""
+ hook, sftp_client_mock = sftp_hook_mocked
+ sftp_client = sftp_client_mock.__aenter__.return_value
+ sftp_client.makedirs = AsyncMock()
+
+ await hook.store_file("/remote/new/dir/file.txt", "/local/file.txt")
+
+ sftp_client.makedirs.assert_awaited_once_with("/remote/new/dir")
+ sftp_client.put.assert_awaited_once_with("/local/file.txt",
"/remote/new/dir/file.txt")
+
+ @pytest.mark.asyncio
+ async def test_store_file_rejects_raw_bytes(self, sftp_hook_mocked):
+ """Raw bytes must be wrapped in BytesIO to avoid path/content
ambiguity."""
+ hook, _ = sftp_hook_mocked
+
+ with pytest.raises(TypeError, match="Wrap raw bytes in BytesIO"):
+ await hook.store_file("/remote/new/dir/file.txt", b"abc") # type:
ignore[arg-type]
+
+ @pytest.mark.asyncio
+ async def test_walktree_recursive(self, sftp_hook_mocked):
+ """Assert that walktree recursively traverses files and directories."""
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+
+ async def readdir_side_effect(path):
+ if path == "/dir":
+
+ class File:
+ filename = "file1"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ class Subdir:
+ filename = "subdir"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [File(), Subdir()]
+ if path == "/dir/subdir":
+
+ class File:
+ filename = "file2"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ return [File()]
+ return []
+
+ sftp_client.readdir.side_effect = readdir_side_effect
+ sftp_client.realpath.side_effect = lambda path: path
+
+ files: list[str] = []
+ dirs: list[str] = []
+ unknowns: list[str] = []
+
+ await hook.walktree(
+ path="/dir",
+ fcallback=files.append,
+ dcallback=dirs.append,
+ ucallback=unknowns.append,
+ )
+
+ assert sorted(files) == sorted(["/dir/file1", "/dir/subdir/file2"])
+ assert dirs == ["/dir/subdir"]
+ assert unknowns == []
+ sftp_client_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ async def test_walktree_skips_disappearing_subdirectory(self,
sftp_hook_mocked):
+ """walktree should continue when a subdirectory disappears during
traversal."""
+ hook, sftp_client_mock = sftp_hook_mocked
+ sftp_client = sftp_client_mock.__aenter__.return_value
+
+ async def readdir_side_effect(path):
+ if path == "/dir":
+
+ class File:
+ filename = "file1"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ class Subdir:
+ filename = "subdir"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [File(), Subdir()]
+ if path == "/dir/subdir":
+ raise SFTPNoSuchFile("gone")
+ return []
+
+ sftp_client.readdir.side_effect = readdir_side_effect
+ sftp_client.realpath.side_effect = lambda path: path
+
+ files: list[str] = []
+ dirs: list[str] = []
+ unknowns: list[str] = []
+
+ await hook.walktree(
+ path="/dir",
+ fcallback=files.append,
+ dcallback=dirs.append,
+ ucallback=unknowns.append,
+ )
+
+ assert files == ["/dir/file1"]
+ assert dirs == ["/dir/subdir"]
+ assert unknowns == []
+
+ @pytest.mark.asyncio
+ async def test_walktree_avoids_cycle_via_canonical_path(self,
sftp_hook_mocked):
+ """walktree should avoid infinite recursion when two directory paths
resolve to one target."""
+ hook, sftp_client_mock = sftp_hook_mocked
+ sftp_client = sftp_client_mock.__aenter__.return_value
+
+ async def readdir_side_effect(path):
+ if path == "/dir":
+
+ class Link:
+ filename = "link"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [Link()]
+ if path == "/dir/link":
+
+ class LinkBack:
+ filename = "again"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [LinkBack()]
+ return []
+
+ async def realpath_side_effect(path):
+ # Both paths resolve to same canonical target, which should stop
recursion.
+ if path in {"/dir/link", "/dir/link/again"}:
+ return "/dir/link"
+ return path
+
+ sftp_client.readdir.side_effect = readdir_side_effect
+ sftp_client.realpath.side_effect = realpath_side_effect
+
+ dirs: list[str] = []
+ await hook.walktree(
+ path="/dir",
+ fcallback=lambda _: None,
+ dcallback=dirs.append,
+ ucallback=lambda _: None,
+ )
+
+ assert dirs == ["/dir/link", "/dir/link/again"]
+
+ @pytest.mark.asyncio
+ async def test_list_directory_non_recursive(self, sftp_hook_mocked):
+ """Assert that default list_directory returns one-level entry names
only."""
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+
+ async def readdir_side_effect(path):
+ if path == "/dir":
+
+ class File:
+ filename = "file1"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ class Subdir:
+ filename = "subdir"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [File(), Subdir()]
+ if path == "/dir/subdir":
+
+ class File:
+ filename = "file2"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ return [File()]
+ return []
+
+ sftp_client.readdir.side_effect = readdir_side_effect
+
+ files = await hook.list_directory("/dir")
+ assert files is not None
+ assert sorted(files) == sorted(["file1", "subdir"])
+ sftp_client_mock.__aexit__.assert_awaited()
+
+ @pytest.mark.asyncio
+ async def test_list_directory_recursive_returns_relative_entries(self,
sftp_hook_mocked):
+ """Assert that recursive list_directory delegates to walktree and
returns relative paths."""
+ hook, sftp_client_mock = sftp_hook_mocked
+
+ sftp_client = sftp_client_mock.__aenter__.return_value
+
+ async def readdir_side_effect(path):
+ if path == "/dir":
+
+ class File:
+ filename = "file1"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ class Subdir:
+ filename = "subdir"
+ attrs = type("attrs", (), {"permissions": stat.S_IFDIR})
+
+ return [File(), Subdir()]
+ if path == "/dir/subdir":
+
+ class File:
+ filename = "file2"
+ attrs = type("attrs", (), {"permissions": stat.S_IFREG})
+
+ return [File()]
+ return []
+
+ sftp_client.readdir.side_effect = readdir_side_effect
+ sftp_client.realpath.side_effect = lambda path: path
+
+ files = await hook.list_directory("/dir", recursive=True)
+ assert files is not None
+ assert sorted(files) == sorted(["file1", "subdir/file2"])
+ sftp_client_mock.__aexit__.assert_awaited()
diff --git a/providers/sftp/tests/conftest.py
b/providers/sftp/tests/unit/sftp/pools/__init__.py
similarity index 90%
copy from providers/sftp/tests/conftest.py
copy to providers/sftp/tests/unit/sftp/pools/__init__.py
index f56ccce0a3f..13a83393a91 100644
--- a/providers/sftp/tests/conftest.py
+++ b/providers/sftp/tests/unit/sftp/pools/__init__.py
@@ -14,6 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from __future__ import annotations
-
-pytest_plugins = "tests_common.pytest_plugin"
diff --git a/providers/sftp/tests/unit/sftp/pools/test_sftp.py
b/providers/sftp/tests/unit/sftp/pools/test_sftp.py
new file mode 100644
index 00000000000..98125f47198
--- /dev/null
+++ b/providers/sftp/tests/unit/sftp/pools/test_sftp.py
@@ -0,0 +1,372 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+
+import pytest
+
+from airflow.providers.sftp.pools.sftp import SFTPClientPool
+
+
+class TestSFTPClientPool:
+ @pytest.fixture(autouse=True)
+ def cleanup_singleton(self):
+ """Clear SFTPClientPool._instances before and after each test to
ensure test isolation."""
+ # Clear before test
+ SFTPClientPool._instances.clear()
+ yield
+ # Clear after test
+ SFTPClientPool._instances.clear()
+
+ @pytest.mark.asyncio
+ async def test_acquire_and_release(self, sftp_hook_mocked):
+ async with SFTPClientPool("test_conn", pool_size=2) as pool:
+ ssh, sftp = await pool.acquire()
+ assert ssh is not None
+ assert sftp is not None
+
+ await pool.release((ssh, sftp))
+ ssh2, sftp2 = await pool.acquire()
+ assert ssh2 is not None
+ assert sftp2 is not None
+ await pool.release((ssh2, sftp2))
+
+ @pytest.mark.asyncio
+ async def test_acquire_blocks_when_pool_full(self, sftp_hook_mocked):
+ async with SFTPClientPool("blocking_conn", pool_size=2) as pool:
+ first = await pool.acquire()
+ second = await pool.acquire()
+
+ # Third acquire should block until one of the held connections is
released.
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(pool.acquire(), timeout=0.1)
+
+ await pool.release(first)
+ third = await asyncio.wait_for(pool.acquire(), timeout=1)
+
+ await pool.release(second)
+ await pool.release(third)
+
+ @pytest.mark.asyncio
+ async def test_get_sftp_client_context_manager(self, sftp_hook_mocked,
mocker):
+ async with SFTPClientPool("test_conn", pool_size=1) as pool:
+ release_spy = mocker.spy(pool, "_release_pair")
+ async with pool.get_sftp_client() as sftp:
+ assert sftp is not None
+
+ # If the context manager releases correctly, the single slot can
be acquired again.
+ ssh2, sftp2 = await asyncio.wait_for(pool.acquire(), timeout=1)
+ assert ssh2 is not None
+ assert sftp2 is not None
+ await pool.release((ssh2, sftp2))
+
+ assert any(call.kwargs.get("faulty") is False for call in
release_spy.call_args_list)
+
+ @pytest.mark.asyncio
+ async def test_get_sftp_client_marks_connection_faulty_on_exception(self,
sftp_hook_mocked, mocker):
+ pool = SFTPClientPool("faulty_conn", pool_size=1)
+ release_spy = mocker.spy(pool, "_release_pair")
+
+ with pytest.raises(ValueError, match="boom"):
+ async with pool.get_sftp_client():
+ raise ValueError("boom")
+
+ assert any(call.kwargs.get("faulty") is True for call in
release_spy.call_args_list)
+
+ @pytest.mark.asyncio
+ async def
test_get_sftp_client_marks_connection_faulty_on_cancellation(self,
sftp_hook_mocked, mocker):
+ pool = SFTPClientPool("cancel_conn", pool_size=1)
+ release_spy = mocker.spy(pool, "_release_pair")
+
+ with pytest.raises(asyncio.CancelledError):
+ async with pool.get_sftp_client():
+ raise asyncio.CancelledError()
+
+ assert any(call.kwargs.get("faulty") is True for call in
release_spy.call_args_list)
+
+ @pytest.mark.asyncio
+ async def test_acquire_failure_releases_semaphore(self, sftp_hook_mocked,
monkeypatch):
+ from airflow.providers.sftp.hooks.sftp import SFTPHookAsync
+
+ orig_get_conn = SFTPHookAsync._get_conn
+
+ async def fail_get_conn(self):
+ raise Exception("fail")
+
+ monkeypatch.setattr(SFTPHookAsync, "_get_conn", fail_get_conn)
+
+ async with SFTPClientPool("test_conn", pool_size=2) as pool:
+ monkeypatch.setattr(pool, "_create_connection_retry_base_delay", 0)
+ monkeypatch.setattr(pool, "_create_connection_retry_max_delay", 0)
+ with pytest.raises(Exception, match="fail"):
+ await pool.acquire()
+
+ monkeypatch.setattr(SFTPHookAsync, "_get_conn", orig_get_conn)
+ ssh, sftp = await pool.acquire()
+ assert ssh is not None
+ assert sftp is not None
+ await pool.release((ssh, sftp))
+
+ @pytest.mark.asyncio
+ async def test_close(self, sftp_hook_mocked, mocker):
+ pool = SFTPClientPool("test_conn", pool_size=2)
+ close_spy = mocker.spy(pool, "close")
+
+ async with pool:
+ ssh, sftp = await pool.acquire()
+ await pool.release((ssh, sftp))
+
+ # __aexit__ must NOT call close() — the pool is a process-wide
singleton
+ assert close_spy.call_count == 0
+
+ # close() works when called explicitly
+ await pool.close()
+ assert close_spy.call_count == 1
+
+ @pytest.mark.asyncio
+ async def test_aexit_does_not_close_pool(self, sftp_hook_mocked):
+ """Exiting async-with must not close the singleton for other
concurrent users."""
+ pool = SFTPClientPool("no_close_conn", pool_size=2)
+ async with pool:
+ pass # exit block
+
+ state = pool._get_loop_state()
+ assert not state.closed
+
+ @pytest.mark.asyncio
+ async def test_close_warns_when_active_connections_exist(self,
sftp_hook_mocked):
+ class DummySSH:
+ def __init__(self):
+ self.closed = False
+
+ def close(self):
+ self.closed = True
+
+ class DummySFTP:
+ def __init__(self):
+ self.exited = False
+
+ def exit(self):
+ self.exited = True
+
+ pool = SFTPClientPool("warn_conn", pool_size=2)
+ await pool._ensure_initialized()
+ state = pool._get_loop_state()
+ ssh = DummySSH()
+ sftp = DummySFTP()
+ pair = (ssh, sftp)
+ state.in_use.add(pair)
+
+ await pool.close()
+
+ assert pair not in state.in_use
+ assert ssh.closed is True
+ assert sftp.exited is True
+
+ def test_pool_size_consistency_validation(self, sftp_hook_mocked):
+ """Second construction with a different pool_size keeps the original
singleton configuration."""
+ pool1 = SFTPClientPool("consistent_conn", pool_size=2)
+ pool2 = SFTPClientPool("consistent_conn", pool_size=5)
+
+ assert pool2 is pool1
+ assert pool2.pool_size == 2
+
+ def test_pool_size_consistency_with_default(self, sftp_hook_mocked,
monkeypatch):
+ """Default-first construction keeps its pool_size even if later calls
pass another value."""
+ monkeypatch.setattr("airflow.providers.sftp.pools.sftp.os.cpu_count",
lambda: 3)
+
+ pool1 = SFTPClientPool("default_conn")
+ pool2 = SFTPClientPool("default_conn", pool_size=10)
+
+ assert pool2 is pool1
+ assert pool2.pool_size == 3
+
+ def test_pool_size_defaults_to_one_when_cpu_count_is_none(self,
sftp_hook_mocked, monkeypatch):
+ """Default pool size falls back to 1 when os.cpu_count() is
unavailable."""
+ monkeypatch.setattr("airflow.providers.sftp.pools.sftp.os.cpu_count",
lambda: None)
+
+ pool = SFTPClientPool("none_cpu_conn")
+ assert pool.pool_size == 1
+
+ def test_pool_size_consistency_same_pool_size(self, sftp_hook_mocked):
+ """Test that creating a pool with same pool_size for same conn_id
succeeds."""
+ # Create first instance with pool_size=4
+ pool1 = SFTPClientPool("same_pool_conn", pool_size=4)
+ assert pool1.pool_size == 4
+
+ # Create another instance with same conn_id and same pool_size should
succeed
+ pool2 = SFTPClientPool("same_pool_conn", pool_size=4)
+ assert pool2 is pool1 # Should be the same instance (singleton)
+ assert pool2.pool_size == 4
+
+ def test_pool_size_is_ignored_on_reuse_even_if_invalid(self,
sftp_hook_mocked):
+ """Subsequent constructions ignore pool_size and keep first singleton
configuration."""
+ pool1 = SFTPClientPool("reuse_invalid_conn", pool_size=2)
+
+ # This would be invalid for first construction, but must be ignored on
reuse.
+ pool2 = SFTPClientPool("reuse_invalid_conn", pool_size=0)
+
+ assert pool2 is pool1
+ assert pool2.pool_size == 2
+
+ def test_pool_works_across_separate_asyncio_run_calls(self,
sftp_hook_mocked):
+ """Regression: pool must not raise when used from two separate
asyncio.run() calls.
+
+ Each ``asyncio.run()`` creates and then destroys its own event loop.
The singleton
+ pool must lazily create fresh asyncio primitives for each new loop
rather than
+ reusing primitives that are bound to the now-dead previous loop.
Crucially,
+ ``async with`` must NOT close the pool between the two calls.
+ """
+ results: list[bool] = []
+
+ async def use_pool_once():
+ pool = SFTPClientPool("multi_loop_conn", pool_size=2)
+ async with pool.get_sftp_client() as sftp:
+ results.append(sftp is not None)
+
+ # First run — creates a new event loop, uses the pool, then shuts the
loop down.
+ asyncio.run(use_pool_once())
+ # Second run — creates a *different* event loop; the pool must not
reuse the
+ # primitives from the first (now-closed) loop.
+ asyncio.run(use_pool_once())
+
+ assert results == [True, True]
+
+ def test_pool_singleton_is_preserved_across_asyncio_run_calls(self,
sftp_hook_mocked):
+ """The same pool instance (singleton) must be returned for the same
conn_id
+ even when called from successive asyncio.run() invocations."""
+ pools: list[SFTPClientPool] = []
+
+ async def capture_pool():
+ pools.append(SFTPClientPool("singleton_loop_conn", pool_size=1))
+
+ asyncio.run(capture_pool())
+ asyncio.run(capture_pool())
+
+ assert pools[0] is pools[1]
+
+ def test_pool_per_loop_state_isolation(self, sftp_hook_mocked):
+ """Regression: two event loops must maintain isolated per-loop state
in same singleton pool."""
+ loop_states: list[tuple[asyncio.AbstractEventLoop, int]] = []
+
+ async def capture_loop_state():
+ pool = SFTPClientPool("isolated_conn", pool_size=2)
+ state = pool._get_loop_state()
+ loop = asyncio.get_running_loop()
+ # Acquire and hold a connection to track state per loop
+ con = await pool.acquire()
+ loop_states.append((loop, len(state.in_use)))
+ await pool.release(con)
+
+ asyncio.run(capture_loop_state())
+ asyncio.run(capture_loop_state())
+
+ # Both runs should have recorded state, but from different loops
+ assert len(loop_states) == 2
+ loop1, count1 = loop_states[0]
+ loop2, count2 = loop_states[1]
+ assert loop1 is not loop2
+ assert count1 == 1 # One connection in use during acquire
+ assert count2 == 1 # Same state captured in the second loop's context
+
+ def test_connection_reuse_within_same_loop(self, sftp_hook_mocked):
+ """Verify that connections acquired and released in the same loop are
reused."""
+
+ async def acquire_release_reuse():
+ pool = SFTPClientPool("reuse_conn", pool_size=1)
+ # First acquire
+ con1 = await pool.acquire()
+ await pool.release(con1)
+ # Second acquire should get the same connection back (from idle
queue)
+ con2 = await pool.acquire()
+ assert con1 is con2 # Same object reused
+ await pool.release(con2)
+
+ asyncio.run(acquire_release_reuse())
+
+ def test_no_semaphore_leak_on_cross_loop_release(self, sftp_hook_mocked):
+ """Regression: releasing a connection acquired in one loop but
released in another
+ should not leak semaphore permits or raise errors.
+
+ Since asyncio primitives are per-loop, we expect release to work on
the current loop's state only.
+ """
+ acquired_pair = None
+
+ async def acquire_in_loop1():
+ nonlocal acquired_pair
+ pool = SFTPClientPool("leak_test_conn", pool_size=1)
+ acquired_pair = await pool.acquire()
+
+ async def release_in_loop2():
+ # This will be called with a different event loop
+ pool = SFTPClientPool("leak_test_conn", pool_size=1)
+ state = pool._get_loop_state()
+ # The pair was acquired in loop1, but we're now in loop2
+ # The loop2 state should show empty in_use (since the pair was
acquired in loop1's state)
+ # Attempting to release should not crash but log a warning
+ if acquired_pair:
+ await pool.release(acquired_pair)
+ # Verify the new loop's state is clean
+ assert len(state.in_use) == 0
+
+ asyncio.run(acquire_in_loop1())
+ asyncio.run(release_in_loop2())
+
+ @pytest.mark.asyncio
+ async def test_create_connection_retries_then_succeeds(self,
sftp_hook_mocked, monkeypatch):
+ pool = SFTPClientPool("retry_success_conn", pool_size=1)
+ await pool._ensure_initialized()
+
+ call_count = 0
+
+ async def flaky_create_connection():
+ nonlocal call_count
+ call_count += 1
+ if call_count == 1:
+ raise ConnectionError("transient")
+ return "ssh", "sftp"
+
+ monkeypatch.setattr(pool, "_create_connection",
flaky_create_connection)
+ monkeypatch.setattr(pool, "_create_connection_retry_base_delay", 0)
+ monkeypatch.setattr(pool, "_create_connection_retry_max_delay", 0)
+
+ pair = await pool.acquire()
+ assert pair == ("ssh", "sftp")
+ assert call_count == 2
+ await pool.release(pair)
+
+ @pytest.mark.asyncio
+ async def test_create_connection_retries_exhaust_and_releases_permit(self,
sftp_hook_mocked, monkeypatch):
+ pool = SFTPClientPool("retry_fail_conn", pool_size=1)
+ await pool._ensure_initialized()
+ state = pool._get_loop_state()
+
+ async def always_fail_create_connection():
+ raise ConnectionError("still failing")
+
+ monkeypatch.setattr(pool, "_create_connection",
always_fail_create_connection)
+ monkeypatch.setattr(pool, "_create_connection_retry_base_delay", 0)
+ monkeypatch.setattr(pool, "_create_connection_retry_max_delay", 0)
+ monkeypatch.setattr(pool, "_create_connection_max_retries", 1)
+
+ with pytest.raises(ConnectionError, match="still failing"):
+ await pool.acquire()
+
+ # Ensure permit is returned after failure so a later acquire can
proceed.
+ assert state.semaphore is not None
+ assert state.semaphore._value == pool.pool_size
diff --git a/uv.lock b/uv.lock
index 76d632fe2d4..30e362221b8 100644
--- a/uv.lock
+++ b/uv.lock
@@ -7113,6 +7113,7 @@ name = "apache-airflow-providers-sftp"
version = "5.7.4"
source = { editable = "providers/sftp" }
dependencies = [
+ { name = "aiofiles" },
{ name = "apache-airflow" },
{ name = "apache-airflow-providers-common-compat" },
{ name = "apache-airflow-providers-ssh" },
@@ -7143,6 +7144,7 @@ docs = [
[package.metadata]
requires-dist = [
+ { name = "aiofiles", specifier = ">=23.2.0" },
{ name = "apache-airflow", editable = "." },
{ name = "apache-airflow-providers-common-compat", editable =
"providers/common/compat" },
{ name = "apache-airflow-providers-openlineage", marker = "extra ==
'openlineage'", editable = "providers/openlineage" },