codeant-ai-for-open-source[bot] commented on code in PR #36368:
URL: https://github.com/apache/superset/pull/36368#discussion_r2733260391


##########
superset/tasks/manager.py:
##########
@@ -0,0 +1,585 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Task manager for the Global Task Framework (GTF)"""
+
+from __future__ import annotations
+
+import logging
+import threading
+from typing import Any, Callable, TYPE_CHECKING
+
+import redis
+from redis.sentinel import Sentinel
+from superset_core.api.tasks import TaskProperties, TaskScope, TaskStatus
+
+from superset.commands.tasks.exceptions import TaskCreateFailedError
+from superset.tasks.utils import generate_random_task_key
+
+if TYPE_CHECKING:
+    from flask import Flask
+
+    from superset.models.tasks import Task
+
+logger = logging.getLogger(__name__)
+
+
+class AbortListener:
+    """
+    Handle for a background abort listener.
+
+    Returned by TaskManager.listen_for_abort() to allow stopping the listener.
+    """
+
+    def __init__(
+        self,
+        task_uuid: str,
+        thread: threading.Thread,
+        stop_event: threading.Event,
+        pubsub: redis.client.PubSub | None = None,
+    ) -> None:
+        self._task_uuid = task_uuid
+        self._thread = thread
+        self._stop_event = stop_event
+        self._pubsub = pubsub
+
+    def stop(self) -> None:
+        """Stop the abort listener."""
+        self._stop_event.set()
+
+        # Close pub/sub subscription if active
+        if self._pubsub is not None:
+            try:
+                self._pubsub.unsubscribe()
+                self._pubsub.close()
+            except Exception as ex:
+                logger.debug("Error closing pub/sub during stop: %s", ex)
+
+        # Wait for thread to finish
+        if self._thread.is_alive():
+            self._thread.join(timeout=2.0)
+
+        logger.debug("Stopped abort listener for task %s", self._task_uuid)
+
+
+class TaskManager:
+    """
+    Handles task creation, scheduling, and abort notifications.
+
+    The TaskManager is responsible for:
+    1. Creating task entries in the metastore (Task model)
+    2. Scheduling task execution via Celery
+    3. Handling deduplication (returning existing active task if duplicate)
+    4. Managing real-time abort notifications (optional)
+
+    Redis pub/sub is opt-in via TASKS_BACKEND configuration. When not 
configured,
+    tasks fall back to database polling for abort detection.
+    """
+
+    # Class-level Redis state (initialized once via init_app)
+    _redis: redis.Redis[Any] | None = None
+    _channel_prefix: str = "gtf:abort:"
+    _config: dict[str, Any] | None = None
+    _initialized: bool = False
+
+    @classmethod
+    def init_app(cls, app: Flask) -> None:
+        """
+        Initialize the TaskManager with Flask app config.
+
+        Sets up Redis connection for pub/sub abort notifications if configured.
+
+        :param app: Flask application instance
+        """
+        if cls._initialized:
+            return
+
+        cls._config = app.config.get("TASKS_BACKEND")
+        cls._channel_prefix = app.config.get("TASKS_ABORT_CHANNEL_PREFIX", 
"gtf:abort:")
+
+        if cls._config is None:
+            logger.info(
+                "TASKS_BACKEND not configured, using database polling for 
abort"
+            )
+            cls._initialized = True
+            return
+
+        cache_type = cls._config.get("CACHE_TYPE")
+
+        if cache_type == "RedisCache":
+            cls._init_redis()
+        elif cache_type == "RedisSentinelCache":
+            cls._init_redis_sentinel()
+        else:
+            logger.warning(
+                "Unsupported TASKS_BACKEND cache type: %s, falling back to 
polling",
+                cache_type,
+            )
+
+        cls._initialized = True
+
+    @classmethod
+    def _init_redis(cls) -> None:
+        """Initialize standard Redis connection."""
+        if cls._config is None:
+            return
+
+        kwargs: dict[str, Any] = {
+            "host": cls._config.get("CACHE_REDIS_HOST", "localhost"),
+            "port": cls._config.get("CACHE_REDIS_PORT", 6379),
+            "db": cls._config.get("CACHE_REDIS_DB", 0),
+            "password": cls._config.get("CACHE_REDIS_PASSWORD"),
+            "decode_responses": True,
+        }
+
+        # Add username if provided
+        if configured_username := cls._config.get("CACHE_REDIS_USER"):
+            kwargs["username"] = configured_username
+
+        # Add SSL options if configured
+        if cls._config.get("CACHE_REDIS_SSL"):
+            kwargs["ssl"] = True
+            if ssl_certfile := cls._config.get("CACHE_REDIS_SSL_CERTFILE"):
+                kwargs["ssl_certfile"] = ssl_certfile
+            if ssl_keyfile := cls._config.get("CACHE_REDIS_SSL_KEYFILE"):
+                kwargs["ssl_keyfile"] = ssl_keyfile
+            if ssl_cert_reqs := cls._config.get("CACHE_REDIS_SSL_CERT_REQS"):
+                kwargs["ssl_cert_reqs"] = ssl_cert_reqs
+            if ssl_ca_certs := cls._config.get("CACHE_REDIS_SSL_CA_CERTS"):
+                kwargs["ssl_ca_certs"] = ssl_ca_certs
+
+        try:
+            cls._redis = redis.Redis(**kwargs)
+            # Test connection
+            cls._redis.ping()
+            logger.info("Initialized Redis backend for GTF abort pub/sub")
+        except redis.ConnectionError as ex:
+            logger.warning(
+                "Failed to connect to Redis for GTF pub/sub: %s. "
+                "Falling back to database polling.",
+                ex,
+            )
+            cls._redis = None
+
+    @classmethod
+    def _init_redis_sentinel(cls) -> None:
+        """Initialize Redis Sentinel connection."""
+        if cls._config is None:
+            return
+
+        sentinels = cls._config.get("CACHE_REDIS_SENTINELS", [("localhost", 
26379)])
+        master_name = cls._config.get("CACHE_REDIS_SENTINEL_MASTER", 
"mymaster")
+
+        try:
+            sentinel = Sentinel(
+                sentinels,
+                sentinel_kwargs={
+                    "password": 
cls._config.get("CACHE_REDIS_SENTINEL_PASSWORD"),
+                },
+            )
+
+            # Prepare master connection kwargs
+            master_kwargs: dict[str, Any] = {
+                "password": cls._config.get("CACHE_REDIS_PASSWORD"),
+                "db": cls._config.get("CACHE_REDIS_DB", 0),
+                "decode_responses": True,
+            }
+
+            # Add SSL options if configured
+            if cls._config.get("CACHE_REDIS_SSL"):
+                master_kwargs["ssl"] = True
+                if ssl_certfile := cls._config.get("CACHE_REDIS_SSL_CERTFILE"):
+                    master_kwargs["ssl_certfile"] = ssl_certfile
+                if ssl_keyfile := cls._config.get("CACHE_REDIS_SSL_KEYFILE"):
+                    master_kwargs["ssl_keyfile"] = ssl_keyfile
+                if ssl_cert_reqs := 
cls._config.get("CACHE_REDIS_SSL_CERT_REQS"):
+                    master_kwargs["ssl_cert_reqs"] = ssl_cert_reqs
+                if ssl_ca_certs := cls._config.get("CACHE_REDIS_SSL_CA_CERTS"):
+                    master_kwargs["ssl_ca_certs"] = ssl_ca_certs
+
+            cls._redis = sentinel.master_for(master_name, **master_kwargs)
+            # Test connection
+            if cls._redis is not None:
+                cls._redis.ping()
+            logger.info("Initialized Redis Sentinel backend for GTF abort 
pub/sub")
+        except (redis.ConnectionError, redis.sentinel.MasterNotFoundError) as 
ex:
+            logger.warning(
+                "Failed to connect to Redis Sentinel for GTF pub/sub: %s. "
+                "Falling back to database polling.",
+                ex,
+            )
+            cls._redis = None
+
+    @classmethod
+    def is_pubsub_available(cls) -> bool:
+        """
+        Check if Redis pub/sub backend is configured and available.
+
+        :returns: True if Redis is available for pub/sub, False otherwise
+        """
+        return cls._redis is not None
+
+    @classmethod
+    def get_abort_channel(cls, task_uuid: str) -> str:
+        """
+        Get the abort channel name for a task.
+
+        :param task_uuid: UUID of the task
+        :returns: Channel name for the task's abort notifications
+        """
+        return f"{cls._channel_prefix}{task_uuid}"
+
+    @classmethod
+    def publish_abort(cls, task_uuid: str) -> bool:
+        """
+        Publish an abort message to the task's channel.
+
+        :param task_uuid: UUID of the task to abort
+        :returns: True if message was published, False if Redis unavailable
+        """
+        if not cls._redis:
+            return False
+
+        try:
+            channel = cls.get_abort_channel(task_uuid)
+            subscriber_count = cls._redis.publish(channel, "abort")
+            logger.debug(
+                "Published abort to channel %s (%d subscribers)",
+                channel,
+                subscriber_count,
+            )
+            return True
+        except redis.RedisError as ex:
+            logger.error("Failed to publish abort for task %s: %s", task_uuid, 
ex)
+            return False
+
+    @classmethod
+    def listen_for_abort(
+        cls,
+        task_uuid: str,
+        callback: Callable[[], None],
+        poll_interval: float,
+        app: Any = None,
+    ) -> AbortListener:
+        """
+        Start listening for abort notifications for a task.
+
+        Uses Redis pub/sub if available, otherwise falls back to database 
polling.
+        The callback is invoked when an abort is detected.
+
+        :param task_uuid: UUID of the task to monitor
+        :param callback: Function to call when abort is detected
+        :param poll_interval: Interval for database polling (fallback mode)
+        :param app: Flask app for database access in background thread
+        :returns: AbortListener handle to stop listening
+        """
+        stop_event = threading.Event()
+        pubsub: redis.client.PubSub | None = None
+
+        # Try to set up Redis pub/sub
+        if cls._redis is not None:
+            try:
+                pubsub = cls._redis.pubsub()
+                channel = cls.get_abort_channel(task_uuid)
+                pubsub.subscribe(channel)
+                logger.debug("Subscribed to abort channel: %s", channel)
+            except redis.RedisError as ex:
+                logger.warning(
+                    "Failed to subscribe to Redis for task %s: %s. Using 
polling.",
+                    task_uuid,
+                    ex,
+                )
+                pubsub = None
+
+        if pubsub is not None:
+            # Start pub/sub listener thread
+            thread = threading.Thread(
+                target=cls._listen_pubsub,
+                args=(task_uuid, pubsub, callback, stop_event, poll_interval, 
app),
+                daemon=True,
+                name=f"abort-listener-{task_uuid[:8]}",
+            )
+            logger.info("Started pub/sub abort listener for task %s", 
task_uuid)
+        else:
+            # Start polling thread (fallback)
+            thread = threading.Thread(
+                target=cls._poll_for_abort,
+                args=(task_uuid, callback, stop_event, poll_interval, app),
+                daemon=True,
+                name=f"abort-poller-{task_uuid[:8]}",
+            )
+            logger.info(
+                "Started database abort polling for task %s (interval=%ss)",
+                task_uuid,
+                poll_interval,
+            )
+
+        thread.start()
+        return AbortListener(task_uuid, thread, stop_event, pubsub)
+
+    @classmethod
+    def _listen_pubsub(  # noqa: C901
+        cls,
+        task_uuid: str,
+        pubsub: redis.client.PubSub,
+        callback: Callable[[], None],
+        stop_event: threading.Event,
+        fallback_interval: float,
+        app: Any,
+    ) -> None:
+        """
+        Listen for abort via Redis pub/sub.
+
+        If pub/sub connection fails, falls back to database polling.
+        """
+        try:
+            while not stop_event.is_set():
+                message = pubsub.get_message(
+                    ignore_subscribe_messages=True, timeout=1.0
+                )
+
+                if message is not None and message.get("type") == "message":
+                    # Abort message received
+                    logger.info(
+                        "Abort received via pub/sub for task %s",
+                        task_uuid,
+                    )
+                    # Invoke callback with app context if provided
+                    if app:
+                        with app.app_context():
+                            callback()
+                    else:
+                        callback()
+                    break
+
+        except redis.RedisError as ex:
+            # Check if we were asked to stop - if so, this is expected
+            if stop_event.is_set():
+                logger.debug(
+                    "Abort listener for task %s stopped (Redis error: %s)",
+                    task_uuid,
+                    ex,
+                )
+            else:
+                logger.warning(
+                    "Redis pub/sub error for task %s: %s. Falling back to 
polling.",
+                    task_uuid,
+                    ex,
+                )
+                # Fall back to database polling on pub/sub failure
+                cls._poll_for_abort(
+                    task_uuid, callback, stop_event, fallback_interval, app
+                )
+
+        except (ValueError, OSError) as ex:
+            # ValueError: "I/O operation on closed file" - expected when 
stop() closes
+            # OSError: Similar connection-closed errors
+            if stop_event.is_set():
+                # Clean shutdown, expected behavior
+                logger.debug(
+                    "Abort listener for task %s stopped cleanly",
+                    task_uuid,
+                )
+            else:
+                # Unexpected error while running
+                logger.error(
+                    "Error in abort listener for task %s: %s",
+                    task_uuid,
+                    str(ex),
+                    exc_info=True,
+                )
+
+        except Exception as ex:
+            # Only log as error if we weren't asked to stop
+            if stop_event.is_set():
+                logger.debug(
+                    "Abort listener for task %s stopped with exception: %s",
+                    task_uuid,
+                    ex,
+                )
+            else:
+                logger.error(
+                    "Error in abort listener for task %s: %s",
+                    task_uuid,
+                    str(ex),
+                    exc_info=True,
+                )
+
+        finally:
+            # Clean up pub/sub subscription
+            try:
+                pubsub.unsubscribe()
+                pubsub.close()
+            except Exception as ex:
+                logger.debug("Error closing pub/sub during cleanup: %s", ex)
+
+    @classmethod
+    def _poll_for_abort(
+        cls,
+        task_uuid: str,
+        callback: Callable[[], None],
+        stop_event: threading.Event,
+        interval: float,
+        app: Any,
+    ) -> None:
+        """Background polling loop - fallback when pub/sub is unavailable."""
+        # Lazy import to avoid circular dependencies
+        from superset.daos.tasks import TaskDAO
+
+        while not stop_event.is_set():
+            try:
+                # Wrap database access in Flask app context
+                if app:
+                    with app.app_context():
+                        task = TaskDAO.find_one_or_none(uuid=task_uuid)
+                        if task and task.status in [
+                            TaskStatus.ABORTING.value,
+                            TaskStatus.ABORTED.value,
+                        ]:
+                            logger.info(
+                                "Abort detected via polling for task %s 
(status=%s)",
+                                task_uuid,
+                                task.status,
+                            )
+                            callback()
+                            break
+                else:
+                    # Fallback without app context (e.g., in tests)
+                    task = TaskDAO.find_one_or_none(uuid=task_uuid)
+                    if task and task.status in [
+                        TaskStatus.ABORTING.value,
+                        TaskStatus.ABORTED.value,
+                    ]:
+                        logger.info(
+                            "Abort detected via polling for task %s 
(status=%s)",
+                            task_uuid,
+                            task.status,
+                        )
+                        callback()
+                        break
+
+                # Wait for interval or until stop is requested
+                stop_event.wait(timeout=interval)
+
+            except Exception as ex:
+                logger.error(
+                    "Error in abort polling for task %s: %s",
+                    task_uuid,
+                    str(ex),
+                    exc_info=True,
+                )
+                break
+
+    @staticmethod
+    def submit_task(
+        task_type: str,
+        task_key: str | None,
+        task_name: str | None,
+        scope: TaskScope,
+        timeout: int | None,
+        args: tuple[Any, ...],
+        kwargs: dict[str, Any],
+    ) -> "Task":
+        """
+        Create task entry and schedule for async execution.
+
+        Flow:
+        1. Generate task_id if not provided (random UUID)
+        2. Create Task record in metastore (with PENDING status)
+        3. If duplicate active task exists, return it instead
+        4. Submit to Celery for background execution
+        5. Return Task model to caller
+
+        :param task_type: Task type identifier (e.g., 
"superset.generate_thumbnail")
+        :param task_key: Optional deduplication key (None for random UUID)
+        :param task_name: Human readable task name
+        :param scope: Task scope (TaskScope.PRIVATE, SHARED, or SYSTEM)
+        :param timeout: Optional timeout in seconds
+        :param args: Positional arguments for the task function
+        :param kwargs: Keyword arguments for the task function
+        :returns: Task model representing the scheduled task
+        """
+        if task_key is None:
+            task_key = generate_random_task_key()
+
+        # Build properties with timeout if configured
+        properties: TaskProperties | None = {"timeout": timeout} if timeout 
else None
+
+        try:
+            # Create task entry in metastore
+            # Command automatically extracts current user for subscription
+            # Lazy import to avoid circular dependency
+            from superset.commands.tasks.create import CreateTaskCommand
+
+            task = CreateTaskCommand(
+                {
+                    "task_key": task_key,
+                    "task_type": task_type,
+                    "task_name": task_name,
+                    "scope": scope.value,
+                    "properties": properties,
+                }
+            ).run()
+
+            # Import here to avoid circular dependency
+            from superset.tasks.scheduler import execute_task
+
+            # Schedule Celery task for async execution
+            execute_task.delay(
+                task_uuid=task.uuid,
+                task_type=task_type,
+                args=args,
+                kwargs=kwargs,
+            )
+
+            logger.info(
+                "Scheduled task %s (uuid=%s) for async execution",
+                task_type,
+                task.uuid,
+            )
+
+            return task
+
+        except TaskCreateFailedError:
+            # Task with same task_key already exists and is active
+            # Return existing task instead of creating duplicate
+            # Lazy import to avoid circular dependency
+            from superset.daos.tasks import TaskDAO
+
+            existing = TaskDAO.find_by_task_key(task_type, task_key, 
scope.value)
+            if existing:
+                logger.info(
+                    "Task %s with key '%s' and scope '%s' already exists 
(uuid=%s), "
+                    "returning existing task",
+                    task_type,
+                    task_key,
+                    scope.value,
+                    existing.uuid,
+                )
+                return existing
+
+            # Race condition: task completed between check and here
+            # Try again to create new task
+            logger.warning(
+                "Race condition detected for task %s with key '%s' and "
+                "scope '%s', retrying",
+                task_type,
+                task_key,
+                scope.value,
+            )
+            return TaskManager.submit_task(
+                task_type, task_key, task_name, scope, timeout, args, kwargs
+            )

Review Comment:
   **Suggestion:** Recursive retry on task-creation failure can lead to 
unbounded recursion: when CreateTaskCommand raises TaskCreateFailedError and no 
existing task is found, the code calls TaskManager.submit_task(...) recursively 
with no backoff or retry limit, which can cause infinite recursion / 
RecursionError under persistent races; replace the recursive retry with a 
deterministic failure (or a bounded retry loop). [race condition]
   
   <details>
   <summary><b>Severity Level:</b> Critical 🚨</summary>
   
   ```mdx
   - ❌ Task creation may crash process via RecursionError.
   - ⚠️ Background task scheduling may fail silently.
   - ⚠️ Affects TaskManager.submit_task flows (task scheduling).
   ```
   </details>
   
   ```suggestion
               # Do not recurse indefinitely; surface failure so caller can 
handle retry if desired
               logger.error(
                   "Race condition detected for task %s with key '%s' and scope 
'%s' - failing task creation",
                   task_type,
                   task_key,
                   scope.value,
               )
               raise TaskCreateFailedError()
   ```
   <details>
   <summary><b>Steps of Reproduction ✅ </b></summary>
   
   ```mdx
   1. Trigger submission of a global task by calling TaskManager.submit_task() 
at
   superset/tasks/manager.py:486. This instantiates CreateTaskCommand and calls 
.run()
   (superset/tasks/manager.py:525-536).
   
   2. Inside CreateTaskCommand.run (superset/commands/tasks/create.py:38-86) a 
DAO create
   failure occurs (DAOCreateFailedError), causing CreateTaskCommand to raise
   TaskCreateFailedError which is caught by the except in 
superset/tasks/manager.py:556.
   
   3. The except block calls TaskDAO.find_by_task_key(...) at 
superset/tasks/manager.py:562
   to see if an existing active task exists. If TaskDAO.find_by_task_key 
returns None (no
   existing task found), the current code logs and immediately calls
   TaskManager.submit_task(...) again (superset/tasks/manager.py:576-585).
   
   4. If the underlying create keeps failing (for example persistent 
constraint/DB error or a
   repeat race), each retry will re-enter the same except branch and recurse 
again. With
   repeated persistent failures this recursion grows until Python raises 
RecursionError or
   the process stack overflows. This reproduces deterministically by simulating 
a persistent
   DAO create failure during TaskManager.submit_task (see
   superset/commands/tasks/create.py:38-86 and 
superset/tasks/manager.py:556-585).
   
   Note: The pattern is intentional in that it attempts a retry on a race, but 
unbounded
   recursion is dangerous; replacing recursion with a bounded retry or 
surfacing the error
   avoids RecursionError.
   ```
   </details>
   <details>
   <summary><b>Prompt for AI Agent 🤖 </b></summary>
   
   ```mdx
   This is a comment left during a code review.
   
   **Path:** superset/tasks/manager.py
   **Line:** 575:585
   **Comment:**
        *Race Condition: Recursive retry on task-creation failure can lead to 
unbounded recursion: when CreateTaskCommand raises TaskCreateFailedError and no 
existing task is found, the code calls TaskManager.submit_task(...) recursively 
with no backoff or retry limit, which can cause infinite recursion / 
RecursionError under persistent races; replace the recursive retry with a 
deterministic failure (or a bounded retry loop).
   
   Validate the correctness of the flagged issue. If correct, How can I resolve 
this? If you propose a fix, implement it and please make it concise.
   ```
   </details>



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to