This is an automated email from the ASF dual-hosted git repository. taragolis pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 5983506df3 Add operator to invoke Azure-Synapse pipeline (#35091) 5983506df3 is described below commit 5983506df370325f7b23a182798341d17d091a32 Author: ambika-garg <70703123+ambika-g...@users.noreply.github.com> AuthorDate: Thu Nov 16 04:51:42 2023 -0500 Add operator to invoke Azure-Synapse pipeline (#35091) * Update to resolve rebase conflicts and pass pre-commit hooks * Feature: Add Azure Synapse Pipeline run operator in Microsoft Provider * Add a hook to interact with Azure Synapse Analytics * Add a operator to trigger Synapse pipeline from DAG and operator link * Add unit tests for operator and hook * Update provider.yaml to support new operator, operator link and hook * Update provider_dependencies to install azure-synapse-artifacts * Add spellings to resolve build docs test * Fix: Pytest Tests * Add Mock Synapse Workspace URL * Set Default wait_for_termination to False * Rename files as per standards * Move AzureSynapsePipelineHook class to synapse.py for ease of find * Fix all imports for the class * Remove the file from provider.yaml * Delete the synapse_pipeline.py file --------- Co-authored-by: Ambika Garg <ambika.g...@growthjockey.com> --- airflow/providers/microsoft/azure/hooks/synapse.py | 191 +++++++++++++++++- .../providers/microsoft/azure/operators/synapse.py | 187 +++++++++++++++++- airflow/providers/microsoft/azure/provider.yaml | 8 +- docs/spelling_wordlist.txt | 2 + generated/provider_dependencies.json | 1 + .../microsoft/azure/hooks/test_synapse_pipeline.py | 159 +++++++++++++++ .../microsoft/azure/operators/test_synapse.py | 218 ++++++++++++++++++++- .../azure/example_synapse_run_pipeline.py | 59 ++++++ 8 files changed, 813 insertions(+), 12 deletions(-) diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py b/airflow/providers/microsoft/azure/hooks/synapse.py index e284194376..d48109d694 100644 --- a/airflow/providers/microsoft/azure/hooks/synapse.py +++ b/airflow/providers/microsoft/azure/hooks/synapse.py @@ -19,10 +19,12 @@ from __future__ import annotations import time from typing import TYPE_CHECKING, Any, Union +from azure.core.exceptions import ServiceRequestError from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.synapse.artifacts import ArtifactsClient from azure.synapse.spark import SparkClient -from airflow.exceptions import AirflowTaskTimeout +from airflow.exceptions import AirflowException, AirflowTaskTimeout from airflow.hooks.base import BaseHook from airflow.providers.microsoft.azure.utils import ( add_managed_identity_connection_widgets, @@ -31,6 +33,7 @@ from airflow.providers.microsoft.azure.utils import ( ) if TYPE_CHECKING: + from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun from azure.synapse.spark.models import SparkBatchJobOptions Credentials = Union[ClientSecretCredential, DefaultAzureCredential] @@ -217,3 +220,189 @@ class AzureSynapseHook(BaseHook): :param job_id: The synapse spark job identifier. """ self.get_conn().spark_batch.cancel_spark_batch_job(job_id) + + +class AzureSynapsePipelineRunStatus: + """Azure Synapse pipeline operation statuses.""" + + QUEUED = "Queued" + IN_PROGRESS = "InProgress" + SUCCEEDED = "Succeeded" + FAILED = "Failed" + CANCELING = "Canceling" + CANCELLED = "Cancelled" + TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED} + INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING} + FAILURE_STATES = {FAILED, CANCELLED} + + +class AzureSynapsePipelineRunException(AirflowException): + """An exception that indicates a pipeline run failed to complete.""" + + +class AzureSynapsePipelineHook(BaseHook): + """ + A hook to interact with Azure Synapse Pipeline. + + :param azure_synapse_conn_id: The :ref:`Azure Synapse connection id<howto/connection:synapse>`. + :param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace development endpoint. + """ + + conn_type: str = "azure_synapse_pipeline" + conn_name_attr: str = "azure_synapse_conn_id" + default_conn_name: str = "azure_synapse_connection" + hook_name: str = "Azure Synapse Pipeline" + + @staticmethod + def get_connection_form_widgets() -> dict[str, Any]: + """Returns connection widgets to add to connection form.""" + from flask_appbuilder.fieldwidgets import BS3TextFieldWidget + from flask_babel import lazy_gettext + from wtforms import StringField + + return { + "tenantId": StringField(lazy_gettext("Tenant ID"), widget=BS3TextFieldWidget()), + } + + @staticmethod + def get_ui_field_behaviour() -> dict[str, Any]: + """Returns custom field behaviour.""" + return { + "hidden_fields": ["schema", "port", "extra"], + "relabeling": {"login": "Client ID", "password": "Secret", "host": "Synapse Workspace URL"}, + } + + def __init__( + self, azure_synapse_workspace_dev_endpoint: str, azure_synapse_conn_id: str = default_conn_name + ): + self._conn = None + self.conn_id = azure_synapse_conn_id + self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint + super().__init__() + + def _get_field(self, extras, name): + return get_field( + conn_id=self.conn_id, + conn_type=self.conn_type, + extras=extras, + field_name=name, + ) + + def get_conn(self) -> ArtifactsClient: + if self._conn is not None: + return self._conn + + conn = self.get_connection(self.conn_id) + extras = conn.extra_dejson + tenant = self._get_field(extras, "tenantId") + + credential: Credentials + if conn.login is not None and conn.password is not None: + if not tenant: + raise ValueError("A Tenant ID is required when authenticating with Client ID and Secret.") + + credential = ClientSecretCredential( + client_id=conn.login, client_secret=conn.password, tenant_id=tenant + ) + else: + credential = DefaultAzureCredential() + self._conn = self._create_client(credential, self.azure_synapse_workspace_dev_endpoint) + + if self._conn is not None: + return self._conn + else: + raise ValueError("Failed to create ArtifactsClient") + + @staticmethod + def _create_client(credential: Credentials, endpoint: str): + return ArtifactsClient(credential=credential, endpoint=endpoint) + + def run_pipeline(self, pipeline_name: str, **config: Any) -> CreateRunResponse: + """ + Run a Synapse pipeline. + + :param pipeline_name: The pipeline name. + :param config: Extra parameters for the Synapse Artifact Client. + :return: The pipeline run Id. + """ + return self.get_conn().pipeline.create_pipeline_run(pipeline_name, **config) + + def get_pipeline_run(self, run_id: str) -> PipelineRun: + """ + Get the pipeline run. + + :param run_id: The pipeline run identifier. + :return: The pipeline run. + """ + return self.get_conn().pipeline_run.get_pipeline_run(run_id=run_id) + + def get_pipeline_run_status(self, run_id: str) -> str: + """ + Get a pipeline run's current status. + + :param run_id: The pipeline run identifier. + + :return: The status of the pipeline run. + """ + pipeline_run_status = self.get_pipeline_run( + run_id=run_id, + ).status + + return str(pipeline_run_status) + + def refresh_conn(self) -> ArtifactsClient: + self._conn = None + return self.get_conn() + + def wait_for_pipeline_run_status( + self, + run_id: str, + expected_statuses: str | set[str], + check_interval: int = 60, + timeout: int = 60 * 60 * 24 * 7, + ) -> bool: + """ + Waits for a pipeline run to match an expected status. + + :param run_id: The pipeline run identifier. + :param expected_statuses: The desired status(es) to check against a pipeline run's current status. + :param check_interval: Time in seconds to check on a pipeline run's status. + :param timeout: Time in seconds to wait for a pipeline to reach a terminal status or the expected + status. + + :return: Boolean indicating if the pipeline run has reached the ``expected_status``. + """ + pipeline_run_status = self.get_pipeline_run_status(run_id=run_id) + executed_after_token_refresh = True + start_time = time.monotonic() + + while ( + pipeline_run_status not in AzureSynapsePipelineRunStatus.TERMINAL_STATUSES + and pipeline_run_status not in expected_statuses + ): + if start_time + timeout < time.monotonic(): + raise AzureSynapsePipelineRunException( + f"Pipeline run {run_id} has not reached a terminal status after {timeout} seconds." + ) + + # Wait to check the status of the pipeline run based on the ``check_interval`` configured. + time.sleep(check_interval) + + try: + pipeline_run_status = self.get_pipeline_run_status(run_id=run_id) + executed_after_token_refresh = True + except ServiceRequestError: + if executed_after_token_refresh: + self.refresh_conn() + else: + raise + + return pipeline_run_status in expected_statuses + + def cancel_run_pipeline(self, run_id: str) -> None: + """ + Cancel the pipeline run. + + :param run_id: The pipeline run identifier. + """ + self.get_conn().pipeline_run.cancel_pipeline_run(run_id) diff --git a/airflow/providers/microsoft/azure/operators/synapse.py b/airflow/providers/microsoft/azure/operators/synapse.py index e7fde11528..f7a23d5f09 100644 --- a/airflow/providers/microsoft/azure/operators/synapse.py +++ b/airflow/providers/microsoft/azure/operators/synapse.py @@ -17,14 +17,24 @@ from __future__ import annotations from functools import cached_property -from typing import TYPE_CHECKING, Sequence +from typing import TYPE_CHECKING, Any, Sequence +from urllib.parse import urlencode -from airflow.models import BaseOperator -from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, AzureSynapseSparkBatchRunStatus +from airflow.exceptions import AirflowException +from airflow.hooks.base import BaseHook +from airflow.models import BaseOperator, BaseOperatorLink, XCom +from airflow.providers.microsoft.azure.hooks.synapse import ( + AzureSynapseHook, + AzureSynapsePipelineHook, + AzureSynapsePipelineRunException, + AzureSynapsePipelineRunStatus, + AzureSynapseSparkBatchRunStatus, +) if TYPE_CHECKING: from azure.synapse.spark.models import SparkBatchJobOptions + from airflow.models.taskinstancekey import TaskInstanceKey from airflow.utils.context import Context @@ -108,3 +118,174 @@ class AzureSynapseRunSparkBatchOperator(BaseOperator): job_id=self.job_id, ) self.log.info("Job run %s has been cancelled successfully.", self.job_id) + + +class AzureSynapsePipelineRunLink(BaseOperatorLink): + """Constructs a link to monitor a pipeline run in Azure Synapse.""" + + name = "Monitor Pipeline Run" + + def get_fields_from_url(self, workspace_url): + """ + Extracts the workspace_name, subscription_id and resource_group from the Synapse workspace url. + + :param workspace_url: The workspace url. + """ + import re + from urllib.parse import unquote, urlparse + + pattern = r"https://web\.azuresynapse\.net\?workspace=(.*)" + match = re.search(pattern, workspace_url) + + if not match: + raise ValueError("Invalid workspace URL format") + + extracted_text = match.group(1) + parsed_url = urlparse(extracted_text) + path = unquote(parsed_url.path) + path_segments = path.split("/") + if len(path_segments) < 5: + raise + + return { + "workspace_name": path_segments[-1], + "subscription_id": path_segments[2], + "resource_group": path_segments[4], + } + + def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey): + run_id = XCom.get_value(key="run_id", ti_key=ti_key) or "" + conn_id = operator.azure_synapse_conn_id # type: ignore + conn = BaseHook.get_connection(conn_id) + self.synapse_workspace_url = conn.host + + fields = self.get_fields_from_url(self.synapse_workspace_url) + + params = { + "workspace": f"/subscriptions/{fields['subscription_id']}" + f"/resourceGroups/{fields['resource_group']}/providers/Microsoft.Synapse" + f"/workspaces/{fields['workspace_name']}", + } + encoded_params = urlencode(params) + base_url = f"https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}?" + + return base_url + encoded_params + + +class AzureSynapseRunPipelineOperator(BaseOperator): + """ + Executes a Synapse Pipeline. + + :param pipeline_name: The name of the pipeline to execute. + :param azure_synapse_conn_id: The Airflow connection ID for Azure Synapse. + :param azure_synapse_workspace_dev_endpoint: The Azure Synapse workspace development endpoint. + :param wait_for_termination: Flag to wait on a pipeline run's termination. + :param reference_pipeline_run_id: The pipeline run identifier. If this run ID is specified the parameters + of the specified run will be used to create a new run. + :param is_recovery: Recovery mode flag. If recovery mode is set to `True`, the specified referenced + pipeline run and the new run will be grouped under the same ``groupId``. + :param start_activity_name: In recovery mode, the rerun will start from this activity. If not specified, + all activities will run. + :param parameters: Parameters of the pipeline run. These parameters are referenced in a pipeline via + ``@pipeline().parameters.parameterName`` and will be used only if the ``reference_pipeline_run_id`` is + not specified. + :param timeout: Time in seconds to wait for a pipeline to reach a terminal status for non-asynchronous + waits. Used only if ``wait_for_termination`` is True. + :param check_interval: Time in seconds to check on a pipeline run's status for non-asynchronous waits. + Used only if ``wait_for_termination`` is True. + + """ + + template_fields: Sequence[str] = ("azure_synapse_conn_id",) + + operator_extra_links = (AzureSynapsePipelineRunLink(),) + + def __init__( + self, + pipeline_name: str, + azure_synapse_conn_id: str, + azure_synapse_workspace_dev_endpoint: str, + wait_for_termination: bool = True, + reference_pipeline_run_id: str | None = None, + is_recovery: bool | None = None, + start_activity_name: str | None = None, + parameters: dict[str, Any] | None = None, + timeout: int = 60 * 60 * 24 * 7, + check_interval: int = 60, + **kwargs, + ) -> None: + super().__init__(**kwargs) + self.azure_synapse_conn_id = azure_synapse_conn_id + self.pipeline_name = pipeline_name + self.azure_synapse_workspace_dev_endpoint = azure_synapse_workspace_dev_endpoint + self.wait_for_termination = wait_for_termination + self.reference_pipeline_run_id = reference_pipeline_run_id + self.is_recovery = is_recovery + self.start_activity_name = start_activity_name + self.parameters = parameters + self.timeout = timeout + self.check_interval = check_interval + + @cached_property + def hook(self): + """Create and return an AzureSynapsePipelineHook (cached).""" + return AzureSynapsePipelineHook( + azure_synapse_conn_id=self.azure_synapse_conn_id, + azure_synapse_workspace_dev_endpoint=self.azure_synapse_workspace_dev_endpoint, + ) + + def execute(self, context) -> None: + self.log.info("Executing the %s pipeline.", self.pipeline_name) + response = self.hook.run_pipeline( + pipeline_name=self.pipeline_name, + reference_pipeline_run_id=self.reference_pipeline_run_id, + is_recovery=self.is_recovery, + start_activity_name=self.start_activity_name, + parameters=self.parameters, + ) + self.run_id = vars(response)["run_id"] + # Push the ``run_id`` value to XCom regardless of what happens during execution. This allows for + # retrieval the executed pipeline's ``run_id`` for downstream tasks especially if performing an + # asynchronous wait. + context["ti"].xcom_push(key="run_id", value=self.run_id) + + if self.wait_for_termination: + self.log.info("Waiting for pipeline run %s to terminate.", self.run_id) + + if self.hook.wait_for_pipeline_run_status( + run_id=self.run_id, + expected_statuses=AzureSynapsePipelineRunStatus.SUCCEEDED, + check_interval=self.check_interval, + timeout=self.timeout, + ): + self.log.info("Pipeline run %s has completed successfully.", self.run_id) + else: + raise AzureSynapsePipelineRunException( + f"Pipeline run {self.run_id} has failed or has been cancelled." + ) + + def execute_complete(self, event: dict[str, str]) -> None: + """ + Callback for when the trigger fires - returns immediately. + + Relies on trigger to throw an exception, otherwise it assumes execution was successful. + """ + if event: + if event["status"] == "error": + raise AirflowException(event["message"]) + self.log.info(event["message"]) + + def on_kill(self) -> None: + if self.run_id: + self.hook.cancel_run_pipeline(run_id=self.run_id) + + # Check to ensure the pipeline run was cancelled as expected. + if self.hook.wait_for_pipeline_run_status( + run_id=self.run_id, + expected_statuses=AzureSynapsePipelineRunStatus.CANCELLED, + check_interval=self.check_interval, + timeout=self.timeout, + ): + self.log.info("Pipeline run %s has been cancelled successfully.", self.run_id) + else: + raise AzureSynapsePipelineRunException(f"Pipeline run {self.run_id} was not cancelled.") diff --git a/airflow/providers/microsoft/azure/provider.yaml b/airflow/providers/microsoft/azure/provider.yaml index 975feb276a..53bccda6c7 100644 --- a/airflow/providers/microsoft/azure/provider.yaml +++ b/airflow/providers/microsoft/azure/provider.yaml @@ -82,6 +82,7 @@ dependencies: - azure-storage-file-share - azure-servicebus>=7.6.1 - azure-synapse-spark + - azure-synapse-artifacts>=0.17.0 - adal>=1.2.7 - azure-storage-file-datalake>=12.9.1 - azure-kusto-data>=4.1.0 @@ -279,15 +280,13 @@ connection-types: connection-type: azure_fileshare - hook-class-name: airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook connection-type: azure_container_volume - - hook-class-name: >- - airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook connection-type: azure_container_instance - hook-class-name: airflow.providers.microsoft.azure.hooks.wasb.WasbHook connection-type: wasb - hook-class-name: airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook connection-type: azure_data_factory - - hook-class-name: >- - airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook + - hook-class-name: airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook connection-type: azure_container_registry - hook-class-name: airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook connection-type: azure_service_bus @@ -304,6 +303,7 @@ logging: extra-links: - airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryPipelineRunLink + - airflow.providers.microsoft.azure.operators.synapse.AzureSynapsePipelineRunLink config: azure_remote_logging: diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt index 3979ae3c59..b787191fd0 100644 --- a/docs/spelling_wordlist.txt +++ b/docs/spelling_wordlist.txt @@ -315,6 +315,7 @@ cpus crd createDisposition CreateQueryOperator +CreateRunResponse creationTimestamp credssp Cron @@ -1144,6 +1145,7 @@ pinecone pinodb Pinot pinot +PipelineRun pkill plaintext platformVersion diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 0b234b4e12..058078a191 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -586,6 +586,7 @@ "azure-storage-blob>=12.14.0", "azure-storage-file-datalake>=12.9.1", "azure-storage-file-share", + "azure-synapse-artifacts>=0.17.0", "azure-synapse-spark" ], "cross-providers-deps": [ diff --git a/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py new file mode 100644 index 0000000000..d0309bd0a6 --- /dev/null +++ b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py @@ -0,0 +1,159 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +from unittest.mock import MagicMock, patch + +import pytest +from azure.identity import ClientSecretCredential, DefaultAzureCredential +from azure.synapse.artifacts import ArtifactsClient + +from airflow.models.connection import Connection +from airflow.providers.microsoft.azure.hooks.synapse import ( + AzureSynapsePipelineHook, + AzureSynapsePipelineRunException, + AzureSynapsePipelineRunStatus, +) + +DEFAULT_CONNECTION_CLIENT_SECRET = "azure_synapse_test_client_secret" +DEFAULT_CONNECTION_DEFAULT_CREDENTIAL = "azure_synapse_test_default_credential" + +SYNAPSE_WORKSPACE_URL = "synapse_workspace_url" +AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint" +PIPELINE_NAME = "pipeline_name" +RUN_ID = "run_id" + + +@pytest.fixture(autouse=True) +def setup_connections(create_mock_connections): + create_mock_connections( + # connection_client_secret + Connection( + conn_id=DEFAULT_CONNECTION_CLIENT_SECRET, + conn_type="azure_synapse_pipeline", + host=SYNAPSE_WORKSPACE_URL, + login="clientId", + password="clientSecret", + extra={"tenantId": "tenantId"}, + ), + # connection_default_credential + Connection( + conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, + conn_type="azure_synapse_pipeline", + host=SYNAPSE_WORKSPACE_URL, + extra={}, + ), + # connection_missing_tenant_id + Connection( + conn_id="azure_synapse_missing_tenant_id", + conn_type="azure_synapse_pipeline", + host=SYNAPSE_WORKSPACE_URL, + login="clientId", + password="clientSecret", + extra={}, + ), + ) + + +@pytest.fixture +def hook(): + client = AzureSynapsePipelineHook( + azure_synapse_conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, + azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT, + ) + client._conn = MagicMock(spec=["pipeline_run", "pipeline"]) + + return client + + +@pytest.mark.parametrize( + ("connection_id", "credential_type"), + [ + (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential), + (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential), + ], +) +def test_get_connection_by_credential_client_secret(connection_id: str, credential_type: type): + hook = AzureSynapsePipelineHook( + azure_synapse_conn_id=connection_id, + azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT, + ) + + with patch.object(hook, "_create_client") as mock_create_client: + mock_create_client.return_value = MagicMock() + connection = hook.get_conn() + assert connection is not None + mock_create_client.assert_called_once() + assert isinstance(mock_create_client.call_args.args[0], credential_type) + assert mock_create_client.call_args.args[1] == AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT + + +def test_run_pipeline(hook: AzureSynapsePipelineHook): + hook.run_pipeline(PIPELINE_NAME) + + if hook._conn is not None and isinstance(hook._conn, ArtifactsClient): + hook._conn.pipeline.create_pipeline_run.assert_called_with(PIPELINE_NAME) + + +def test_get_pipeline_run(hook: AzureSynapsePipelineHook): + hook.get_pipeline_run(run_id=RUN_ID) + + if hook._conn is not None and isinstance(hook._conn, ArtifactsClient): + hook._conn.pipeline_run.get_pipeline_run.assert_called_with(run_id=RUN_ID) + + +def test_cancel_run_pipeline(hook: AzureSynapsePipelineHook): + hook.cancel_run_pipeline(RUN_ID) + + if hook._conn is not None and isinstance(hook._conn, ArtifactsClient): + hook._conn.pipeline_run.cancel_pipeline_run.assert_called_with(RUN_ID) + + +_wait_for_pipeline_run_status_test_args = [ + (AzureSynapsePipelineRunStatus.SUCCEEDED, AzureSynapsePipelineRunStatus.SUCCEEDED, True), + (AzureSynapsePipelineRunStatus.FAILED, AzureSynapsePipelineRunStatus.SUCCEEDED, False), + (AzureSynapsePipelineRunStatus.CANCELLED, AzureSynapsePipelineRunStatus.SUCCEEDED, False), + (AzureSynapsePipelineRunStatus.IN_PROGRESS, AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"), + (AzureSynapsePipelineRunStatus.QUEUED, AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"), + (AzureSynapsePipelineRunStatus.CANCELING, AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"), + (AzureSynapsePipelineRunStatus.SUCCEEDED, AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True), + (AzureSynapsePipelineRunStatus.FAILED, AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True), + (AzureSynapsePipelineRunStatus.CANCELLED, AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True), +] + + +@pytest.mark.parametrize( + argnames=("pipeline_run_status", "expected_status", "expected_output"), + argvalues=_wait_for_pipeline_run_status_test_args, + ids=[ + f"run_status_{argval[0]}_expected_{argval[1]}" + if isinstance(argval[1], str) + else f"run_status_{argval[0]}_expected_AnyTerminalStatus" + for argval in _wait_for_pipeline_run_status_test_args + ], +) +def test_wait_for_pipeline_run_status(hook, pipeline_run_status, expected_status, expected_output): + config = {"run_id": RUN_ID, "timeout": 3, "check_interval": 1, "expected_statuses": expected_status} + + with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as mock_pipeline_run: + mock_pipeline_run.return_value.status = pipeline_run_status + + if expected_output != "timeout": + assert hook.wait_for_pipeline_run_status(**config) == expected_output + else: + with pytest.raises(AzureSynapsePipelineRunException): + hook.wait_for_pipeline_run_status(**config) diff --git a/tests/providers/microsoft/azure/operators/test_synapse.py b/tests/providers/microsoft/azure/operators/test_synapse.py index 233e1c57fd..14ffd783ba 100644 --- a/tests/providers/microsoft/azure/operators/test_synapse.py +++ b/tests/providers/microsoft/azure/operators/test_synapse.py @@ -17,24 +17,42 @@ from __future__ import annotations from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest from airflow.models import Connection -from airflow.providers.microsoft.azure.operators.synapse import AzureSynapseRunSparkBatchOperator +from airflow.providers.microsoft.azure.hooks.synapse import ( + AzureSynapsePipelineHook, + AzureSynapsePipelineRunException, + AzureSynapsePipelineRunStatus, +) +from airflow.providers.microsoft.azure.operators.synapse import ( + AzureSynapsePipelineRunLink, + AzureSynapseRunPipelineOperator, + AzureSynapseRunSparkBatchOperator, +) from airflow.utils import timezone DEFAULT_DATE = timezone.datetime(2021, 1, 1) -SUBSCRIPTION_ID = "my-subscription-id" +SUBSCRIPTION_ID = "subscription_id" +TENANT_ID = "tenant_id" TASK_ID = "run_spark_op" +AZURE_SYNAPSE_PIPELINE_TASK_ID = "run_pipeline_op" AZURE_SYNAPSE_CONN_ID = "azure_synapse_test" CONN_EXTRAS = { "synapse__subscriptionId": SUBSCRIPTION_ID, "synapse__tenantId": "my-tenant-id", "synapse__spark_pool": "my-spark-pool", } +SYNAPSE_PIPELINE_CONN_EXTRAS = {"tenantId": TENANT_ID} JOB_RUN_RESPONSE = {"id": 123} +PIPELINE_NAME = "Pipeline 1" +AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint" +RESOURCE_GROUP = "op-resource-group" +WORKSPACE_NAME = "workspace-test" +AZURE_SYNAPSE_WORKSPACE_URL = f"https://web.azuresynapse.net?workspace=%2fsubscriptions%{SUBSCRIPTION_ID}%2fresourceGroups%2f{RESOURCE_GROUP}%2fproviders%2fMicrosoft.Synapse%2fworkspaces%2f{WORKSPACE_NAME}" +PIPELINE_RUN_RESPONSE = {"run_id": "run_id"} class TestAzureSynapseRunSparkBatchOperator: @@ -53,7 +71,7 @@ class TestAzureSynapseRunSparkBatchOperator: create_mock_connection( Connection( conn_id=AZURE_SYNAPSE_CONN_ID, - conn_type="azure_synapse", + conn_type="azure_synapse_pipeline", host="https://synapsetest.net", login="client-id", password="client-secret", @@ -102,3 +120,195 @@ class TestAzureSynapseRunSparkBatchOperator: op.execute(context=self.mock_context) op.on_kill() mock_cancel_job_run.assert_called_once_with(job_id=JOB_RUN_RESPONSE["id"]) + + +class TestAzureSynapseRunPipelineOperator: + @pytest.fixture(autouse=True) + def setup_test_cases(self, create_mock_connection): + self.mock_ti = MagicMock() + self.mock_context = {"ti": self.mock_ti} + self.config = { + "task_id": AZURE_SYNAPSE_PIPELINE_TASK_ID, + "azure_synapse_conn_id": AZURE_SYNAPSE_CONN_ID, + "pipeline_name": PIPELINE_NAME, + "azure_synapse_workspace_dev_endpoint": AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT, + "check_interval": 1, + "timeout": 3, + } + + create_mock_connection( + Connection( + conn_id=AZURE_SYNAPSE_CONN_ID, + conn_type="azure_synapse_pipeline", + host=AZURE_SYNAPSE_WORKSPACE_URL, + login="client_id", + password="client_secret", + extra=SYNAPSE_PIPELINE_CONN_EXTRAS, + ) + ) + + @staticmethod + def create_pipeline_run(status: str): + """Helper function to create a mock pipeline run with a given execution status.""" + + run = MagicMock() + run.status = status + + return run + + @patch.object(AzureSynapsePipelineHook, "run_pipeline", return_value=MagicMock(**PIPELINE_RUN_RESPONSE)) + @pytest.mark.parametrize( + "pipeline_run_status,expected_output", + [ + (AzureSynapsePipelineRunStatus.SUCCEEDED, None), + (AzureSynapsePipelineRunStatus.FAILED, "exception"), + (AzureSynapsePipelineRunStatus.CANCELLED, "exception"), + (AzureSynapsePipelineRunStatus.IN_PROGRESS, "timeout"), + (AzureSynapsePipelineRunStatus.QUEUED, "timeout"), + (AzureSynapsePipelineRunStatus.CANCELING, "timeout"), + ], + ) + def test_execute_wait_for_termination(self, mock_run_pipeline, pipeline_run_status, expected_output): + # Initialize the operator with mock config, (**) unpacks the config dict. + operator = AzureSynapseRunPipelineOperator(**self.config) + + assert operator.azure_synapse_conn_id == self.config["azure_synapse_conn_id"] + assert operator.pipeline_name == self.config["pipeline_name"] + assert ( + operator.azure_synapse_workspace_dev_endpoint + == self.config["azure_synapse_workspace_dev_endpoint"] + ) + assert operator.check_interval == self.config["check_interval"] + assert operator.timeout == self.config["timeout"] + assert operator.wait_for_termination + + with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as mock_get_pipeline_run: + mock_get_pipeline_run.return_value = TestAzureSynapseRunPipelineOperator.create_pipeline_run( + pipeline_run_status + ) + + if not expected_output: + # A successful operator execution should not return any values. + assert not operator.execute(context=self.mock_context) + elif expected_output == "exception": + # The operator should fail if the pipeline run fails or is canceled. + with pytest.raises( + AzureSynapsePipelineRunException, + match=f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has failed or has been cancelled.", + ): + operator.execute(context=self.mock_context) + else: + # Demonstrating the operator timing out after surpassing the configured timeout value. + with pytest.raises( + AzureSynapsePipelineRunException, + match=( + f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has not reached a terminal status " + f"after {self.config['timeout']} seconds." + ), + ): + operator.execute(context=self.mock_context) + + # Check the ``run_id`` attr is assigned after executing the pipeline. + assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"] + + # Check to ensure an `XCom` is pushed regardless of pipeline run result. + self.mock_ti.xcom_push.assert_called_once_with( + key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"] + ) + + # Check if mock_run_pipeline called with particular set of arguments. + mock_run_pipeline.assert_called_once_with( + pipeline_name=self.config["pipeline_name"], + reference_pipeline_run_id=None, + is_recovery=None, + start_activity_name=None, + parameters=None, + ) + + if pipeline_run_status in AzureSynapsePipelineRunStatus.TERMINAL_STATUSES: + mock_get_pipeline_run.assert_called_once_with(run_id=mock_run_pipeline.return_value.run_id) + else: + # When the pipeline run status is not in a terminal status or "Succeeded", the operator will + # continue to call ``get_pipeline_run()`` until a ``timeout`` number of seconds has passed + # (3 seconds for this test). Therefore, there should be 4 calls of this function: one + # initially and 3 for each check done at a 1 second interval. + assert mock_get_pipeline_run.call_count == 4 + + mock_get_pipeline_run.assert_called_with(run_id=mock_run_pipeline.return_value.run_id) + + @patch.object(AzureSynapsePipelineHook, "run_pipeline", return_value=MagicMock(**PIPELINE_RUN_RESPONSE)) + def test_execute_no_wait_for_termination(self, mock_run_pipeline): + operator = AzureSynapseRunPipelineOperator(wait_for_termination=False, **self.config) + + assert operator.azure_synapse_conn_id == self.config["azure_synapse_conn_id"] + assert operator.pipeline_name == self.config["pipeline_name"] + assert ( + operator.azure_synapse_workspace_dev_endpoint + == self.config["azure_synapse_workspace_dev_endpoint"] + ) + assert operator.check_interval == self.config["check_interval"] + assert operator.timeout == self.config["timeout"] + assert not operator.wait_for_termination + + with patch.object( + AzureSynapsePipelineHook, "get_pipeline_run", autospec=True + ) as mock_get_pipeline_run: + operator.execute(context=self.mock_context) + + # Check the ``run_id`` attr is assigned after executing the pipeline. + assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"] + + # Check to ensure an `XCom` is pushed regardless of pipeline run result. + self.mock_ti.xcom_push.assert_called_once_with( + key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"] + ) + + mock_run_pipeline.assert_called_once_with( + pipeline_name=self.config["pipeline_name"], + reference_pipeline_run_id=None, + is_recovery=None, + start_activity_name=None, + parameters=None, + ) + + # Checking the pipeline run status should _not_ be called when ``wait_for_termination`` is False. + mock_get_pipeline_run.assert_not_called() + + @pytest.mark.db_test + def test_run_pipeline_operator_link(self, create_task_instance_of_operator): + ti = create_task_instance_of_operator( + AzureSynapseRunPipelineOperator, + dag_id="test_synapse_run_pipeline_op_link", + execution_date=DEFAULT_DATE, + task_id=AZURE_SYNAPSE_PIPELINE_TASK_ID, + azure_synapse_conn_id=AZURE_SYNAPSE_CONN_ID, + pipeline_name=PIPELINE_NAME, + azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT, + ) + + ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]) + + url = ti.task.get_extra_links(ti, "Monitor Pipeline Run") + + EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = ( + "https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}" + "?workspace=%2Fsubscriptions%2F{subscription_id}%2F" + "resourceGroups%2F{resource_group}%2Fproviders%2FMicrosoft.Synapse" + "%2Fworkspaces%2F{workspace_name}" + ) + + conn = AzureSynapsePipelineHook.get_connection(AZURE_SYNAPSE_CONN_ID) + conn_synapse_workspace_url = conn.host + + # Extract the workspace_name, subscription_id and resource_group from the Synapse workspace url. + pipeline_run_object = AzureSynapsePipelineRunLink() + fields = pipeline_run_object.get_fields_from_url(workspace_url=conn_synapse_workspace_url) + + assert url == ( + EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK.format( + run_id=PIPELINE_RUN_RESPONSE["run_id"], + subscription_id=fields["subscription_id"], + resource_group=fields["resource_group"], + workspace_name=fields["workspace_name"], + ) + ) diff --git a/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py new file mode 100644 index 0000000000..be11a665c4 --- /dev/null +++ b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py @@ -0,0 +1,59 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import os +from datetime import datetime + +from airflow.models import DAG +from airflow.providers.microsoft.azure.operators.synapse import AzureSynapseRunPipelineOperator + +try: + from airflow.operators.empty import EmptyOperator +except ModuleNotFoundError: + from airflow.operators.dummy import DummyOperator as EmptyOperator # type: ignore + +ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID") + +with DAG( + dag_id="example_synapse_run_pipeline", + start_date=datetime(2021, 8, 13), + schedule="@daily", + catchup=False, + tags=["synapse", "example"], +) as dag: + begin = EmptyOperator(task_id="begin") + + run_pipeline1 = AzureSynapseRunPipelineOperator( + task_id="run_pipeline1", + azure_synapse_conn_id="azure_synapse_connection", + pipeline_name="Pipeline 1", + azure_synapse_workspace_dev_endpoint="azure_synapse_workspace_dev_endpoint", + ) + + begin >> run_pipeline1 + + from tests.system.utils.watcher import watcher + + # This test needs watcher in order to properly mark success/failure + # when "tearDown" task with trigger rule is part of the DAG + list(dag.tasks) >> watcher() + +from tests.system.utils import get_test_run # noqa: E402 + +# Needed to run the example DAG with pytest (see: tests/system/README.md#run_via_pytest) +test_run = get_test_run(dag)