Coverage for backend/common/src/common/services/task_cancellation_state.py: 53%

41 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-05-13 02:47 +0000

1"""Shared helpers for Redis-backed task-cancellation state.""" 

2 

3import logging 

4import os 

5from collections.abc import Iterable 

6 

7from dapr.clients import DaprClient 

8from dapr.clients.grpc._state import StateItem 

9 

10logger = logging.getLogger(__name__) 

11 

12TASK_CANCELLATION_STATE_STORE_NAME = os.getenv( 

13 "TASK_CANCELLATION_STATE_STORE_NAME", 

14 "task-cancellation-store", 

15) 

16TASK_CANCELLATION_STATE_TTL_SECONDS = int( 

17 os.getenv("TASK_CANCELLATION_STATE_TTL_SECONDS", "1800") 

18) 

19CANCELLED_TASK_STATE_VALUE = "1" 

20MAX_BULK_STATE_PARALLELISM = 50 

21 

22 

23def _task_cancellation_key(task_id: int) -> str: 

24 """Build the Redis/Dapr state-store key for a cancelled task identifier. 

25 

26 This format must remain stable because cancellation keys can outlive a 

27 single service process while the shared-state TTL is still active. 

28 """ 

29 return f"cancelled-task:{task_id}" 

30 

31 

32def cache_cancelled_tasks( 

33 task_ids: Iterable[int], 

34 ttl_seconds: int | None = None, 

35) -> bool: 

36 """Persist cancelled task identifiers in the shared Dapr state store.""" 

37 ttl_seconds = ttl_seconds or TASK_CANCELLATION_STATE_TTL_SECONDS 

38 candidate_ids = list(dict.fromkeys(task_ids)) 

39 if not candidate_ids: 39 ↛ 40line 39 didn't jump to line 40 because the condition on line 39 was never true

40 return True 

41 

42 try: 

43 with DaprClient() as dapr: 

44 dapr.save_bulk_state( 

45 TASK_CANCELLATION_STATE_STORE_NAME, 

46 [ 

47 StateItem( 

48 key=_task_cancellation_key(task_id), 

49 value=CANCELLED_TASK_STATE_VALUE, 

50 metadata={"ttlInSeconds": str(ttl_seconds)}, 

51 ) 

52 for task_id in candidate_ids 

53 ], 

54 ) 

55 except Exception: 

56 logger.warning( 

57 "Failed to cache cancellation state for task IDs %s.", 

58 candidate_ids, 

59 exc_info=True, 

60 ) 

61 return False 

62 

63 return True 

64 

65 

66def cache_cancelled_task(task_id: int, ttl_seconds: int | None = None) -> bool: 

67 """Persist a cancelled task identifier in the shared Dapr state store.""" 

68 return cache_cancelled_tasks([task_id], ttl_seconds=ttl_seconds) 

69 

70 

71def find_cancelled_task_id(task_ids: Iterable[int | None]) -> int | None: 

72 """Return the first task identifier that is currently cached as cancelled.""" 

73 candidate_ids = [task_id for task_id in task_ids if task_id is not None] 

74 if not candidate_ids: 

75 return None 

76 

77 try: 

78 with DaprClient() as dapr: 

79 state_items = dapr.get_bulk_state( 

80 TASK_CANCELLATION_STATE_STORE_NAME, 

81 [_task_cancellation_key(task_id) for task_id in candidate_ids], 

82 parallelism=min(len(candidate_ids), MAX_BULK_STATE_PARALLELISM), 

83 ).items 

84 except Exception: 

85 logger.warning( 

86 "Failed to read cancellation state for task IDs %s; continuing without shared-state gating.", 

87 candidate_ids, 

88 exc_info=True, 

89 ) 

90 return None 

91 

92 cancelled_keys = { 

93 item.key 

94 for item in state_items 

95 if item.data == CANCELLED_TASK_STATE_VALUE and not item.error 

96 } 

97 for task_id in candidate_ids: 

98 if _task_cancellation_key(task_id) in cancelled_keys: 

99 return task_id 

100 

101 return None