This is an automated email from the ASF dual-hosted git repository. ferruzzi pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new ff28969ff3 fix: EmrServerlessStartJobOperator not serializing DAGs correctly when partial/expand is used. (#38022) ff28969ff3 is described below commit ff28969ff3370034ed9246d4ce9d0022129b3152 Author: jliu0812 <114856647+jliu0...@users.noreply.github.com> AuthorDate: Mon Mar 25 16:47:53 2024 -0500 fix: EmrServerlessStartJobOperator not serializing DAGs correctly when partial/expand is used. (#38022) --- airflow/providers/amazon/aws/operators/emr.py | 62 +++++++++++++++++++--- .../amazon/aws/operators/test_emr_serverless.py | 55 +++++++++++++++++++ 2 files changed, 111 insertions(+), 6 deletions(-) diff --git a/airflow/providers/amazon/aws/operators/emr.py b/airflow/providers/amazon/aws/operators/emr.py index 7c4d86c5e8..01e1567eab 100644 --- a/airflow/providers/amazon/aws/operators/emr.py +++ b/airflow/providers/amazon/aws/operators/emr.py @@ -1253,27 +1253,77 @@ class EmrServerlessStartJobOperator(BaseOperator): op_extra_links = [] if isinstance(self, MappedOperator): + operator_class = self.operator_class enable_application_ui_links = self.partial_kwargs.get( "enable_application_ui_links" ) or self.expand_input.value.get("enable_application_ui_links") - job_driver = self.partial_kwargs.get("job_driver") or self.expand_input.value.get("job_driver") + job_driver = self.partial_kwargs.get("job_driver", {}) or self.expand_input.value.get( + "job_driver", {} + ) configuration_overrides = self.partial_kwargs.get( "configuration_overrides" ) or self.expand_input.value.get("configuration_overrides") + # Configuration overrides can either be a list or a dictionary, depending on whether it's passed in as partial or expand. + if isinstance(configuration_overrides, list): + if any( + [ + operator_class.is_monitoring_in_job_override( + self=operator_class, + config_key="s3MonitoringConfiguration", + job_override=job_override, + ) + for job_override in configuration_overrides + ] + ): + op_extra_links.extend([EmrServerlessS3LogsLink()]) + if any( + [ + operator_class.is_monitoring_in_job_override( + self=operator_class, + config_key="cloudWatchLoggingConfiguration", + job_override=job_override, + ) + for job_override in configuration_overrides + ] + ): + op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) + else: + if operator_class.is_monitoring_in_job_override( + self=operator_class, + config_key="s3MonitoringConfiguration", + job_override=configuration_overrides, + ): + op_extra_links.extend([EmrServerlessS3LogsLink()]) + if operator_class.is_monitoring_in_job_override( + self=operator_class, + config_key="cloudWatchLoggingConfiguration", + job_override=configuration_overrides, + ): + op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) + else: + operator_class = self enable_application_ui_links = self.enable_application_ui_links configuration_overrides = self.configuration_overrides job_driver = self.job_driver + if operator_class.is_monitoring_in_job_override( + "s3MonitoringConfiguration", configuration_overrides + ): + op_extra_links.extend([EmrServerlessS3LogsLink()]) + if operator_class.is_monitoring_in_job_override( + "cloudWatchLoggingConfiguration", configuration_overrides + ): + op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) + if enable_application_ui_links: op_extra_links.extend([EmrServerlessDashboardLink()]) - if "sparkSubmit" in job_driver: + if isinstance(job_driver, list): + if any("sparkSubmit" in ind_job_driver for ind_job_driver in job_driver): + op_extra_links.extend([EmrServerlessLogsLink()]) + elif "sparkSubmit" in job_driver: op_extra_links.extend([EmrServerlessLogsLink()]) - if self.is_monitoring_in_job_override("s3MonitoringConfiguration", configuration_overrides): - op_extra_links.extend([EmrServerlessS3LogsLink()]) - if self.is_monitoring_in_job_override("cloudWatchLoggingConfiguration", configuration_overrides): - op_extra_links.extend([EmrServerlessCloudWatchLogsLink()]) return tuple(op_extra_links) diff --git a/tests/providers/amazon/aws/operators/test_emr_serverless.py b/tests/providers/amazon/aws/operators/test_emr_serverless.py index eed292c3cd..35eae39210 100644 --- a/tests/providers/amazon/aws/operators/test_emr_serverless.py +++ b/tests/providers/amazon/aws/operators/test_emr_serverless.py @@ -25,12 +25,21 @@ from botocore.exceptions import WaiterError from airflow.exceptions import AirflowException, TaskDeferred from airflow.providers.amazon.aws.hooks.emr import EmrServerlessHook +from airflow.providers.amazon.aws.links.emr import ( + EmrServerlessCloudWatchLogsLink, + EmrServerlessDashboardLink, + EmrServerlessLogsLink, + EmrServerlessS3LogsLink, +) from airflow.providers.amazon.aws.operators.emr import ( EmrServerlessCreateApplicationOperator, EmrServerlessDeleteApplicationOperator, EmrServerlessStartJobOperator, EmrServerlessStopApplicationOperator, ) +from airflow.serialization.serialized_objects import ( + SerializedBaseOperator, +) from airflow.utils.types import NOTSET if TYPE_CHECKING: @@ -1096,6 +1105,52 @@ class TestEmrServerlessStartJobOperator: job_run_id=job_run_id, ) + def test_operator_extra_links_mapped_without_applicationui_enabled( + self, + ): + operator = EmrServerlessStartJobOperator.partial( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=spark_job_driver, + enable_application_ui_links=False, + ).expand( + configuration_overrides=[s3_configuration_overrides, cloudwatch_configuration_overrides], + ) + + serialize = SerializedBaseOperator.serialize + deserialize = SerializedBaseOperator.deserialize_operator + deserialized_operator = deserialize(serialize(operator)) + + assert deserialized_operator.operator_extra_links == [ + EmrServerlessS3LogsLink(), + EmrServerlessCloudWatchLogsLink(), + ] + + def test_operator_extra_links_mapped_with_applicationui_enabled_at_partial( + self, + ): + operator = EmrServerlessStartJobOperator.partial( + task_id=task_id, + application_id=application_id, + execution_role_arn=execution_role_arn, + job_driver=spark_job_driver, + enable_application_ui_links=True, + ).expand( + configuration_overrides=[s3_configuration_overrides, cloudwatch_configuration_overrides], + ) + + serialize = SerializedBaseOperator.serialize + deserialize = SerializedBaseOperator.deserialize_operator + deserialized_operator = deserialize(serialize(operator)) + + assert deserialized_operator.operator_extra_links == [ + EmrServerlessS3LogsLink(), + EmrServerlessCloudWatchLogsLink(), + EmrServerlessDashboardLink(), + EmrServerlessLogsLink(), + ] + class TestEmrServerlessDeleteOperator: @mock.patch.object(EmrServerlessHook, "get_waiter")