This is an automated email from the ASF dual-hosted git repository.

ferruzzi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new ff28969ff3 fix: EmrServerlessStartJobOperator not serializing DAGs 
correctly when partial/expand is used. (#38022)
ff28969ff3 is described below

commit ff28969ff3370034ed9246d4ce9d0022129b3152
Author: jliu0812 <114856647+jliu0...@users.noreply.github.com>
AuthorDate: Mon Mar 25 16:47:53 2024 -0500

    fix: EmrServerlessStartJobOperator not serializing DAGs correctly when 
partial/expand is used. (#38022)
---
 airflow/providers/amazon/aws/operators/emr.py      | 62 +++++++++++++++++++---
 .../amazon/aws/operators/test_emr_serverless.py    | 55 +++++++++++++++++++
 2 files changed, 111 insertions(+), 6 deletions(-)

diff --git a/airflow/providers/amazon/aws/operators/emr.py 
b/airflow/providers/amazon/aws/operators/emr.py
index 7c4d86c5e8..01e1567eab 100644
--- a/airflow/providers/amazon/aws/operators/emr.py
+++ b/airflow/providers/amazon/aws/operators/emr.py
@@ -1253,27 +1253,77 @@ class EmrServerlessStartJobOperator(BaseOperator):
         op_extra_links = []
 
         if isinstance(self, MappedOperator):
+            operator_class = self.operator_class
             enable_application_ui_links = self.partial_kwargs.get(
                 "enable_application_ui_links"
             ) or self.expand_input.value.get("enable_application_ui_links")
-            job_driver = self.partial_kwargs.get("job_driver") or 
self.expand_input.value.get("job_driver")
+            job_driver = self.partial_kwargs.get("job_driver", {}) or 
self.expand_input.value.get(
+                "job_driver", {}
+            )
             configuration_overrides = self.partial_kwargs.get(
                 "configuration_overrides"
             ) or self.expand_input.value.get("configuration_overrides")
 
+            # Configuration overrides can either be a list or a dictionary, 
depending on whether it's passed in as partial or expand.
+            if isinstance(configuration_overrides, list):
+                if any(
+                    [
+                        operator_class.is_monitoring_in_job_override(
+                            self=operator_class,
+                            config_key="s3MonitoringConfiguration",
+                            job_override=job_override,
+                        )
+                        for job_override in configuration_overrides
+                    ]
+                ):
+                    op_extra_links.extend([EmrServerlessS3LogsLink()])
+                if any(
+                    [
+                        operator_class.is_monitoring_in_job_override(
+                            self=operator_class,
+                            config_key="cloudWatchLoggingConfiguration",
+                            job_override=job_override,
+                        )
+                        for job_override in configuration_overrides
+                    ]
+                ):
+                    op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+            else:
+                if operator_class.is_monitoring_in_job_override(
+                    self=operator_class,
+                    config_key="s3MonitoringConfiguration",
+                    job_override=configuration_overrides,
+                ):
+                    op_extra_links.extend([EmrServerlessS3LogsLink()])
+                if operator_class.is_monitoring_in_job_override(
+                    self=operator_class,
+                    config_key="cloudWatchLoggingConfiguration",
+                    job_override=configuration_overrides,
+                ):
+                    op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+
         else:
+            operator_class = self
             enable_application_ui_links = self.enable_application_ui_links
             configuration_overrides = self.configuration_overrides
             job_driver = self.job_driver
 
+            if operator_class.is_monitoring_in_job_override(
+                "s3MonitoringConfiguration", configuration_overrides
+            ):
+                op_extra_links.extend([EmrServerlessS3LogsLink()])
+            if operator_class.is_monitoring_in_job_override(
+                "cloudWatchLoggingConfiguration", configuration_overrides
+            ):
+                op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
+
         if enable_application_ui_links:
             op_extra_links.extend([EmrServerlessDashboardLink()])
-            if "sparkSubmit" in job_driver:
+            if isinstance(job_driver, list):
+                if any("sparkSubmit" in ind_job_driver for ind_job_driver in 
job_driver):
+                    op_extra_links.extend([EmrServerlessLogsLink()])
+            elif "sparkSubmit" in job_driver:
                 op_extra_links.extend([EmrServerlessLogsLink()])
-        if self.is_monitoring_in_job_override("s3MonitoringConfiguration", 
configuration_overrides):
-            op_extra_links.extend([EmrServerlessS3LogsLink()])
-        if 
self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", 
configuration_overrides):
-            op_extra_links.extend([EmrServerlessCloudWatchLogsLink()])
 
         return tuple(op_extra_links)
 
diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py 
b/tests/providers/amazon/aws/operators/test_emr_serverless.py
index eed292c3cd..35eae39210 100644
--- a/tests/providers/amazon/aws/operators/test_emr_serverless.py
+++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py
@@ -25,12 +25,21 @@ from botocore.exceptions import WaiterError
 
 from airflow.exceptions import AirflowException, TaskDeferred
 from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook
+from airflow.providers.amazon.aws.links.emr import (
+    EmrServerlessCloudWatchLogsLink,
+    EmrServerlessDashboardLink,
+    EmrServerlessLogsLink,
+    EmrServerlessS3LogsLink,
+)
 from airflow.providers.amazon.aws.operators.emr import (
     EmrServerlessCreateApplicationOperator,
     EmrServerlessDeleteApplicationOperator,
     EmrServerlessStartJobOperator,
     EmrServerlessStopApplicationOperator,
 )
+from airflow.serialization.serialized_objects import (
+    SerializedBaseOperator,
+)
 from airflow.utils.types import NOTSET
 
 if TYPE_CHECKING:
@@ -1096,6 +1105,52 @@ class TestEmrServerlessStartJobOperator:
             job_run_id=job_run_id,
         )
 
+    def test_operator_extra_links_mapped_without_applicationui_enabled(
+        self,
+    ):
+        operator = EmrServerlessStartJobOperator.partial(
+            task_id=task_id,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=spark_job_driver,
+            enable_application_ui_links=False,
+        ).expand(
+            configuration_overrides=[s3_configuration_overrides, 
cloudwatch_configuration_overrides],
+        )
+
+        serialize = SerializedBaseOperator.serialize
+        deserialize = SerializedBaseOperator.deserialize_operator
+        deserialized_operator = deserialize(serialize(operator))
+
+        assert deserialized_operator.operator_extra_links == [
+            EmrServerlessS3LogsLink(),
+            EmrServerlessCloudWatchLogsLink(),
+        ]
+
+    def test_operator_extra_links_mapped_with_applicationui_enabled_at_partial(
+        self,
+    ):
+        operator = EmrServerlessStartJobOperator.partial(
+            task_id=task_id,
+            application_id=application_id,
+            execution_role_arn=execution_role_arn,
+            job_driver=spark_job_driver,
+            enable_application_ui_links=True,
+        ).expand(
+            configuration_overrides=[s3_configuration_overrides, 
cloudwatch_configuration_overrides],
+        )
+
+        serialize = SerializedBaseOperator.serialize
+        deserialize = SerializedBaseOperator.deserialize_operator
+        deserialized_operator = deserialize(serialize(operator))
+
+        assert deserialized_operator.operator_extra_links == [
+            EmrServerlessS3LogsLink(),
+            EmrServerlessCloudWatchLogsLink(),
+            EmrServerlessDashboardLink(),
+            EmrServerlessLogsLink(),
+        ]
+
 
 class TestEmrServerlessDeleteOperator:
     @mock.patch.object(EmrServerlessHook, "get_waiter")

Reply via email to