This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 645d52f129 Add `use_krb5ccache` option to `SparkSubmitHook` (#34386)
645d52f129 is described below

commit 645d52f1298c49b2111d058971e1a9f159f1e257
Author: zeotuan <48720253+zeot...@users.noreply.github.com>
AuthorDate: Sat Oct 21 20:19:29 2023 +1100

    Add `use_krb5ccache` option to `SparkSubmitHook` (#34386)
---
 .../providers/apache/spark/hooks/spark_submit.py   | 33 +++++++++++++++++++++-
 .../apache/spark/hooks/test_spark_submit.py        | 27 +++++++++++++++++-
 2 files changed, 58 insertions(+), 2 deletions(-)

diff --git a/airflow/providers/apache/spark/hooks/spark_submit.py 
b/airflow/providers/apache/spark/hooks/spark_submit.py
index 75c13c8099..d519eb3e6e 100644
--- a/airflow/providers/apache/spark/hooks/spark_submit.py
+++ b/airflow/providers/apache/spark/hooks/spark_submit.py
@@ -78,6 +78,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
     :param verbose: Whether to pass the verbose flag to spark-submit process 
for debugging
     :param spark_binary: The command to use for spark submit.
                          Some distros may use spark2-submit or spark3-submit.
+    :param use_krb5ccache: if True, configure spark to use ticket cache 
instead of relying
+        on keytab for Kerberos login
     """
 
     conn_name_attr = "conn_id"
@@ -120,6 +122,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         env_vars: dict[str, Any] | None = None,
         verbose: bool = False,
         spark_binary: str | None = None,
+        *,
+        use_krb5ccache: bool = False,
     ) -> None:
         super().__init__()
         self._conf = conf or {}
@@ -138,7 +142,8 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
         self._executor_memory = executor_memory
         self._driver_memory = driver_memory
         self._keytab = keytab
-        self._principal = principal
+        self._principal = self._resolve_kerberos_principal(principal) if 
use_krb5ccache else principal
+        self._use_krb5ccache = use_krb5ccache
         self._proxy_user = proxy_user
         self._name = name
         self._num_executors = num_executors
@@ -317,6 +322,12 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
             connection_cmd += ["--keytab", self._keytab]
         if self._principal:
             connection_cmd += ["--principal", self._principal]
+        if self._use_krb5ccache:
+            if not os.getenv("KRB5CCNAME"):
+                raise AirflowException(
+                    "KRB5CCNAME environment variable required to use ticket 
ccache is missing."
+                )
+            connection_cmd += ["--conf", 
"spark.kerberos.renewal.credentials=ccache"]
         if self._proxy_user:
             connection_cmd += ["--proxy-user", self._proxy_user]
         if self._name:
@@ -383,6 +394,26 @@ class SparkSubmitHook(BaseHook, LoggingMixin):
 
         return connection_cmd
 
+    def _resolve_kerberos_principal(self, principal: str | None) -> str:
+        """Resolve kerberos principal if airflow > 2.8.
+
+        TODO: delete when min airflow version >= 2.8 and import directly from 
airflow.security.kerberos
+        """
+        from packaging.version import Version
+
+        from airflow.version import version
+
+        if Version(version) < Version("2.8"):
+            from airflow.utils.net import get_hostname
+
+            return principal or airflow_conf.get_mandatory_value("kerberos", 
"principal").replace(
+                "_HOST", get_hostname()
+            )
+        else:
+            from airflow.security.kerberos import get_kerberos_principle
+
+            return get_kerberos_principle(principal)
+
     def submit(self, application: str = "", **kwargs: Any) -> None:
         """
         Remote Popen to execute the spark-submit job.
diff --git a/tests/providers/apache/spark/hooks/test_spark_submit.py 
b/tests/providers/apache/spark/hooks/test_spark_submit.py
index d3feccb081..9bf828e94d 100644
--- a/tests/providers/apache/spark/hooks/test_spark_submit.py
+++ b/tests/providers/apache/spark/hooks/test_spark_submit.py
@@ -61,6 +61,7 @@ class TestSparkSubmitHook:
             "args should keep embedded spaces",
             "baz",
         ],
+        "use_krb5ccache": True,
     }
 
     @staticmethod
@@ -141,7 +142,10 @@ class TestSparkSubmitHook:
             )
         )
 
-    def test_build_spark_submit_command(self):
+    @patch(
+        "airflow.providers.apache.spark.hooks.spark_submit.os.getenv", 
return_value="/tmp/airflow_krb5_ccache"
+    )
+    def test_build_spark_submit_command(self, mock_get_env):
         # Given
         hook = SparkSubmitHook(**self._config)
 
@@ -183,6 +187,8 @@ class TestSparkSubmitHook:
             "privileged_user.keytab",
             "--principal",
             "user/sp...@airflow.org",
+            "--conf",
+            "spark.kerberos.renewal.credentials=ccache",
             "--proxy-user",
             "sample_user",
             "--name",
@@ -200,6 +206,25 @@ class TestSparkSubmitHook:
             "baz",
         ]
         assert expected_build_cmd == cmd
+        mock_get_env.assert_called_with("KRB5CCNAME")
+
+    @patch("airflow.configuration.conf.get_mandatory_value")
+    def 
test_resolve_spark_submit_env_vars_use_krb5ccache_missing_principal(self, 
mock_get_madantory_value):
+        mock_principle = "airflow"
+        mock_get_madantory_value.return_value = mock_principle
+        hook = SparkSubmitHook(conn_id="spark_yarn_cluster", principal=None, 
use_krb5ccache=True)
+        mock_get_madantory_value.assert_called_with("kerberos", "principal")
+        assert hook._principal == mock_principle
+
+    def 
test_resolve_spark_submit_env_vars_use_krb5ccache_missing_KRB5CCNAME_env(self):
+        hook = SparkSubmitHook(
+            conn_id="spark_yarn_cluster", principal="user/sp...@airflow.org", 
use_krb5ccache=True
+        )
+        with pytest.raises(
+            AirflowException,
+            match="KRB5CCNAME environment variable required to use ticket 
ccache is missing.",
+        ):
+            hook._build_spark_submit_command(self._spark_job_file)
 
     def test_build_track_driver_status_command(self):
         # note this function is only relevant for spark setup matching below 
condition

Reply via email to