This is an automated email from the ASF dual-hosted git repository.

ephraimanierobi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e7d8f5bdaba Use SSH to authenticate GitDagBundle (#44976)
e7d8f5bdaba is described below

commit e7d8f5bdaba74f9a3980f4b71e89879f4ff8bfc0
Author: Ephraim Anierobi <splendidzig...@gmail.com>
AuthorDate: Thu Jan 16 12:04:09 2025 +0100

    Use SSH to authenticate GitDagBundle (#44976)
    
    * Use SSH to authenticate GitDagBundle
    
    This uses SSH hook to authenticate GitDagBundle when provided.
    
    * Add tests
    
    * Account for remotes with ssh
    
    * renames
    
    * fix tests
    
    * Refactor code
    
    * Use githook
    
    * fixup! Use githook
    
    * Populate the connection form with git type connection
    
    * Mark test_dag_bundles as db test
    
    * Add names to the extra items
    
    * Update airflow/dag_processing/bundles/git.py
    
    Co-authored-by: Felix Uellendall <felue...@users.noreply.github.com>
    
    * Fix refresh
    
    * Apply suggestions from code review
    
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
    
    * Remove ssh hook inheritance
    
    * fixup! Remove ssh hook inheritance
    
    * Apply suggestions from code review
    
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
    
    * Fix code and link to dag processor
    
    * Apply suggestions from code review
    
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
    
    ---------
    
    Co-authored-by: Felix Uellendall <felue...@users.noreply.github.com>
    Co-authored-by: Jed Cunningham 
<66968678+jedcunning...@users.noreply.github.com>
---
 airflow/dag_processing/bundles/base.py       |  10 ++
 airflow/dag_processing/bundles/git.py        | 107 ++++++++++--
 airflow/dag_processing/bundles/manager.py    |   1 +
 airflow/dag_processing/bundles/provider.yaml |  44 +++++
 airflow/dag_processing/manager.py            |   4 +
 airflow/providers_manager.py                 |   5 +
 tests/dag_processing/test_dag_bundles.py     | 237 +++++++++++++++++++++++----
 7 files changed, 361 insertions(+), 47 deletions(-)

diff --git a/airflow/dag_processing/bundles/base.py 
b/airflow/dag_processing/bundles/base.py
index ea560f1be26..cf0467b372a 100644
--- a/airflow/dag_processing/bundles/base.py
+++ b/airflow/dag_processing/bundles/base.py
@@ -50,6 +50,16 @@ class BaseDagBundle(ABC):
         self.name = name
         self.version = version
         self.refresh_interval = refresh_interval
+        self.is_initialized: bool = False
+
+    def initialize(self) -> None:
+        """
+        Initialize the bundle.
+
+        This method is called by the DAG processor before the bundle is used,
+        and allows for deferring expensive operations until that point in time.
+        """
+        self.is_initialized = True
 
     @property
     def _dag_bundle_root_storage_path(self) -> Path:
diff --git a/airflow/dag_processing/bundles/git.py 
b/airflow/dag_processing/bundles/git.py
index d731f65db3b..4b2a19de364 100644
--- a/airflow/dag_processing/bundles/git.py
+++ b/airflow/dag_processing/bundles/git.py
@@ -17,8 +17,9 @@
 
 from __future__ import annotations
 
+import json
 import os
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 from urllib.parse import urlparse
 
 from git import Repo
@@ -26,63 +27,141 @@ from git.exc import BadName
 
 from airflow.dag_processing.bundles.base import BaseDagBundle
 from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.utils.log.logging_mixin import LoggingMixin
 
 if TYPE_CHECKING:
     from pathlib import Path
 
 
-class GitDagBundle(BaseDagBundle):
+class GitHook(BaseHook):
+    """
+    Hook for git repositories.
+
+    :param git_conn_id: Connection ID for SSH connection to the repository
+
+    """
+
+    conn_name_attr = "git_conn_id"
+    default_conn_name = "git_default"
+    conn_type = "git"
+    hook_name = "GIT"
+
+    @classmethod
+    def get_ui_field_behaviour(cls) -> dict[str, Any]:
+        return {
+            "hidden_fields": ["schema"],
+            "relabeling": {
+                "login": "Username",
+                "host": "Repository URL",
+                "password": "Access Token (optional)",
+            },
+            "placeholders": {
+                "extra": json.dumps(
+                    {
+                        "key_file": "optional/path/to/keyfile",
+                    }
+                )
+            },
+        }
+
+    def __init__(self, git_conn_id="git_default", *args, **kwargs):
+        super().__init__()
+        connection = self.get_connection(git_conn_id)
+        self.repo_url = connection.host
+        self.auth_token = connection.password
+        self.key_file = connection.extra_dejson.get("key_file")
+        self.env: dict[str, str] = {}
+        if self.key_file:
+            self.env["GIT_SSH_COMMAND"] = f"ssh -i {self.key_file} -o 
IdentitiesOnly=yes"
+        self._process_git_auth_url()
+
+    def _process_git_auth_url(self):
+        if not isinstance(self.repo_url, str):
+            return
+        if self.auth_token and self.repo_url.startswith("https://";):
+            self.repo_url = self.repo_url.replace("https://";, 
f"https://{self.auth_token}@";)
+        elif not self.repo_url.startswith("git@") or not 
self.repo_url.startswith("https://";):
+            self.repo_url = os.path.expanduser(self.repo_url)
+
+
+class GitDagBundle(BaseDagBundle, LoggingMixin):
     """
     git DAG bundle - exposes a git repository as a DAG bundle.
 
     Instead of cloning the repository every time, we clone the repository once 
into a bare repo from the source
     and then do a clone for each version from there.
 
-    :param repo_url: URL of the git repository
     :param tracking_ref: Branch or tag for this DAG bundle
     :param subdir: Subdirectory within the repository where the DAGs are 
stored (Optional)
+    :param git_conn_id: Connection ID for SSH/token based connection to the 
repository (Optional)
     """
 
     supports_versioning = True
 
-    def __init__(self, *, repo_url: str, tracking_ref: str, subdir: str | None 
= None, **kwargs) -> None:
+    def __init__(
+        self,
+        *,
+        tracking_ref: str,
+        subdir: str | None = None,
+        git_conn_id: str = "git_default",
+        **kwargs,
+    ) -> None:
         super().__init__(**kwargs)
-        self.repo_url = repo_url
         self.tracking_ref = tracking_ref
         self.subdir = subdir
-
         self.bare_repo_path = self._dag_bundle_root_storage_path / "git" / 
self.name
         self.repo_path = (
             self._dag_bundle_root_storage_path / "git" / (self.name + 
f"+{self.version or self.tracking_ref}")
         )
+        self.git_conn_id = git_conn_id
+        self.hook = GitHook(git_conn_id=self.git_conn_id)
+        self.repo_url = self.hook.repo_url
+
+    def _initialize(self):
         self._clone_bare_repo_if_required()
         self._ensure_version_in_bare_repo()
         self._clone_repo_if_required()
         self.repo.git.checkout(self.tracking_ref)
-
         if self.version:
             if not self._has_version(self.repo, self.version):
                 self.repo.remotes.origin.fetch()
-
             self.repo.head.set_reference(self.repo.commit(self.version))
             self.repo.head.reset(index=True, working_tree=True)
         else:
             self.refresh()
 
+    def initialize(self) -> None:
+        if not self.repo_url:
+            raise AirflowException(f"Connection {self.git_conn_id} doesn't 
have a git_repo_url")
+        if isinstance(self.repo_url, os.PathLike):
+            self._initialize()
+        elif not self.repo_url.startswith("git@") or not 
self.repo_url.endswith(".git"):
+            raise AirflowException(
+                f"Invalid git URL: {self.repo_url}. URL must start with git@ 
and end with .git"
+            )
+        else:
+            self._initialize()
+        super().initialize()
+
     def _clone_repo_if_required(self) -> None:
         if not os.path.exists(self.repo_path):
+            self.log.info("Cloning repository to %s from %s", self.repo_path, 
self.bare_repo_path)
             Repo.clone_from(
                 url=self.bare_repo_path,
                 to_path=self.repo_path,
             )
+
         self.repo = Repo(self.repo_path)
 
     def _clone_bare_repo_if_required(self) -> None:
         if not os.path.exists(self.bare_repo_path):
+            self.log.info("Cloning bare repository to %s", self.bare_repo_path)
             Repo.clone_from(
                 url=self.repo_url,
                 to_path=self.bare_repo_path,
                 bare=True,
+                env=self.hook.env,
             )
         self.bare_repo = Repo(self.bare_repo_path)
 
@@ -90,7 +169,7 @@ class GitDagBundle(BaseDagBundle):
         if not self.version:
             return
         if not self._has_version(self.bare_repo, self.version):
-            self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+            self._fetch_bare_repo()
             if not self._has_version(self.bare_repo, self.version):
                 raise AirflowException(f"Version {self.version} not found in 
the repository")
 
@@ -121,11 +200,17 @@ class GitDagBundle(BaseDagBundle):
         except BadName:
             return False
 
+    def _fetch_bare_repo(self):
+        if self.hook.env:
+            with 
self.bare_repo.git.custom_environment(GIT_SSH_COMMAND=self.hook.env.get("GIT_SSH_COMMAND")):
+                
self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+        else:
+            self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+
     def refresh(self) -> None:
         if self.version:
             raise AirflowException("Refreshing a specific version is not 
supported")
-
-        self.bare_repo.remotes.origin.fetch("+refs/heads/*:refs/heads/*")
+        self._fetch_bare_repo()
         self.repo.remotes.origin.pull()
 
     def _convert_git_ssh_url_to_https(self) -> str:
diff --git a/airflow/dag_processing/bundles/manager.py 
b/airflow/dag_processing/bundles/manager.py
index 1ae751f8d33..ad1ebc58891 100644
--- a/airflow/dag_processing/bundles/manager.py
+++ b/airflow/dag_processing/bundles/manager.py
@@ -96,6 +96,7 @@ class DagBundlesManager(LoggingMixin):
             class_ = import_string(cfg["classpath"])
             kwargs = cfg["kwargs"]
             self._bundle_config[name] = (class_, kwargs)
+        self.log.info("DAG bundles loaded: %s", ", 
".join(self._bundle_config.keys()))
 
     @provide_session
     def sync_bundles_to_db(self, *, session: Session = NEW_SESSION) -> None:
diff --git a/airflow/dag_processing/bundles/provider.yaml 
b/airflow/dag_processing/bundles/provider.yaml
new file mode 100644
index 00000000000..9ca5d1479f2
--- /dev/null
+++ b/airflow/dag_processing/bundles/provider.yaml
@@ -0,0 +1,44 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+---
+package-name: apache-airflow-providers-bundles
+name: GIT
+description: |
+    `GIT <https://git-scm.com/>`__
+
+state: not-ready
+source-date-epoch: 1726861127
+# note that those versions are maintained by release manager - do not update 
them manually
+versions:
+  - 1.0.0
+
+dependencies:
+  - apache-airflow-providers-ssh
+
+integrations:
+  - integration-name: GIT (Git)
+
+hooks:
+  - integration-name: GIT
+    python-modules:
+      - airflow.dag_processing.bundles.git
+
+
+connection-types:
+  - hook-class-name: airflow.dag_processing.bundles.git.GitHook
+    connection-type: git
diff --git a/airflow/dag_processing/manager.py 
b/airflow/dag_processing/manager.py
index 96c7fe4f0ed..220b55edce6 100644
--- a/airflow/dag_processing/manager.py
+++ b/airflow/dag_processing/manager.py
@@ -653,6 +653,10 @@ class DagFileProcessorManager:
         self.log.info("Refreshing DAG bundles")
 
         for bundle in self._dag_bundles:
+            # TODO: AIP-66 handle errors in the case of incomplete cloning? 
And test this.
+            #  What if the cloning/refreshing took too long(longer than the 
dag processor timeout)
+            if not bundle.is_initialized:
+                bundle.initialize()
             # TODO: AIP-66 test to make sure we get a fresh record from the db 
and it's not cached
             with create_session() as session:
                 bundle_model = session.get(DagBundleModel, bundle.name)
diff --git a/airflow/providers_manager.py b/airflow/providers_manager.py
index 575306a840b..9b39439384f 100644
--- a/airflow/providers_manager.py
+++ b/airflow/providers_manager.py
@@ -175,6 +175,9 @@ def 
_create_customized_form_field_behaviours_schema_validator():
 
 
 def _check_builtin_provider_prefix(provider_package: str, class_name: str) -> 
bool:
+    if "bundles" in provider_package:
+        # TODO: AIP-66: remove this when this package is moved to providers 
directory
+        return True
     if provider_package.startswith("apache-airflow"):
         provider_path = provider_package[len("apache-") :].replace("-", ".")
         if not class_name.startswith(provider_path):
@@ -676,6 +679,8 @@ class ProvidersManager(LoggingMixin, metaclass=Singleton):
                     
self._add_provider_info_from_local_source_files_on_path(path)
             except Exception as e:
                 log.warning("Error when loading 'provider.yaml' files from %s 
airflow sources: %s", path, e)
+        # TODO: AIP-66: Remove this when the package is moved to providers
+        
self._add_provider_info_from_local_source_files_on_path("airflow/dag_processing")
 
     def _add_provider_info_from_local_source_files_on_path(self, path) -> None:
         """
diff --git a/tests/dag_processing/test_dag_bundles.py 
b/tests/dag_processing/test_dag_bundles.py
index d450a561313..49b7da1a03a 100644
--- a/tests/dag_processing/test_dag_bundles.py
+++ b/tests/dag_processing/test_dag_bundles.py
@@ -26,11 +26,16 @@ from git import Repo
 
 from airflow.dag_processing.bundles.base import BaseDagBundle
 from airflow.dag_processing.bundles.dagfolder import DagsFolderDagBundle
-from airflow.dag_processing.bundles.git import GitDagBundle
+from airflow.dag_processing.bundles.git import GitDagBundle, GitHook
 from airflow.dag_processing.bundles.local import LocalDagBundle
 from airflow.exceptions import AirflowException
+from airflow.models import Connection
+from airflow.utils import db
 
 from tests_common.test_utils.config import conf_vars
+from tests_common.test_utils.db import clear_db_connections
+
+pytestmark = pytest.mark.db_test
 
 
 @pytest.fixture(autouse=True)
@@ -107,27 +112,111 @@ def git_repo(tmp_path_factory):
     return (directory, repo)
 
 
+AIRFLOW_HTTPS_URL = "https://github.com/apache/airflow.git";
+AIRFLOW_GIT = "g...@github.com:apache/airflow.git"
+ACCESS_TOKEN = "my_access_token"
+CONN_DEFAULT = "git_default"
+CONN_HTTPS = "my_git_conn"
+CONN_HTTPS_PASSWORD = "my_git_conn_https_password"
+CONN_ONLY_PATH = "my_git_conn_only_path"
+CONN_NO_REPO_URL = "my_git_conn_no_repo_url"
+
+
+class TestGitHook:
+    @classmethod
+    def teardown_class(cls) -> None:
+        clear_db_connections()
+
+    @classmethod
+    def setup_class(cls) -> None:
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_DEFAULT,
+                host=AIRFLOW_GIT,
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_HTTPS,
+                host=AIRFLOW_HTTPS_URL,
+                password=ACCESS_TOKEN,
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_HTTPS_PASSWORD,
+                host=AIRFLOW_HTTPS_URL,
+                conn_type="git",
+                password=ACCESS_TOKEN,
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_ONLY_PATH,
+                host="path/to/repo",
+                conn_type="git",
+            )
+        )
+
+    @pytest.mark.parametrize(
+        "conn_id, expected_repo_url",
+        [
+            (CONN_DEFAULT, AIRFLOW_GIT),
+            (CONN_HTTPS, 
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git";),
+            (CONN_HTTPS_PASSWORD, 
f"https://{ACCESS_TOKEN}@github.com/apache/airflow.git";),
+            (CONN_ONLY_PATH, "path/to/repo"),
+        ],
+    )
+    def test_correct_repo_urls(self, conn_id, expected_repo_url):
+        hook = GitHook(git_conn_id=conn_id)
+        assert hook.repo_url == expected_repo_url
+
+
 class TestGitDagBundle:
+    @classmethod
+    def teardown_class(cls) -> None:
+        clear_db_connections()
+
+    @classmethod
+    def setup_class(cls) -> None:
+        db.merge_conn(
+            Connection(
+                conn_id="git_default",
+                host="g...@github.com:apache/airflow.git",
+                conn_type="git",
+            )
+        )
+        db.merge_conn(
+            Connection(
+                conn_id=CONN_NO_REPO_URL,
+                conn_type="git",
+            )
+        )
+
     def test_supports_versioning(self):
         assert GitDagBundle.supports_versioning is True
 
     def test_uses_dag_bundle_root_storage_path(self, git_repo):
         repo_path, repo = git_repo
-        bundle = GitDagBundle(
-            name="test", refresh_interval=300, repo_url=repo_path, 
tracking_ref=GIT_DEFAULT_BRANCH
-        )
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref=GIT_DEFAULT_BRANCH)
         assert str(bundle._dag_bundle_root_storage_path) in str(bundle.path)
 
-    def test_get_current_version(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_get_current_version(self, mock_githook, git_repo):
         repo_path, repo = git_repo
-        bundle = GitDagBundle(
-            name="test", refresh_interval=300, repo_url=repo_path, 
tracking_ref=GIT_DEFAULT_BRANCH
-        )
+        mock_githook.return_value.repo_url = repo_path
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref=GIT_DEFAULT_BRANCH)
+
+        bundle.initialize()
 
         assert bundle.get_current_version() == repo.head.commit.hexsha
 
-    def test_get_specific_version(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_get_specific_version(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
         starting_commit = repo.head.commit
 
         # Add new file to the repo
@@ -141,17 +230,19 @@ class TestGitDagBundle:
             name="test",
             refresh_interval=300,
             version=starting_commit.hexsha,
-            repo_url=repo_path,
             tracking_ref=GIT_DEFAULT_BRANCH,
         )
+        bundle.initialize()
 
         assert bundle.get_current_version() == starting_commit.hexsha
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py"} == files_in_repo
 
-    def test_get_tag_version(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_get_tag_version(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
         starting_commit = repo.head.commit
 
         # add tag
@@ -169,17 +260,18 @@ class TestGitDagBundle:
             name="test",
             refresh_interval=300,
             version="test",
-            repo_url=repo_path,
             tracking_ref=GIT_DEFAULT_BRANCH,
         )
-
+        bundle.initialize()
         assert bundle.get_current_version() == starting_commit.hexsha
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py"} == files_in_repo
 
-    def test_get_latest(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_get_latest(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
         starting_commit = repo.head.commit
 
         file_path = repo_path / "new_test.py"
@@ -188,22 +280,22 @@ class TestGitDagBundle:
         repo.index.add([file_path])
         repo.index.commit("Another commit")
 
-        bundle = GitDagBundle(
-            name="test", refresh_interval=300, repo_url=repo_path, 
tracking_ref=GIT_DEFAULT_BRANCH
-        )
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref=GIT_DEFAULT_BRANCH)
+        bundle.initialize()
 
         assert bundle.get_current_version() != starting_commit.hexsha
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py", "new_test.py"} == files_in_repo
 
-    def test_refresh(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_refresh(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
         starting_commit = repo.head.commit
 
-        bundle = GitDagBundle(
-            name="test", refresh_interval=300, repo_url=repo_path, 
tracking_ref=GIT_DEFAULT_BRANCH
-        )
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref=GIT_DEFAULT_BRANCH)
+        bundle.initialize()
 
         assert bundle.get_current_version() == starting_commit.hexsha
 
@@ -223,27 +315,34 @@ class TestGitDagBundle:
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert {"test_dag.py", "new_test.py"} == files_in_repo
 
-    def test_head(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_head(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
 
         repo.create_head("test")
-        bundle = GitDagBundle(name="test", refresh_interval=300, 
repo_url=repo_path, tracking_ref="test")
+        bundle = GitDagBundle(name="test", refresh_interval=300, 
tracking_ref="test")
+        bundle.initialize()
         assert bundle.repo.head.ref.name == "test"
 
-    def test_version_not_found(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_version_not_found(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            version="not_found",
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
 
         with pytest.raises(AirflowException, match="Version not_found not 
found in the repository"):
-            GitDagBundle(
-                name="test",
-                refresh_interval=300,
-                version="not_found",
-                repo_url=repo_path,
-                tracking_ref=GIT_DEFAULT_BRANCH,
-            )
+            bundle.initialize()
 
-    def test_subdir(self, git_repo):
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_subdir(self, mock_githook, git_repo):
         repo_path, repo = git_repo
+        mock_githook.return_value.repo_url = repo_path
 
         subdir = "somesubdir"
         subdir_path = repo_path / subdir
@@ -258,15 +357,75 @@ class TestGitDagBundle:
         bundle = GitDagBundle(
             name="test",
             refresh_interval=300,
-            repo_url=repo_path,
             tracking_ref=GIT_DEFAULT_BRANCH,
             subdir=subdir,
         )
+        bundle.initialize()
 
         files_in_repo = {f.name for f in bundle.path.iterdir() if f.is_file()}
         assert str(bundle.path).endswith(subdir)
         assert {"some_new_file.py"} == files_in_repo
 
+    def test_raises_when_no_repo_url(self):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id=CONN_NO_REPO_URL,
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        with pytest.raises(
+            AirflowException, match=f"Connection {CONN_NO_REPO_URL} doesn't 
have a git_repo_url"
+        ):
+            bundle.initialize()
+
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    def test_with_path_as_repo_url(self, mock_gitRepo, mock_githook):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id=CONN_ONLY_PATH,
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        assert mock_gitRepo.clone_from.call_count == 2
+        assert mock_gitRepo.return_value.git.checkout.call_count == 1
+
+    @mock.patch("airflow.dag_processing.bundles.git.Repo")
+    def test_refresh_with_git_connection(self, mock_gitRepo):
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id="git_default",
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        bundle.initialize()
+        bundle.refresh()
+        # check remotes called twice. one at initialize and one at refresh 
above
+        assert mock_gitRepo.return_value.remotes.origin.fetch.call_count == 2
+
+    @pytest.mark.parametrize(
+        "repo_url",
+        [
+            pytest.param("https://github.com/apache/airflow";, id="https_url"),
+            pytest.param("airflow@example:apache/airflow.git", 
id="does_not_start_with_git_at"),
+            pytest.param("git@example:apache/airflow", 
id="does_not_end_with_dot_git"),
+        ],
+    )
+    @mock.patch("airflow.dag_processing.bundles.git.GitHook")
+    def test_repo_url_validation_for_ssh(self, mock_hook, repo_url, session):
+        mock_hook.return_value.repo_url = repo_url
+        bundle = GitDagBundle(
+            name="test",
+            refresh_interval=300,
+            git_conn_id="git_default",
+            tracking_ref=GIT_DEFAULT_BRANCH,
+        )
+        with pytest.raises(
+            AirflowException, match=f"Invalid git URL: {repo_url}. URL must 
start with git@ and end with .git"
+        ):
+            bundle.initialize()
+
     @pytest.mark.parametrize(
         "repo_url, expected_url",
         [
@@ -280,11 +439,18 @@ class TestGitDagBundle:
         ],
     )
     @mock.patch("airflow.dag_processing.bundles.git.Repo")
-    def test_view_url(self, mock_gitrepo, repo_url, expected_url):
+    def test_view_url(self, mock_gitrepo, repo_url, expected_url, session):
+        session.query(Connection).delete()
+        conn = Connection(
+            conn_id="git_default",
+            host=repo_url,
+            conn_type="git",
+        )
+        session.add(conn)
+        session.commit()
         bundle = GitDagBundle(
             name="test",
             refresh_interval=300,
-            repo_url=repo_url,
             tracking_ref="main",
         )
         view_url = bundle.view_url("0f0f0f")
@@ -295,7 +461,6 @@ class TestGitDagBundle:
         bundle = GitDagBundle(
             name="test",
             refresh_interval=300,
-            repo_url="g...@github.com:apache/airflow.git",
             tracking_ref="main",
         )
         view_url = bundle.view_url(None)

Reply via email to