This is an automated email from the ASF dual-hosted git repository.

bolke pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ab87cd02e4 Pass conn ID to ObjectStoragePath via URI (#35913)
ab87cd02e4 is described below

commit ab87cd02e4080df987a123d56cd9fa8c393a980d
Author: Tzu-ping Chung <uranu...@gmail.com>
AuthorDate: Fri Dec 1 21:03:10 2023 +0800

    Pass conn ID to ObjectStoragePath via URI (#35913)
    
    This enables an alternative ObjectStoragePath init syntax, using the
    auth section in the URI to supply conn ID instead of a separate keyword
    argument. The explicit keyword argument is honored if supplied.
---
 airflow/example_dags/tutorial_objectstorage.py     |  2 +-
 airflow/io/path.py                                 | 10 +++++++-
 airflow/io/store/__init__.py                       |  2 +-
 .../apache-airflow/core-concepts/objectstorage.rst | 16 +++++++------
 docs/apache-airflow/tutorial/objectstorage.rst     | 18 ++++++++++----
 tests/io/test_path.py                              | 28 +++++++++++++++++-----
 6 files changed, 56 insertions(+), 20 deletions(-)

diff --git a/airflow/example_dags/tutorial_objectstorage.py 
b/airflow/example_dags/tutorial_objectstorage.py
index 11d817400d..4660aa3c8e 100644
--- a/airflow/example_dags/tutorial_objectstorage.py
+++ b/airflow/example_dags/tutorial_objectstorage.py
@@ -43,7 +43,7 @@ aq_fields = {
 }
 
 # [START create_object_storage_path]
-base = ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default")
+base = ObjectStoragePath("s3://aws_default@airflow-tutorial-data/")
 # [END create_object_storage_path]
 
 
diff --git a/airflow/io/path.py b/airflow/io/path.py
index f5eeb14eff..0e6f80254b 100644
--- a/airflow/io/path.py
+++ b/airflow/io/path.py
@@ -92,6 +92,7 @@ class ObjectStoragePath(CloudPath):
         cls: type[PT],
         *args: str | os.PathLike,
         scheme: str | None = None,
+        conn_id: str | None = None,
         **kwargs: typing.Any,
     ) -> PT:
         args_list = list(args)
@@ -137,7 +138,14 @@ class ObjectStoragePath(CloudPath):
         else:
             args_list.insert(0, parsed_url.path)
 
-        return cls._from_parts(args_list, url=parsed_url, **kwargs)  # type: 
ignore
+        # This matches the parsing logic in urllib.parse; see:
+        # 
https://github.com/python/cpython/blob/46adf6b701c440e047abf925df9a75a/Lib/urllib/parse.py#L194-L203
+        userinfo, have_info, hostinfo = parsed_url.netloc.rpartition("@")
+        if have_info:
+            conn_id = conn_id or userinfo or None
+            parsed_url = parsed_url._replace(netloc=hostinfo)
+
+        return cls._from_parts(args_list, url=parsed_url, conn_id=conn_id, 
**kwargs)  # type: ignore
 
     @functools.lru_cache
     def __hash__(self) -> int:
diff --git a/airflow/io/store/__init__.py b/airflow/io/store/__init__.py
index 6bf40c939f..a5a4bd12dd 100644
--- a/airflow/io/store/__init__.py
+++ b/airflow/io/store/__init__.py
@@ -131,7 +131,7 @@ def attach(
 
     if not alias:
         alias = f"{protocol}-{conn_id}" if conn_id else protocol
-        if store := _STORE_CACHE.get(alias, None):
+        if store := _STORE_CACHE.get(alias):
             return store
 
     _STORE_CACHE[alias] = store = ObjectStore(protocol=protocol, 
conn_id=conn_id, fs=fs)
diff --git a/docs/apache-airflow/core-concepts/objectstorage.rst 
b/docs/apache-airflow/core-concepts/objectstorage.rst
index 046cb48522..d72a734293 100644
--- a/docs/apache-airflow/core-concepts/objectstorage.rst
+++ b/docs/apache-airflow/core-concepts/objectstorage.rst
@@ -74,20 +74,22 @@ object you want to interact with. For example, to point to 
a bucket in s3, you w
 
     from airflow.io.path import ObjectStoragePath
 
-    base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default")  # 
conn_id is optional
+    base = ObjectStoragePath("s3://aws_default@my-bucket/")
 
+The username part of the URI is optional. It can alternatively be passed in as 
a separate keyword argument:
+
+.. code-block:: python
+
+    # Equivalent to the previous example.
+    base = ObjectStoragePath("s3://my-bucket/", conn_id="aws_default")
 
 Listing file-objects:
 
 .. code-block:: python
 
     @task
-    def list_files() -> list(ObjectStoragePath):
-        files = []
-        for f in base.iterdir():
-            if f.is_file():
-                files.append(f)
-
+    def list_files() -> list[ObjectStoragePath]:
+        files = [f for f in base.iterdir() if f.is_file()]
         return files
 
 
diff --git a/docs/apache-airflow/tutorial/objectstorage.rst 
b/docs/apache-airflow/tutorial/objectstorage.rst
index 89ffe0e8f9..610450b931 100644
--- a/docs/apache-airflow/tutorial/objectstorage.rst
+++ b/docs/apache-airflow/tutorial/objectstorage.rst
@@ -32,7 +32,7 @@ analytical database. You can do this by running ``pip install 
duckdb``. The tuto
 makes use of S3 Object Storage. This requires that the amazon provider is 
installed
 including ``s3fs`` by running ``pip install 
apache-airflow-providers-amazon[s3fs]``.
 If you would like to use a different storage provider, you can do so by 
changing the
-url in the ``create_object_storage_path`` function to the appropriate url for 
your
+URL in the ``create_object_storage_path`` function to the appropriate URL for 
your
 provider, for example by replacing ``s3://`` with ``gs://`` for Google Cloud 
Storage.
 You will also need the right provider to be installed then. Finally, you will 
need
 ``pandas``, which can be installed by running ``pip install pandas``.
@@ -49,9 +49,19 @@ It is the fundamental building block of the Object Storage 
API.
     :start-after: [START create_object_storage_path]
     :end-before: [END create_object_storage_path]
 
-The ObjectStoragePath constructor can take an optional connection id. If 
supplied
-it will use the connection to obtain the right credentials to access the 
backend.
-Otherwise it will revert to the default for that backend.
+The username part of the URL given to ObjectStoragePath should be a connection 
ID.
+The specified connection will be used to obtain the right credentials to access
+the backend. If it is omitted, the default connection for the backend will be 
used.
+
+The connection ID can alternatively be passed in with a keyword argument:
+
+.. code-block:: python
+
+    ObjectStoragePath("s3://airflow-tutorial-data/", conn_id="aws_default")
+
+This is useful when reusing a URL defined for another purpose (e.g. Dataset),
+which generally does not contain a username part. The explicit keyword argument
+takes precedence over the URL's username value if both are specified.
 
 It is safe to instantiate an ObjectStoragePath at the root of your DAG. 
Connections
 will not be created until the path is used. This means that you can create the
diff --git a/tests/io/test_path.py b/tests/io/test_path.py
index 1ac263c59f..54a675360f 100644
--- a/tests/io/test_path.py
+++ b/tests/io/test_path.py
@@ -52,6 +52,13 @@ class FakeRemoteFileSystem(LocalFileSystem):
 
 
 class TestFs:
+    def setup_class(self):
+        self._store_cache = _STORE_CACHE.copy()
+
+    def teardown(self):
+        _STORE_CACHE.clear()
+        _STORE_CACHE.update(self._store_cache)
+
     def test_alias(self):
         store = attach("file", alias="local")
         assert isinstance(store.fs, LocalFileSystem)
@@ -100,6 +107,19 @@ class TestFs:
 
         assert not o.exists()
 
+    @pytest.fixture()
+    def fake_fs(self):
+        fs = mock.Mock()
+        fs._strip_protocol.return_value = "/"
+        fs.conn_id = "fake"
+        return fs
+
+    def test_objectstoragepath_init_conn_id_in_uri(self, fake_fs):
+        fake_fs.stat.return_value = {"stat": "result"}
+        attach(protocol="fake", conn_id="fake", fs=fake_fs)
+        p = ObjectStoragePath("fake://fake@bucket/path")
+        assert p.stat() == {"stat": "result", "conn_id": "fake", "protocol": 
"fake"}
+
     @pytest.mark.parametrize(
         "fn, args, fn2, path, expected_args, expected_kwargs",
         [
@@ -124,12 +144,8 @@ class TestFs:
             ),
         ],
     )
-    def test_standard_extended_api(self, monkeypatch, fn, args, fn2, path, 
expected_args, expected_kwargs):
-        _fs = mock.Mock()
-        _fs._strip_protocol.return_value = "/"
-        _fs.conn_id = "fake"
-
-        store = attach(protocol="file", conn_id="fake", fs=_fs)
+    def test_standard_extended_api(self, fake_fs, fn, args, fn2, path, 
expected_args, expected_kwargs):
+        store = attach(protocol="file", conn_id="fake", fs=fake_fs)
         o = ObjectStoragePath(path, conn_id="fake")
 
         getattr(o, fn)(**args)

Reply via email to