SameerMesiah97 commented on code in PR #68277:
URL: https://github.com/apache/airflow/pull/68277#discussion_r3384904879
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -275,6 +296,36 @@ def execute(self, context: Context) -> None:
hook._poll_k8s_driver_via_api()
return
hook.submit(self.application)
+
+ def execute_complete(self, context: Context, event: dict) -> None:
+ """
+ Handle the result emitted by SparkDriverTrigger.
+ Called by Airflow when the trigger fires after deferrable=True
execution.
+ Raises AirflowException if the driver did not finish successfully.
+ """
+ from airflow.providers.common.compat.sdk import AirflowException
+ driver_state = event.get("driver_state", "UNKNOWN")
+ driver_id = event.get("driver_id", "unknown")
+ message = event.get("message", "")
+ if event.get("status") != "success":
+ raise AirflowException(
+ f"Spark driver {driver_id} did not finish successfully "
+ f"(state={driver_state}): {message}"
+ )
+ self.log.info("Spark driver %s finished successfully (state=%s)",
driver_id, driver_state)
+ def _build_master_rest_urls(self) -> list[str]:
+ """
+ Build Spark master REST API base URLs for SparkDriverTrigger.
+ Supports HA (comma-separated master URL) and respects rest_scheme /
+ rest_port connection extras (same logic as get_job_status).
+ """
+ if self._hook is None:
+ self._hook = self._get_hook()
+ scheme = self._hook._connection.get("rest_scheme", "http")
+ rest_port = self._hook._connection.get("rest_port", 6066)
+ master_hosts = self._hook._connection["master"].replace("spark://",
"").split(",")
+ return [f"{scheme}://{m.strip().split(':')[0]}:{rest_port}" for m in
master_hosts]
Review Comment:
I think the list comprehension is doing way too much heavy lifting. An
explicit loop like this would be better:
```
urls = []
for host in master_hosts:
hostname = host.strip().split(":")[0]
urls.append(f"{scheme}://{hostname}:{rest_port}")
return urls
```
You can see that it is much easier to understand the string parsing logic.
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -275,6 +296,36 @@ def execute(self, context: Context) -> None:
hook._poll_k8s_driver_via_api()
return
hook.submit(self.application)
+
+ def execute_complete(self, context: Context, event: dict) -> None:
+ """
+ Handle the result emitted by SparkDriverTrigger.
+ Called by Airflow when the trigger fires after deferrable=True
execution.
+ Raises AirflowException if the driver did not finish successfully.
+ """
+ from airflow.providers.common.compat.sdk import AirflowException
+ driver_state = event.get("driver_state", "UNKNOWN")
+ driver_id = event.get("driver_id", "unknown")
Review Comment:
So I can see 2 issues here:
1) The fallback for `driver_state` when the event payload containts no
"driver_state" is defaulting to `UNKNOWN` which is a valid spark driver state.
Is it really the same as the state not being retrievable?
2) There is an overlap between the fallback value for `driver_id` and
`UNKNOWN`, which may confuse users reading the logs.
I think it would be better to validate both like this:
```
status = event.get("status")
driver_id = event.get("driver_id")
if status is None:
raise RuntimeError(f"Malformed trigger event: {event}")
if driver_id is None:
raise RuntimeError(f"Malformed trigger event: {event}")
```
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -261,6 +269,19 @@ def execute(self, context: Context) -> None:
hook = self._hook
if self._track_driver_via_k8s_api:
hook._validate_track_driver_via_k8s_api_config()
+ if self.deferrable:
+ driver_id = self.submit_job(context)
+ master_urls = self._build_master_rest_urls()
+ from airflow.providers.apache.spark.triggers.spark_submit import
SparkDriverTrigger
Review Comment:
This import belongs at the top of the file.
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -275,6 +296,36 @@ def execute(self, context: Context) -> None:
hook._poll_k8s_driver_via_api()
return
hook.submit(self.application)
+
+ def execute_complete(self, context: Context, event: dict) -> None:
+ """
+ Handle the result emitted by SparkDriverTrigger.
+ Called by Airflow when the trigger fires after deferrable=True
execution.
+ Raises AirflowException if the driver did not finish successfully.
+ """
+ from airflow.providers.common.compat.sdk import AirflowException
+ driver_state = event.get("driver_state", "UNKNOWN")
+ driver_id = event.get("driver_id", "unknown")
+ message = event.get("message", "")
+ if event.get("status") != "success":
+ raise AirflowException(
+ f"Spark driver {driver_id} did not finish successfully "
+ f"(state={driver_state}): {message}"
+ )
+ self.log.info("Spark driver %s finished successfully (state=%s)",
driver_id, driver_state)
+ def _build_master_rest_urls(self) -> list[str]:
+ """
+ Build Spark master REST API base URLs for SparkDriverTrigger.
+ Supports HA (comma-separated master URL) and respects rest_scheme /
+ rest_port connection extras (same logic as get_job_status).
+ """
+ if self._hook is None:
+ self._hook = self._get_hook()
+ scheme = self._hook._connection.get("rest_scheme", "http")
Review Comment:
I am not sure why you are accessing the private attribute `_connection` to
get the scheme. Isn't there a public API you can use?
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -134,6 +134,12 @@ class SparkSubmitOperator(ResumableJobMixin, BaseOperator):
omitted, Kerberos-enabled Spark connections with both ``keytab`` and
``principal`` configured use ``requests-kerberos`` automatically.
Defaults to ``None`` (no auth for non-Kerberos connections).
+ :param deferrable: If ``True``, submits the job then defers to
+ ``SparkDriverTrigger``; the worker slot is freed while the trigger
+ polls the Spark REST API. On crash the trigger is re-created from
+ its serialised state (no reconnect needed). On user-clear, execute()
+ runs again and a fresh job is submitted.
+ If ``False`` (default), the sync ``ResumableJobMixin`` path is used.
Review Comment:
This docstring entry is a bit too verbose. I don't think there is a need to
go deep into the mechanies of deferrable mode here. I understand you want to
explain the implications of an adjacent development i.e. `ResumableJobMixin`
but I believe that might be better suited for a comment. I think you should use
this instead:
:param deferrable: Run operator in deferrable mode.
##########
providers/apache/spark/src/airflow/providers/apache/spark/operators/spark_submit.py:
##########
@@ -275,6 +296,36 @@ def execute(self, context: Context) -> None:
hook._poll_k8s_driver_via_api()
return
hook.submit(self.application)
+
+ def execute_complete(self, context: Context, event: dict) -> None:
+ """
+ Handle the result emitted by SparkDriverTrigger.
+ Called by Airflow when the trigger fires after deferrable=True
execution.
+ Raises AirflowException if the driver did not finish successfully.
+ """
+ from airflow.providers.common.compat.sdk import AirflowException
Review Comment:
Same here. And there has been a consensus for a while to move away from
`AirflowException` in favour of native python Errors for e.g. `ValueError`,
`RuntineError` or SDK exceptions if available. I would remove
`AirflowException`.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]