This is an automated email from the ASF dual-hosted git repository.
shahar1 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1bfa1f35760 Add BigQueryStreamingBufferEmptySensor for DML on
streaming tables (#66652)
1bfa1f35760 is described below
commit 1bfa1f35760ed75dea90dd599f4e8de91d057bbb
Author: Radhwène Dhouafli <[email protected]>
AuthorDate: Wed May 13 21:24:13 2026 +0200
Add BigQueryStreamingBufferEmptySensor for DML on streaming tables (#66652)
---
providers/google/docs/operators/cloud/bigquery.rst | 22 +++
.../providers/google/cloud/sensors/bigquery.py | 91 +++++++++++
.../providers/google/cloud/triggers/bigquery.py | 104 ++++++++++++
.../cloud/bigquery/example_bigquery_sensors.py | 69 +++++++-
.../unit/google/cloud/sensors/test_bigquery.py | 107 ++++++++++++
.../unit/google/cloud/triggers/test_bigquery.py | 182 ++++++++++++++++++++-
6 files changed, 572 insertions(+), 3 deletions(-)
diff --git a/providers/google/docs/operators/cloud/bigquery.rst
b/providers/google/docs/operators/cloud/bigquery.rst
index 82f43c319dc..2008e7d03d0 100644
--- a/providers/google/docs/operators/cloud/bigquery.rst
+++ b/providers/google/docs/operators/cloud/bigquery.rst
@@ -526,6 +526,28 @@ Also you can use deferrable mode in this operator if you
would like to free up t
:start-after: [START howto_sensor_bigquery_table_partition_async]
:end-before: [END howto_sensor_bigquery_table_partition_async]
+Check that the BigQuery Table Streaming Buffer is empty
+""""""""""""""""""""""""""""""""""""""""""""""""""""""""
+
+To check that the BigQuery streaming buffer of a table is empty you can use
+:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryStreamingBufferEmptySensor`.
+This sensor is useful in ETL pipelines to ensure that recent streamed data has
been fully
+processed before continuing downstream tasks.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_bigquery_streaming_buffer_empty]
+ :end-before: [END howto_sensor_bigquery_streaming_buffer_empty]
+
+Also you can use deferrable mode in this operator if you would like to free up
the worker slots while the sensor is running.
+
+.. exampleinclude::
/../../google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_sensor_bigquery_streaming_buffer_empty_deferred]
+ :end-before: [END howto_sensor_bigquery_streaming_buffer_empty_deferred]
+
Reference
^^^^^^^^^
diff --git
a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
index 71e37bff784..aa40186df60 100644
--- a/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/sensors/bigquery.py
@@ -30,6 +30,7 @@ from airflow.exceptions import
AirflowProviderDeprecationWarning
from airflow.providers.common.compat.sdk import AirflowException,
BaseSensorOperator, conf
from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.triggers.bigquery import (
+ BigQueryStreamingBufferEmptyTrigger,
BigQueryTableExistenceTrigger,
BigQueryTablePartitionExistenceTrigger,
)
@@ -319,3 +320,93 @@ class
BigQueryTablePartitionExistenceSensor(BaseSensorOperator):
message = "No event received in trigger callback"
raise AirflowException(message)
+
+
+class BigQueryStreamingBufferEmptySensor(BaseSensorOperator):
+ """
+ Wait for the streaming buffer of a BigQuery table to be empty.
+
+ BigQuery DML statements (UPDATE, DELETE, MERGE) cannot run against rows
that
+ are still in the streaming buffer; the buffer is flushed within ~90
minutes.
+ Use this sensor between a streaming insert and a DML step to avoid
+ ``UPDATE/MERGE/DELETE statement over table ... would affect rows in the
+ streaming buffer`` errors.
+
+ :param project_id: Google Cloud project containing the table.
+ :param dataset_id: Dataset of the table to monitor.
+ :param table_id: Table to monitor.
+ :param gcp_conn_id: Airflow connection ID for GCP.
+ :param impersonation_chain: Optional service account to impersonate, or a
+ chained list of accounts. See the Google provider docs for details.
+ :param deferrable: Run in deferrable mode using
+ :class:`BigQueryStreamingBufferEmptyTrigger`.
+ """
+
+ template_fields: Sequence[str] = (
+ "project_id",
+ "dataset_id",
+ "table_id",
+ "impersonation_chain",
+ )
+
+ ui_color = "#f0eee4"
+
+ def __init__(
+ self,
+ *,
+ project_id: str,
+ dataset_id: str,
+ table_id: str,
+ gcp_conn_id: str = "google_cloud_default",
+ impersonation_chain: str | Sequence[str] | None = None,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ) -> None:
+ if deferrable and "poke_interval" not in kwargs:
+ kwargs["poke_interval"] = 30
+
+ super().__init__(**kwargs)
+
+ self.project_id = project_id
+ self.dataset_id = dataset_id
+ self.table_id = table_id
+ self.gcp_conn_id = gcp_conn_id
+ self.impersonation_chain = impersonation_chain
+ self.deferrable = deferrable
+
+ def execute(self, context: Context) -> None:
+ if not self.deferrable:
+ super().execute(context)
+ return
+ if self.poke(context=context):
+ return
+ self.defer(
+ timeout=timedelta(seconds=self.timeout),
+ trigger=BigQueryStreamingBufferEmptyTrigger(
+ project_id=self.project_id,
+ dataset_id=self.dataset_id,
+ table_id=self.table_id,
+ poll_interval=self.poke_interval,
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ ),
+ method_name="execute_complete",
+ )
+
+ def execute_complete(self, context: dict[str, Any], event: dict[str, str]
| None = None) -> str:
+ if event is None:
+ raise ValueError("No event received in trigger callback")
+ if event["status"] == "success":
+ return event["message"]
+ raise RuntimeError(event["message"])
+
+ def poke(self, context: Context) -> bool:
+ table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
+ self.log.info("Checking streaming buffer state for table: %s",
table_uri)
+
+ hook = BigQueryHook(gcp_conn_id=self.gcp_conn_id,
impersonation_chain=self.impersonation_chain)
+ try:
+ table =
hook.get_client(project_id=self.project_id).get_table(table_uri)
+ except NotFound as err:
+ raise ValueError(f"Table {table_uri} not found") from err
+ return table.streaming_buffer is None
diff --git
a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
index c9f8acde0a5..ff31833a7c3 100644
--- a/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
+++ b/providers/google/src/airflow/providers/google/cloud/triggers/bigquery.py
@@ -873,3 +873,107 @@ class
BigQueryTablePartitionExistenceTrigger(BigQueryTableExistenceTrigger):
if records:
records = [row[0] for row in records]
return self.partition_id in records
+
+
+class BigQueryStreamingBufferEmptyTrigger(BaseTrigger):
+ """
+ Poll a BigQuery table until its streaming buffer is empty.
+
+ Used by
:class:`~airflow.providers.google.cloud.sensors.bigquery.BigQueryStreamingBufferEmptySensor`
+ in deferrable mode.
+
+ :param project_id: Google Cloud project ID.
+ :param dataset_id: Dataset of the table to monitor.
+ :param table_id: Table to monitor.
+ :param gcp_conn_id: Airflow connection ID for GCP.
+ :param poll_interval: Seconds between polls.
+ :param impersonation_chain: Optional service account to impersonate, or a
+ chained list of accounts.
+ """
+
+ def __init__(
+ self,
+ project_id: str,
+ dataset_id: str,
+ table_id: str,
+ gcp_conn_id: str,
+ poll_interval: float = 30.0,
+ impersonation_chain: str | Sequence[str] | None = None,
+ ):
+ super().__init__()
+ self.project_id = project_id
+ self.dataset_id = dataset_id
+ self.table_id = table_id
+ self.gcp_conn_id = gcp_conn_id
+ self.poll_interval = poll_interval
+ self.impersonation_chain = impersonation_chain
+
+ def serialize(self) -> tuple[str, dict[str, Any]]:
+ return (
+
"airflow.providers.google.cloud.triggers.bigquery.BigQueryStreamingBufferEmptyTrigger",
+ {
+ "project_id": self.project_id,
+ "dataset_id": self.dataset_id,
+ "table_id": self.table_id,
+ "gcp_conn_id": self.gcp_conn_id,
+ "poll_interval": self.poll_interval,
+ "impersonation_chain": self.impersonation_chain,
+ },
+ )
+
+ def _get_async_hook(self) -> BigQueryTableAsyncHook:
+ return BigQueryTableAsyncHook(
+ gcp_conn_id=self.gcp_conn_id,
+ impersonation_chain=self.impersonation_chain,
+ )
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ table_uri = f"{self.project_id}:{self.dataset_id}.{self.table_id}"
+ try:
+ hook = self._get_async_hook()
+ async with ClientSession() as session:
+ while True:
+ self.log.info("Checking streaming buffer for table %s",
table_uri)
+ is_empty = await self._is_streaming_buffer_empty(
+ hook=hook,
+ session=session,
+ project_id=self.project_id,
+ dataset_id=self.dataset_id,
+ table_id=self.table_id,
+ )
+ if is_empty:
+ message = f"Streaming buffer is empty for table:
{table_uri}"
+ self.log.info(message)
+ yield TriggerEvent({"status": "success", "message":
message})
+ return
+ self.log.info("Streaming buffer not empty, sleeping %ss",
self.poll_interval)
+ await asyncio.sleep(self.poll_interval)
+ except Exception as e:
+ self.log.exception("Error while checking streaming buffer for
table %s", table_uri)
+ yield TriggerEvent({"status": "error", "message": str(e)})
+
+ async def _is_streaming_buffer_empty(
+ self,
+ hook: BigQueryTableAsyncHook,
+ session: ClientSession,
+ project_id: str,
+ dataset_id: str,
+ table_id: str,
+ ) -> bool:
+ try:
+ client = await hook.get_table_client(
+ dataset=dataset_id,
+ table_id=table_id,
+ project_id=project_id,
+ session=session,
+ )
+ response = await client.get()
+ except ClientResponseError as err:
+ if err.status == 404:
+ raise ValueError(f"Table {project_id}.{dataset_id}.{table_id}
not found") from err
+ raise
+
+ if not response:
+ raise ValueError(f"Table {project_id}.{dataset_id}.{table_id} does
not exist")
+
+ return response.get("streamingBuffer") is None
diff --git
a/providers/google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
b/providers/google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
index 59d5b210715..849b978d4b6 100644
---
a/providers/google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
+++
b/providers/google/tests/system/google/cloud/bigquery/example_bigquery_sensors.py
@@ -25,6 +25,7 @@ import os
from datetime import datetime
from airflow.models.dag import DAG
+from airflow.providers.google.cloud.hooks.bigquery import BigQueryHook
from airflow.providers.google.cloud.operators.bigquery import (
BigQueryCreateEmptyDatasetOperator,
BigQueryCreateTableOperator,
@@ -32,14 +33,16 @@ from airflow.providers.google.cloud.operators.bigquery
import (
BigQueryInsertJobOperator,
)
from airflow.providers.google.cloud.sensors.bigquery import (
+ BigQueryStreamingBufferEmptySensor,
BigQueryTableExistenceSensor,
BigQueryTablePartitionExistenceSensor,
)
try:
- from airflow.sdk import TriggerRule
+ from airflow.sdk import TriggerRule, task
except ImportError:
# Compatibility for Airflow < 3.1
+ from airflow.decorators import task # type: ignore[no-redef,attr-defined]
from airflow.utils.trigger_rule import TriggerRule # type:
ignore[no-redef,attr-defined]
ENV_ID = os.environ.get("SYSTEM_TESTS_ENV_ID", "default")
@@ -54,6 +57,11 @@ PARTITION_NAME = "{{ ds_nodash }}"
INSERT_ROWS_QUERY = f"INSERT {DATASET_NAME}.{TABLE_NAME} VALUES (42, '{{{{ ds
}}}}')"
+# DML on rows still in the streaming buffer is rejected by BigQuery, hence the
+# sensor in the streaming-insert -> sensor -> DML chain below.
+STREAMING_UPDATE_QUERY = f"UPDATE {DATASET_NAME}.{TABLE_NAME} SET value = 200
WHERE value = 100"
+STREAMING_DELETE_QUERY = f"DELETE FROM {DATASET_NAME}.{TABLE_NAME} WHERE value
= 200"
+
SCHEMA = [
{"name": "value", "type": "INTEGER", "mode": "REQUIRED"},
{"name": "ds", "type": "DATE", "mode": "NULLABLE"},
@@ -152,6 +160,60 @@ with DAG(
)
# [END howto_sensor_bigquery_table_partition_async]
+ # [START howto_sensor_bigquery_streaming_buffer_empty]
+ check_streaming_buffer_empty = BigQueryStreamingBufferEmptySensor(
+ task_id="check_streaming_buffer_empty",
+ project_id=PROJECT_ID,
+ dataset_id=DATASET_NAME,
+ table_id=TABLE_NAME,
+ poke_interval=30,
+ timeout=5400, # BigQuery flushes the streaming buffer within ~90
minutes
+ )
+ # [END howto_sensor_bigquery_streaming_buffer_empty]
+
+ @task(task_id="streaming_insert")
+ def streaming_insert(ds: str | None = None) -> None:
+ BigQueryHook().insert_all(
+ project_id=PROJECT_ID,
+ dataset_id=DATASET_NAME,
+ table_id=TABLE_NAME,
+ rows=[{"json": {"value": 100, "ds": ds}}],
+ )
+
+ streaming_insert_task = streaming_insert()
+
+ stream_update = BigQueryInsertJobOperator(
+ task_id="stream_update",
+ configuration={
+ "query": {
+ "query": STREAMING_UPDATE_QUERY,
+ "useLegacySql": False,
+ }
+ },
+ )
+
+ stream_delete = BigQueryInsertJobOperator(
+ task_id="stream_delete",
+ configuration={
+ "query": {
+ "query": STREAMING_DELETE_QUERY,
+ "useLegacySql": False,
+ }
+ },
+ )
+
+ # [START howto_sensor_bigquery_streaming_buffer_empty_deferred]
+ check_streaming_buffer_empty_def = BigQueryStreamingBufferEmptySensor(
+ task_id="check_streaming_buffer_empty_def",
+ project_id=PROJECT_ID,
+ dataset_id=DATASET_NAME,
+ table_id=TABLE_NAME,
+ deferrable=True,
+ poke_interval=30,
+ timeout=5400, # BigQuery flushes the streaming buffer within ~90
minutes
+ )
+ # [END howto_sensor_bigquery_streaming_buffer_empty_deferred]
+
delete_dataset = BigQueryDeleteDatasetOperator(
task_id="delete_dataset",
dataset_id=DATASET_NAME,
@@ -169,6 +231,11 @@ with DAG(
check_table_partition_exists_async,
check_table_partition_exists_def,
]
+ >> streaming_insert_task
+ >> check_streaming_buffer_empty
+ >> stream_update
+ >> check_streaming_buffer_empty_def
+ >> stream_delete
>> delete_dataset
)
diff --git a/providers/google/tests/unit/google/cloud/sensors/test_bigquery.py
b/providers/google/tests/unit/google/cloud/sensors/test_bigquery.py
index 174f3b217bd..bba88a91025 100644
--- a/providers/google/tests/unit/google/cloud/sensors/test_bigquery.py
+++ b/providers/google/tests/unit/google/cloud/sensors/test_bigquery.py
@@ -24,10 +24,12 @@ from google.api_core.exceptions import NotFound
from airflow.providers.common.compat.sdk import AirflowException, TaskDeferred
from airflow.providers.google.cloud.sensors.bigquery import (
BigQueryRoutineExistenceSensor,
+ BigQueryStreamingBufferEmptySensor,
BigQueryTableExistenceSensor,
BigQueryTablePartitionExistenceSensor,
)
from airflow.providers.google.cloud.triggers.bigquery import (
+ BigQueryStreamingBufferEmptyTrigger,
BigQueryTableExistenceTrigger,
BigQueryTablePartitionExistenceTrigger,
)
@@ -291,3 +293,108 @@ def context():
"""
context = {}
return context
+
+
+def _make_streaming_sensor(**overrides):
+ kwargs = {
+ "task_id": "task-id",
+ "project_id": TEST_PROJECT_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "table_id": TEST_TABLE_ID,
+ }
+ kwargs.update(overrides)
+ return BigQueryStreamingBufferEmptySensor(**kwargs)
+
+
+class TestBigQueryStreamingBufferEmptySensor:
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+ def test_poke_returns_true_when_buffer_absent(self, mock_hook):
+ sensor = _make_streaming_sensor(
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ )
+ mock_table = mock.MagicMock(streaming_buffer=None)
+ mock_hook.return_value.get_client.return_value.get_table.return_value
= mock_table
+
+ assert sensor.poke(mock.MagicMock()) is True
+
+ mock_hook.assert_called_once_with(
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ )
+
mock_hook.return_value.get_client.assert_called_once_with(project_id=TEST_PROJECT_ID)
+
mock_hook.return_value.get_client.return_value.get_table.assert_called_once_with(
+ f"{TEST_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
+ )
+
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+ def test_poke_returns_false_when_buffer_present(self, mock_hook):
+ sensor = _make_streaming_sensor()
+ mock_hook.return_value.get_client.return_value.get_table.return_value
= mock.MagicMock(
+ streaming_buffer={"estimatedRows": 10}
+ )
+
+ assert sensor.poke(mock.MagicMock()) is False
+
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+ def test_poke_raises_value_error_when_table_not_found(self, mock_hook):
+ mock_hook.return_value.get_client.return_value.get_table.side_effect =
NotFound("missing")
+
+ with pytest.raises(ValueError, match="not found"):
+ _make_streaming_sensor().poke(mock.MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+ def test_poke_propagates_unexpected_errors(self, mock_hook):
+ mock_hook.return_value.get_client.return_value.get_table.side_effect =
RuntimeError("boom")
+
+ with pytest.raises(RuntimeError, match="boom"):
+ _make_streaming_sensor().poke(mock.MagicMock())
+
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+
@mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryStreamingBufferEmptySensor.defer")
+ def test_execute_does_not_defer_when_buffer_already_empty(self,
mock_defer, mock_hook):
+ mock_hook.return_value.get_client.return_value.get_table.return_value
= mock.MagicMock(
+ streaming_buffer=None
+ )
+
+ _make_streaming_sensor(deferrable=True).execute(mock.MagicMock())
+
+ mock_defer.assert_not_called()
+
+ @mock.patch("airflow.providers.google.cloud.sensors.bigquery.BigQueryHook")
+ def test_execute_defers_with_trigger_when_buffer_not_empty(self,
mock_hook):
+ sensor = _make_streaming_sensor(
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ deferrable=True,
+ )
+ mock_hook.return_value.get_client.return_value.get_table.return_value
= mock.MagicMock(
+ streaming_buffer={"estimatedRows": 1}
+ )
+
+ with pytest.raises(TaskDeferred) as exc:
+ sensor.execute(mock.MagicMock())
+
+ trigger = exc.value.trigger
+ assert isinstance(trigger, BigQueryStreamingBufferEmptyTrigger)
+ # impersonation_chain must be passed directly to the trigger, not
buried in hook_params,
+ # otherwise async hook construction silently drops it.
+ assert trigger.impersonation_chain == TEST_IMPERSONATION_CHAIN
+ assert trigger.gcp_conn_id == TEST_GCP_CONN_ID
+ assert trigger.project_id == TEST_PROJECT_ID
+ assert trigger.dataset_id == TEST_DATASET_ID
+ assert trigger.table_id == TEST_TABLE_ID
+
+ def test_execute_complete_returns_message_on_success(self):
+ sensor = _make_streaming_sensor(deferrable=True)
+ assert sensor.execute_complete(context={}, event={"status": "success",
"message": "ok"}) == "ok"
+
+ def test_execute_complete_raises_runtime_error_on_error_event(self):
+ with pytest.raises(RuntimeError, match="boom"):
+ _make_streaming_sensor(deferrable=True).execute_complete(
+ context={}, event={"status": "error", "message": "boom"}
+ )
+
+ def test_execute_complete_raises_value_error_when_event_is_none(self):
+ with pytest.raises(ValueError, match="No event received in trigger
callback"):
+
_make_streaming_sensor(deferrable=True).execute_complete(context={}, event=None)
diff --git a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
index 78448c064a0..f974cccc338 100644
--- a/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
+++ b/providers/google/tests/unit/google/cloud/triggers/test_bigquery.py
@@ -17,15 +17,16 @@
from __future__ import annotations
import asyncio
+import contextlib
import logging
from typing import Any
from unittest import mock
from unittest.mock import AsyncMock
import pytest
-from aiohttp import ClientResponseError, RequestInfo
+from aiohttp import ClientResponseError, ClientSession, RequestInfo
from gcloud.aio.bigquery import Table
-from multidict import CIMultiDict
+from multidict import CIMultiDict, CIMultiDictProxy
from yarl import URL
from airflow.providers.google.cloud.hooks.bigquery import
BigQueryTableAsyncHook
@@ -34,6 +35,7 @@ from airflow.providers.google.cloud.triggers.bigquery import (
BigQueryGetDataTrigger,
BigQueryInsertJobTrigger,
BigQueryIntervalCheckTrigger,
+ BigQueryStreamingBufferEmptyTrigger,
BigQueryTableExistenceTrigger,
BigQueryTablePartitionExistenceTrigger,
BigQueryValueCheckTrigger,
@@ -996,3 +998,179 @@ class TestBigQueryTablePartitionExistenceTrigger:
"poll_interval": POLLING_PERIOD_SECONDS,
"hook_params": TEST_HOOK_PARAMS,
}
+
+
[email protected]
+def streaming_buffer_trigger():
+ return BigQueryStreamingBufferEmptyTrigger(
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ poll_interval=POLLING_PERIOD_SECONDS,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ )
+
+
+def _make_client_response_error(status: int, message: str = "Not Found") ->
ClientResponseError:
+ return ClientResponseError(
+ history=(),
+ request_info=RequestInfo(
+ headers=CIMultiDictProxy(CIMultiDict()),
+ real_url=URL("https://example.com"),
+ method="GET",
+ url=URL("https://example.com"),
+ ),
+ status=status,
+ message=message,
+ )
+
+
+_TRIGGER_PATH =
"airflow.providers.google.cloud.triggers.bigquery.BigQueryStreamingBufferEmptyTrigger"
+
+
+class TestBigQueryStreamingBufferEmptyTrigger:
+ def test_serialization(self, streaming_buffer_trigger):
+ classpath, kwargs = streaming_buffer_trigger.serialize()
+ assert classpath == _TRIGGER_PATH
+ assert kwargs == {
+ "project_id": TEST_GCP_PROJECT_ID,
+ "dataset_id": TEST_DATASET_ID,
+ "table_id": TEST_TABLE_ID,
+ "gcp_conn_id": TEST_GCP_CONN_ID,
+ "poll_interval": POLLING_PERIOD_SECONDS,
+ "impersonation_chain": TEST_IMPERSONATION_CHAIN,
+ }
+
+
@mock.patch("airflow.providers.google.cloud.triggers.bigquery.BigQueryTableAsyncHook")
+ def test_async_hook_receives_impersonation_chain(self, mock_hook_cls,
streaming_buffer_trigger):
+ streaming_buffer_trigger._get_async_hook()
+ mock_hook_cls.assert_called_once_with(
+ gcp_conn_id=TEST_GCP_CONN_ID,
+ impersonation_chain=TEST_IMPERSONATION_CHAIN,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{_TRIGGER_PATH}._is_streaming_buffer_empty")
+ @mock.patch(f"{_TRIGGER_PATH}._get_async_hook")
+ async def test_run_yields_success_when_buffer_empty(
+ self, _mock_hook, mock_is_empty, streaming_buffer_trigger
+ ):
+ mock_is_empty.return_value = True
+ actual = await streaming_buffer_trigger.run().asend(None)
+
+ table_uri = f"{TEST_GCP_PROJECT_ID}:{TEST_DATASET_ID}.{TEST_TABLE_ID}"
+ assert actual == TriggerEvent(
+ {"status": "success", "message": f"Streaming buffer is empty for
table: {table_uri}"}
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{_TRIGGER_PATH}._is_streaming_buffer_empty")
+ @mock.patch(f"{_TRIGGER_PATH}._get_async_hook")
+ async def test_run_keeps_polling_while_buffer_not_empty(
+ self, _mock_hook, mock_is_empty, streaming_buffer_trigger
+ ):
+ mock_is_empty.return_value = False
+ task = asyncio.create_task(streaming_buffer_trigger.run().__anext__())
+ try:
+ with pytest.raises(asyncio.TimeoutError):
+ await asyncio.wait_for(asyncio.shield(task), timeout=0.2)
+ assert not task.done()
+ finally:
+ task.cancel()
+ with contextlib.suppress(asyncio.CancelledError):
+ await task
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{_TRIGGER_PATH}._is_streaming_buffer_empty")
+ @mock.patch(f"{_TRIGGER_PATH}._get_async_hook")
+ async def test_run_yields_error_when_table_not_found(
+ self, _mock_hook, mock_is_empty, streaming_buffer_trigger
+ ):
+ message = f"Table
{TEST_GCP_PROJECT_ID}.{TEST_DATASET_ID}.{TEST_TABLE_ID} not found"
+ mock_is_empty.side_effect = ValueError(message)
+
+ actual = await streaming_buffer_trigger.run().asend(None)
+
+ assert actual == TriggerEvent({"status": "error", "message": message})
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{_TRIGGER_PATH}._is_streaming_buffer_empty")
+ @mock.patch(f"{_TRIGGER_PATH}._get_async_hook")
+ async def test_run_yields_error_on_unexpected_exception(
+ self, _mock_hook, mock_is_empty, streaming_buffer_trigger
+ ):
+ mock_is_empty.side_effect = Exception("API failure")
+
+ actual = await streaming_buffer_trigger.run().asend(None)
+
+ assert actual == TriggerEvent({"status": "error", "message": "API
failure"})
+
+ @pytest.mark.asyncio
+ async def test_is_streaming_buffer_empty_true_when_key_absent(self,
streaming_buffer_trigger):
+ mock_hook = mock.MagicMock(spec=BigQueryTableAsyncHook)
+ mock_client = mock.MagicMock()
+ mock_client.get = AsyncMock(return_value={"id": "some-table"})
+ mock_hook.get_table_client = AsyncMock(return_value=mock_client)
+
+ async with ClientSession() as session:
+ result = await streaming_buffer_trigger._is_streaming_buffer_empty(
+ hook=mock_hook,
+ session=session,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+
+ assert result is True
+ mock_hook.get_table_client.assert_awaited_once()
+
+ @pytest.mark.asyncio
+ async def test_is_streaming_buffer_empty_false_when_key_present(self,
streaming_buffer_trigger):
+ mock_hook = mock.MagicMock(spec=BigQueryTableAsyncHook)
+ mock_client = mock.MagicMock()
+ mock_client.get = AsyncMock(
+ return_value={"streamingBuffer": {"estimatedRows": "10"}, "id":
"some-table"}
+ )
+ mock_hook.get_table_client = AsyncMock(return_value=mock_client)
+
+ async with ClientSession() as session:
+ result = await streaming_buffer_trigger._is_streaming_buffer_empty(
+ hook=mock_hook,
+ session=session,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+
+ assert result is False
+
+ @pytest.mark.asyncio
+ async def test_is_streaming_buffer_empty_raises_value_error_on_404(self,
streaming_buffer_trigger):
+ mock_hook = mock.MagicMock(spec=BigQueryTableAsyncHook)
+ mock_hook.get_table_client =
AsyncMock(side_effect=_make_client_response_error(404))
+
+ async with ClientSession() as session:
+ with pytest.raises(ValueError, match="not found"):
+ await streaming_buffer_trigger._is_streaming_buffer_empty(
+ hook=mock_hook,
+ session=session,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )
+
+ @pytest.mark.asyncio
+ async def
test_is_streaming_buffer_empty_propagates_other_client_errors(self,
streaming_buffer_trigger):
+ mock_hook = mock.MagicMock(spec=BigQueryTableAsyncHook)
+ mock_hook.get_table_client =
AsyncMock(side_effect=_make_client_response_error(500, "Server Error"))
+
+ async with ClientSession() as session:
+ with pytest.raises(ClientResponseError):
+ await streaming_buffer_trigger._is_streaming_buffer_empty(
+ hook=mock_hook,
+ session=session,
+ project_id=TEST_GCP_PROJECT_ID,
+ dataset_id=TEST_DATASET_ID,
+ table_id=TEST_TABLE_ID,
+ )