This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 04217f1d9b airflow.models.xcom deprecations removed (#41803)
04217f1d9b is described below
commit 04217f1d9b64b85d1d49acf97421e7a38ad61b50
Author: Gopal Dirisala <[email protected]>
AuthorDate: Sun Sep 1 21:19:14 2024 +0530
airflow.models.xcom deprecations removed (#41803)
---
airflow/models/baseoperator.py | 6 +-
airflow/models/taskinstance.py | 16 +-
airflow/models/xcom.py | 325 +++------------------
airflow/serialization/pydantic/taskinstance.py | 3 -
tests/models/test_xcom.py | 166 -----------
tests/providers/amazon/aws/links/test_base_aws.py | 13 +-
.../google/cloud/operators/test_bigquery_dts.py | 11 +-
.../google/cloud/operators/test_dataproc.py | 154 +++++++---
tests/providers/microsoft/conftest.py | 10 +-
tests/providers/yandex/links/test_yq.py | 13 +-
tests/providers/yandex/operators/test_yq.py | 30 +-
tests/test_utils/compat.py | 1 +
12 files changed, 191 insertions(+), 557 deletions(-)
diff --git a/airflow/models/baseoperator.py b/airflow/models/baseoperator.py
index 7d18aa1474..2347289428 100644
--- a/airflow/models/baseoperator.py
+++ b/airflow/models/baseoperator.py
@@ -1592,7 +1592,6 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
context: Any,
key: str,
value: Any,
- execution_date: datetime | None = None,
) -> None:
"""
Make an XCom available for tasks to pull.
@@ -1601,11 +1600,8 @@ class BaseOperator(AbstractOperator,
metaclass=BaseOperatorMeta):
:param key: A key for the XCom
:param value: A value for the XCom. The value is pickled and stored
in the database.
- :param execution_date: if provided, the XCom will not be visible until
- this date. This can be used, for example, to send a message to a
- task on a future date without it being immediately visible.
"""
- context["ti"].xcom_push(key=key, value=value,
execution_date=execution_date)
+ context["ti"].xcom_push(key=key, value=value)
@staticmethod
@provide_session
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index 0d633f8bf3..165f5c7987 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -79,7 +79,6 @@ from airflow.exceptions import (
AirflowSkipException,
AirflowTaskTerminated,
AirflowTaskTimeout,
- RemovedInAirflow3Warning,
TaskDeferralError,
TaskDeferred,
UnmappableXComLengthPushed,
@@ -3473,7 +3472,6 @@ class TaskInstance(Base, LoggingMixin):
self,
key: str,
value: Any,
- execution_date: datetime | None = None,
session: Session = NEW_SESSION,
) -> None:
"""
@@ -3483,19 +3481,7 @@ class TaskInstance(Base, LoggingMixin):
:param value: Value to store. What types are possible depends on
whether
``enable_xcom_pickling`` is true or not. If so, this can be any
picklable object; only be JSON-serializable may be used otherwise.
- :param execution_date: Deprecated parameter that has no effect.
- """
- if execution_date is not None:
- self_execution_date = self.get_dagrun(session).execution_date
- if execution_date < self_execution_date:
- raise ValueError(
- f"execution_date can not be in the past (current
execution_date is "
- f"{self_execution_date}; received {execution_date})"
- )
- elif execution_date is not None:
- message = "Passing 'execution_date' to
'TaskInstance.xcom_push()' is deprecated."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
-
+ """
XCom.set(
key=key,
value=value,
diff --git a/airflow/models/xcom.py b/airflow/models/xcom.py
index 9829f11fbb..87c72d5bf7 100644
--- a/airflow/models/xcom.py
+++ b/airflow/models/xcom.py
@@ -21,9 +21,7 @@ import inspect
import json
import logging
import pickle
-import warnings
-from functools import wraps
-from typing import TYPE_CHECKING, Any, Iterable, cast, overload
+from typing import TYPE_CHECKING, Any, Iterable, cast
from sqlalchemy import (
Column,
@@ -40,15 +38,13 @@ from sqlalchemy import (
from sqlalchemy.dialects.mysql import LONGBLOB
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy.orm import Query, reconstructor, relationship
-from sqlalchemy.orm.exc import NoResultFound
from airflow.api_internal.internal_api_call import internal_api_call
from airflow.configuration import conf
-from airflow.exceptions import RemovedInAirflow3Warning
from airflow.models.base import COLLATION_ARGS, ID_LEN,
TaskInstanceDependencies
from airflow.utils import timezone
from airflow.utils.db import LazySelectSequence
-from airflow.utils.helpers import exactly_one, is_container
+from airflow.utils.helpers import is_container
from airflow.utils.json import XComDecoder, XComEncoder
from airflow.utils.log.logging_mixin import LoggingMixin
from airflow.utils.session import NEW_SESSION, provide_session
@@ -64,9 +60,6 @@ from airflow.utils.xcom import (
log = logging.getLogger(__name__)
if TYPE_CHECKING:
- import datetime
-
- import pendulum
from sqlalchemy.engine import Row
from sqlalchemy.orm import Session
from sqlalchemy.sql.expression import Select, TextClause
@@ -134,8 +127,9 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
return f'<XCom "{self.key}" ({self.task_id} @ {self.run_id})>'
return f'<XCom "{self.key}" ({self.task_id}[{self.map_index}] @
{self.run_id})>'
- @overload
@classmethod
+ @internal_api_call
+ @provide_session
def set(
cls,
key: str,
@@ -150,9 +144,6 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
"""
Store an XCom value.
- A deprecated form of this function accepts ``execution_date`` instead
of
- ``run_id``. The two arguments are mutually exclusive.
-
:param key: Key to store the XCom.
:param value: XCom value to store.
:param dag_id: DAG ID.
@@ -163,67 +154,14 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
:param session: Database session. If not given, a new session will be
created for this function.
"""
-
- @overload
- @classmethod
- def set(
- cls,
- key: str,
- value: Any,
- task_id: str,
- dag_id: str,
- execution_date: datetime.datetime,
- session: Session = NEW_SESSION,
- ) -> None:
- """
- Store an XCom value.
-
- :sphinx-autoapi-skip:
- """
-
- @classmethod
- @internal_api_call
- @provide_session
- def set(
- cls,
- key: str,
- value: Any,
- task_id: str,
- dag_id: str,
- execution_date: datetime.datetime | None = None,
- session: Session = NEW_SESSION,
- *,
- run_id: str | None = None,
- map_index: int = -1,
- ) -> None:
- """
- Store an XCom value.
-
- :sphinx-autoapi-skip:
- """
from airflow.models.dagrun import DagRun
- if not exactly_one(execution_date is not None, run_id is not None):
- raise ValueError(
- f"Exactly one of run_id or execution_date must be passed. "
- f"Passed execution_date={execution_date}, run_id={run_id}"
- )
+ if not run_id:
+ raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
- if run_id is None:
- message = "Passing 'execution_date' to 'XCom.set()' is deprecated.
Use 'run_id' instead."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
- try:
- dag_run_id, run_id = (
- session.query(DagRun.id, DagRun.run_id)
- .filter(DagRun.dag_id == dag_id, DagRun.execution_date ==
execution_date)
- .one()
- )
- except NoResultFound:
- raise ValueError(f"DAG run not found on DAG {dag_id!r} at
{execution_date}") from None
- else:
- dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id,
run_id=run_id).scalar()
- if dag_run_id is None:
- raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID
{run_id!r}")
+ dag_run_id = session.query(DagRun.id).filter_by(dag_id=dag_id,
run_id=run_id).scalar()
+ if dag_run_id is None:
+ raise ValueError(f"DAG run not found on DAG {dag_id!r} with ID
{run_id!r}")
# Seamlessly resolve LazySelectSequence to a list. This intends to work
# as a "lazy list" to avoid pulling a ton of XComs unnecessarily, but
if
@@ -242,7 +180,7 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
"return value" if key == XCOM_RETURN_KEY else f"value {key}",
task_id,
dag_id,
- run_id or execution_date,
+ run_id,
)
value = list(value)
@@ -311,17 +249,18 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
session=session,
)
- @overload
@staticmethod
+ @provide_session
@internal_api_call
def get_one(
*,
key: str | None = None,
dag_id: str | None = None,
task_id: str | None = None,
- run_id: str | None = None,
+ run_id: str,
map_index: int | None = None,
session: Session = NEW_SESSION,
+ include_prior_dates: bool = False,
) -> Any | None:
"""
Retrieve an XCom value, optionally meeting certain criteria.
@@ -333,9 +272,6 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
If there are no results, *None* is returned. If multiple XCom entries
match the criteria, an arbitrary one is returned.
- A deprecated form of this function accepts ``execution_date`` instead
of
- ``run_id``. The two arguments are mutually exclusive.
-
.. seealso:: ``get_value()`` is a convenience function if you already
have a structured TaskInstance or TaskInstanceKey object available.
@@ -354,83 +290,27 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
:param session: Database session. If not given, a new session will be
created for this function.
"""
-
- @overload
- @staticmethod
- @internal_api_call
- def get_one(
- execution_date: datetime.datetime,
- key: str | None = None,
- task_id: str | None = None,
- dag_id: str | None = None,
- include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
- ) -> Any | None:
- """
- Retrieve an XCom value, optionally meeting certain criteria.
-
- :sphinx-autoapi-skip:
- """
-
- @staticmethod
- @provide_session
- @internal_api_call
- def get_one(
- execution_date: datetime.datetime | None = None,
- key: str | None = None,
- task_id: str | None = None,
- dag_id: str | None = None,
- include_prior_dates: bool = False,
- session: Session = NEW_SESSION,
- *,
- run_id: str | None = None,
- map_index: int | None = None,
- ) -> Any | None:
- """
- Retrieve an XCom value, optionally meeting certain criteria.
-
- :sphinx-autoapi-skip:
- """
- if not exactly_one(execution_date is not None, run_id is not None):
- raise ValueError("Exactly one of run_id or execution_date must be
passed")
-
- if run_id:
- query = BaseXCom.get_many(
- run_id=run_id,
- key=key,
- task_ids=task_id,
- dag_ids=dag_id,
- map_indexes=map_index,
- include_prior_dates=include_prior_dates,
- limit=1,
- session=session,
- )
- elif execution_date is not None:
- message = "Passing 'execution_date' to 'XCom.get_one()' is
deprecated. Use 'run_id' instead."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
-
- with warnings.catch_warnings():
- warnings.simplefilter("ignore", RemovedInAirflow3Warning)
- query = BaseXCom.get_many(
- execution_date=execution_date,
- key=key,
- task_ids=task_id,
- dag_ids=dag_id,
- map_indexes=map_index,
- include_prior_dates=include_prior_dates,
- limit=1,
- session=session,
- )
- else:
- raise RuntimeError("Should not happen?")
+ query = BaseXCom.get_many(
+ run_id=run_id,
+ key=key,
+ task_ids=task_id,
+ dag_ids=dag_id,
+ map_indexes=map_index,
+ include_prior_dates=include_prior_dates,
+ limit=1,
+ session=session,
+ )
result = query.with_entities(BaseXCom.value).first()
if result:
return XCom.deserialize_value(result)
return None
- @overload
+ # The 'get_many` is not supported via database isolation mode. Attempting
to use it in DB isolation
+ # mode will result in a crash - Resulting Query object cannot be
**really** serialized
+ # TODO(potiuk) - document it in AIP-44 docs
@staticmethod
+ @provide_session
def get_many(
*,
run_id: str,
@@ -448,9 +328,6 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
This function returns an SQLAlchemy query of full XCom objects. If you
just want one stored value, use :meth:`get_one` instead.
- A deprecated form of this function accepts ``execution_date`` instead
of
- ``run_id``. The two arguments are mutually exclusive.
-
:param run_id: DAG run ID for the task.
:param key: A key for the XComs. If provided, only XComs with matching
keys will be returned. Pass *None* (default) to remove the filter.
@@ -467,58 +344,10 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
created for this function.
:param limit: Limiting returning XComs
"""
-
- @overload
- @staticmethod
- @internal_api_call
- def get_many(
- execution_date: datetime.datetime,
- key: str | None = None,
- task_ids: str | Iterable[str] | None = None,
- dag_ids: str | Iterable[str] | None = None,
- map_indexes: int | Iterable[int] | None = None,
- include_prior_dates: bool = False,
- limit: int | None = None,
- session: Session = NEW_SESSION,
- ) -> Query:
- """
- Composes a query to get one or more XCom entries.
-
- :sphinx-autoapi-skip:
- """
-
- # The 'get_many` is not supported via database isolation mode. Attempting
to use it in DB isolation
- # mode will result in a crash - Resulting Query object cannot be
**really** serialized
- # TODO(potiuk) - document it in AIP-44 docs
- @staticmethod
- @provide_session
- def get_many(
- execution_date: datetime.datetime | None = None,
- key: str | None = None,
- task_ids: str | Iterable[str] | None = None,
- dag_ids: str | Iterable[str] | None = None,
- map_indexes: int | Iterable[int] | None = None,
- include_prior_dates: bool = False,
- limit: int | None = None,
- session: Session = NEW_SESSION,
- *,
- run_id: str | None = None,
- ) -> Query:
- """
- Composes a query to get one or more XCom entries.
-
- :sphinx-autoapi-skip:
- """
from airflow.models.dagrun import DagRun
- if not exactly_one(execution_date is not None, run_id is not None):
- raise ValueError(
- f"Exactly one of run_id or execution_date must be passed. "
- f"Passed execution_date={execution_date}, run_id={run_id}"
- )
- if execution_date is not None:
- message = "Passing 'execution_date' to 'XCom.get_many()' is
deprecated. Use 'run_id' instead."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
+ if not run_id:
+ raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
query = session.query(BaseXCom).join(BaseXCom.dag_run)
@@ -545,13 +374,8 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
query = query.filter(BaseXCom.map_index == map_indexes)
if include_prior_dates:
- if execution_date is not None:
- query = query.filter(DagRun.execution_date <= execution_date)
- else:
- dr = session.query(DagRun.execution_date).filter(DagRun.run_id
== run_id).subquery()
- query = query.filter(BaseXCom.execution_date <=
dr.c.execution_date)
- elif execution_date is not None:
- query = query.filter(DagRun.execution_date == execution_date)
+ dr = session.query(DagRun.execution_date).filter(DagRun.run_id ==
run_id).subquery()
+ query = query.filter(BaseXCom.execution_date <=
dr.c.execution_date)
else:
query = query.filter(BaseXCom.run_id == run_id)
@@ -578,8 +402,8 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
"""Purge an XCom entry from underlying storage implementations."""
pass
- @overload
@staticmethod
+ @provide_session
@internal_api_call
def clear(
*,
@@ -592,9 +416,6 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
"""
Clear all XCom data from the database for the given task instance.
- A deprecated form of this function accepts ``execution_date`` instead
of
- ``run_id``. The two arguments are mutually exclusive.
-
:param dag_id: ID of DAG to clear the XCom for.
:param task_id: ID of task to clear the XCom for.
:param run_id: ID of DAG run to clear the XCom for.
@@ -603,41 +424,6 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
:param session: Database session. If not given, a new session will be
created for this function.
"""
-
- @overload
- @staticmethod
- @internal_api_call
- def clear(
- execution_date: pendulum.DateTime,
- dag_id: str,
- task_id: str,
- session: Session = NEW_SESSION,
- ) -> None:
- """
- Clear all XCom data from the database for the given task instance.
-
- :sphinx-autoapi-skip:
- """
-
- @staticmethod
- @provide_session
- @internal_api_call
- def clear(
- execution_date: pendulum.DateTime | None = None,
- dag_id: str | None = None,
- task_id: str | None = None,
- session: Session = NEW_SESSION,
- *,
- run_id: str | None = None,
- map_index: int | None = None,
- ) -> None:
- """
- Clear all XCom data from the database for the given task instance.
-
- :sphinx-autoapi-skip:
- """
- from airflow.models import DagRun
-
# Given the historic order of this function (execution_date was first
argument) to add a new optional
# param we need to add default values for everything :(
if dag_id is None:
@@ -645,20 +431,8 @@ class BaseXCom(TaskInstanceDependencies, LoggingMixin):
if task_id is None:
raise TypeError("clear() missing required argument: task_id")
- if not exactly_one(execution_date is not None, run_id is not None):
- raise ValueError(
- f"Exactly one of run_id or execution_date must be passed. "
- f"Passed execution_date={execution_date}, run_id={run_id}"
- )
-
- if execution_date is not None:
- message = "Passing 'execution_date' to 'XCom.clear()' is
deprecated. Use 'run_id' instead."
- warnings.warn(message, RemovedInAirflow3Warning, stacklevel=3)
- run_id = (
- session.query(DagRun.run_id)
- .filter(DagRun.dag_id == dag_id, DagRun.execution_date ==
execution_date)
- .scalar()
- )
+ if not run_id:
+ raise ValueError(f"run_id must be passed. Passed run_id={run_id}")
query = session.query(BaseXCom).filter_by(dag_id=dag_id,
task_id=task_id, run_id=run_id)
if map_index is not None:
@@ -747,33 +521,6 @@ class LazyXComSelectSequence(LazySelectSequence[Any]):
return XCom.deserialize_value(row)
-def _patch_outdated_serializer(clazz: type[BaseXCom], params: Iterable[str])
-> None:
- """
- Patch a custom ``serialize_value`` to accept the modern signature.
-
- To give custom XCom backends more flexibility with how they store values,
we
- now forward all params passed to ``XCom.set`` to ``XCom.serialize_value``.
- In order to maintain compatibility with custom XCom backends written with
- the old signature, we check the signature and, if necessary, patch with a
- method that ignores kwargs the backend does not accept.
- """
- old_serializer = clazz.serialize_value
-
- @wraps(old_serializer)
- def _shim(**kwargs):
- kwargs = {k: kwargs.get(k) for k in params}
- warnings.warn(
- f"Method `serialize_value` in XCom backend {XCom.__name__} is
using outdated signature and"
- f"must be updated to accept all params in `BaseXCom.set` except
`session`. Support will be "
- f"removed in a future release.",
- RemovedInAirflow3Warning,
- stacklevel=1,
- )
- return old_serializer(**kwargs)
-
- clazz.serialize_value = _shim # type: ignore[assignment]
-
-
def _get_function_params(function) -> list[str]:
"""
Return the list of variables names of a function.
@@ -801,10 +548,6 @@ def resolve_xcom_backend() -> type[BaseXCom]:
raise TypeError(
f"Your custom XCom class `{clazz.__name__}` is not a subclass of
`{BaseXCom.__name__}`."
)
- base_xcom_params = _get_function_params(BaseXCom.serialize_value)
- xcom_params = _get_function_params(clazz.serialize_value)
- if set(base_xcom_params) != set(xcom_params):
- _patch_outdated_serializer(clazz=clazz, params=xcom_params)
return clazz
diff --git a/airflow/serialization/pydantic/taskinstance.py
b/airflow/serialization/pydantic/taskinstance.py
index 0dcb7880eb..549b03680d 100644
--- a/airflow/serialization/pydantic/taskinstance.py
+++ b/airflow/serialization/pydantic/taskinstance.py
@@ -203,7 +203,6 @@ class TaskInstancePydantic(BaseModelPydantic, LoggingMixin):
self,
key: str,
value: Any,
- execution_date: datetime | None = None,
session: Session | None = None,
) -> None:
"""
@@ -211,13 +210,11 @@ class TaskInstancePydantic(BaseModelPydantic,
LoggingMixin):
:param key: the key to identify the XCom value
:param value: the value of the XCom
- :param execution_date: the execution date to push the XCom for
"""
return TaskInstance.xcom_push(
self=self, # type: ignore[arg-type]
key=key,
value=value,
- execution_date=execution_date,
session=session,
)
diff --git a/tests/models/test_xcom.py b/tests/models/test_xcom.py
index e9db3d946d..07533ec944 100644
--- a/tests/models/test_xcom.py
+++ b/tests/models/test_xcom.py
@@ -214,33 +214,6 @@ class TestXCom:
assert value == {"key": "value"}
XCom.orm_deserialize_value.assert_not_called()
- @pytest.mark.skip_if_database_isolation_mode
- @conf_vars({("core", "enable_xcom_pickling"): "False"})
- @mock.patch("airflow.models.xcom.conf.getimport")
- def test_set_serialize_call_old_signature(self, get_import, task_instance):
- """
- When XCom.serialize_value takes only param ``value``, other kwargs
should be ignored.
- """
- serialize_watcher = MagicMock()
-
- class OldSignatureXCom(BaseXCom):
- @staticmethod
- def serialize_value(value, **kwargs):
- serialize_watcher(value=value, **kwargs)
- return json.dumps(value).encode("utf-8")
-
- get_import.return_value = OldSignatureXCom
-
- XCom = resolve_xcom_backend()
- XCom.set(
- key=XCOM_RETURN_KEY,
- value={"my_xcom_key": "my_xcom_value"},
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- run_id=task_instance.run_id,
- )
- serialize_watcher.assert_called_once_with(value={"my_xcom_key":
"my_xcom_value"})
-
@pytest.mark.skip_if_database_isolation_mode
@conf_vars({("core", "enable_xcom_pickling"): "False"})
@mock.patch("airflow.models.xcom.conf.getimport")
@@ -335,19 +308,6 @@ class TestXComGet:
)
assert stored_value == {"key": "value"}
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_get_one")
- def test_xcom_get_one_with_execution_date(self, session, task_instance):
- with pytest.deprecated_call():
- stored_value = XCom.get_one(
- key="xcom_1",
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- execution_date=task_instance.execution_date,
- session=session,
- )
- assert stored_value == {"key": "value"}
-
@pytest.fixture
def tis_for_xcom_get_one_from_prior_date(self, task_instance_factory,
push_simple_json_xcom):
date1 = timezone.datetime(2021, 12, 3, 4, 56)
@@ -376,24 +336,6 @@ class TestXComGet:
)
assert retrieved_value == {"key": "value"}
- @pytest.mark.skip_if_database_isolation_mode
- def test_xcom_get_one_from_prior_with_execution_date(
- self,
- session,
- tis_for_xcom_get_one_from_prior_date,
- ):
- _, ti2 = tis_for_xcom_get_one_from_prior_date
- with pytest.deprecated_call():
- retrieved_value = XCom.get_one(
- execution_date=ti2.execution_date,
- key="xcom_1",
- task_id="task_1",
- dag_id="dag",
- include_prior_dates=True,
- session=session,
- )
- assert retrieved_value == {"key": "value"}
-
@pytest.mark.skip_if_database_isolation_mode
@pytest.fixture
def setup_for_xcom_get_many_single_argument_value(self, task_instance,
push_simple_json_xcom):
@@ -413,21 +355,6 @@ class TestXComGet:
assert stored_xcoms[0].key == "xcom_1"
assert stored_xcoms[0].value == {"key": "value"}
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_get_many_single_argument_value")
- def test_xcom_get_many_single_argument_value_with_execution_date(self,
session, task_instance):
- with pytest.deprecated_call():
- stored_xcoms = XCom.get_many(
- execution_date=task_instance.execution_date,
- key="xcom_1",
- dag_ids=task_instance.dag_id,
- task_ids=task_instance.task_id,
- session=session,
- ).all()
- assert len(stored_xcoms) == 1
- assert stored_xcoms[0].key == "xcom_1"
- assert stored_xcoms[0].value == {"key": "value"}
-
@pytest.mark.skip_if_database_isolation_mode
@pytest.fixture
def setup_for_xcom_get_many_multiple_tasks(self, task_instances,
push_simple_json_xcom):
@@ -448,20 +375,6 @@ class TestXComGet:
sorted_values = [x.value for x in sorted(stored_xcoms,
key=operator.attrgetter("task_id"))]
assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}]
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_get_many_multiple_tasks")
- def test_xcom_get_many_multiple_tasks_with_execution_date(self, session,
task_instance):
- with pytest.deprecated_call():
- stored_xcoms = XCom.get_many(
- execution_date=task_instance.execution_date,
- key="xcom_1",
- dag_ids=task_instance.dag_id,
- task_ids=["task_1", "task_2"],
- session=session,
- )
- sorted_values = [x.value for x in sorted(stored_xcoms,
key=operator.attrgetter("task_id"))]
- assert sorted_values == [{"key1": "value1"}, {"key2": "value2"}]
-
@pytest.fixture
def tis_for_xcom_get_many_from_prior_dates(self, task_instance_factory,
push_simple_json_xcom):
date1 = timezone.datetime(2021, 12, 3, 4, 56)
@@ -488,27 +401,6 @@ class TestXComGet:
assert [x.value for x in stored_xcoms] == [{"key2": "value2"},
{"key1": "value1"}]
assert [x.execution_date for x in stored_xcoms] ==
[ti2.execution_date, ti1.execution_date]
- @pytest.mark.skip_if_database_isolation_mode
- def test_xcom_get_many_from_prior_dates_with_execution_date(
- self,
- session,
- tis_for_xcom_get_many_from_prior_dates,
- ):
- ti1, ti2 = tis_for_xcom_get_many_from_prior_dates
- with pytest.deprecated_call():
- stored_xcoms = XCom.get_many(
- execution_date=ti2.execution_date,
- key="xcom_1",
- dag_ids="dag",
- task_ids="task_1",
- include_prior_dates=True,
- session=session,
- )
-
- # The retrieved XComs should be ordered by logical date, latest first.
- assert [x.value for x in stored_xcoms] == [{"key2": "value2"},
{"key1": "value1"}]
- assert [x.execution_date for x in stored_xcoms] ==
[ti2.execution_date, ti1.execution_date]
-
@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComSet:
@@ -528,24 +420,6 @@ class TestXComSet:
assert stored_xcoms[0].task_id == "task_1"
assert stored_xcoms[0].execution_date == task_instance.execution_date
- @pytest.mark.skip_if_database_isolation_mode
- def test_xcom_set_with_execution_date(self, session, task_instance):
- with pytest.deprecated_call():
- XCom.set(
- key="xcom_1",
- value={"key": "value"},
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- execution_date=task_instance.execution_date,
- session=session,
- )
- stored_xcoms = session.query(XCom).all()
- assert stored_xcoms[0].key == "xcom_1"
- assert stored_xcoms[0].value == {"key": "value"}
- assert stored_xcoms[0].dag_id == "dag"
- assert stored_xcoms[0].task_id == "task_1"
- assert stored_xcoms[0].execution_date == task_instance.execution_date
-
@pytest.fixture
def setup_for_xcom_set_again_replace(self, task_instance,
push_simple_json_xcom):
push_simple_json_xcom(ti=task_instance, key="xcom_1", value={"key1":
"value1"})
@@ -563,21 +437,6 @@ class TestXComSet:
)
assert session.query(XCom).one().value == {"key2": "value2"}
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_set_again_replace")
- def test_xcom_set_again_replace_with_execution_date(self, session,
task_instance):
- assert session.query(XCom).one().value == {"key1": "value1"}
- with pytest.deprecated_call():
- XCom.set(
- key="xcom_1",
- value={"key2": "value2"},
- dag_id=task_instance.dag_id,
- task_id="task_1",
- execution_date=task_instance.execution_date,
- session=session,
- )
- assert session.query(XCom).one().value == {"key2": "value2"}
-
@pytest.mark.usefixtures("setup_xcom_pickling")
class TestXComClear:
@@ -598,19 +457,6 @@ class TestXComClear:
assert session.query(XCom).count() == 0
assert mock_purge.call_count == 0 if is_db_isolation_mode() else 1
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_clear")
- def test_xcom_clear_with_execution_date(self, session, task_instance):
- assert session.query(XCom).count() == 1
- with pytest.deprecated_call():
- XCom.clear(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- execution_date=task_instance.execution_date,
- session=session,
- )
- assert session.query(XCom).count() == 0
-
@pytest.mark.usefixtures("setup_for_xcom_clear")
def test_xcom_clear_different_run(self, session, task_instance):
XCom.clear(
@@ -620,15 +466,3 @@ class TestXComClear:
session=session,
)
assert session.query(XCom).count() == 1
-
- @pytest.mark.skip_if_database_isolation_mode
- @pytest.mark.usefixtures("setup_for_xcom_clear")
- def test_xcom_clear_different_execution_date(self, session, task_instance):
- with pytest.deprecated_call():
- XCom.clear(
- dag_id=task_instance.dag_id,
- task_id=task_instance.task_id,
- execution_date=timezone.utcnow(),
- session=session,
- )
- assert session.query(XCom).count() == 1
diff --git a/tests/providers/amazon/aws/links/test_base_aws.py
b/tests/providers/amazon/aws/links/test_base_aws.py
index 446d584edf..1afcfea0a8 100644
--- a/tests/providers/amazon/aws/links/test_base_aws.py
+++ b/tests/providers/amazon/aws/links/test_base_aws.py
@@ -25,6 +25,7 @@ import pytest
from airflow.models.xcom import XCom
from airflow.providers.amazon.aws.links.base_aws import BaseAwsLink
from airflow.serialization.serialized_objects import SerializedDAG
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.mock_operators import MockOperator
if TYPE_CHECKING:
@@ -75,11 +76,13 @@ class TestBaseAwsLink:
)
ti = mock_context["ti"]
- ti.xcom_push.assert_called_once_with(
- execution_date=None,
- key=XCOM_KEY,
- value=expected_value,
- )
+ if AIRFLOW_V_3_0_PLUS:
+ ti.xcom_push.assert_called_once_with(
+ key=XCOM_KEY,
+ value=expected_value,
+ )
+ else:
+ ti.xcom_push.assert_called_once_with(key=XCOM_KEY,
value=expected_value, execution_date=None)
def test_disable_xcom_push(self):
mock_context = mock.MagicMock()
diff --git a/tests/providers/google/cloud/operators/test_bigquery_dts.py
b/tests/providers/google/cloud/operators/test_bigquery_dts.py
index b3145151d3..f44479bbce 100644
--- a/tests/providers/google/cloud/operators/test_bigquery_dts.py
+++ b/tests/providers/google/cloud/operators/test_bigquery_dts.py
@@ -27,6 +27,7 @@ from airflow.providers.google.cloud.operators.bigquery_dts
import (
BigQueryDataTransferServiceStartTransferRunsOperator,
BigQueryDeleteDataTransferConfigOperator,
)
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
PROJECT_ID = "id"
@@ -71,7 +72,10 @@ class TestBigQueryCreateDataTransferOperator:
retry=DEFAULT,
timeout=None,
)
- ti.xcom_push.assert_called_with(execution_date=None,
key="transfer_config_id", value="1a2b3c")
+ if AIRFLOW_V_3_0_PLUS:
+ ti.xcom_push.assert_called_with(key="transfer_config_id",
value="1a2b3c")
+ else:
+ ti.xcom_push.assert_called_with(key="transfer_config_id",
value="1a2b3c", execution_date=None)
assert "secret_access_key" not in return_value.get("params", {})
assert "access_key_id" not in return_value.get("params", {})
@@ -126,7 +130,10 @@ class
TestBigQueryDataTransferServiceStartTransferRunsOperator:
retry=DEFAULT,
timeout=None,
)
- ti.xcom_push.assert_called_with(execution_date=None, key="run_id",
value="123")
+ if AIRFLOW_V_3_0_PLUS:
+ ti.xcom_push.assert_called_with(key="run_id", value="123")
+ else:
+ ti.xcom_push.assert_called_with(key="run_id", value="123",
execution_date=None)
@mock.patch(
f"{OPERATOR_MODULE_PATH}.BiqQueryDataTransferServiceHook",
diff --git a/tests/providers/google/cloud/operators/test_dataproc.py
b/tests/providers/google/cloud/operators/test_dataproc.py
index 1d1f2a1ef8..58b38125ee 100644
--- a/tests/providers/google/cloud/operators/test_dataproc.py
+++ b/tests/providers/google/cloud/operators/test_dataproc.py
@@ -79,7 +79,7 @@ from airflow.providers.google.cloud.triggers.dataproc import (
from airflow.providers.google.common.consts import
GOOGLE_DEFAULT_DEFERRABLE_METHOD_NAME
from airflow.serialization.serialized_objects import SerializedDAG
from airflow.utils.timezone import datetime
-from tests.test_utils.compat import AIRFLOW_VERSION
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS, AIRFLOW_VERSION
from tests.test_utils.db import clear_db_runs, clear_db_xcom
AIRFLOW_VERSION_LABEL = "v" + str(AIRFLOW_VERSION).replace(".",
"-").replace("+", "-")
@@ -440,19 +440,32 @@ class DataprocTestBase:
class DataprocJobTestBase(DataprocTestBase):
@classmethod
def setup_class(cls):
- cls.extra_links_expected_calls = [
- call.ti.xcom_push(execution_date=None, key="conf",
value=DATAPROC_JOB_CONF_EXPECTED),
- call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION,
project_id=GCP_PROJECT),
- ]
+ if AIRFLOW_V_3_0_PLUS:
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_JOB_CONF_EXPECTED),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID,
region=GCP_REGION, project_id=GCP_PROJECT),
+ ]
+ else:
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID,
region=GCP_REGION, project_id=GCP_PROJECT),
+ ]
class DataprocClusterTestBase(DataprocTestBase):
@classmethod
def setup_class(cls):
super().setup_class()
- cls.extra_links_expected_calls_base = [
- call.ti.xcom_push(execution_date=None, key="dataproc_cluster",
value=DATAPROC_CLUSTER_EXPECTED)
- ]
+ if AIRFLOW_V_3_0_PLUS:
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(key="dataproc_cluster",
value=DATAPROC_CLUSTER_EXPECTED)
+ ]
+ else:
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(
+ key="dataproc_cluster", value=DATAPROC_CLUSTER_EXPECTED,
execution_date=None
+ )
+ ]
class TestsClusterGenerator:
@@ -758,11 +771,17 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
- self.mock_ti.xcom_push.assert_called_once_with(
- key="dataproc_cluster",
- value=DATAPROC_CLUSTER_EXPECTED,
- execution_date=None,
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ )
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ execution_date=None,
+ )
@mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -808,11 +827,17 @@ class
TestDataprocCreateClusterOperator(DataprocClusterTestBase):
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
to_dict_mock.assert_called_once_with(mock_hook().wait_for_operation())
- self.mock_ti.xcom_push.assert_called_once_with(
- key="dataproc_cluster",
- value=DATAPROC_CLUSTER_EXPECTED,
- execution_date=None,
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ )
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ execution_date=None,
+ )
@mock.patch(DATAPROC_PATH.format("Cluster.to_dict"))
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@@ -1095,9 +1120,14 @@ class
TestDataprocClusterScaleOperator(DataprocClusterTestBase):
@classmethod
def setup_class(cls):
super().setup_class()
- cls.extra_links_expected_calls_base = [
- call.ti.xcom_push(execution_date=None, key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED)
- ]
+ if AIRFLOW_V_3_0_PLUS:
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED)
+ ]
+ else:
+ cls.extra_links_expected_calls_base = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_CLUSTER_CONF_EXPECTED, execution_date=None)
+ ]
def test_deprecation_warning(self):
with pytest.warns(AirflowProviderDeprecationWarning) as warnings:
@@ -1142,11 +1172,17 @@ class
TestDataprocClusterScaleOperator(DataprocClusterTestBase):
# Test whether xcom push occurs before cluster is updated
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
- self.mock_ti.xcom_push.assert_called_once_with(
- key="conf",
- value=DATAPROC_CLUSTER_CONF_EXPECTED,
- execution_date=None,
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ )
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="conf",
+ value=DATAPROC_CLUSTER_CONF_EXPECTED,
+ execution_date=None,
+ )
@pytest.mark.db_test
@@ -1310,9 +1346,12 @@ class TestDataprocClusterDeleteOperator:
class TestDataprocSubmitJobOperator(DataprocJobTestBase):
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute(self, mock_hook):
- xcom_push_call = call.ti.xcom_push(
- execution_date=None, key="dataproc_job",
value=DATAPROC_JOB_EXPECTED
- )
+ if AIRFLOW_V_3_0_PLUS:
+ xcom_push_call = call.ti.xcom_push(key="dataproc_job",
value=DATAPROC_JOB_EXPECTED)
+ else:
+ xcom_push_call = call.ti.xcom_push(
+ key="dataproc_job", value=DATAPROC_JOB_EXPECTED,
execution_date=None
+ )
wait_for_job_call = call.hook().wait_for_job(
job_id=TEST_JOB_ID, region=GCP_REGION, project_id=GCP_PROJECT,
timeout=None
)
@@ -1358,9 +1397,12 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
job_id=TEST_JOB_ID, project_id=GCP_PROJECT, region=GCP_REGION,
timeout=None
)
- self.mock_ti.xcom_push.assert_called_once_with(
- key="dataproc_job", value=DATAPROC_JOB_EXPECTED,
execution_date=None
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job",
value=DATAPROC_JOB_EXPECTED)
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_job", value=DATAPROC_JOB_EXPECTED,
execution_date=None
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
def test_execute_async(self, mock_hook):
@@ -1398,9 +1440,12 @@ class TestDataprocSubmitJobOperator(DataprocJobTestBase):
)
mock_hook.return_value.wait_for_job.assert_not_called()
- self.mock_ti.xcom_push.assert_called_once_with(
- key="dataproc_job", value=DATAPROC_JOB_EXPECTED,
execution_date=None
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(key="dataproc_job",
value=DATAPROC_JOB_EXPECTED)
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_job", value=DATAPROC_JOB_EXPECTED,
execution_date=None
+ )
@mock.patch(DATAPROC_PATH.format("DataprocHook"))
@mock.patch(DATAPROC_TRIGGERS_PATH.format("DataprocAsyncHook"))
@@ -1633,11 +1678,17 @@ class
TestDataprocUpdateClusterOperator(DataprocClusterTestBase):
# Test whether the xcom push happens before updating the cluster
self.extra_links_manager_mock.assert_has_calls(expected_calls,
any_order=False)
- self.mock_ti.xcom_push.assert_called_once_with(
- key="dataproc_cluster",
- value=DATAPROC_CLUSTER_EXPECTED,
- execution_date=None,
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ )
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="dataproc_cluster",
+ value=DATAPROC_CLUSTER_EXPECTED,
+ execution_date=None,
+ )
def test_missing_region_parameter(self):
with pytest.raises(AirflowException):
@@ -2399,10 +2450,16 @@ class TestDataProcSparkSqlOperator:
class TestDataProcSparkOperator(DataprocJobTestBase):
@classmethod
def setup_class(cls):
- cls.extra_links_expected_calls = [
- call.ti.xcom_push(execution_date=None, key="conf",
value=DATAPROC_JOB_CONF_EXPECTED),
- call.hook().wait_for_job(job_id=TEST_JOB_ID, region=GCP_REGION,
project_id=GCP_PROJECT),
- ]
+ if AIRFLOW_V_3_0_PLUS:
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_JOB_CONF_EXPECTED),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID,
region=GCP_REGION, project_id=GCP_PROJECT),
+ ]
+ else:
+ cls.extra_links_expected_calls = [
+ call.ti.xcom_push(key="conf",
value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None),
+ call.hook().wait_for_job(job_id=TEST_JOB_ID,
region=GCP_REGION, project_id=GCP_PROJECT),
+ ]
main_class = "org.apache.spark.examples.SparkPi"
jars = ["file:///usr/lib/spark/examples/jars/spark-examples.jar"]
@@ -2446,9 +2503,12 @@ class TestDataProcSparkOperator(DataprocJobTestBase):
assert self.job == job
op.execute(context=self.mock_context)
- self.mock_ti.xcom_push.assert_called_once_with(
- key="conf", value=DATAPROC_JOB_CONF_EXPECTED, execution_date=None
- )
+ if AIRFLOW_V_3_0_PLUS:
+ self.mock_ti.xcom_push.assert_called_once_with(key="conf",
value=DATAPROC_JOB_CONF_EXPECTED)
+ else:
+ self.mock_ti.xcom_push.assert_called_once_with(
+ key="conf", value=DATAPROC_JOB_CONF_EXPECTED,
execution_date=None
+ )
# Test whether xcom push occurs before polling for job
self.extra_links_manager_mock.assert_has_calls(self.extra_links_expected_calls,
any_order=False)
diff --git a/tests/providers/microsoft/conftest.py
b/tests/providers/microsoft/conftest.py
index c77dd7747d..8a25873529 100644
--- a/tests/providers/microsoft/conftest.py
+++ b/tests/providers/microsoft/conftest.py
@@ -111,8 +111,6 @@ def mock_response(status_code, content: Any = None,
headers: dict | None = None)
def mock_context(task) -> Context:
- from datetime import datetime
-
from airflow.models import TaskInstance
from airflow.utils.session import NEW_SESSION
from airflow.utils.state import TaskInstanceState
@@ -146,13 +144,7 @@ def mock_context(task) -> Context:
return values.get(f"{task_ids or self.task_id}_{dag_id or
self.dag_id}_{key}_{map_indexes}")
return values.get(f"{task_ids or self.task_id}_{dag_id or
self.dag_id}_{key}")
- def xcom_push(
- self,
- key: str,
- value: Any,
- execution_date: datetime | None = None,
- session: Session = NEW_SESSION,
- ) -> None:
+ def xcom_push(self, key: str, value: Any, session: Session =
NEW_SESSION, **kwargs) -> None:
values[f"{self.task_id}_{self.dag_id}_{key}_{self.map_index}"] =
value
values["ti"] = MockedTaskInstance(task=task)
diff --git a/tests/providers/yandex/links/test_yq.py
b/tests/providers/yandex/links/test_yq.py
index 06f1e83939..d46862f1c7 100644
--- a/tests/providers/yandex/links/test_yq.py
+++ b/tests/providers/yandex/links/test_yq.py
@@ -23,6 +23,7 @@ import pytest
from airflow.models.taskinstance import TaskInstance
from airflow.models.xcom import XCom
from airflow.providers.yandex.links.yq import YQLink
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
from tests.test_utils.mock_operators import MockOperator
yandexcloud = pytest.importorskip("yandexcloud")
@@ -34,11 +35,13 @@ def test_persist():
YQLink.persist(context=mock_context,
task_instance=MockOperator(task_id="test_task_id"), web_link="g.com")
ti = mock_context["ti"]
- ti.xcom_push.assert_called_once_with(
- execution_date=None,
- key="web_link",
- value="g.com",
- )
+ if AIRFLOW_V_3_0_PLUS:
+ ti.xcom_push.assert_called_once_with(
+ key="web_link",
+ value="g.com",
+ )
+ else:
+ ti.xcom_push.assert_called_once_with(key="web_link", value="g.com",
execution_date=None)
def test_default_link():
diff --git a/tests/providers/yandex/operators/test_yq.py
b/tests/providers/yandex/operators/test_yq.py
index e342a2f961..034f050551 100644
--- a/tests/providers/yandex/operators/test_yq.py
+++ b/tests/providers/yandex/operators/test_yq.py
@@ -22,6 +22,8 @@ from unittest.mock import MagicMock, call, patch
import pytest
+from tests.test_utils.compat import AIRFLOW_V_3_0_PLUS
+
yandexcloud = pytest.importorskip("yandexcloud")
import responses
@@ -89,15 +91,25 @@ class TestYQExecuteQueryOperator:
results = operator.execute(context)
assert results == {"rows": [[777]], "columns": [{"name": "column0",
"type": "Int32"}]}
- context["ti"].xcom_push.assert_has_calls(
- [
- call(
- key="web_link",
-
value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1",
- execution_date=None,
- ),
- ]
- )
+ if AIRFLOW_V_3_0_PLUS:
+ context["ti"].xcom_push.assert_has_calls(
+ [
+ call(
+ key="web_link",
+
value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1",
+ ),
+ ]
+ )
+ else:
+ context["ti"].xcom_push.assert_has_calls(
+ [
+ call(
+ key="web_link",
+
value=f"https://yq.cloud.yandex.ru/folders/{FOLDER_ID}/ide/queries/query1",
+ execution_date=None,
+ ),
+ ]
+ )
responses.get(
"https://api.yandex-query.cloud.yandex.net/api/fq/v1/queries/query1/status",
diff --git a/tests/test_utils/compat.py b/tests/test_utils/compat.py
index b09973903b..5daf429cf6 100644
--- a/tests/test_utils/compat.py
+++ b/tests/test_utils/compat.py
@@ -46,6 +46,7 @@ AIRFLOW_V_2_7_PLUS = Version(AIRFLOW_VERSION.base_version) >=
Version("2.7.0")
AIRFLOW_V_2_8_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.8.0")
AIRFLOW_V_2_9_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("2.9.0")
AIRFLOW_V_2_10_PLUS = Version(AIRFLOW_VERSION.base_version) >=
Version("2.10.0")
+AIRFLOW_V_3_0_PLUS = Version(AIRFLOW_VERSION.base_version) >= Version("3.0.0")
try:
from airflow.models.baseoperatorlink import BaseOperatorLink