This is an automated email from the ASF dual-hosted git repository.

taragolis pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 5983506df3 Add operator to invoke Azure-Synapse pipeline (#35091)
5983506df3 is described below

commit 5983506df370325f7b23a182798341d17d091a32
Author: ambika-garg <70703123+ambika-g...@users.noreply.github.com>
AuthorDate: Thu Nov 16 04:51:42 2023 -0500

    Add operator to invoke Azure-Synapse pipeline (#35091)
    
    * Update to resolve rebase conflicts and pass pre-commit hooks
    
    * Feature: Add Azure Synapse Pipeline run operator in Microsoft Provider
    
            * Add a hook to interact with Azure Synapse Analytics
    
            * Add a operator to trigger Synapse pipeline from DAG and operator 
link
    
            * Add unit tests for operator and hook
    
            * Update provider.yaml to support new operator, operator link and 
hook
    
            * Update provider_dependencies to install azure-synapse-artifacts
    
    * Add spellings to resolve build docs test
    
    * Fix: Pytest Tests
    
           * Add Mock Synapse Workspace URL
    
           * Set Default wait_for_termination to False
    
    * Rename files as per standards
    
    * Move AzureSynapsePipelineHook class to synapse.py for ease of find
        * Fix all imports for the class
    
        * Remove the file from provider.yaml
    
        * Delete the synapse_pipeline.py file
    
    ---------
    
    Co-authored-by: Ambika Garg <ambika.g...@growthjockey.com>
---
 airflow/providers/microsoft/azure/hooks/synapse.py | 191 +++++++++++++++++-
 .../providers/microsoft/azure/operators/synapse.py | 187 +++++++++++++++++-
 airflow/providers/microsoft/azure/provider.yaml    |   8 +-
 docs/spelling_wordlist.txt                         |   2 +
 generated/provider_dependencies.json               |   1 +
 .../microsoft/azure/hooks/test_synapse_pipeline.py | 159 +++++++++++++++
 .../microsoft/azure/operators/test_synapse.py      | 218 ++++++++++++++++++++-
 .../azure/example_synapse_run_pipeline.py          |  59 ++++++
 8 files changed, 813 insertions(+), 12 deletions(-)

diff --git a/airflow/providers/microsoft/azure/hooks/synapse.py 
b/airflow/providers/microsoft/azure/hooks/synapse.py
index e284194376..d48109d694 100644
--- a/airflow/providers/microsoft/azure/hooks/synapse.py
+++ b/airflow/providers/microsoft/azure/hooks/synapse.py
@@ -19,10 +19,12 @@ from __future__ import annotations
 import time
 from typing import TYPE_CHECKING, Any, Union
 
+from azure.core.exceptions import ServiceRequestError
 from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.synapse.artifacts import ArtifactsClient
 from azure.synapse.spark import SparkClient
 
-from airflow.exceptions import AirflowTaskTimeout
+from airflow.exceptions import AirflowException, AirflowTaskTimeout
 from airflow.hooks.base import BaseHook
 from airflow.providers.microsoft.azure.utils import (
     add_managed_identity_connection_widgets,
@@ -31,6 +33,7 @@ from airflow.providers.microsoft.azure.utils import (
 )
 
 if TYPE_CHECKING:
+    from azure.synapse.artifacts.models import CreateRunResponse, PipelineRun
     from azure.synapse.spark.models import SparkBatchJobOptions
 
 Credentials = Union[ClientSecretCredential, DefaultAzureCredential]
@@ -217,3 +220,189 @@ class AzureSynapseHook(BaseHook):
         :param job_id: The synapse spark job identifier.
         """
         self.get_conn().spark_batch.cancel_spark_batch_job(job_id)
+
+
+class AzureSynapsePipelineRunStatus:
+    """Azure Synapse pipeline operation statuses."""
+
+    QUEUED = "Queued"
+    IN_PROGRESS = "InProgress"
+    SUCCEEDED = "Succeeded"
+    FAILED = "Failed"
+    CANCELING = "Canceling"
+    CANCELLED = "Cancelled"
+    TERMINAL_STATUSES = {CANCELLED, FAILED, SUCCEEDED}
+    INTERMEDIATE_STATES = {QUEUED, IN_PROGRESS, CANCELING}
+    FAILURE_STATES = {FAILED, CANCELLED}
+
+
+class AzureSynapsePipelineRunException(AirflowException):
+    """An exception that indicates a pipeline run failed to complete."""
+
+
+class AzureSynapsePipelineHook(BaseHook):
+    """
+    A hook to interact with Azure Synapse Pipeline.
+
+    :param azure_synapse_conn_id: The :ref:`Azure Synapse connection 
id<howto/connection:synapse>`.
+    :param azure_synapse_workspace_dev_endpoint: The Azure Synapse Workspace 
development endpoint.
+    """
+
+    conn_type: str = "azure_synapse_pipeline"
+    conn_name_attr: str = "azure_synapse_conn_id"
+    default_conn_name: str = "azure_synapse_connection"
+    hook_name: str = "Azure Synapse Pipeline"
+
+    @staticmethod
+    def get_connection_form_widgets() -> dict[str, Any]:
+        """Returns connection widgets to add to connection form."""
+        from flask_appbuilder.fieldwidgets import BS3TextFieldWidget
+        from flask_babel import lazy_gettext
+        from wtforms import StringField
+
+        return {
+            "tenantId": StringField(lazy_gettext("Tenant ID"), 
widget=BS3TextFieldWidget()),
+        }
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Returns custom field behaviour."""
+        return {
+            "hidden_fields": ["schema", "port", "extra"],
+            "relabeling": {"login": "Client ID", "password": "Secret", "host": 
"Synapse Workspace URL"},
+        }
+
+    def __init__(
+        self, azure_synapse_workspace_dev_endpoint: str, 
azure_synapse_conn_id: str = default_conn_name
+    ):
+        self._conn = None
+        self.conn_id = azure_synapse_conn_id
+        self.azure_synapse_workspace_dev_endpoint = 
azure_synapse_workspace_dev_endpoint
+        super().__init__()
+
+    def _get_field(self, extras, name):
+        return get_field(
+            conn_id=self.conn_id,
+            conn_type=self.conn_type,
+            extras=extras,
+            field_name=name,
+        )
+
+    def get_conn(self) -> ArtifactsClient:
+        if self._conn is not None:
+            return self._conn
+
+        conn = self.get_connection(self.conn_id)
+        extras = conn.extra_dejson
+        tenant = self._get_field(extras, "tenantId")
+
+        credential: Credentials
+        if conn.login is not None and conn.password is not None:
+            if not tenant:
+                raise ValueError("A Tenant ID is required when authenticating 
with Client ID and Secret.")
+
+            credential = ClientSecretCredential(
+                client_id=conn.login, client_secret=conn.password, 
tenant_id=tenant
+            )
+        else:
+            credential = DefaultAzureCredential()
+        self._conn = self._create_client(credential, 
self.azure_synapse_workspace_dev_endpoint)
+
+        if self._conn is not None:
+            return self._conn
+        else:
+            raise ValueError("Failed to create ArtifactsClient")
+
+    @staticmethod
+    def _create_client(credential: Credentials, endpoint: str):
+        return ArtifactsClient(credential=credential, endpoint=endpoint)
+
+    def run_pipeline(self, pipeline_name: str, **config: Any) -> 
CreateRunResponse:
+        """
+        Run a Synapse pipeline.
+
+        :param pipeline_name: The pipeline name.
+        :param config: Extra parameters for the Synapse Artifact Client.
+        :return: The pipeline run Id.
+        """
+        return self.get_conn().pipeline.create_pipeline_run(pipeline_name, 
**config)
+
+    def get_pipeline_run(self, run_id: str) -> PipelineRun:
+        """
+        Get the pipeline run.
+
+        :param run_id: The pipeline run identifier.
+        :return: The pipeline run.
+        """
+        return self.get_conn().pipeline_run.get_pipeline_run(run_id=run_id)
+
+    def get_pipeline_run_status(self, run_id: str) -> str:
+        """
+        Get a pipeline run's current status.
+
+        :param run_id: The pipeline run identifier.
+
+        :return: The status of the pipeline run.
+        """
+        pipeline_run_status = self.get_pipeline_run(
+            run_id=run_id,
+        ).status
+
+        return str(pipeline_run_status)
+
+    def refresh_conn(self) -> ArtifactsClient:
+        self._conn = None
+        return self.get_conn()
+
+    def wait_for_pipeline_run_status(
+        self,
+        run_id: str,
+        expected_statuses: str | set[str],
+        check_interval: int = 60,
+        timeout: int = 60 * 60 * 24 * 7,
+    ) -> bool:
+        """
+        Waits for a pipeline run to match an expected status.
+
+        :param run_id: The pipeline run identifier.
+        :param expected_statuses: The desired status(es) to check against a 
pipeline run's current status.
+        :param check_interval: Time in seconds to check on a pipeline run's 
status.
+        :param timeout: Time in seconds to wait for a pipeline to reach a 
terminal status or the expected
+            status.
+
+        :return: Boolean indicating if the pipeline run has reached the 
``expected_status``.
+        """
+        pipeline_run_status = self.get_pipeline_run_status(run_id=run_id)
+        executed_after_token_refresh = True
+        start_time = time.monotonic()
+
+        while (
+            pipeline_run_status not in 
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES
+            and pipeline_run_status not in expected_statuses
+        ):
+            if start_time + timeout < time.monotonic():
+                raise AzureSynapsePipelineRunException(
+                    f"Pipeline run {run_id} has not reached a terminal status 
after {timeout} seconds."
+                )
+
+            # Wait to check the status of the pipeline run based on the 
``check_interval`` configured.
+            time.sleep(check_interval)
+
+            try:
+                pipeline_run_status = 
self.get_pipeline_run_status(run_id=run_id)
+                executed_after_token_refresh = True
+            except ServiceRequestError:
+                if executed_after_token_refresh:
+                    self.refresh_conn()
+                else:
+                    raise
+
+        return pipeline_run_status in expected_statuses
+
+    def cancel_run_pipeline(self, run_id: str) -> None:
+        """
+        Cancel the pipeline run.
+
+        :param run_id: The pipeline run identifier.
+        """
+        self.get_conn().pipeline_run.cancel_pipeline_run(run_id)
diff --git a/airflow/providers/microsoft/azure/operators/synapse.py 
b/airflow/providers/microsoft/azure/operators/synapse.py
index e7fde11528..f7a23d5f09 100644
--- a/airflow/providers/microsoft/azure/operators/synapse.py
+++ b/airflow/providers/microsoft/azure/operators/synapse.py
@@ -17,14 +17,24 @@
 from __future__ import annotations
 
 from functools import cached_property
-from typing import TYPE_CHECKING, Sequence
+from typing import TYPE_CHECKING, Any, Sequence
+from urllib.parse import urlencode
 
-from airflow.models import BaseOperator
-from airflow.providers.microsoft.azure.hooks.synapse import AzureSynapseHook, 
AzureSynapseSparkBatchRunStatus
+from airflow.exceptions import AirflowException
+from airflow.hooks.base import BaseHook
+from airflow.models import BaseOperator, BaseOperatorLink, XCom
+from airflow.providers.microsoft.azure.hooks.synapse import (
+    AzureSynapseHook,
+    AzureSynapsePipelineHook,
+    AzureSynapsePipelineRunException,
+    AzureSynapsePipelineRunStatus,
+    AzureSynapseSparkBatchRunStatus,
+)
 
 if TYPE_CHECKING:
     from azure.synapse.spark.models import SparkBatchJobOptions
 
+    from airflow.models.taskinstancekey import TaskInstanceKey
     from airflow.utils.context import Context
 
 
@@ -108,3 +118,174 @@ class AzureSynapseRunSparkBatchOperator(BaseOperator):
                 job_id=self.job_id,
             )
             self.log.info("Job run %s has been cancelled successfully.", 
self.job_id)
+
+
+class AzureSynapsePipelineRunLink(BaseOperatorLink):
+    """Constructs a link to monitor a pipeline run in Azure Synapse."""
+
+    name = "Monitor Pipeline Run"
+
+    def get_fields_from_url(self, workspace_url):
+        """
+        Extracts the workspace_name, subscription_id and resource_group from 
the Synapse workspace url.
+
+        :param workspace_url: The workspace url.
+        """
+        import re
+        from urllib.parse import unquote, urlparse
+
+        pattern = r"https://web\.azuresynapse\.net\?workspace=(.*)"
+        match = re.search(pattern, workspace_url)
+
+        if not match:
+            raise ValueError("Invalid workspace URL format")
+
+        extracted_text = match.group(1)
+        parsed_url = urlparse(extracted_text)
+        path = unquote(parsed_url.path)
+        path_segments = path.split("/")
+        if len(path_segments) < 5:
+            raise
+
+        return {
+            "workspace_name": path_segments[-1],
+            "subscription_id": path_segments[2],
+            "resource_group": path_segments[4],
+        }
+
+    def get_link(self, operator: BaseOperator, *, ti_key: TaskInstanceKey):
+        run_id = XCom.get_value(key="run_id", ti_key=ti_key) or ""
+        conn_id = operator.azure_synapse_conn_id  # type: ignore
+        conn = BaseHook.get_connection(conn_id)
+        self.synapse_workspace_url = conn.host
+
+        fields = self.get_fields_from_url(self.synapse_workspace_url)
+
+        params = {
+            "workspace": f"/subscriptions/{fields['subscription_id']}"
+            
f"/resourceGroups/{fields['resource_group']}/providers/Microsoft.Synapse"
+            f"/workspaces/{fields['workspace_name']}",
+        }
+        encoded_params = urlencode(params)
+        base_url = 
f"https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}?";
+
+        return base_url + encoded_params
+
+
+class AzureSynapseRunPipelineOperator(BaseOperator):
+    """
+    Executes a Synapse Pipeline.
+
+    :param pipeline_name: The name of the pipeline to execute.
+    :param azure_synapse_conn_id: The Airflow connection ID for Azure Synapse.
+    :param azure_synapse_workspace_dev_endpoint: The Azure Synapse workspace 
development endpoint.
+    :param wait_for_termination: Flag to wait on a pipeline run's termination.
+    :param reference_pipeline_run_id: The pipeline run identifier. If this run 
ID is specified the parameters
+        of the specified run will be used to create a new run.
+    :param is_recovery: Recovery mode flag. If recovery mode is set to `True`, 
the specified referenced
+        pipeline run and the new run will be grouped under the same 
``groupId``.
+    :param start_activity_name: In recovery mode, the rerun will start from 
this activity. If not specified,
+        all activities will run.
+    :param parameters: Parameters of the pipeline run. These parameters are 
referenced in a pipeline via
+        ``@pipeline().parameters.parameterName`` and will be used only if the 
``reference_pipeline_run_id`` is
+        not specified.
+    :param timeout: Time in seconds to wait for a pipeline to reach a terminal 
status for non-asynchronous
+        waits. Used only if ``wait_for_termination`` is True.
+    :param check_interval: Time in seconds to check on a pipeline run's status 
for non-asynchronous waits.
+        Used only if ``wait_for_termination`` is True.
+
+    """
+
+    template_fields: Sequence[str] = ("azure_synapse_conn_id",)
+
+    operator_extra_links = (AzureSynapsePipelineRunLink(),)
+
+    def __init__(
+        self,
+        pipeline_name: str,
+        azure_synapse_conn_id: str,
+        azure_synapse_workspace_dev_endpoint: str,
+        wait_for_termination: bool = True,
+        reference_pipeline_run_id: str | None = None,
+        is_recovery: bool | None = None,
+        start_activity_name: str | None = None,
+        parameters: dict[str, Any] | None = None,
+        timeout: int = 60 * 60 * 24 * 7,
+        check_interval: int = 60,
+        **kwargs,
+    ) -> None:
+        super().__init__(**kwargs)
+        self.azure_synapse_conn_id = azure_synapse_conn_id
+        self.pipeline_name = pipeline_name
+        self.azure_synapse_workspace_dev_endpoint = 
azure_synapse_workspace_dev_endpoint
+        self.wait_for_termination = wait_for_termination
+        self.reference_pipeline_run_id = reference_pipeline_run_id
+        self.is_recovery = is_recovery
+        self.start_activity_name = start_activity_name
+        self.parameters = parameters
+        self.timeout = timeout
+        self.check_interval = check_interval
+
+    @cached_property
+    def hook(self):
+        """Create and return an AzureSynapsePipelineHook (cached)."""
+        return AzureSynapsePipelineHook(
+            azure_synapse_conn_id=self.azure_synapse_conn_id,
+            
azure_synapse_workspace_dev_endpoint=self.azure_synapse_workspace_dev_endpoint,
+        )
+
+    def execute(self, context) -> None:
+        self.log.info("Executing the %s pipeline.", self.pipeline_name)
+        response = self.hook.run_pipeline(
+            pipeline_name=self.pipeline_name,
+            reference_pipeline_run_id=self.reference_pipeline_run_id,
+            is_recovery=self.is_recovery,
+            start_activity_name=self.start_activity_name,
+            parameters=self.parameters,
+        )
+        self.run_id = vars(response)["run_id"]
+        # Push the ``run_id`` value to XCom regardless of what happens during 
execution. This allows for
+        # retrieval the executed pipeline's ``run_id`` for downstream tasks 
especially if performing an
+        # asynchronous wait.
+        context["ti"].xcom_push(key="run_id", value=self.run_id)
+
+        if self.wait_for_termination:
+            self.log.info("Waiting for pipeline run %s to terminate.", 
self.run_id)
+
+            if self.hook.wait_for_pipeline_run_status(
+                run_id=self.run_id,
+                expected_statuses=AzureSynapsePipelineRunStatus.SUCCEEDED,
+                check_interval=self.check_interval,
+                timeout=self.timeout,
+            ):
+                self.log.info("Pipeline run %s has completed successfully.", 
self.run_id)
+            else:
+                raise AzureSynapsePipelineRunException(
+                    f"Pipeline run {self.run_id} has failed or has been 
cancelled."
+                )
+
+    def execute_complete(self, event: dict[str, str]) -> None:
+        """
+        Callback for when the trigger fires - returns immediately.
+
+        Relies on trigger to throw an exception, otherwise it assumes 
execution was successful.
+        """
+        if event:
+            if event["status"] == "error":
+                raise AirflowException(event["message"])
+            self.log.info(event["message"])
+
+    def on_kill(self) -> None:
+        if self.run_id:
+            self.hook.cancel_run_pipeline(run_id=self.run_id)
+
+            # Check to ensure the pipeline run was cancelled as expected.
+            if self.hook.wait_for_pipeline_run_status(
+                run_id=self.run_id,
+                expected_statuses=AzureSynapsePipelineRunStatus.CANCELLED,
+                check_interval=self.check_interval,
+                timeout=self.timeout,
+            ):
+                self.log.info("Pipeline run %s has been cancelled 
successfully.", self.run_id)
+            else:
+                raise AzureSynapsePipelineRunException(f"Pipeline run 
{self.run_id} was not cancelled.")
diff --git a/airflow/providers/microsoft/azure/provider.yaml 
b/airflow/providers/microsoft/azure/provider.yaml
index 975feb276a..53bccda6c7 100644
--- a/airflow/providers/microsoft/azure/provider.yaml
+++ b/airflow/providers/microsoft/azure/provider.yaml
@@ -82,6 +82,7 @@ dependencies:
   - azure-storage-file-share
   - azure-servicebus>=7.6.1
   - azure-synapse-spark
+  - azure-synapse-artifacts>=0.17.0
   - adal>=1.2.7
   - azure-storage-file-datalake>=12.9.1
   - azure-kusto-data>=4.1.0
@@ -279,15 +280,13 @@ connection-types:
     connection-type: azure_fileshare
   - hook-class-name: 
airflow.providers.microsoft.azure.hooks.container_volume.AzureContainerVolumeHook
     connection-type: azure_container_volume
-  - hook-class-name: >-
-      
airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook
+  - hook-class-name: 
airflow.providers.microsoft.azure.hooks.container_instance.AzureContainerInstanceHook
     connection-type: azure_container_instance
   - hook-class-name: airflow.providers.microsoft.azure.hooks.wasb.WasbHook
     connection-type: wasb
   - hook-class-name: 
airflow.providers.microsoft.azure.hooks.data_factory.AzureDataFactoryHook
     connection-type: azure_data_factory
-  - hook-class-name: >-
-      
airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook
+  - hook-class-name: 
airflow.providers.microsoft.azure.hooks.container_registry.AzureContainerRegistryHook
     connection-type: azure_container_registry
   - hook-class-name: 
airflow.providers.microsoft.azure.hooks.asb.BaseAzureServiceBusHook
     connection-type: azure_service_bus
@@ -304,6 +303,7 @@ logging:
 
 extra-links:
   - 
airflow.providers.microsoft.azure.operators.data_factory.AzureDataFactoryPipelineRunLink
+  - 
airflow.providers.microsoft.azure.operators.synapse.AzureSynapsePipelineRunLink
 
 config:
   azure_remote_logging:
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 3979ae3c59..b787191fd0 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -315,6 +315,7 @@ cpus
 crd
 createDisposition
 CreateQueryOperator
+CreateRunResponse
 creationTimestamp
 credssp
 Cron
@@ -1144,6 +1145,7 @@ pinecone
 pinodb
 Pinot
 pinot
+PipelineRun
 pkill
 plaintext
 platformVersion
diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 0b234b4e12..058078a191 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -586,6 +586,7 @@
       "azure-storage-blob>=12.14.0",
       "azure-storage-file-datalake>=12.9.1",
       "azure-storage-file-share",
+      "azure-synapse-artifacts>=0.17.0",
       "azure-synapse-spark"
     ],
     "cross-providers-deps": [
diff --git a/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py 
b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py
new file mode 100644
index 0000000000..d0309bd0a6
--- /dev/null
+++ b/tests/providers/microsoft/azure/hooks/test_synapse_pipeline.py
@@ -0,0 +1,159 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from azure.identity import ClientSecretCredential, DefaultAzureCredential
+from azure.synapse.artifacts import ArtifactsClient
+
+from airflow.models.connection import Connection
+from airflow.providers.microsoft.azure.hooks.synapse import (
+    AzureSynapsePipelineHook,
+    AzureSynapsePipelineRunException,
+    AzureSynapsePipelineRunStatus,
+)
+
+DEFAULT_CONNECTION_CLIENT_SECRET = "azure_synapse_test_client_secret"
+DEFAULT_CONNECTION_DEFAULT_CREDENTIAL = "azure_synapse_test_default_credential"
+
+SYNAPSE_WORKSPACE_URL = "synapse_workspace_url"
+AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint"
+PIPELINE_NAME = "pipeline_name"
+RUN_ID = "run_id"
+
+
+@pytest.fixture(autouse=True)
+def setup_connections(create_mock_connections):
+    create_mock_connections(
+        # connection_client_secret
+        Connection(
+            conn_id=DEFAULT_CONNECTION_CLIENT_SECRET,
+            conn_type="azure_synapse_pipeline",
+            host=SYNAPSE_WORKSPACE_URL,
+            login="clientId",
+            password="clientSecret",
+            extra={"tenantId": "tenantId"},
+        ),
+        # connection_default_credential
+        Connection(
+            conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+            conn_type="azure_synapse_pipeline",
+            host=SYNAPSE_WORKSPACE_URL,
+            extra={},
+        ),
+        # connection_missing_tenant_id
+        Connection(
+            conn_id="azure_synapse_missing_tenant_id",
+            conn_type="azure_synapse_pipeline",
+            host=SYNAPSE_WORKSPACE_URL,
+            login="clientId",
+            password="clientSecret",
+            extra={},
+        ),
+    )
+
+
+@pytest.fixture
+def hook():
+    client = AzureSynapsePipelineHook(
+        azure_synapse_conn_id=DEFAULT_CONNECTION_DEFAULT_CREDENTIAL,
+        
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+    )
+    client._conn = MagicMock(spec=["pipeline_run", "pipeline"])
+
+    return client
+
+
+@pytest.mark.parametrize(
+    ("connection_id", "credential_type"),
+    [
+        (DEFAULT_CONNECTION_CLIENT_SECRET, ClientSecretCredential),
+        (DEFAULT_CONNECTION_DEFAULT_CREDENTIAL, DefaultAzureCredential),
+    ],
+)
+def test_get_connection_by_credential_client_secret(connection_id: str, 
credential_type: type):
+    hook = AzureSynapsePipelineHook(
+        azure_synapse_conn_id=connection_id,
+        
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+    )
+
+    with patch.object(hook, "_create_client") as mock_create_client:
+        mock_create_client.return_value = MagicMock()
+        connection = hook.get_conn()
+        assert connection is not None
+        mock_create_client.assert_called_once()
+        assert isinstance(mock_create_client.call_args.args[0], 
credential_type)
+        assert mock_create_client.call_args.args[1] == 
AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT
+
+
+def test_run_pipeline(hook: AzureSynapsePipelineHook):
+    hook.run_pipeline(PIPELINE_NAME)
+
+    if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+        
hook._conn.pipeline.create_pipeline_run.assert_called_with(PIPELINE_NAME)
+
+
+def test_get_pipeline_run(hook: AzureSynapsePipelineHook):
+    hook.get_pipeline_run(run_id=RUN_ID)
+
+    if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+        
hook._conn.pipeline_run.get_pipeline_run.assert_called_with(run_id=RUN_ID)
+
+
+def test_cancel_run_pipeline(hook: AzureSynapsePipelineHook):
+    hook.cancel_run_pipeline(RUN_ID)
+
+    if hook._conn is not None and isinstance(hook._conn, ArtifactsClient):
+        hook._conn.pipeline_run.cancel_pipeline_run.assert_called_with(RUN_ID)
+
+
+_wait_for_pipeline_run_status_test_args = [
+    (AzureSynapsePipelineRunStatus.SUCCEEDED, 
AzureSynapsePipelineRunStatus.SUCCEEDED, True),
+    (AzureSynapsePipelineRunStatus.FAILED, 
AzureSynapsePipelineRunStatus.SUCCEEDED, False),
+    (AzureSynapsePipelineRunStatus.CANCELLED, 
AzureSynapsePipelineRunStatus.SUCCEEDED, False),
+    (AzureSynapsePipelineRunStatus.IN_PROGRESS, 
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureSynapsePipelineRunStatus.QUEUED, 
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureSynapsePipelineRunStatus.CANCELING, 
AzureSynapsePipelineRunStatus.SUCCEEDED, "timeout"),
+    (AzureSynapsePipelineRunStatus.SUCCEEDED, 
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+    (AzureSynapsePipelineRunStatus.FAILED, 
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+    (AzureSynapsePipelineRunStatus.CANCELLED, 
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES, True),
+]
+
+
+@pytest.mark.parametrize(
+    argnames=("pipeline_run_status", "expected_status", "expected_output"),
+    argvalues=_wait_for_pipeline_run_status_test_args,
+    ids=[
+        f"run_status_{argval[0]}_expected_{argval[1]}"
+        if isinstance(argval[1], str)
+        else f"run_status_{argval[0]}_expected_AnyTerminalStatus"
+        for argval in _wait_for_pipeline_run_status_test_args
+    ],
+)
+def test_wait_for_pipeline_run_status(hook, pipeline_run_status, 
expected_status, expected_output):
+    config = {"run_id": RUN_ID, "timeout": 3, "check_interval": 1, 
"expected_statuses": expected_status}
+
+    with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as 
mock_pipeline_run:
+        mock_pipeline_run.return_value.status = pipeline_run_status
+
+        if expected_output != "timeout":
+            assert hook.wait_for_pipeline_run_status(**config) == 
expected_output
+        else:
+            with pytest.raises(AzureSynapsePipelineRunException):
+                hook.wait_for_pipeline_run_status(**config)
diff --git a/tests/providers/microsoft/azure/operators/test_synapse.py 
b/tests/providers/microsoft/azure/operators/test_synapse.py
index 233e1c57fd..14ffd783ba 100644
--- a/tests/providers/microsoft/azure/operators/test_synapse.py
+++ b/tests/providers/microsoft/azure/operators/test_synapse.py
@@ -17,24 +17,42 @@
 from __future__ import annotations
 
 from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
 
 import pytest
 
 from airflow.models import Connection
-from airflow.providers.microsoft.azure.operators.synapse import 
AzureSynapseRunSparkBatchOperator
+from airflow.providers.microsoft.azure.hooks.synapse import (
+    AzureSynapsePipelineHook,
+    AzureSynapsePipelineRunException,
+    AzureSynapsePipelineRunStatus,
+)
+from airflow.providers.microsoft.azure.operators.synapse import (
+    AzureSynapsePipelineRunLink,
+    AzureSynapseRunPipelineOperator,
+    AzureSynapseRunSparkBatchOperator,
+)
 from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2021, 1, 1)
-SUBSCRIPTION_ID = "my-subscription-id"
+SUBSCRIPTION_ID = "subscription_id"
+TENANT_ID = "tenant_id"
 TASK_ID = "run_spark_op"
+AZURE_SYNAPSE_PIPELINE_TASK_ID = "run_pipeline_op"
 AZURE_SYNAPSE_CONN_ID = "azure_synapse_test"
 CONN_EXTRAS = {
     "synapse__subscriptionId": SUBSCRIPTION_ID,
     "synapse__tenantId": "my-tenant-id",
     "synapse__spark_pool": "my-spark-pool",
 }
+SYNAPSE_PIPELINE_CONN_EXTRAS = {"tenantId": TENANT_ID}
 JOB_RUN_RESPONSE = {"id": 123}
+PIPELINE_NAME = "Pipeline 1"
+AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT = "azure_synapse_workspace_dev_endpoint"
+RESOURCE_GROUP = "op-resource-group"
+WORKSPACE_NAME = "workspace-test"
+AZURE_SYNAPSE_WORKSPACE_URL = 
f"https://web.azuresynapse.net?workspace=%2fsubscriptions%{SUBSCRIPTION_ID}%2fresourceGroups%2f{RESOURCE_GROUP}%2fproviders%2fMicrosoft.Synapse%2fworkspaces%2f{WORKSPACE_NAME}";
+PIPELINE_RUN_RESPONSE = {"run_id": "run_id"}
 
 
 class TestAzureSynapseRunSparkBatchOperator:
@@ -53,7 +71,7 @@ class TestAzureSynapseRunSparkBatchOperator:
         create_mock_connection(
             Connection(
                 conn_id=AZURE_SYNAPSE_CONN_ID,
-                conn_type="azure_synapse",
+                conn_type="azure_synapse_pipeline",
                 host="https://synapsetest.net";,
                 login="client-id",
                 password="client-secret",
@@ -102,3 +120,195 @@ class TestAzureSynapseRunSparkBatchOperator:
         op.execute(context=self.mock_context)
         op.on_kill()
         
mock_cancel_job_run.assert_called_once_with(job_id=JOB_RUN_RESPONSE["id"])
+
+
+class TestAzureSynapseRunPipelineOperator:
+    @pytest.fixture(autouse=True)
+    def setup_test_cases(self, create_mock_connection):
+        self.mock_ti = MagicMock()
+        self.mock_context = {"ti": self.mock_ti}
+        self.config = {
+            "task_id": AZURE_SYNAPSE_PIPELINE_TASK_ID,
+            "azure_synapse_conn_id": AZURE_SYNAPSE_CONN_ID,
+            "pipeline_name": PIPELINE_NAME,
+            "azure_synapse_workspace_dev_endpoint": 
AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+            "check_interval": 1,
+            "timeout": 3,
+        }
+
+        create_mock_connection(
+            Connection(
+                conn_id=AZURE_SYNAPSE_CONN_ID,
+                conn_type="azure_synapse_pipeline",
+                host=AZURE_SYNAPSE_WORKSPACE_URL,
+                login="client_id",
+                password="client_secret",
+                extra=SYNAPSE_PIPELINE_CONN_EXTRAS,
+            )
+        )
+
+    @staticmethod
+    def create_pipeline_run(status: str):
+        """Helper function to create a mock pipeline run with a given 
execution status."""
+
+        run = MagicMock()
+        run.status = status
+
+        return run
+
+    @patch.object(AzureSynapsePipelineHook, "run_pipeline", 
return_value=MagicMock(**PIPELINE_RUN_RESPONSE))
+    @pytest.mark.parametrize(
+        "pipeline_run_status,expected_output",
+        [
+            (AzureSynapsePipelineRunStatus.SUCCEEDED, None),
+            (AzureSynapsePipelineRunStatus.FAILED, "exception"),
+            (AzureSynapsePipelineRunStatus.CANCELLED, "exception"),
+            (AzureSynapsePipelineRunStatus.IN_PROGRESS, "timeout"),
+            (AzureSynapsePipelineRunStatus.QUEUED, "timeout"),
+            (AzureSynapsePipelineRunStatus.CANCELING, "timeout"),
+        ],
+    )
+    def test_execute_wait_for_termination(self, mock_run_pipeline, 
pipeline_run_status, expected_output):
+        # Initialize the operator with mock config, (**) unpacks the config 
dict.
+        operator = AzureSynapseRunPipelineOperator(**self.config)
+
+        assert operator.azure_synapse_conn_id == 
self.config["azure_synapse_conn_id"]
+        assert operator.pipeline_name == self.config["pipeline_name"]
+        assert (
+            operator.azure_synapse_workspace_dev_endpoint
+            == self.config["azure_synapse_workspace_dev_endpoint"]
+        )
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert operator.wait_for_termination
+
+        with patch.object(AzureSynapsePipelineHook, "get_pipeline_run") as 
mock_get_pipeline_run:
+            mock_get_pipeline_run.return_value = 
TestAzureSynapseRunPipelineOperator.create_pipeline_run(
+                pipeline_run_status
+            )
+
+            if not expected_output:
+                # A successful operator execution should not return any values.
+                assert not operator.execute(context=self.mock_context)
+            elif expected_output == "exception":
+                # The operator should fail if the pipeline run fails or is 
canceled.
+                with pytest.raises(
+                    AzureSynapsePipelineRunException,
+                    match=f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has 
failed or has been cancelled.",
+                ):
+                    operator.execute(context=self.mock_context)
+            else:
+                # Demonstrating the operator timing out after surpassing the 
configured timeout value.
+                with pytest.raises(
+                    AzureSynapsePipelineRunException,
+                    match=(
+                        f"Pipeline run {PIPELINE_RUN_RESPONSE['run_id']} has 
not reached a terminal status "
+                        f"after {self.config['timeout']} seconds."
+                    ),
+                ):
+                    operator.execute(context=self.mock_context)
+
+            # Check the ``run_id`` attr is assigned after executing the 
pipeline.
+            assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"]
+
+            # Check to ensure an `XCom` is pushed regardless of pipeline run 
result.
+            self.mock_ti.xcom_push.assert_called_once_with(
+                key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]
+            )
+
+            # Check if mock_run_pipeline called with particular set of 
arguments.
+            mock_run_pipeline.assert_called_once_with(
+                pipeline_name=self.config["pipeline_name"],
+                reference_pipeline_run_id=None,
+                is_recovery=None,
+                start_activity_name=None,
+                parameters=None,
+            )
+
+            if pipeline_run_status in 
AzureSynapsePipelineRunStatus.TERMINAL_STATUSES:
+                
mock_get_pipeline_run.assert_called_once_with(run_id=mock_run_pipeline.return_value.run_id)
+            else:
+                # When the pipeline run status is not in a terminal status or 
"Succeeded", the operator will
+                # continue to call ``get_pipeline_run()`` until a ``timeout`` 
number of seconds has passed
+                # (3 seconds for this test).  Therefore, there should be 4 
calls of this function: one
+                # initially and 3 for each check done at a 1 second interval.
+                assert mock_get_pipeline_run.call_count == 4
+
+                
mock_get_pipeline_run.assert_called_with(run_id=mock_run_pipeline.return_value.run_id)
+
+    @patch.object(AzureSynapsePipelineHook, "run_pipeline", 
return_value=MagicMock(**PIPELINE_RUN_RESPONSE))
+    def test_execute_no_wait_for_termination(self, mock_run_pipeline):
+        operator = AzureSynapseRunPipelineOperator(wait_for_termination=False, 
**self.config)
+
+        assert operator.azure_synapse_conn_id == 
self.config["azure_synapse_conn_id"]
+        assert operator.pipeline_name == self.config["pipeline_name"]
+        assert (
+            operator.azure_synapse_workspace_dev_endpoint
+            == self.config["azure_synapse_workspace_dev_endpoint"]
+        )
+        assert operator.check_interval == self.config["check_interval"]
+        assert operator.timeout == self.config["timeout"]
+        assert not operator.wait_for_termination
+
+        with patch.object(
+            AzureSynapsePipelineHook, "get_pipeline_run", autospec=True
+        ) as mock_get_pipeline_run:
+            operator.execute(context=self.mock_context)
+
+            # Check the ``run_id`` attr is assigned after executing the 
pipeline.
+            assert operator.run_id == PIPELINE_RUN_RESPONSE["run_id"]
+
+            # Check to ensure an `XCom` is pushed regardless of pipeline run 
result.
+            self.mock_ti.xcom_push.assert_called_once_with(
+                key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"]
+            )
+
+            mock_run_pipeline.assert_called_once_with(
+                pipeline_name=self.config["pipeline_name"],
+                reference_pipeline_run_id=None,
+                is_recovery=None,
+                start_activity_name=None,
+                parameters=None,
+            )
+
+            # Checking the pipeline run status should _not_ be called when 
``wait_for_termination`` is False.
+            mock_get_pipeline_run.assert_not_called()
+
+    @pytest.mark.db_test
+    def test_run_pipeline_operator_link(self, 
create_task_instance_of_operator):
+        ti = create_task_instance_of_operator(
+            AzureSynapseRunPipelineOperator,
+            dag_id="test_synapse_run_pipeline_op_link",
+            execution_date=DEFAULT_DATE,
+            task_id=AZURE_SYNAPSE_PIPELINE_TASK_ID,
+            azure_synapse_conn_id=AZURE_SYNAPSE_CONN_ID,
+            pipeline_name=PIPELINE_NAME,
+            
azure_synapse_workspace_dev_endpoint=AZURE_SYNAPSE_WORKSPACE_DEV_ENDPOINT,
+        )
+
+        ti.xcom_push(key="run_id", value=PIPELINE_RUN_RESPONSE["run_id"])
+
+        url = ti.task.get_extra_links(ti, "Monitor Pipeline Run")
+
+        EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK = (
+            
"https://ms.web.azuresynapse.net/en/monitoring/pipelineruns/{run_id}";
+            "?workspace=%2Fsubscriptions%2F{subscription_id}%2F"
+            "resourceGroups%2F{resource_group}%2Fproviders%2FMicrosoft.Synapse"
+            "%2Fworkspaces%2F{workspace_name}"
+        )
+
+        conn = AzureSynapsePipelineHook.get_connection(AZURE_SYNAPSE_CONN_ID)
+        conn_synapse_workspace_url = conn.host
+
+        # Extract the workspace_name, subscription_id and resource_group from 
the Synapse workspace url.
+        pipeline_run_object = AzureSynapsePipelineRunLink()
+        fields = 
pipeline_run_object.get_fields_from_url(workspace_url=conn_synapse_workspace_url)
+
+        assert url == (
+            EXPECTED_PIPELINE_RUN_OP_EXTRA_LINK.format(
+                run_id=PIPELINE_RUN_RESPONSE["run_id"],
+                subscription_id=fields["subscription_id"],
+                resource_group=fields["resource_group"],
+                workspace_name=fields["workspace_name"],
+            )
+        )
diff --git 
a/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py 
b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py
new file mode 100644
index 0000000000..be11a665c4
--- /dev/null
+++ b/tests/system/providers/microsoft/azure/example_synapse_run_pipeline.py
@@ -0,0 +1,59 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import os
+from datetime import datetime
+
+from airflow.models import DAG
+from airflow.providers.microsoft.azure.operators.synapse import 
AzureSynapseRunPipelineOperator
+
+try:
+    from airflow.operators.empty import EmptyOperator
+except ModuleNotFoundError:
+    from airflow.operators.dummy import DummyOperator as EmptyOperator  # 
type: ignore
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID")
+
+with DAG(
+    dag_id="example_synapse_run_pipeline",
+    start_date=datetime(2021, 8, 13),
+    schedule="@daily",
+    catchup=False,
+    tags=["synapse", "example"],
+) as dag:
+    begin = EmptyOperator(task_id="begin")
+
+    run_pipeline1 = AzureSynapseRunPipelineOperator(
+        task_id="run_pipeline1",
+        azure_synapse_conn_id="azure_synapse_connection",
+        pipeline_name="Pipeline 1",
+        
azure_synapse_workspace_dev_endpoint="azure_synapse_workspace_dev_endpoint",
+    )
+
+    begin >> run_pipeline1
+
+    from tests.system.utils.watcher import watcher
+
+    # This test needs watcher in order to properly mark success/failure
+    # when "tearDown" task with trigger rule is part of the DAG
+    list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run  # noqa: E402
+
+# Needed to run the example DAG with pytest (see: 
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)


Reply via email to