anishgirianish commented on code in PR #62343:
URL: https://github.com/apache/airflow/pull/62343#discussion_r3177844042
##########
airflow-core/src/airflow/jobs/scheduler_job_runner.py:
##########
@@ -3189,6 +3208,88 @@ def _activate_assets_generate_warnings() ->
Iterator[tuple[str, str]]:
session.add(warning)
existing_warned_dag_ids.add(warning.dag_id)
+ def _enqueue_connection_tests(self, *, session: Session) -> None:
+ """Enqueue pending connection tests to executors that support them."""
+ max_concurrency = conf.getint("scheduler",
"max_connection_test_concurrency", fallback=4)
+ timeout = conf.getint("scheduler", "connection_test_timeout",
fallback=60)
+
+ num_occupied_slots = sum(executor.slots_occupied for executor in
self.executors)
+ parallelism_budget = conf.getint("core", "parallelism") -
num_occupied_slots
+ if parallelism_budget <= 0:
+ return
+
+ active_count = session.scalar(
+ select(func.count(ConnectionTestRequest.id)).where(
+ ConnectionTestRequest.state.in_(DISPATCHED_STATES)
+ )
+ )
+ concurrency_budget = max_concurrency - (active_count or 0)
+ budget = min(concurrency_budget, parallelism_budget)
+ if budget <= 0:
+ return
+
+ pending_stmt = (
+ select(ConnectionTestRequest)
+ .where(ConnectionTestRequest.state == ConnectionTestState.PENDING)
+ .order_by(ConnectionTestRequest.created_at)
+ .limit(budget)
+ )
+ pending_stmt = with_row_locks(pending_stmt, session,
of=ConnectionTestRequest, skip_locked=True)
+ pending_tests = session.scalars(pending_stmt).all()
+
+ if not pending_tests:
+ return
+
+ for ct in pending_tests:
+ executor = self._try_to_load_executor(ct, session)
+ if executor is None:
+ reason = f"No executor matches '{ct.executor}'"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id,
reason)
+ continue
+ if not executor.supports_connection_test:
+ exec_name = executor.name
+ name = ct.executor or (exec_name and (exec_name.alias or
exec_name.module_path))
+ reason = f"Executor '{name}' does not support connection
testing"
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = reason
+ self.log.warning("Failing connection test %s: %s", ct.id,
reason)
+ continue
+
+ workload = workloads.TestConnection.make(
+ connection_test_id=ct.id,
+ connection_id=ct.connection_id,
+ timeout=timeout,
+ queue=ct.queue,
+ generator=executor.jwt_generator,
+ )
+ executor.queue_workload(workload, session=session)
+ ct.state = ConnectionTestState.QUEUED
+
+ session.flush()
+
+ @provide_session
+ def _reap_stale_connection_tests(self, *, session: Session = NEW_SESSION)
-> None:
+ """Mark connection tests that have exceeded their timeout as FAILED."""
+ timeout = conf.getint("scheduler", "connection_test_timeout",
fallback=60)
+ grace_period = max(30, timeout // 2)
+ cutoff = timezone.utcnow() - timedelta(seconds=timeout + grace_period)
+
+ stale_stmt = select(ConnectionTestRequest).where(
+ ConnectionTestRequest.state.in_(ACTIVE_STATES),
+ ConnectionTestRequest.updated_at < cutoff,
+ )
+ stale_stmt = with_row_locks(stale_stmt, session,
of=ConnectionTestRequest, skip_locked=True)
+ stale_tests = session.scalars(stale_stmt).all()
+
+ for ct in stale_tests:
+ ct.state = ConnectionTestState.FAILED
+ ct.result_message = f"Connection test timed out (exceeded
{timeout}s + {grace_period}s grace)"
+ self.log.warning("Reaped stale connection test %s", ct.id)
Review Comment:
Done, added BaseExecutor.fail_connection_test(key); reaper calls it for
every supporting executor after marking FAILED. thank you
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]