Coverage for backend/common/src/common/services/messaging.py: 90%
80 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
1import os
2from dapr.clients import DaprClient
3from opentelemetry import trace
4from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
5from pydantic import BaseModel
7from common.models.scenario_import import ScenarioCsvImportRequestPayload
8from common.models.pinch_import import PinchUtilityCsvImportRequestPayload
9from ahuora_builder_types.payloads.build_state_request_schema import (
10 BuildStateCompletionPayload,
11 BuildStateRequestSchema,
12)
13from common.models.solve_completion_email import SolveCompletionEmailRequestPayload
14from ahuora_builder_types.payloads.solve_request_schema import (
15 IdaesSolveRequestPayload,
16 IdaesSolveCompletionPayload,
17 MultiSolvePayload,
18)
19from ahuora_builder_types.payloads.ml_request_schema import (
20 MLTrainRequestPayload,
21 MLTrainingCompletionPayload,
22)
23from common.models.notifications import NotificationServiceEnvelope
24from common.models.general import TaskPayload
25from common.models.notifications.payloads import (
26 NotificationServiceMessage,
27 NotificationServiceMessageType,
28)
30DAPR_PUBSUB_NAME = os.getenv("DAPR_PUBSUB_NAME", "rabbitmq-pubsub")
31TASK_CANCELLATION_DAPR_PUBSUB_NAME = os.getenv(
32 "TASK_CANCELLATION_DAPR_PUBSUB_NAME",
33 "rabbitmq-pubsub-task-cancellation",
34)
35IDAES_SOLVE_TOPIC = os.getenv("IDAES_SOLVE_TOPIC", "idaes-solve")
36IDAES_BUILD_STATE_REQUEST_TOPIC = os.getenv(
37 "IDAES_BUILD_STATE_REQUEST_TOPIC", "idaes-build-state-request"
38)
39IDAES_BUILD_STATE_REQUEST_TTL_SECONDS = os.getenv(
40 "IDAES_BUILD_STATE_REQUEST_TTL_SECONDS", "300"
41)
42IDAES_BUILD_STATE_RESPONSE_TOPIC = os.getenv(
43 "IDAES_BUILD_STATE_RESPONSE_TOPIC", "idaes-build-state-response"
44)
45IDAES_SOLVE_COMPLETION_TOPIC = os.getenv(
46 "IDAES_SOLVE_COMPLETION_TOPIC", "idaes-solve-completion"
47)
48DISPATCH_MULTI_SOLVE_TOPIC = os.getenv(
49 "DISPATCH_MULTI_SOLVE_TOPIC", "dispatch-multi-solve"
50)
51TASK_RUNNING_TOPIC = os.getenv("TASK_RUNNING_TOPIC", "task-running")
52TASK_CANCEL_TOPIC = os.getenv("TASK_CANCEL_TOPIC", "task-cancel")
53TASK_CANCELLED_TOPIC = os.getenv("TASK_CANCELLED_TOPIC", "task-cancelled")
54ML_TRAINING_TOPIC = os.getenv("ML_TRAINING_TOPIC", "ml-training")
55ML_TRAINING_COMPLETION_TOPIC = os.getenv(
56 "ML_TRAINING_COMPLETION_TOPIC", "ml-training-completion"
57)
58SCENARIO_CSV_IMPORT_TOPIC = os.getenv(
59 "SCENARIO_CSV_IMPORT_TOPIC", "scenario-csv-import"
60)
61PINCH_UTILITY_CSV_IMPORT_TOPIC = os.getenv(
62 "PINCH_UTILITY_CSV_IMPORT_TOPIC", "pinch-utility-csv-import"
63)
64SOLVE_COMPLETION_EMAIL_TOPIC = os.getenv(
65 "SOLVE_COMPLETION_EMAIL_TOPIC", "solve-completion-email"
66)
67USER_NOTIFICATION_TOPIC = os.getenv(
68 "USER_NOTIFICATION_TOPIC", "user-notification")
70# RabbitMQ quorum queues support only "normal" and "high" priority messages,
71# with any priority above 4 being treated as "high" priority.
72HIGH_PRIORITY = "5"
73NORMAL_PRIORITY = "0"
76def _send_message(
77 payload: BaseModel,
78 topic: str,
79 priority: str = NORMAL_PRIORITY,
80 pubsub_name: str = DAPR_PUBSUB_NAME,
81 publish_metadata: dict[str, str] | None = None,
82):
83 """Publish a Pydantic payload to the configured Dapr pub/sub topic.
85 Args:
86 payload: Pydantic model to serialise and send.
87 topic: Pub/sub topic that should receive the event.
88 priority: Optional RabbitMQ priority to tag onto the message metadata.
89 pubsub_name: Dapr pub/sub component name to publish through.
90 publish_metadata: Extra Dapr publish metadata to merge onto the event.
91 """
92 metadata = _get_dapr_tracing_headers() | {"priority": priority}
93 if publish_metadata:
94 metadata |= publish_metadata
96 with DaprClient() as dapr:
97 dapr.publish_event(
98 pubsub_name=pubsub_name,
99 topic_name=topic, # This should be an environment variable
100 data=payload.model_dump_json(),
101 data_content_type="application/json",
102 publish_metadata=metadata,
103 )
106def _get_dapr_tracing_headers():
107 """Extract W3C trace headers for forwarding to Dapr if a span is active."""
108 current_span = trace.get_current_span()
109 w3c_trace_headers = {}
111 # If there is an active span, and it is recording, we can pass on
112 # trace metadata to Dapr for distributed tracing.
113 if current_span.is_recording():
114 TraceContextTextMapPropagator().inject(w3c_trace_headers)
115 w3c_trace_headers = {
116 "cloudevent.traceparent": w3c_trace_headers.get("traceparent", None),
117 "cloudevent.traceid": w3c_trace_headers.get("traceparent", None),
118 }
120 return w3c_trace_headers
123def send_idaes_solve_message(
124 payload: IdaesSolveRequestPayload, high_priority: bool = False
125):
126 """
127 Send a solve message to IDAES service asynchronously. A message can be marked as high priority,
128 meaning it will be processed at a two to one ratio to normal priority solves.
129 """
131 priority = HIGH_PRIORITY if high_priority else NORMAL_PRIORITY
132 _send_message(payload, IDAES_SOLVE_TOPIC, priority=priority)
135def send_idaes_build_state_request_message(payload: BuildStateRequestSchema):
136 """Send a build-state request to the IDAES service asynchronously."""
137 _send_message(
138 payload,
139 IDAES_BUILD_STATE_REQUEST_TOPIC,
140 publish_metadata={"ttlInSeconds": IDAES_BUILD_STATE_REQUEST_TTL_SECONDS},
141 )
144def send_idaes_build_state_response_message(payload: BuildStateCompletionPayload):
145 """Publish a build-state completion event emitted by the IDAES service."""
146 _send_message(payload, IDAES_BUILD_STATE_RESPONSE_TOPIC)
149def send_idaes_solve_completion_message(payload: IdaesSolveCompletionPayload):
150 """Publish a flowsheet solve completion event emitted by the IDAES solver service."""
151 _send_message(payload, IDAES_SOLVE_COMPLETION_TOPIC)
154def send_dispatch_multi_solve_message(payload: MultiSolvePayload):
155 """Publish a request instructing a dispatcher to construct and send off multi-solve tasks."""
156 _send_message(payload, DISPATCH_MULTI_SOLVE_TOPIC)
159def send_ml_training_message(payload: MLTrainRequestPayload):
160 """Send a message to trigger asynchronous ML training for a flowsheet."""
161 _send_message(payload, ML_TRAINING_TOPIC)
164def send_ml_training_completion_message(payload: MLTrainingCompletionPayload):
165 """Announce that a remote ML training task has completed and provide the results."""
166 _send_message(payload, ML_TRAINING_COMPLETION_TOPIC)
169def send_scenario_csv_import_message(payload: ScenarioCsvImportRequestPayload):
170 """Send a message to trigger asynchronous scenario CSV import in Django."""
171 _send_message(payload, SCENARIO_CSV_IMPORT_TOPIC)
174def send_pinch_utility_csv_import_message(payload: PinchUtilityCsvImportRequestPayload):
175 """Send a message to trigger asynchronous Pinch utility CSV import in Django."""
176 _send_message(payload, PINCH_UTILITY_CSV_IMPORT_TOPIC)
179def send_solve_completion_email_message(payload: SolveCompletionEmailRequestPayload):
180 """Publish a request to render and deliver a solve completion email."""
181 _send_message(payload, SOLVE_COMPLETION_EMAIL_TOPIC)
184def send_task_running_message(task_id: int):
185 """Notify listeners that a task has transitioned into the running state."""
186 _send_message(TaskPayload(task_id=task_id), TASK_RUNNING_TOPIC)
189def send_task_cancel_message(task_id: int):
190 """Publish a task-cancel request via the dedicated fanout pub/sub component.
192 Cancellation must reach every IDAES replica so whichever instance is
193 actively solving the task (or one of its children) can observe the signal.
194 """
195 _send_message(
196 TaskPayload(task_id=task_id),
197 TASK_CANCEL_TOPIC,
198 pubsub_name=TASK_CANCELLATION_DAPR_PUBSUB_NAME,
199 )
202def send_task_cancelled_message(task_id: int, *, timed_out: bool = False):
203 """Notify Django that the remote solver has acknowledged cancellation.
205 `timed_out` distinguishes timeout-driven termination from explicit user
206 cancellation so Django/UI layers can surface an accurate reason.
207 """
208 _send_message(
209 TaskPayload(task_id=task_id, timed_out=timed_out),
210 TASK_CANCELLED_TOPIC,
211 )
214def send_flowsheet_notification_message(
215 flowsheet_id: int, message_data: dict, message_type: NotificationServiceMessageType
216):
217 """Send a single notification message for a specific flowsheet to clients."""
218 message = NotificationServiceMessage(
219 data=message_data, message_type=message_type)
220 envelope = NotificationServiceEnvelope(
221 flowsheet_id=flowsheet_id, messages=[message]
222 )
224 _send_message(envelope, USER_NOTIFICATION_TOPIC)
227def send_flowsheet_notification_messages(
228 flowsheet_id: int, messages: list[NotificationServiceMessage]
229):
230 """Send a batch of notification messages for a specific flowsheet."""
231 envelope = NotificationServiceEnvelope(
232 flowsheet_id=flowsheet_id, messages=messages)
233 _send_message(envelope, USER_NOTIFICATION_TOPIC)