Coverage for backend/common/src/common/services/task_cancellation_state.py: 53%
41 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
1"""Shared helpers for Redis-backed task-cancellation state."""
3import logging
4import os
5from collections.abc import Iterable
7from dapr.clients import DaprClient
8from dapr.clients.grpc._state import StateItem
10logger = logging.getLogger(__name__)
12TASK_CANCELLATION_STATE_STORE_NAME = os.getenv(
13 "TASK_CANCELLATION_STATE_STORE_NAME",
14 "task-cancellation-store",
15)
16TASK_CANCELLATION_STATE_TTL_SECONDS = int(
17 os.getenv("TASK_CANCELLATION_STATE_TTL_SECONDS", "1800")
18)
19CANCELLED_TASK_STATE_VALUE = "1"
20MAX_BULK_STATE_PARALLELISM = 50
23def _task_cancellation_key(task_id: int) -> str:
24 """Build the Redis/Dapr state-store key for a cancelled task identifier.
26 This format must remain stable because cancellation keys can outlive a
27 single service process while the shared-state TTL is still active.
28 """
29 return f"cancelled-task:{task_id}"
32def cache_cancelled_tasks(
33 task_ids: Iterable[int],
34 ttl_seconds: int | None = None,
35) -> bool:
36 """Persist cancelled task identifiers in the shared Dapr state store."""
37 ttl_seconds = ttl_seconds or TASK_CANCELLATION_STATE_TTL_SECONDS
38 candidate_ids = list(dict.fromkeys(task_ids))
39 if not candidate_ids: 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true
40 return True
42 try:
43 with DaprClient() as dapr:
44 dapr.save_bulk_state(
45 TASK_CANCELLATION_STATE_STORE_NAME,
46 [
47 StateItem(
48 key=_task_cancellation_key(task_id),
49 value=CANCELLED_TASK_STATE_VALUE,
50 metadata={"ttlInSeconds": str(ttl_seconds)},
51 )
52 for task_id in candidate_ids
53 ],
54 )
55 except Exception:
56 logger.warning(
57 "Failed to cache cancellation state for task IDs %s.",
58 candidate_ids,
59 exc_info=True,
60 )
61 return False
63 return True
66def cache_cancelled_task(task_id: int, ttl_seconds: int | None = None) -> bool:
67 """Persist a cancelled task identifier in the shared Dapr state store."""
68 return cache_cancelled_tasks([task_id], ttl_seconds=ttl_seconds)
71def find_cancelled_task_id(task_ids: Iterable[int | None]) -> int | None:
72 """Return the first task identifier that is currently cached as cancelled."""
73 candidate_ids = [task_id for task_id in task_ids if task_id is not None]
74 if not candidate_ids:
75 return None
77 try:
78 with DaprClient() as dapr:
79 state_items = dapr.get_bulk_state(
80 TASK_CANCELLATION_STATE_STORE_NAME,
81 [_task_cancellation_key(task_id) for task_id in candidate_ids],
82 parallelism=min(len(candidate_ids), MAX_BULK_STATE_PARALLELISM),
83 ).items
84 except Exception:
85 logger.warning(
86 "Failed to read cancellation state for task IDs %s; continuing without shared-state gating.",
87 candidate_ids,
88 exc_info=True,
89 )
90 return None
92 cancelled_keys = {
93 item.key
94 for item in state_items
95 if item.data == CANCELLED_TASK_STATE_VALUE and not item.error
96 }
97 for task_id in candidate_ids:
98 if _task_cancellation_key(task_id) in cancelled_keys:
99 return task_id
101 return None