Coverage for backend/common/src/common/services/messaging.py: 90%

80 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-05-13 02:47 +0000

1import os 

2from dapr.clients import DaprClient 

3from opentelemetry import trace 

4from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 

5from pydantic import BaseModel 

6 

7from common.models.scenario_import import ScenarioCsvImportRequestPayload 

8from common.models.pinch_import import PinchUtilityCsvImportRequestPayload 

9from ahuora_builder_types.payloads.build_state_request_schema import ( 

10 BuildStateCompletionPayload, 

11 BuildStateRequestSchema, 

12) 

13from common.models.solve_completion_email import SolveCompletionEmailRequestPayload 

14from ahuora_builder_types.payloads.solve_request_schema import ( 

15 IdaesSolveRequestPayload, 

16 IdaesSolveCompletionPayload, 

17 MultiSolvePayload, 

18) 

19from ahuora_builder_types.payloads.ml_request_schema import ( 

20 MLTrainRequestPayload, 

21 MLTrainingCompletionPayload, 

22) 

23from common.models.notifications import NotificationServiceEnvelope 

24from common.models.general import TaskPayload 

25from common.models.notifications.payloads import ( 

26 NotificationServiceMessage, 

27 NotificationServiceMessageType, 

28) 

29 

30DAPR_PUBSUB_NAME = os.getenv("DAPR_PUBSUB_NAME", "rabbitmq-pubsub") 

31TASK_CANCELLATION_DAPR_PUBSUB_NAME = os.getenv( 

32 "TASK_CANCELLATION_DAPR_PUBSUB_NAME", 

33 "rabbitmq-pubsub-task-cancellation", 

34) 

35IDAES_SOLVE_TOPIC = os.getenv("IDAES_SOLVE_TOPIC", "idaes-solve") 

36IDAES_BUILD_STATE_REQUEST_TOPIC = os.getenv( 

37 "IDAES_BUILD_STATE_REQUEST_TOPIC", "idaes-build-state-request" 

38) 

39IDAES_BUILD_STATE_REQUEST_TTL_SECONDS = os.getenv( 

40 "IDAES_BUILD_STATE_REQUEST_TTL_SECONDS", "300" 

41) 

42IDAES_BUILD_STATE_RESPONSE_TOPIC = os.getenv( 

43 "IDAES_BUILD_STATE_RESPONSE_TOPIC", "idaes-build-state-response" 

44) 

45IDAES_SOLVE_COMPLETION_TOPIC = os.getenv( 

46 "IDAES_SOLVE_COMPLETION_TOPIC", "idaes-solve-completion" 

47) 

48DISPATCH_MULTI_SOLVE_TOPIC = os.getenv( 

49 "DISPATCH_MULTI_SOLVE_TOPIC", "dispatch-multi-solve" 

50) 

51TASK_RUNNING_TOPIC = os.getenv("TASK_RUNNING_TOPIC", "task-running") 

52TASK_CANCEL_TOPIC = os.getenv("TASK_CANCEL_TOPIC", "task-cancel") 

53TASK_CANCELLED_TOPIC = os.getenv("TASK_CANCELLED_TOPIC", "task-cancelled") 

54ML_TRAINING_TOPIC = os.getenv("ML_TRAINING_TOPIC", "ml-training") 

55ML_TRAINING_COMPLETION_TOPIC = os.getenv( 

56 "ML_TRAINING_COMPLETION_TOPIC", "ml-training-completion" 

57) 

58SCENARIO_CSV_IMPORT_TOPIC = os.getenv( 

59 "SCENARIO_CSV_IMPORT_TOPIC", "scenario-csv-import" 

60) 

61PINCH_UTILITY_CSV_IMPORT_TOPIC = os.getenv( 

62 "PINCH_UTILITY_CSV_IMPORT_TOPIC", "pinch-utility-csv-import" 

63) 

64SOLVE_COMPLETION_EMAIL_TOPIC = os.getenv( 

65 "SOLVE_COMPLETION_EMAIL_TOPIC", "solve-completion-email" 

66) 

67USER_NOTIFICATION_TOPIC = os.getenv( 

68 "USER_NOTIFICATION_TOPIC", "user-notification") 

69 

70# RabbitMQ quorum queues support only "normal" and "high" priority messages, 

71# with any priority above 4 being treated as "high" priority. 

72HIGH_PRIORITY = "5" 

73NORMAL_PRIORITY = "0" 

74 

75 

76def _send_message( 

77 payload: BaseModel, 

78 topic: str, 

79 priority: str = NORMAL_PRIORITY, 

80 pubsub_name: str = DAPR_PUBSUB_NAME, 

81 publish_metadata: dict[str, str] | None = None, 

82): 

83 """Publish a Pydantic payload to the configured Dapr pub/sub topic. 

84 

85 Args: 

86 payload: Pydantic model to serialise and send. 

87 topic: Pub/sub topic that should receive the event. 

88 priority: Optional RabbitMQ priority to tag onto the message metadata. 

89 pubsub_name: Dapr pub/sub component name to publish through. 

90 publish_metadata: Extra Dapr publish metadata to merge onto the event. 

91 """ 

92 metadata = _get_dapr_tracing_headers() | {"priority": priority} 

93 if publish_metadata: 

94 metadata |= publish_metadata 

95 

96 with DaprClient() as dapr: 

97 dapr.publish_event( 

98 pubsub_name=pubsub_name, 

99 topic_name=topic, # This should be an environment variable 

100 data=payload.model_dump_json(), 

101 data_content_type="application/json", 

102 publish_metadata=metadata, 

103 ) 

104 

105 

106def _get_dapr_tracing_headers(): 

107 """Extract W3C trace headers for forwarding to Dapr if a span is active.""" 

108 current_span = trace.get_current_span() 

109 w3c_trace_headers = {} 

110 

111 # If there is an active span, and it is recording, we can pass on 

112 # trace metadata to Dapr for distributed tracing. 

113 if current_span.is_recording(): 

114 TraceContextTextMapPropagator().inject(w3c_trace_headers) 

115 w3c_trace_headers = { 

116 "cloudevent.traceparent": w3c_trace_headers.get("traceparent", None), 

117 "cloudevent.traceid": w3c_trace_headers.get("traceparent", None), 

118 } 

119 

120 return w3c_trace_headers 

121 

122 

123def send_idaes_solve_message( 

124 payload: IdaesSolveRequestPayload, high_priority: bool = False 

125): 

126 """ 

127 Send a solve message to IDAES service asynchronously. A message can be marked as high priority, 

128 meaning it will be processed at a two to one ratio to normal priority solves. 

129 """ 

130 

131 priority = HIGH_PRIORITY if high_priority else NORMAL_PRIORITY 

132 _send_message(payload, IDAES_SOLVE_TOPIC, priority=priority) 

133 

134 

135def send_idaes_build_state_request_message(payload: BuildStateRequestSchema): 

136 """Send a build-state request to the IDAES service asynchronously.""" 

137 _send_message( 

138 payload, 

139 IDAES_BUILD_STATE_REQUEST_TOPIC, 

140 publish_metadata={"ttlInSeconds": IDAES_BUILD_STATE_REQUEST_TTL_SECONDS}, 

141 ) 

142 

143 

144def send_idaes_build_state_response_message(payload: BuildStateCompletionPayload): 

145 """Publish a build-state completion event emitted by the IDAES service.""" 

146 _send_message(payload, IDAES_BUILD_STATE_RESPONSE_TOPIC) 

147 

148 

149def send_idaes_solve_completion_message(payload: IdaesSolveCompletionPayload): 

150 """Publish a flowsheet solve completion event emitted by the IDAES solver service.""" 

151 _send_message(payload, IDAES_SOLVE_COMPLETION_TOPIC) 

152 

153 

154def send_dispatch_multi_solve_message(payload: MultiSolvePayload): 

155 """Publish a request instructing a dispatcher to construct and send off multi-solve tasks.""" 

156 _send_message(payload, DISPATCH_MULTI_SOLVE_TOPIC) 

157 

158 

159def send_ml_training_message(payload: MLTrainRequestPayload): 

160 """Send a message to trigger asynchronous ML training for a flowsheet.""" 

161 _send_message(payload, ML_TRAINING_TOPIC) 

162 

163 

164def send_ml_training_completion_message(payload: MLTrainingCompletionPayload): 

165 """Announce that a remote ML training task has completed and provide the results.""" 

166 _send_message(payload, ML_TRAINING_COMPLETION_TOPIC) 

167 

168 

169def send_scenario_csv_import_message(payload: ScenarioCsvImportRequestPayload): 

170 """Send a message to trigger asynchronous scenario CSV import in Django.""" 

171 _send_message(payload, SCENARIO_CSV_IMPORT_TOPIC) 

172 

173 

174def send_pinch_utility_csv_import_message(payload: PinchUtilityCsvImportRequestPayload): 

175 """Send a message to trigger asynchronous Pinch utility CSV import in Django.""" 

176 _send_message(payload, PINCH_UTILITY_CSV_IMPORT_TOPIC) 

177 

178 

179def send_solve_completion_email_message(payload: SolveCompletionEmailRequestPayload): 

180 """Publish a request to render and deliver a solve completion email.""" 

181 _send_message(payload, SOLVE_COMPLETION_EMAIL_TOPIC) 

182 

183 

184def send_task_running_message(task_id: int): 

185 """Notify listeners that a task has transitioned into the running state.""" 

186 _send_message(TaskPayload(task_id=task_id), TASK_RUNNING_TOPIC) 

187 

188 

189def send_task_cancel_message(task_id: int): 

190 """Publish a task-cancel request via the dedicated fanout pub/sub component. 

191 

192 Cancellation must reach every IDAES replica so whichever instance is 

193 actively solving the task (or one of its children) can observe the signal. 

194 """ 

195 _send_message( 

196 TaskPayload(task_id=task_id), 

197 TASK_CANCEL_TOPIC, 

198 pubsub_name=TASK_CANCELLATION_DAPR_PUBSUB_NAME, 

199 ) 

200 

201 

202def send_task_cancelled_message(task_id: int, *, timed_out: bool = False): 

203 """Notify Django that the remote solver has acknowledged cancellation. 

204 

205 `timed_out` distinguishes timeout-driven termination from explicit user 

206 cancellation so Django/UI layers can surface an accurate reason. 

207 """ 

208 _send_message( 

209 TaskPayload(task_id=task_id, timed_out=timed_out), 

210 TASK_CANCELLED_TOPIC, 

211 ) 

212 

213 

214def send_flowsheet_notification_message( 

215 flowsheet_id: int, message_data: dict, message_type: NotificationServiceMessageType 

216): 

217 """Send a single notification message for a specific flowsheet to clients.""" 

218 message = NotificationServiceMessage( 

219 data=message_data, message_type=message_type) 

220 envelope = NotificationServiceEnvelope( 

221 flowsheet_id=flowsheet_id, messages=[message] 

222 ) 

223 

224 _send_message(envelope, USER_NOTIFICATION_TOPIC) 

225 

226 

227def send_flowsheet_notification_messages( 

228 flowsheet_id: int, messages: list[NotificationServiceMessage] 

229): 

230 """Send a batch of notification messages for a specific flowsheet.""" 

231 envelope = NotificationServiceEnvelope( 

232 flowsheet_id=flowsheet_id, messages=messages) 

233 _send_message(envelope, USER_NOTIFICATION_TOPIC)