This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 16ffc87ff4 Improve error handling in AWS Links (#35518)
16ffc87ff4 is described below

commit 16ffc87ff42f702221f6cb7e42e08bc208183cf1
Author: Andrey Anshin <andrey.ans...@taragol.is>
AuthorDate: Wed Nov 8 21:34:06 2023 +0400

    Improve error handling in AWS Links (#35518)
    
    * Improve error handling in AWS Links
    
    * Remove debug code
---
 airflow/providers/amazon/aws/links/base_aws.py    |  3 ++
 airflow/providers/amazon/aws/links/emr.py         |  2 +-
 tests/providers/amazon/aws/links/test_base_aws.py | 35 +++++++++++++++++++----
 tests/providers/amazon/aws/links/test_emr.py      |  2 +-
 4 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/airflow/providers/amazon/aws/links/base_aws.py 
b/airflow/providers/amazon/aws/links/base_aws.py
index 97130fabd6..83da4dd93c 100644
--- a/airflow/providers/amazon/aws/links/base_aws.py
+++ b/airflow/providers/amazon/aws/links/base_aws.py
@@ -20,6 +20,7 @@ from __future__ import annotations
 from typing import TYPE_CHECKING, ClassVar
 
 from airflow.models import BaseOperatorLink, XCom
+from airflow.providers.amazon.aws.utils.suppress import return_on_error
 
 if TYPE_CHECKING:
     from airflow.models import BaseOperator
@@ -60,6 +61,7 @@ class BaseAwsLink(BaseOperatorLink):
         except KeyError:
             return ""
 
+    @return_on_error("")
     def get_link(
         self,
         operator: BaseOperator,
@@ -77,6 +79,7 @@ class BaseAwsLink(BaseOperatorLink):
         return self.format_link(**conf) if conf else ""
 
     @classmethod
+    @return_on_error(None)
     def persist(
         cls, context: Context, operator: BaseOperator, region_name: str, 
aws_partition: str, **kwargs
     ) -> None:
diff --git a/airflow/providers/amazon/aws/links/emr.py 
b/airflow/providers/amazon/aws/links/emr.py
index 6c8cd2181e..1bd651a00c 100644
--- a/airflow/providers/amazon/aws/links/emr.py
+++ b/airflow/providers/amazon/aws/links/emr.py
@@ -43,7 +43,7 @@ class EmrLogsLink(BaseAwsLink):
     format_str = BASE_AWS_CONSOLE_LINK + 
"/s3/buckets/{log_uri}?region={region_name}&prefix={job_flow_id}/"
 
     def format_link(self, **kwargs) -> str:
-        if not kwargs["log_uri"]:
+        if not kwargs.get("log_uri"):
             return ""
         return super().format_link(**kwargs)
 
diff --git a/tests/providers/amazon/aws/links/test_base_aws.py 
b/tests/providers/amazon/aws/links/test_base_aws.py
index a8bf17c3db..546ead164d 100644
--- a/tests/providers/amazon/aws/links/test_base_aws.py
+++ b/tests/providers/amazon/aws/links/test_base_aws.py
@@ -18,16 +18,17 @@ from __future__ import annotations
 
 from abc import abstractmethod
 from typing import TYPE_CHECKING, NamedTuple
-from unittest.mock import MagicMock
+from unittest import mock
 
 import pytest
 
+from airflow.models.xcom import XCom
 from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink
 from airflow.serialization.serialized_objects import SerializedDAG
 from tests.test_utils.mock_operators import MockOperator
 
 if TYPE_CHECKING:
-    from airflow.models import TaskInstance
+    from airflow.models.taskinstance import TaskInstance
 
 XCOM_KEY = "test_xcom_key"
 CUSTOM_KEYS = {
@@ -63,7 +64,7 @@ class TestBaseAwsLink:
         ],
     )
     def test_persist(self, region_name, aws_partition, keywords, 
expected_value):
-        mock_context = MagicMock()
+        mock_context = mock.MagicMock()
 
         SimpleBaseAwsLink.persist(
             context=mock_context,
@@ -81,7 +82,7 @@ class TestBaseAwsLink:
         )
 
     def test_disable_xcom_push(self):
-        mock_context = MagicMock()
+        mock_context = mock.MagicMock()
         SimpleBaseAwsLink.persist(
             context=mock_context,
             operator=MockOperator(task_id="test_task_id", do_xcom_push=False),
@@ -91,6 +92,21 @@ class TestBaseAwsLink:
         ti = mock_context["ti"]
         ti.xcom_push.assert_not_called()
 
+    def test_suppress_error_on_xcom_push(self):
+        mock_context = mock.MagicMock()
+        with mock.patch.object(MockOperator, "xcom_push", 
side_effect=PermissionError("FakeError")) as m:
+            SimpleBaseAwsLink.persist(
+                context=mock_context,
+                operator=MockOperator(task_id="test_task_id"),
+                region_name="eu-east-1",
+                aws_partition="aws",
+            )
+            m.assert_called_once_with(
+                mock_context,
+                key="test_xcom_key",
+                value={"region_name": "eu-east-1", "aws_domain": 
"aws.amazon.com"},
+            )
+
 
 def link_test_operator(*links):
     """Helper for create mock operator class with extra links"""
@@ -162,7 +178,7 @@ class BaseAwsLinksTestCase:
         """Helper method for create extra link URL from the parameters."""
         task, ti = self.create_op_and_ti(self.link_class, 
dag_id="test_extra_link", task_id=self.task_id)
 
-        mock_context = MagicMock()
+        mock_context = mock.MagicMock()
         mock_context.__getitem__.side_effect = {"ti": ti}.__getitem__
 
         self.link_class.persist(
@@ -209,6 +225,15 @@ class BaseAwsLinksTestCase:
             deserialized_task.get_extra_links(ti, self.link_class.name) == ""
         ), "Operator link should be empty for deserialized task with no XCom 
push"
 
+    def test_suppress_error_on_xcom_pull(self):
+        """Test ignore any error on XCom pull"""
+        with mock.patch.object(XCom, "get_value", 
side_effect=OSError("FakeError")) as m:
+            op, ti = self.create_op_and_ti(
+                self.link_class, dag_id="test_error_on_xcom_pull", 
task_id=self.task_id
+            )
+            self.link_class().get_link(op, ti_key=ti.key)
+            m.assert_called_once()
+
     @abstractmethod
     def test_extra_link(self, **kwargs):
         """Test: Expected URL Link."""
diff --git a/tests/providers/amazon/aws/links/test_emr.py 
b/tests/providers/amazon/aws/links/test_emr.py
index 59c883362a..c7f12983e8 100644
--- a/tests/providers/amazon/aws/links/test_emr.py
+++ b/tests/providers/amazon/aws/links/test_emr.py
@@ -68,7 +68,7 @@ class TestEmrLogsLink(BaseAwsLinksTestCase):
     @pytest.mark.parametrize(
         "log_url_extra",
         [
-            pytest.param({}, id="no-log-uri", marks=pytest.mark.xfail),
+            pytest.param({}, id="no-log-uri"),
             pytest.param({"log_uri": None}, id="log-uri-none"),
             pytest.param({"log_uri": ""}, id="log-uri-empty"),
         ],

Reply via email to