This is an automated email from the ASF dual-hosted git repository.
villebro pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 0f1278fa614 fix(gtf): set dedup_key on atomic sql (#37820)
0f1278fa614 is described below
commit 0f1278fa61419f322930612951d189e98dbb2173
Author: Ville Brofeldt <[email protected]>
AuthorDate: Tue Feb 10 06:56:14 2026 -0800
fix(gtf): set dedup_key on atomic sql (#37820)
---
docs/developer_portal/extensions/tasks.md | 8 +--
superset-core/src/superset_core/api/tasks.py | 2 +-
superset/daos/tasks.py | 10 ++-
superset/models/tasks.py | 15 +---
superset/tasks/manager.py | 11 ++-
superset/tasks/schemas.py | 4 ++
.../integration_tests/tasks/test_sync_join_wait.py | 15 ----
tests/unit_tests/daos/test_tasks.py | 83 ++++++++++++++++++++++
tests/unit_tests/tasks/test_manager.py | 5 --
9 files changed, 106 insertions(+), 47 deletions(-)
diff --git a/docs/developer_portal/extensions/tasks.md
b/docs/developer_portal/extensions/tasks.md
index 25b6569b4ad..c833cd93680 100644
--- a/docs/developer_portal/extensions/tasks.md
+++ b/docs/developer_portal/extensions/tasks.md
@@ -50,7 +50,7 @@ When GTF is considered stable, it will replace legacy Celery
tasks for built-in
### Define a Task
```python
-from superset_core.api.types import task, get_context
+from superset_core.api.tasks import task, get_context
@task
def process_data(dataset_id: int) -> None:
@@ -245,7 +245,7 @@ Always implement an abort handler for long-running tasks.
This allows users to c
Set a timeout to automatically abort tasks that run too long:
```python
-from superset_core.api.types import task, get_context, TaskOptions
+from superset_core.api.tasks import task, get_context, TaskOptions
# Set default timeout in decorator
@task(timeout=300) # 5 minutes
@@ -299,7 +299,7 @@ Timeouts require an abort handler to be effective. Without
one, the timeout trig
Use `task_key` to prevent duplicate task execution:
```python
-from superset_core.api.types import TaskOptions
+from superset_core.api.tasks import TaskOptions
# Without key - creates new task each time (random UUID)
task1 = my_task.schedule(x=1)
@@ -331,7 +331,7 @@ print(task2.status) # "success" (terminal status)
## Task Scopes
```python
-from superset_core.api.types import task, TaskScope
+from superset_core.api.tasks import task, TaskScope
@task # Private by default
def private_task(): ...
diff --git a/superset-core/src/superset_core/api/tasks.py
b/superset-core/src/superset_core/api/tasks.py
index 1adcd9ab327..cc00689d622 100644
--- a/superset-core/src/superset_core/api/tasks.py
+++ b/superset-core/src/superset_core/api/tasks.py
@@ -259,7 +259,7 @@ def task(
is discarded; only side effects and context updates matter.
Example:
- from superset_core.api.types import task, get_context, TaskScope
+ from superset_core.api.tasks import task, get_context, TaskScope
# Private task (default scope)
@task
diff --git a/superset/daos/tasks.py b/superset/daos/tasks.py
index 8253cf6d579..c8b92cd948a 100644
--- a/superset/daos/tasks.py
+++ b/superset/daos/tasks.py
@@ -28,9 +28,9 @@ from superset.daos.exceptions import DAODeleteFailedError
from superset.extensions import db
from superset.models.task_subscribers import TaskSubscriber
from superset.models.tasks import Task
-from superset.tasks.constants import ABORTABLE_STATES
+from superset.tasks.constants import ABORTABLE_STATES, TERMINAL_STATES
from superset.tasks.filters import TaskFilter
-from superset.tasks.utils import get_active_dedup_key, json
+from superset.tasks.utils import get_active_dedup_key, get_finished_dedup_key,
json
logger = logging.getLogger(__name__)
@@ -243,7 +243,7 @@ class TaskDAO(BaseDAO[Task]):
)
# Transition to ABORTING (not ABORTED yet)
- task.status = TaskStatus.ABORTING.value
+ task.set_status(TaskStatus.ABORTING)
db.session.merge(task)
logger.info("Set task %s to ABORTING (scope: %s)", task_uuid,
task.scope)
@@ -444,6 +444,10 @@ class TaskDAO(BaseDAO[Task]):
if set_ended_at:
update_values["ended_at"] = datetime.now(timezone.utc)
+ # Update dedup_key if transitioning to terminal state
+ if new_status_val in TERMINAL_STATES:
+ update_values["dedup_key"] = get_finished_dedup_key(task_uuid)
+
# Atomic compare-and-swap: only update if status matches expected
rows_updated = (
db.session.query(Task)
diff --git a/superset/models/tasks.py b/superset/models/tasks.py
index 6c6995e9563..e7c3992f2bb 100644
--- a/superset/models/tasks.py
+++ b/superset/models/tasks.py
@@ -37,6 +37,7 @@ from superset_core.api.tasks import TaskProperties, TaskStatus
from superset.models.helpers import AuditMixinNullable
from superset.models.task_subscribers import TaskSubscriber
+from superset.tasks.constants import TERMINAL_STATES
from superset.tasks.utils import (
error_update,
get_finished_dedup_key,
@@ -218,12 +219,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
# (will be set to True if/when an abort handler is registered)
if self.properties_dict.get("is_abortable") is None:
self.update_properties({"is_abortable": False})
- elif status in [
- TaskStatus.SUCCESS.value,
- TaskStatus.FAILURE.value,
- TaskStatus.ABORTED.value,
- TaskStatus.TIMED_OUT.value,
- ]:
+ elif status in TERMINAL_STATES:
if not self.ended_at:
self.ended_at = now
# Update dedup_key to UUID to free up the slot for new tasks
@@ -244,12 +240,7 @@ class Task(CoreTask, AuditMixinNullable, Model):
@property
def is_finished(self) -> bool:
"""Check if task has finished (success, failure, aborted, or timed
out)."""
- return self.status in [
- TaskStatus.SUCCESS.value,
- TaskStatus.FAILURE.value,
- TaskStatus.ABORTED.value,
- TaskStatus.TIMED_OUT.value,
- ]
+ return self.status in TERMINAL_STATES
@property
def is_successful(self) -> bool:
diff --git a/superset/tasks/manager.py b/superset/tasks/manager.py
index f4595c51167..21b28c7d42b 100644
--- a/superset/tasks/manager.py
+++ b/superset/tasks/manager.py
@@ -112,9 +112,6 @@ class TaskManager:
_completion_channel_prefix: str = "gtf:complete:"
_initialized: bool = False
- # Backward compatibility alias - prefer importing from
superset.tasks.constants
- TERMINAL_STATES = TERMINAL_STATES
-
@classmethod
def init_app(cls, app: Flask) -> None:
"""
@@ -271,7 +268,7 @@ class TaskManager:
if not task:
raise ValueError(f"Task {task_uuid} not found")
- if task.status in cls.TERMINAL_STATES:
+ if task.status in TERMINAL_STATES:
return task
logger.debug(
@@ -342,13 +339,13 @@ class TaskManager:
message.get("data"),
)
task = get_task()
- if task and task.status in cls.TERMINAL_STATES:
+ if task and task.status in TERMINAL_STATES:
return task
# Also check database periodically in case we missed the
message
# (e.g., task completed before we subscribed)
task = get_task()
- if task and task.status in cls.TERMINAL_STATES:
+ if task and task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via db check): status=%s",
task_uuid,
@@ -384,7 +381,7 @@ class TaskManager:
if not task:
raise ValueError(f"Task {task_uuid} not found")
- if task.status in cls.TERMINAL_STATES:
+ if task.status in TERMINAL_STATES:
logger.debug(
"Task %s completed (detected via polling): status=%s",
task_uuid,
diff --git a/superset/tasks/schemas.py b/superset/tasks/schemas.py
index 9fe0b31ec7b..bf93c5d0c47 100644
--- a/superset/tasks/schemas.py
+++ b/superset/tasks/schemas.py
@@ -25,6 +25,9 @@ get_delete_ids_schema = {"type": "array", "items": {"type":
"string"}}
# Field descriptions
uuid_description = "The unique identifier (UUID) of the task"
task_key_description = "The task identifier used for deduplication"
+dedup_key_description = (
+ "The hashed deduplication key used internally for task deduplication"
+)
task_type_description = (
"The type of task (e.g., 'sql_execution', 'thumbnail_generation')"
)
@@ -74,6 +77,7 @@ class TaskResponseSchema(Schema):
id = fields.Int(metadata={"description": "Internal task ID"})
uuid = fields.UUID(metadata={"description": uuid_description})
task_key = fields.String(metadata={"description": task_key_description})
+ dedup_key = fields.String(metadata={"description": dedup_key_description})
task_type = fields.String(metadata={"description": task_type_description})
task_name = fields.String(
metadata={"description": task_name_description}, allow_none=True
diff --git a/tests/integration_tests/tasks/test_sync_join_wait.py
b/tests/integration_tests/tasks/test_sync_join_wait.py
index 9379efca1c7..9a611cd6b29 100644
--- a/tests/integration_tests/tasks/test_sync_join_wait.py
+++ b/tests/integration_tests/tasks/test_sync_join_wait.py
@@ -68,21 +68,6 @@ def test_submit_task_distinguishes_new_vs_existing(
db.session.commit()
-def test_terminal_states_recognized_correctly(app_context) -> None:
- """
- Test that TaskManager.TERMINAL_STATES contains the expected values.
- """
- assert TaskStatus.SUCCESS.value in TaskManager.TERMINAL_STATES
- assert TaskStatus.FAILURE.value in TaskManager.TERMINAL_STATES
- assert TaskStatus.ABORTED.value in TaskManager.TERMINAL_STATES
- assert TaskStatus.TIMED_OUT.value in TaskManager.TERMINAL_STATES
-
- # Non-terminal states should not be in the set
- assert TaskStatus.PENDING.value not in TaskManager.TERMINAL_STATES
- assert TaskStatus.IN_PROGRESS.value not in TaskManager.TERMINAL_STATES
- assert TaskStatus.ABORTING.value not in TaskManager.TERMINAL_STATES
-
-
def test_wait_for_completion_timeout(app_context, login_as, get_user) -> None:
"""
Test that wait_for_completion raises TimeoutError on timeout.
diff --git a/tests/unit_tests/daos/test_tasks.py
b/tests/unit_tests/daos/test_tasks.py
index f8f3bdc073a..8d767a17373 100644
--- a/tests/unit_tests/daos/test_tasks.py
+++ b/tests/unit_tests/daos/test_tasks.py
@@ -418,3 +418,86 @@ def test_get_status_not_found(session_with_task: Session)
-> None:
result = TaskDAO.get_status(UUID("00000000-0000-0000-0000-000000000000"))
assert result is None
+
+
+def test_conditional_status_update_non_terminal_state_keeps_dedup_key(
+ session_with_task: Session,
+) -> None:
+ """Test that conditional_status_update preserves dedup_key for
+ non-terminal transitions"""
+ from superset.daos.tasks import TaskDAO
+
+ # Create task in PENDING state
+ task = create_task(
+ session_with_task,
+ task_uuid=TASK_UUID,
+ task_key="non-terminal-test-task",
+ status=TaskStatus.PENDING,
+ )
+
+ # Store original active dedup_key
+ original_dedup_key = task.dedup_key
+
+ # Transition to non-terminal state (IN_PROGRESS)
+ result = TaskDAO.conditional_status_update(
+ task_uuid=TASK_UUID,
+ new_status=TaskStatus.IN_PROGRESS,
+ expected_status=TaskStatus.PENDING,
+ set_started_at=True,
+ )
+
+ # Should succeed
+ assert result is True
+
+ # Refresh task and verify dedup_key was NOT changed
+ session_with_task.refresh(task)
+ assert task.status == TaskStatus.IN_PROGRESS.value
+ assert task.dedup_key == original_dedup_key # Should remain the same
+ assert task.started_at is not None
+
+
[email protected](
+ "terminal_state",
+ [
+ TaskStatus.SUCCESS,
+ TaskStatus.FAILURE,
+ TaskStatus.ABORTED,
+ TaskStatus.TIMED_OUT,
+ ],
+)
+def test_conditional_status_update_terminal_state_updates_dedup_key(
+ session_with_task: Session, terminal_state: TaskStatus
+) -> None:
+ """Test that terminal states (SUCCESS, FAILURE, ABORTED, TIMED_OUT)
+ update dedup_key"""
+ from superset.daos.tasks import TaskDAO
+
+ task = create_task(
+ session_with_task,
+ task_uuid=TASK_UUID,
+ task_key=f"terminal-test-{terminal_state.value}",
+ status=TaskStatus.IN_PROGRESS,
+ )
+
+ original_dedup_key = task.dedup_key
+ expected_finished_key = get_finished_dedup_key(TASK_UUID)
+
+ # Transition to terminal state
+ result = TaskDAO.conditional_status_update(
+ task_uuid=TASK_UUID,
+ new_status=terminal_state,
+ expected_status=TaskStatus.IN_PROGRESS,
+ set_ended_at=True,
+ )
+
+ assert result is True, f"Failed to update to {terminal_state.value}"
+
+ # Verify dedup_key was updated
+ session_with_task.refresh(task)
+ assert task.status == terminal_state.value
+ assert task.dedup_key == expected_finished_key, (
+ f"dedup_key not updated for {terminal_state.value}"
+ )
+ assert task.dedup_key != original_dedup_key, (
+ f"dedup_key should have changed for {terminal_state.value}"
+ )
diff --git a/tests/unit_tests/tasks/test_manager.py
b/tests/unit_tests/tasks/test_manager.py
index 13997a7f113..4fc77c2e080 100644
--- a/tests/unit_tests/tasks/test_manager.py
+++ b/tests/unit_tests/tasks/test_manager.py
@@ -455,8 +455,3 @@ class TestTaskManagerCompletion:
timeout=5.0,
poll_interval=0.1,
)
-
- def test_terminal_states_constant(self):
- """Test TERMINAL_STATES contains expected values"""
- expected = {"success", "failure", "aborted", "timed_out"}
- assert TaskManager.TERMINAL_STATES == expected