kaxil commented on code in PR #52234:
URL: https://github.com/apache/airflow/pull/52234#discussion_r2186076518


##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1607,6 +1615,216 @@ def resume_execution(self, next_method: str, 
next_kwargs: dict[str, Any] | None,
         execute_callable = getattr(self, next_method)
         return execute_callable(context, **next_kwargs)
 
+    def dry_run(self) -> None:
+        """Perform dry run for the operator - just render template fields."""
+        self.log.info("Dry run")
+        for f in self.template_fields:
+            try:
+                content = getattr(self, f)
+            except AttributeError:
+                raise AttributeError(
+                    f"{f!r} is configured as a template field "
+                    f"but {self.task_type} does not have this attribute."
+                )
+
+            if content and isinstance(content, str):
+                self.log.info("Rendering template for %s", f)
+                self.log.info(content)
+
+    # TODO (GH-52141): Either port this, or somehow fix the tests to remove 
this from the sdk.
+    @staticmethod
+    def xcom_push(
+        context: Any,
+        key: str,
+        value: Any,
+    ) -> None:
+        """
+        Make an XCom available for tasks to pull.
+
+        :param context: Execution Context Dictionary
+        :param key: A key for the XCom
+        :param value: A value for the XCom. The value is pickled and stored
+            in the database.
+        """
+        context["ti"].xcom_push(key=key, value=value)
+
+    # TODO (GH-52141): Either port this, or somehow fix the tests to remove 
this from the sdk.
+    @staticmethod
+    def xcom_pull(
+        context: Any,
+        task_ids: str | list[str] | None = None,
+        dag_id: str | None = None,
+        key: str = "return_value",
+        include_prior_dates: bool | None = None,
+        session=None,
+    ) -> Any:
+        """
+        Pull XComs that optionally meet certain criteria.
+
+        The default value for `key` limits the search to XComs
+        that were returned by other tasks (as opposed to those that were pushed
+        manually). To remove this filter, pass key=None (or any desired value).
+
+        If a single task_id string is provided, the result is the value of the
+        most recent matching XCom from that task_id. If multiple task_ids are
+        provided, a tuple of matching values is returned. None is returned
+        whenever no matches are found.
+
+        :param context: Execution Context Dictionary
+        :param key: A key for the XCom. If provided, only XComs with matching
+            keys will be returned. The default key is 'return_value', also
+            available as a constant XCOM_RETURN_KEY. This key is automatically
+            given to XComs returned by tasks (as opposed to being pushed
+            manually). To remove the filter, pass key=None.
+        :param task_ids: Only XComs from tasks with matching ids will be
+            pulled. Can pass None to remove the filter.
+        :param dag_id: If provided, only pulls XComs from this DAG.
+            If None (default), the DAG of the calling task is used.
+        :param include_prior_dates: If False, only XComs from the current
+            logical_date are returned. If True, XComs from previous dates
+            are returned as well.
+        """
+        from airflow.settings import Session
+
+        if session is None:
+            session = Session()
+        return context["ti"].xcom_pull(
+            key=key,
+            task_ids=task_ids,
+            dag_id=dag_id,
+            include_prior_dates=include_prior_dates,
+            session=session,
+        )
+
+    # TODO (GH-52141): Either port this, or somehow fix the tests to remove 
this from the sdk.
+    def run(
+        self,
+        start_date: datetime | None = None,
+        end_date: datetime | None = None,
+        ignore_first_depends_on_past: bool = True,
+        wait_for_past_depends_before_skipping: bool = False,
+        ignore_ti_state: bool = False,
+        mark_success: bool = False,
+        test_mode: bool = False,
+        session=None,
+    ) -> None:
+        """Run a set of task instances for a date range."""
+        import pendulum
+        from sqlalchemy import select
+        from sqlalchemy.exc import NoResultFound
+
+        from airflow.models.dagrun import DagRun
+        from airflow.models.taskinstance import TaskInstance
+        from airflow.settings import Session
+        from airflow.utils.state import DagRunState
+        from airflow.utils.types import DagRunTriggeredByType, DagRunType
+
+        if session is None:
+            session = Session()
+
+        # Assertions for typing -- we need a dag, for this function, and when 
we have a DAG we are
+        # _guaranteed_ to have start_date (else we couldn't have been added to 
a DAG)
+        if TYPE_CHECKING:
+            from airflow.models.dag import DAG as SchedulerDAG
+
+            assert session
+            assert self.start_date
+
+            # TODO: Task-SDK: We need to set this to the scheduler DAG until 
we fully separate scheduling and
+            # definition code
+            assert isinstance(self.dag, SchedulerDAG)
+
+        start_date = pendulum.instance(start_date or self.start_date)
+        end_date = pendulum.instance(end_date or self.end_date or 
timezone.utcnow())
+
+        for info in self.dag.iter_dagrun_infos_between(start_date, end_date, 
align=False):
+            ignore_depends_on_past = info.logical_date == start_date and 
ignore_first_depends_on_past
+            try:
+                dag_run = session.scalars(
+                    select(DagRun).where(
+                        DagRun.dag_id == self.dag_id,
+                        DagRun.logical_date == info.logical_date,
+                    )
+                ).one()
+                ti = TaskInstance(self, run_id=dag_run.run_id)
+            except NoResultFound:
+                # This is _mostly_ only used in tests
+                dr = DagRun(
+                    dag_id=self.dag_id,
+                    run_id=DagRun.generate_run_id(
+                        run_type=DagRunType.MANUAL,
+                        logical_date=info.logical_date,
+                        run_after=info.run_after,
+                    ),
+                    run_type=DagRunType.MANUAL,
+                    logical_date=info.logical_date,
+                    data_interval=info.data_interval,
+                    run_after=info.run_after,
+                    triggered_by=DagRunTriggeredByType.TEST,
+                    state=DagRunState.RUNNING,
+                )
+                ti = TaskInstance(self, run_id=dr.run_id)
+                ti.dag_run = dr
+                session.add(dr)
+                session.flush()
+
+            ti.run(
+                mark_success=mark_success,
+                ignore_depends_on_past=ignore_depends_on_past,
+                
wait_for_past_depends_before_skipping=wait_for_past_depends_before_skipping,
+                ignore_ti_state=ignore_ti_state,
+                test_mode=test_mode,
+                session=session,
+            )
+
+    # TODO (GH-52141): Either port this, or somehow fix the tests to remove 
this from the sdk.
+    def clear(

Review Comment:
   ok tests have been fixed as part of 
https://github.com/apache/airflow/pull/52888



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@airflow.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org

Reply via email to