This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 6008483efe8 Create operators for working with Consumer Groups for GCP
Apache Kafka (#47056)
6008483efe8 is described below
commit 6008483efe83b149d0fd3732e4889aba06b2e5c0
Author: Maksim <[email protected]>
AuthorDate: Mon Mar 3 11:10:18 2025 -0800
Create operators for working with Consumer Groups for GCP Apache Kafka
(#47056)
---
dev/breeze/tests/test_selective_checks.py | 8 +-
generated/provider_dependencies.json | 4 +-
providers/apache/kafka/README.rst | 19 ++
providers/apache/kafka/pyproject.toml | 7 +
.../providers/apache/kafka/get_provider_info.py | 1 +
.../airflow/providers/apache/kafka/hooks/base.py | 11 +
.../google/docs/operators/cloud/managed_kafka.rst | 60 +++++
providers/google/provider.yaml | 1 +
.../providers/google/cloud/hooks/managed_kafka.py | 227 +++++++++++++++++-
.../providers/google/cloud/links/managed_kafka.py | 30 +++
.../google/cloud/operators/managed_kafka.py | 265 +++++++++++++++++++++
.../airflow/providers/google/get_provider_info.py | 1 +
.../example_managed_kafka_consumer_group.py | 254 ++++++++++++++++++++
.../unit/google/cloud/hooks/test_managed_kafka.py | 189 +++++++++++++++
.../unit/google/cloud/links/test_managed_kafka.py | 40 ++++
.../google/cloud/operators/test_managed_kafka.py | 131 ++++++++++
16 files changed, 1241 insertions(+), 7 deletions(-)
diff --git a/dev/breeze/tests/test_selective_checks.py
b/dev/breeze/tests/test_selective_checks.py
index b4fdfa36a21..1fcf0e7aa20 100644
--- a/dev/breeze/tests/test_selective_checks.py
+++ b/dev/breeze/tests/test_selective_checks.py
@@ -1621,7 +1621,7 @@ def test_expected_output_push(
"providers/google/tests/unit/google/file.py",
),
{
- "selected-providers-list-as-string": "amazon apache.beam
apache.cassandra "
+ "selected-providers-list-as-string": "amazon apache.beam
apache.cassandra apache.kafka "
"cncf.kubernetes common.compat common.sql "
"facebook google hashicorp microsoft.azure microsoft.mssql
mysql "
"openlineage oracle postgres presto salesforce samba sftp ssh
trino",
@@ -1635,14 +1635,14 @@ def test_expected_output_push(
"test-groups": "['core', 'providers']",
"docs-build": "true",
"docs-list-as-string": "apache-airflow helm-chart amazon
apache.beam apache.cassandra "
- "cncf.kubernetes common.compat common.sql facebook google
hashicorp microsoft.azure "
+ "apache.kafka cncf.kubernetes common.compat common.sql
facebook google hashicorp microsoft.azure "
"microsoft.mssql mysql openlineage oracle postgres "
"presto salesforce samba sftp ssh trino",
"skip-pre-commits":
"identity,mypy-airflow,mypy-dev,mypy-docs,mypy-providers,mypy-task-sdk,ts-compile-format-lint-ui",
"run-kubernetes-tests": "true",
"upgrade-to-newer-dependencies": "false",
"core-test-types-list-as-string": "Always CLI",
- "providers-test-types-list-as-string": "Providers[amazon]
Providers[apache.beam,apache.cassandra,cncf.kubernetes,common.compat,common.sql,facebook,"
+ "providers-test-types-list-as-string": "Providers[amazon]
Providers[apache.beam,apache.cassandra,apache.kafka,cncf.kubernetes,common.compat,common.sql,facebook,"
"hashicorp,microsoft.azure,microsoft.mssql,mysql,openlineage,oracle,postgres,presto,"
"salesforce,samba,sftp,ssh,trino] Providers[google]",
"needs-mypy": "true",
@@ -1890,7 +1890,7 @@ def test_upgrade_to_newer_dependencies(
pytest.param(
("providers/google/docs/some_file.rst",),
{
- "docs-list-as-string": "amazon apache.beam apache.cassandra "
+ "docs-list-as-string": "amazon apache.beam apache.cassandra
apache.kafka "
"cncf.kubernetes common.compat common.sql facebook google
hashicorp "
"microsoft.azure microsoft.mssql mysql openlineage oracle "
"postgres presto salesforce samba sftp ssh trino",
diff --git a/generated/provider_dependencies.json
b/generated/provider_dependencies.json
index 87657a13aac..b935b9c7d03 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -222,7 +222,9 @@
],
"devel-deps": [],
"plugins": [],
- "cross-providers-deps": [],
+ "cross-providers-deps": [
+ "google"
+ ],
"excluded-python-versions": [],
"state": "ready"
},
diff --git a/providers/apache/kafka/README.rst
b/providers/apache/kafka/README.rst
index cfa6d23c9f7..f85a6682b27 100644
--- a/providers/apache/kafka/README.rst
+++ b/providers/apache/kafka/README.rst
@@ -58,5 +58,24 @@ PIP package Version required
``confluent-kafka`` ``>=2.3.0``
=================== ==================
+Cross provider package dependencies
+-----------------------------------
+
+Those are dependencies that might be needed in order to use all the features
of the package.
+You need to install the specified provider packages in order to use them.
+
+You can install such cross-provider dependencies when installing from PyPI.
For example:
+
+.. code-block:: bash
+
+ pip install apache-airflow-providers-apache-kafka[google]
+
+
+====================================================================================================
==========
+Dependent package
Extra
+====================================================================================================
==========
+`apache-airflow-providers-google
<https://airflow.apache.org/docs/apache-airflow-providers-google>`_ ``google``
+====================================================================================================
==========
+
The changelog for the provider package can be found in the
`changelog
<https://airflow.apache.org/docs/apache-airflow-providers-apache-kafka/1.7.0/changelog.html>`_.
diff --git a/providers/apache/kafka/pyproject.toml
b/providers/apache/kafka/pyproject.toml
index 51c6126c5ca..c236cf0e662 100644
--- a/providers/apache/kafka/pyproject.toml
+++ b/providers/apache/kafka/pyproject.toml
@@ -62,6 +62,13 @@ dependencies = [
"confluent-kafka>=2.3.0",
]
+# The optional dependencies should be modified in place in the generated file
+# Any change in the dependencies is preserved when the file is regenerated
+[project.optional-dependencies]
+"google" = [
+ "apache-airflow-providers-google"
+]
+
[project.urls]
"Documentation" =
"https://airflow.apache.org/docs/apache-airflow-providers-apache-kafka/1.7.0"
"Changelog" =
"https://airflow.apache.org/docs/apache-airflow-providers-apache-kafka/1.7.0/changelog.html"
diff --git
a/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py
b/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py
index 97287ea00e2..d53eee67dc7 100644
---
a/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py
+++
b/providers/apache/kafka/src/airflow/providers/apache/kafka/get_provider_info.py
@@ -90,4 +90,5 @@ def get_provider_info():
}
],
"dependencies": ["apache-airflow>=2.9.0", "asgiref>=2.3.0",
"confluent-kafka>=2.3.0"],
+ "optional-dependencies": {"google":
["apache-airflow-providers-google"]},
}
diff --git
a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py
b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py
index 9b20c7dfc91..5d02903a4d6 100644
--- a/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py
+++ b/providers/apache/kafka/src/airflow/providers/apache/kafka/hooks/base.py
@@ -22,6 +22,7 @@ from typing import Any
from confluent_kafka.admin import AdminClient
from airflow.hooks.base import BaseHook
+from airflow.providers.google.cloud.hooks.managed_kafka import ManagedKafkaHook
class KafkaBaseHook(BaseHook):
@@ -63,6 +64,16 @@ class KafkaBaseHook(BaseHook):
if not (config.get("bootstrap.servers", None)):
raise ValueError("config['bootstrap.servers'] must be provided.")
+ bootstrap_servers = config.get("bootstrap.servers")
+ if (
+ bootstrap_servers
+ and bootstrap_servers.find("cloud.goog") != -1
+ and bootstrap_servers.find("managedkafka") != -1
+ ):
+ self.log.info("Adding token generation for Google Auth to the
confluent configuration.")
+ hook = ManagedKafkaHook()
+ token = hook.get_confluent_token
+ config.update({"oauth_cb": token})
return self._get_client(config)
def test_connection(self) -> tuple[bool, str]:
diff --git a/providers/google/docs/operators/cloud/managed_kafka.rst
b/providers/google/docs/operators/cloud/managed_kafka.rst
index a81f81592ee..791d721827d 100644
--- a/providers/google/docs/operators/cloud/managed_kafka.rst
+++ b/providers/google/docs/operators/cloud/managed_kafka.rst
@@ -117,6 +117,66 @@ To update topic you can use
:start-after: [START how_to_cloud_managed_kafka_update_topic_operator]
:end-before: [END how_to_cloud_managed_kafka_update_topic_operator]
+Interacting with Apache Kafka Consumer Groups
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To delete consumer group you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaDeleteConsumerGroupOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_managed_kafka_delete_consumer_group_operator]
+ :end-before: [END
how_to_cloud_managed_kafka_delete_consumer_group_operator]
+
+To get consumer group you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaGetConsumerGroupOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_managed_kafka_get_consumer_group_operator]
+ :end-before: [END how_to_cloud_managed_kafka_get_consumer_group_operator]
+
+To get a list of consumer groups you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaListConsumerGroupsOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_managed_kafka_list_consumer_group_operator]
+ :end-before: [END how_to_cloud_managed_kafka_list_consumer_group_operator]
+
+To update consumer group you can use
+:class:`~airflow.providers.google.cloud.operators.managed_kafka.ManagedKafkaUpdateConsumerGroupOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_managed_kafka_update_consumer_group_operator]
+ :end-before: [END
how_to_cloud_managed_kafka_update_consumer_group_operator]
+
+Using Apache Kafka provider with Google Cloud Managed Service for Apache Kafka
+^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
+
+To produce data to topic you can use
+:class:`~airflow.providers.apache.kafka.operators.produce.ProduceToTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START how_to_cloud_managed_kafka_produce_to_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_produce_to_topic_operator]
+
+To consume data from topic you can use
+:class:`~airflow.providers.apache.kafka.operators.produce.ConsumeFromTopicOperator`.
+
+.. exampleinclude::
/../../providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
+ :language: python
+ :dedent: 4
+ :start-after: [START
how_to_cloud_managed_kafka_consume_from_topic_operator]
+ :end-before: [END how_to_cloud_managed_kafka_consume_from_topic_operator]
+
Reference
^^^^^^^^^
diff --git a/providers/google/provider.yaml b/providers/google/provider.yaml
index 3e3a9e7c106..a251ef1db9a 100644
--- a/providers/google/provider.yaml
+++ b/providers/google/provider.yaml
@@ -1231,6 +1231,7 @@ extra-links:
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink
-
airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink
- airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink
+ -
airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink
secrets-backends:
diff --git
a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
index f71e8a158c8..738727332d4 100644
--- a/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
+++ b/providers/google/src/airflow/providers/google/cloud/hooks/managed_kafka.py
@@ -19,12 +19,17 @@
from __future__ import annotations
+import base64
+import datetime
+import json
+import time
from collections.abc import Sequence
from copy import deepcopy
from typing import TYPE_CHECKING
from google.api_core.gapic_v1.method import DEFAULT, _MethodDefault
-from google.cloud.managedkafka_v1 import Cluster, ManagedKafkaClient, Topic,
types
+from google.auth.transport import requests as google_requests
+from google.cloud.managedkafka_v1 import Cluster, ConsumerGroup,
ManagedKafkaClient, Topic, types
from airflow.exceptions import AirflowException
from airflow.providers.google.common.consts import CLIENT_INFO
@@ -33,10 +38,62 @@ from airflow.providers.google.common.hooks.base_google
import GoogleBaseHook
if TYPE_CHECKING:
from google.api_core.operation import Operation
from google.api_core.retry import Retry
- from google.cloud.managedkafka_v1.services.managed_kafka.pagers import
ListClustersPager, ListTopicsPager
+ from google.auth.credentials import Credentials
+ from google.cloud.managedkafka_v1.services.managed_kafka.pagers import (
+ ListClustersPager,
+ ListConsumerGroupsPager,
+ ListTopicsPager,
+ )
from google.protobuf.field_mask_pb2 import FieldMask
+class ManagedKafkaTokenProvider:
+ """Helper for providing authentication token for establishing connection
via confluent to Apache Kafka cluster managed by Google Cloud."""
+
+ def __init__(
+ self,
+ credentials: Credentials,
+ ):
+ self._credentials = credentials
+ self._header = json.dumps(dict(typ="JWT", alg="GOOG_OAUTH2_TOKEN"))
+
+ def _valid_credentials(self):
+ if not self._credentials.valid:
+ self._credentials.refresh(google_requests.Request())
+ return self._credentials
+
+ def _get_jwt(self, credentials):
+ return json.dumps(
+ dict(
+ exp=credentials.expiry.timestamp(),
+ iss="Google",
+ iat=datetime.datetime.now(datetime.timezone.utc).timestamp(),
+ scope="kafka",
+ sub=credentials.service_account_email,
+ )
+ )
+
+ def _b64_encode(self, source):
+ return
base64.urlsafe_b64encode(source.encode("utf-8")).decode("utf-8").rstrip("=")
+
+ def _get_kafka_access_token(self, credentials):
+ return ".".join(
+ [
+ self._b64_encode(self._header),
+ self._b64_encode(self._get_jwt(credentials)),
+ self._b64_encode(credentials.token),
+ ]
+ )
+
+ def confluent_token(self):
+ credentials = self._valid_credentials()
+
+ utc_expiry = credentials.expiry.replace(tzinfo=datetime.timezone.utc)
+ expiry_seconds = (utc_expiry -
datetime.datetime.now(datetime.timezone.utc)).total_seconds()
+
+ return self._get_kafka_access_token(credentials), time.time() +
expiry_seconds
+
+
class ManagedKafkaHook(GoogleBaseHook):
"""Hook for Managed Service for Apache Kafka APIs."""
@@ -63,6 +120,12 @@ class ManagedKafkaHook(GoogleBaseHook):
error = operation.exception(timeout=timeout)
raise AirflowException(error)
+ def get_confluent_token(self):
+ """Get the authentication token for confluent client."""
+ token_provider =
ManagedKafkaTokenProvider(credentials=self.get_credentials())
+ token = token_provider.confluent_token()
+ return token
+
@GoogleBaseHook.fallback_to_default_project_id
def create_cluster(
self,
@@ -481,3 +544,163 @@ class ManagedKafkaHook(GoogleBaseHook):
timeout=timeout,
metadata=metadata,
)
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def list_consumer_groups(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ page_size: int | None = None,
+ page_token: str | None = None,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> ListConsumerGroupsPager:
+ """
+ List the consumer groups in a given cluster.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer
groups are to be listed.
+ :param page_size: Optional. The maximum number of consumer groups to
return. The service may return
+ fewer than this value. If unset or zero, all consumer groups for
the parent is returned.
+ :param page_token: Optional. A page token, received from a previous
``ListConsumerGroups`` call.
+ Provide this to retrieve the subsequent page. When paginating, all
other parameters provided to
+ ``ListConsumerGroups`` must match the call that provided the page
token.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ parent = client.cluster_path(project_id, location, cluster_id)
+
+ result = client.list_consumer_groups(
+ request={
+ "parent": parent,
+ "page_size": page_size,
+ "page_token": page_token,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def get_consumer_group(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ consumer_group_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.ConsumerGroup:
+ """
+ Return the properties of a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer
group is to be returned.
+ :param consumer_group_id: Required. The ID of the consumer group whose
configuration to return.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ name = client.consumer_group_path(project_id, location, cluster_id,
consumer_group_id)
+
+ result = client.get_consumer_group(
+ request={
+ "name": name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def update_consumer_group(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ consumer_group_id: str,
+ consumer_group: types.ConsumerGroup | dict,
+ update_mask: FieldMask | dict,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> types.ConsumerGroup:
+ """
+ Update the properties of a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to
be updated.
+ :param consumer_group_id: Required. The ID of the consumer group whose
configuration to update.
+ :param consumer_group: Required. The consumer_group to update. Its
``name`` field must be populated.
+ :param update_mask: Required. Field mask is used to specify the fields
to be overwritten in the
+ ConsumerGroup resource by the update. The fields specified in the
update_mask are relative to the
+ resource, not the full request. A field will be overwritten if it
is in the mask.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ _consumer_group = (
+ deepcopy(consumer_group)
+ if isinstance(consumer_group, dict)
+ else ConsumerGroup.to_dict(consumer_group)
+ )
+ _consumer_group["name"] = client.consumer_group_path(
+ project_id, location, cluster_id, consumer_group_id
+ )
+
+ result = client.update_consumer_group(
+ request={
+ "update_mask": update_mask,
+ "consumer_group": _consumer_group,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
+ return result
+
+ @GoogleBaseHook.fallback_to_default_project_id
+ def delete_consumer_group(
+ self,
+ project_id: str,
+ location: str,
+ cluster_id: str,
+ consumer_group_id: str,
+ retry: Retry | _MethodDefault = DEFAULT,
+ timeout: float | None = None,
+ metadata: Sequence[tuple[str, str]] = (),
+ ) -> None:
+ """
+ Delete a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that
the service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer
group is to be deleted.
+ :param consumer_group_id: Required. The ID of the consumer group to
delete.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request
as metadata.
+ """
+ client = self.get_managed_kafka_client()
+ name = client.consumer_group_path(project_id, location, cluster_id,
consumer_group_id)
+
+ client.delete_consumer_group(
+ request={
+ "name": name,
+ },
+ retry=retry,
+ timeout=timeout,
+ metadata=metadata,
+ )
diff --git
a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
index 0aafe2f202d..45b62901c55 100644
--- a/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
+++ b/providers/google/src/airflow/providers/google/cloud/links/managed_kafka.py
@@ -31,6 +31,10 @@ MANAGED_KAFKA_CLUSTER_LIST_LINK = MANAGED_KAFKA_BASE_LINK +
"/clusters?project={
MANAGED_KAFKA_TOPIC_LINK = (
MANAGED_KAFKA_BASE_LINK +
"/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}"
)
+MANAGED_KAFKA_CONSUMER_GROUP_LINK = (
+ MANAGED_KAFKA_BASE_LINK
+ +
"/{location}/clusters/{cluster_id}/consumer_groups/{consumer_group_id}?project={project_id}"
+)
class ApacheKafkaClusterLink(BaseGoogleLink):
@@ -102,3 +106,29 @@ class ApacheKafkaTopicLink(BaseGoogleLink):
"project_id": task_instance.project_id,
},
)
+
+
+class ApacheKafkaConsumerGroupLink(BaseGoogleLink):
+ """Helper class for constructing Apache Kafka Consumer Group link."""
+
+ name = "Apache Kafka Consumer Group"
+ key = "consumer_group_conf"
+ format_str = MANAGED_KAFKA_CONSUMER_GROUP_LINK
+
+ @staticmethod
+ def persist(
+ context: Context,
+ task_instance,
+ cluster_id: str,
+ consumer_group_id: str,
+ ):
+ task_instance.xcom_push(
+ context=context,
+ key=ApacheKafkaConsumerGroupLink.key,
+ value={
+ "location": task_instance.location,
+ "cluster_id": cluster_id,
+ "consumer_group_id": consumer_group_id,
+ "project_id": task_instance.project_id,
+ },
+ )
diff --git
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
index 0ded649858f..b649149ccc0 100644
---
a/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
+++
b/providers/google/src/airflow/providers/google/cloud/operators/managed_kafka.py
@@ -32,6 +32,7 @@ from airflow.providers.google.cloud.hooks.managed_kafka
import ManagedKafkaHook
from airflow.providers.google.cloud.links.managed_kafka import (
ApacheKafkaClusterLink,
ApacheKafkaClusterListLink,
+ ApacheKafkaConsumerGroupLink,
ApacheKafkaTopicLink,
)
from airflow.providers.google.cloud.operators.cloud_base import
GoogleCloudBaseOperator
@@ -788,3 +789,267 @@ class
ManagedKafkaDeleteTopicOperator(ManagedKafkaBaseOperator):
except NotFound as not_found_err:
self.log.info("The Apache Kafka topic ID %s does not exist.",
self.topic_id)
raise AirflowException(not_found_err)
+
+
+class ManagedKafkaListConsumerGroupsOperator(ManagedKafkaBaseOperator):
+ """
+ List the consumer groups in a given cluster.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer groups
are to be listed.
+ :param page_size: Optional. The maximum number of consumer groups to
return. The service may return
+ fewer than this value. If unset or zero, all consumer groups for the
parent is returned.
+ :param page_token: Optional. A page token, received from a previous
``ListConsumerGroups`` call.
+ Provide this to retrieve the subsequent page. When paginating, all
other parameters provided to
+ ``ListConsumerGroups`` must match the call that provided the page
token.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple({"cluster_id"} |
set(ManagedKafkaBaseOperator.template_fields))
+ operator_extra_links = (ApacheKafkaClusterLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ page_size: int | None = None,
+ page_token: str | None = None,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.page_size = page_size
+ self.page_token = page_token
+
+ def execute(self, context: Context):
+ ApacheKafkaClusterLink.persist(context=context, task_instance=self,
cluster_id=self.cluster_id)
+ self.log.info("Listing Consumer Groups for cluster %s.",
self.cluster_id)
+ try:
+ consumer_group_list_pager = self.hook.list_consumer_groups(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ page_size=self.page_size,
+ page_token=self.page_token,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.xcom_push(
+ context=context,
+ key="consumer_group_page",
+
value=types.ListConsumerGroupsResponse.to_dict(consumer_group_list_pager._response),
+ )
+ except Exception as error:
+ raise AirflowException(error)
+ return [types.ConsumerGroup.to_dict(consumer_group) for consumer_group
in consumer_group_list_pager]
+
+
+class ManagedKafkaGetConsumerGroupOperator(ManagedKafkaBaseOperator):
+ """
+ Return the properties of a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer group is
to be returned.
+ :param consumer_group_id: Required. The ID of the consumer group whose
configuration to return.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "consumer_group_id"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+ operator_extra_links = (ApacheKafkaConsumerGroupLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ consumer_group_id: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.consumer_group_id = consumer_group_id
+
+ def execute(self, context: Context):
+ ApacheKafkaConsumerGroupLink.persist(
+ context=context,
+ task_instance=self,
+ cluster_id=self.cluster_id,
+ consumer_group_id=self.consumer_group_id,
+ )
+ self.log.info("Getting Consumer Group: %s", self.consumer_group_id)
+ try:
+ consumer_group = self.hook.get_consumer_group(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ consumer_group_id=self.consumer_group_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info(
+ "The consumer group %s from cluster %s was retrieved.",
+ self.consumer_group_id,
+ self.cluster_id,
+ )
+ return types.ConsumerGroup.to_dict(consumer_group)
+ except NotFound as not_found_err:
+ self.log.info("The Consumer Group %s does not exist.",
self.consumer_group_id)
+ raise AirflowException(not_found_err)
+
+
+class ManagedKafkaUpdateConsumerGroupOperator(ManagedKafkaBaseOperator):
+ """
+ Update the properties of a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose topic is to be
updated.
+ :param consumer_group_id: Required. The ID of the consumer group whose
configuration to update.
+ :param consumer_group: Required. The consumer_group to update. Its
``name`` field must be populated.
+ :param update_mask: Required. Field mask is used to specify the fields to
be overwritten in the
+ ConsumerGroup resource by the update. The fields specified in the
update_mask are relative to the
+ resource, not the full request. A field will be overwritten if it is
in the mask.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "consumer_group_id", "consumer_group", "update_mask"}
+ | set(ManagedKafkaBaseOperator.template_fields)
+ )
+ operator_extra_links = (ApacheKafkaConsumerGroupLink(),)
+
+ def __init__(
+ self,
+ cluster_id: str,
+ consumer_group_id: str,
+ consumer_group: types.Topic | dict,
+ update_mask: FieldMask | dict,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.consumer_group_id = consumer_group_id
+ self.consumer_group = consumer_group
+ self.update_mask = update_mask
+
+ def execute(self, context: Context):
+ ApacheKafkaConsumerGroupLink.persist(
+ context=context,
+ task_instance=self,
+ cluster_id=self.cluster_id,
+ consumer_group_id=self.consumer_group_id,
+ )
+ self.log.info("Updating an Apache Kafka consumer group.")
+ try:
+ consumer_group_obj = self.hook.update_consumer_group(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ consumer_group_id=self.consumer_group_id,
+ consumer_group=self.consumer_group,
+ update_mask=self.update_mask,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Apache Kafka consumer group %s was updated.",
self.consumer_group_id)
+ return types.ConsumerGroup.to_dict(consumer_group_obj)
+ except NotFound as not_found_err:
+ self.log.info("The Consumer Group %s does not exist.",
self.consumer_group_id)
+ raise AirflowException(not_found_err)
+ except Exception as error:
+ raise AirflowException(error)
+
+
+class ManagedKafkaDeleteConsumerGroupOperator(ManagedKafkaBaseOperator):
+ """
+ Delete a single consumer group.
+
+ :param project_id: Required. The ID of the Google Cloud project that the
service belongs to.
+ :param location: Required. The ID of the Google Cloud region that the
service belongs to.
+ :param cluster_id: Required. The ID of the cluster whose consumer group is
to be deleted.
+ :param consumer_group_id: Required. The ID of the consumer group to delete.
+ :param retry: Designation of what errors, if any, should be retried.
+ :param timeout: The timeout for this request.
+ :param metadata: Strings which should be sent along with the request as
metadata.
+ :param gcp_conn_id: The connection ID to use connecting to Google Cloud.
+ :param impersonation_chain: Optional service account to impersonate using
short-term
+ credentials, or chained list of accounts required to get the
access_token
+ of the last account in the list, which will be impersonated in the
request.
+ If set as a string, the account must grant the originating account
+ the Service Account Token Creator IAM role.
+ If set as a sequence, the identities from the list must grant
+ Service Account Token Creator IAM role to the directly preceding
identity, with first
+ account from the list granting this role to the originating account
(templated).
+ """
+
+ template_fields: Sequence[str] = tuple(
+ {"cluster_id", "consumer_group_id"} |
set(ManagedKafkaBaseOperator.template_fields)
+ )
+
+ def __init__(
+ self,
+ cluster_id: str,
+ consumer_group_id: str,
+ *args,
+ **kwargs,
+ ) -> None:
+ super().__init__(*args, **kwargs)
+ self.cluster_id = cluster_id
+ self.consumer_group_id = consumer_group_id
+
+ def execute(self, context: Context):
+ try:
+ self.log.info("Deleting Apache Kafka consumer group: %s",
self.consumer_group_id)
+ self.hook.delete_consumer_group(
+ project_id=self.project_id,
+ location=self.location,
+ cluster_id=self.cluster_id,
+ consumer_group_id=self.consumer_group_id,
+ retry=self.retry,
+ timeout=self.timeout,
+ metadata=self.metadata,
+ )
+ self.log.info("Apache Kafka consumer group was deleted.")
+ except NotFound as not_found_err:
+ self.log.info("The Apache Kafka consumer group ID %s does not
exist.", self.consumer_group_id)
+ raise AirflowException(not_found_err)
diff --git a/providers/google/src/airflow/providers/google/get_provider_info.py
b/providers/google/src/airflow/providers/google/get_provider_info.py
index cd9b0a75acc..f97a32f8a5f 100644
--- a/providers/google/src/airflow/providers/google/get_provider_info.py
+++ b/providers/google/src/airflow/providers/google/get_provider_info.py
@@ -1570,6 +1570,7 @@ def get_provider_info():
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterLink",
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaClusterListLink",
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaTopicLink",
+
"airflow.providers.google.cloud.links.managed_kafka.ApacheKafkaConsumerGroupLink",
],
"secrets-backends": [
"airflow.providers.google.cloud.secrets.secret_manager.CloudSecretManagerBackend"
diff --git
a/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
new file mode 100644
index 00000000000..1c9929a569e
--- /dev/null
+++
b/providers/google/tests/system/google/cloud/managed_kafka/example_managed_kafka_consumer_group.py
@@ -0,0 +1,254 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+"""
+Example Airflow DAG for Google Cloud Managed Service for Apache Kafka testing
Topic operations.
+"""
+
+from __future__ import annotations
+
+import json
+import logging
+import os
+import random
+from datetime import datetime
+
+from airflow.decorators import task
+from airflow.models import Connection
+from airflow.models.dag import DAG
+from airflow.providers.apache.kafka.operators.consume import
ConsumeFromTopicOperator
+from airflow.providers.apache.kafka.operators.produce import
ProduceToTopicOperator
+from airflow.providers.google.cloud.operators.managed_kafka import (
+ ManagedKafkaCreateClusterOperator,
+ ManagedKafkaCreateTopicOperator,
+ ManagedKafkaDeleteClusterOperator,
+ ManagedKafkaDeleteConsumerGroupOperator,
+ ManagedKafkaGetConsumerGroupOperator,
+ ManagedKafkaListConsumerGroupsOperator,
+ ManagedKafkaUpdateConsumerGroupOperator,
+)
+from airflow.settings import Session
+from airflow.utils.trigger_rule import TriggerRule
+
+ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
+PROJECT_ID = os.environ.get("SYSTEM_TESTS_GCP_PROJECT", "default")
+DAG_ID = "managed_kafka_consumer_group_operations"
+LOCATION = "us-central1"
+
+CLUSTER_ID = f"cluster_{DAG_ID}_{ENV_ID}".replace("_", "-")
+CLUSTER_CONF = {
+ "gcp_config": {
+ "access_config": {
+ "network_configs": [
+ {"subnet":
f"projects/{PROJECT_ID}/regions/{LOCATION}/subnetworks/default"},
+ ],
+ },
+ },
+ "capacity_config": {
+ "vcpu_count": 3,
+ "memory_bytes": 3221225472,
+ },
+}
+TOPIC_ID = f"topic_{DAG_ID}_{ENV_ID}".replace("_", "-")
+TOPIC_CONF = {
+ "partition_count": 3,
+ "replication_factor": 3,
+}
+CONSUMER_GROUP_ID = f"consumer_group_{DAG_ID}_{ENV_ID}".replace("_", "-")
+CONNECTION_ID = f"connection_{DAG_ID}_{ENV_ID}"
+PORT = "9092"
+BOOTSTRAP_URL =
f"bootstrap.{CLUSTER_ID}.{LOCATION}.managedkafka.{PROJECT_ID}.cloud.goog:{PORT}"
+
+log = logging.getLogger(__name__)
+
+
+def producer():
+ """Produce and submit 10 messages"""
+
+ for i in range(10):
+ now = datetime.now()
+ datetime_string = now.strftime("%Y-%m-%d %H:%M:%S")
+
+ message_data = {"random_id": f"{ENV_ID}_{random.randint(1, 100)}",
"date_time": datetime_string}
+
+ yield (
+ json.dumps(i),
+ json.dumps(message_data),
+ )
+
+
+def consumer(message):
+ "Take in consumed messages and print its contents to the logs."
+
+ message_content = json.loads(message.value())
+ random_id = message_content["random_id"]
+ date_time = message_content["date_time"]
+ log.info("id: %s, date_time: %s", random_id, date_time)
+
+
+with DAG(
+ DAG_ID,
+ schedule="@once",
+ start_date=datetime(2021, 1, 1),
+ catchup=False,
+ tags=["example", "managed_kafka", "consumer_group"],
+) as dag:
+ create_cluster = ManagedKafkaCreateClusterOperator(
+ task_id="create_cluster",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster=CLUSTER_CONF,
+ cluster_id=CLUSTER_ID,
+ )
+
+ create_topic = ManagedKafkaCreateTopicOperator(
+ task_id="create_topic",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ topic_id=TOPIC_ID,
+ topic=TOPIC_CONF,
+ )
+
+ @task
+ def create_connection(connection_id: str):
+ conn = Connection(
+ conn_id=connection_id,
+ conn_type="kafka",
+ )
+ conn_extra = {
+ "bootstrap.servers": BOOTSTRAP_URL,
+ "security.protocol": "SASL_SSL",
+ "sasl.mechanisms": "OAUTHBEARER",
+ "group.id": CONSUMER_GROUP_ID,
+ }
+ conn_extra_json = json.dumps(conn_extra)
+ conn.set_extra(conn_extra_json)
+
+ session = Session()
+ log.info("Removing connection %s if it exists", connection_id)
+ query = session.query(Connection).filter(Connection.conn_id ==
connection_id)
+ query.delete()
+
+ session.add(conn)
+ session.commit()
+ log.info("Connection %s created", connection_id)
+
+ create_connection_task = create_connection(connection_id=CONNECTION_ID)
+
+ # [START how_to_cloud_managed_kafka_produce_to_topic_operator]
+ produce_to_topic = ProduceToTopicOperator(
+ task_id="produce_to_topic",
+ kafka_config_id=CONNECTION_ID,
+ topic=TOPIC_ID,
+ producer_function=producer,
+ poll_timeout=10,
+ )
+ # [END how_to_cloud_managed_kafka_produce_to_topic_operator]
+
+ # [START how_to_cloud_managed_kafka_consume_from_topic_operator]
+ consume_from_topic = ConsumeFromTopicOperator(
+ task_id="consume_from_topic",
+ kafka_config_id=CONNECTION_ID,
+ topics=[TOPIC_ID],
+ apply_function=consumer,
+ poll_timeout=20,
+ max_messages=20,
+ max_batch_size=20,
+ )
+ # [END how_to_cloud_managed_kafka_consume_from_topic_operator]
+
+ # [START how_to_cloud_managed_kafka_update_consumer_group_operator]
+ update_consumer_group = ManagedKafkaUpdateConsumerGroupOperator(
+ task_id="update_consumer_group",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ consumer_group_id=CONSUMER_GROUP_ID,
+ consumer_group={},
+ update_mask={},
+ )
+ # [END how_to_cloud_managed_kafka_update_consumer_group_operator]
+
+ # [START how_to_cloud_managed_kafka_get_consumer_group_operator]
+ get_consumer_group = ManagedKafkaGetConsumerGroupOperator(
+ task_id="get_consumer_group",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ consumer_group_id=CONSUMER_GROUP_ID,
+ )
+ # [END how_to_cloud_managed_kafka_get_consumer_group_operator]
+
+ # [START how_to_cloud_managed_kafka_delete_consumer_group_operator]
+ delete_consumer_group = ManagedKafkaDeleteConsumerGroupOperator(
+ task_id="delete_consumer_group",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ consumer_group_id=CONSUMER_GROUP_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+ # [END how_to_cloud_managed_kafka_delete_consumer_group_operator]
+
+ delete_cluster = ManagedKafkaDeleteClusterOperator(
+ task_id="delete_cluster",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ trigger_rule=TriggerRule.ALL_DONE,
+ )
+
+ # [START how_to_cloud_managed_kafka_list_consumer_group_operator]
+ list_consumer_groups = ManagedKafkaListConsumerGroupsOperator(
+ task_id="list_consumer_groups",
+ project_id=PROJECT_ID,
+ location=LOCATION,
+ cluster_id=CLUSTER_ID,
+ )
+ # [END how_to_cloud_managed_kafka_list_consumer_group_operator]
+
+ (
+ # TEST SETUP
+ create_cluster
+ >> create_topic
+ >> create_connection_task
+ >> produce_to_topic
+ >> consume_from_topic
+ # TEST BODY
+ >> update_consumer_group
+ >> get_consumer_group
+ >> list_consumer_groups
+ >> delete_consumer_group
+ # TEST TEARDOWN
+ >> delete_cluster
+ )
+
+ # ### Everything below this line is not part of example ###
+ # ### Just for system tests purpose ###
+ from tests_common.test_utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests_common.test_utils.system_tests import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)
diff --git
a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
index 7261f079555..c8e2a131fca 100644
--- a/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/hooks/test_managed_kafka.py
@@ -66,6 +66,8 @@ TEST_UPDATED_TOPIC: dict = {
"replication_factor": 1912,
}
+TEST_CONSUMER_GROUP_ID: str = "test-consumer-group-id"
+
BASE_STRING = "airflow.providers.google.common.hooks.base_google.{}"
MANAGED_KAFKA_STRING = "airflow.providers.google.cloud.hooks.managed_kafka.{}"
@@ -301,6 +303,98 @@ class TestManagedKafkaWithDefaultProjectIdHook:
TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
)
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_delete_consumer_group(self, mock_client) -> None:
+ self.hook.delete_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.delete_consumer_group.assert_called_once_with(
+
request=dict(name=mock_client.return_value.consumer_group_path.return_value),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID,
TEST_CONSUMER_GROUP_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_get_consumer_group(self, mock_client) -> None:
+ self.hook.get_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.get_consumer_group.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.consumer_group_path.return_value,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_CONSUMER_GROUP_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_update_consumer_group(self, mock_client) -> None:
+ self.hook.update_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ consumer_group={},
+ update_mask={},
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.update_consumer_group.assert_called_once_with(
+ request=dict(
+ update_mask={},
+ consumer_group={
+ "name":
mock_client.return_value.consumer_group_path.return_value,
+ **{},
+ },
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID,
TEST_CONSUMER_GROUP_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_list_consumer_groups(self, mock_client) -> None:
+ self.hook.list_consumer_groups(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.list_consumer_groups.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ page_size=None,
+ page_token=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
+
class TestManagedKafkaWithoutDefaultProjectIdHook:
def setup_method(self):
@@ -535,3 +629,98 @@ class TestManagedKafkaWithoutDefaultProjectIdHook:
mock_client.return_value.cluster_path.assert_called_once_with(
TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
)
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_delete_consumer_group(self, mock_client) -> None:
+ self.hook.delete_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.delete_consumer_group.assert_called_once_with(
+
request=dict(name=mock_client.return_value.consumer_group_path.return_value),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_CONSUMER_GROUP_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_get_consumer_group(self, mock_client) -> None:
+ self.hook.get_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.get_consumer_group.assert_called_once_with(
+ request=dict(
+ name=mock_client.return_value.consumer_group_path.return_value,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID,
+ TEST_LOCATION,
+ TEST_CLUSTER_ID,
+ TEST_CONSUMER_GROUP_ID,
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_update_consumer_group(self, mock_client) -> None:
+ self.hook.update_consumer_group(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ consumer_group={},
+ update_mask={},
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.update_consumer_group.assert_called_once_with(
+ request=dict(
+ update_mask={},
+ consumer_group={
+ "name":
mock_client.return_value.consumer_group_path.return_value,
+ **{},
+ },
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.consumer_group_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID,
TEST_CONSUMER_GROUP_ID
+ )
+
+
@mock.patch(MANAGED_KAFKA_STRING.format("ManagedKafkaHook.get_managed_kafka_client"))
+ def test_list_consumer_groups(self, mock_client) -> None:
+ self.hook.list_consumer_groups(
+ project_id=TEST_PROJECT_ID,
+ location=TEST_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ )
+ mock_client.assert_called_once()
+ mock_client.return_value.list_consumer_groups.assert_called_once_with(
+ request=dict(
+ parent=mock_client.return_value.cluster_path.return_value,
+ page_size=None,
+ page_token=None,
+ ),
+ metadata=(),
+ retry=DEFAULT,
+ timeout=None,
+ )
+ mock_client.return_value.cluster_path.assert_called_once_with(
+ TEST_PROJECT_ID, TEST_LOCATION, TEST_CLUSTER_ID
+ )
diff --git
a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
index 7bf671c68e6..8867b8c0616 100644
--- a/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/links/test_managed_kafka.py
@@ -22,6 +22,7 @@ from unittest import mock
from airflow.providers.google.cloud.links.managed_kafka import (
ApacheKafkaClusterLink,
ApacheKafkaClusterListLink,
+ ApacheKafkaConsumerGroupLink,
ApacheKafkaTopicLink,
)
@@ -29,6 +30,7 @@ TEST_LOCATION = "test-location"
TEST_CLUSTER_ID = "test-cluster-id"
TEST_PROJECT_ID = "test-project-id"
TEST_TOPIC_ID = "test-topic-id"
+TEST_CONSUMER_GROUP_ID = "test-consumer-group-id"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_NAME = "Apache Kafka Cluster"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_KEY = "cluster_conf"
EXPECTED_MANAGED_KAFKA_CLUSTER_LINK_FORMAT_STR = (
@@ -42,6 +44,11 @@ EXPECTED_MANAGED_KAFKA_TOPIC_LINK_KEY = "topic_conf"
EXPECTED_MANAGED_KAFKA_TOPIC_LINK_FORMAT_STR = (
"/managedkafka/{location}/clusters/{cluster_id}/topics/{topic_id}?project={project_id}"
)
+EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_NAME = "Apache Kafka Consumer Group"
+EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY = "consumer_group_conf"
+EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_FORMAT_STR = (
+
"/managedkafka/{location}/clusters/{cluster_id}/consumer_groups/{consumer_group_id}?project={project_id}"
+)
class TestApacheKafkaClusterLink:
@@ -125,3 +132,36 @@ class TestApacheKafkaTopicLink:
"project_id": TEST_PROJECT_ID,
},
)
+
+
+class TestApacheKafkaConsumerGroupLink:
+ def test_class_attributes(self):
+ assert ApacheKafkaConsumerGroupLink.key ==
EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY
+ assert ApacheKafkaConsumerGroupLink.name ==
EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_NAME
+ assert (
+ ApacheKafkaConsumerGroupLink.format_str ==
EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_FORMAT_STR
+ )
+
+ def test_persist(self):
+ mock_context, mock_task_instance = (
+ mock.MagicMock(),
+ mock.MagicMock(location=TEST_LOCATION, project_id=TEST_PROJECT_ID),
+ )
+
+ ApacheKafkaConsumerGroupLink.persist(
+ context=mock_context,
+ task_instance=mock_task_instance,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ )
+
+ mock_task_instance.xcom_push.assert_called_once_with(
+ context=mock_context,
+ key=EXPECTED_MANAGED_KAFKA_CONSUMER_GROUP_LINK_KEY,
+ value={
+ "location": TEST_LOCATION,
+ "cluster_id": TEST_CLUSTER_ID,
+ "consumer_group_id": TEST_CONSUMER_GROUP_ID,
+ "project_id": TEST_PROJECT_ID,
+ },
+ )
diff --git
a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
index e9407cc0a50..fd410682014 100644
--- a/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
+++ b/providers/google/tests/unit/google/cloud/operators/test_managed_kafka.py
@@ -24,12 +24,16 @@ from airflow.providers.google.cloud.operators.managed_kafka
import (
ManagedKafkaCreateClusterOperator,
ManagedKafkaCreateTopicOperator,
ManagedKafkaDeleteClusterOperator,
+ ManagedKafkaDeleteConsumerGroupOperator,
ManagedKafkaDeleteTopicOperator,
ManagedKafkaGetClusterOperator,
+ ManagedKafkaGetConsumerGroupOperator,
ManagedKafkaGetTopicOperator,
ManagedKafkaListClustersOperator,
+ ManagedKafkaListConsumerGroupsOperator,
ManagedKafkaListTopicsOperator,
ManagedKafkaUpdateClusterOperator,
+ ManagedKafkaUpdateConsumerGroupOperator,
ManagedKafkaUpdateTopicOperator,
)
@@ -80,6 +84,8 @@ TEST_UPDATED_TOPIC: dict = {
"replication_factor": 1912,
}
+TEST_CONSUMER_GROUP_ID: str = "test-consumer-group-id"
+
class TestManagedKafkaCreateClusterOperator:
@mock.patch(MANAGED_KAFKA_PATH.format("types.Cluster.to_dict"))
@@ -393,3 +399,128 @@ class TestManagedKafkaDeleteTopicOperator:
timeout=TIMEOUT,
metadata=METADATA,
)
+
+
+class TestManagedKafkaListConsumerGroupsOperator:
+
@mock.patch(MANAGED_KAFKA_PATH.format("types.ListConsumerGroupsResponse.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_cluster_dict_mock,
to_clusters_dict_mock):
+ page_token = "page_token"
+ page_size = 42
+
+ op = ManagedKafkaListConsumerGroupsOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ page_size=page_size,
+ page_token=page_token,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.list_consumer_groups.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ page_size=page_size,
+ page_token=page_token,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaGetConsumerGroupOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_dict_mock):
+ op = ManagedKafkaGetConsumerGroupOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.get_consumer_group.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaUpdateConsumerGroupOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("types.ConsumerGroup.to_dict"))
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook, to_dict_mock):
+ op = ManagedKafkaUpdateConsumerGroupOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ consumer_group={},
+ update_mask={},
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={"ti": mock.MagicMock()})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.update_consumer_group.assert_called_once_with(
+ project_id=GCP_PROJECT,
+ location=GCP_LOCATION,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ consumer_group={},
+ update_mask={},
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+
+
+class TestManagedKafkaDeleteConsumerGroupOperator:
+ @mock.patch(MANAGED_KAFKA_PATH.format("ManagedKafkaHook"))
+ def test_execute(self, mock_hook):
+ op = ManagedKafkaDeleteConsumerGroupOperator(
+ task_id=TASK_ID,
+ gcp_conn_id=GCP_CONN_ID,
+ impersonation_chain=IMPERSONATION_CHAIN,
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )
+ op.execute(context={})
+ mock_hook.assert_called_once_with(gcp_conn_id=GCP_CONN_ID,
impersonation_chain=IMPERSONATION_CHAIN)
+ mock_hook.return_value.delete_consumer_group.assert_called_once_with(
+ location=GCP_LOCATION,
+ project_id=GCP_PROJECT,
+ cluster_id=TEST_CLUSTER_ID,
+ consumer_group_id=TEST_CONSUMER_GROUP_ID,
+ retry=RETRY,
+ timeout=TIMEOUT,
+ metadata=METADATA,
+ )