Coverage for backend/common/src/common/services/messaging.py: 94%
50 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
1import os
3from dapr.clients import DaprClient
4from opentelemetry import trace
5from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
6from pydantic import BaseModel
8from common.models.general import TaskRunningPayload
9from ahuora_builder_types.payloads.solve_request_schema import (
10 IdaesSolveRequestPayload,
11 IdaesSolveCompletionPayload,
12 MultiSolvePayload,
13)
14from common.models.idaes.payloads.ml_request_schema import (
15 MLTrainRequestPayload,
16 MLTrainingCompletionEvent,
17)
18from common.models.notifications import NotificationServiceEnvelope
19from common.models.notifications.payloads import (
20 NotificationServiceMessage,
21 NotificationServiceMessageType,
22)
24DAPR_PUBSUB_NAME = os.getenv("DAPR_PUBSUB_NAME", "rabbitmq-pubsub")
25IDAES_SOLVE_TOPIC = os.getenv("IDAES_SOLVE_TOPIC", "idaes-solve")
26IDAES_SOLVE_COMPLETION_TOPIC = os.getenv(
27 "IDAES_SOLVE_COMPLETION_TOPIC", "idaes-solve-completion"
28)
29DISPATCH_MULTI_SOLVE_TOPIC = os.getenv(
30 "DISPATCH_MULTI_SOLVE_TOPIC", "dispatch-multi-solve"
31)
32TASK_RUNNING_TOPIC = os.getenv("TASK_RUNNING_TOPIC", "task-running")
33ML_TRAINING_TOPIC = os.getenv("ML_TRAINING_TOPIC", "ml-training")
34ML_TRAINING_COMPLETION_TOPIC = os.getenv(
35 "ML_TRAINING_COMPLETION_TOPIC", "ml-training-completion"
36)
37USER_NOTIFICATION_TOPIC = os.getenv("USER_NOTIFICATION_TOPIC", "user-notification")
39# RabbitMQ quorum queues support only "normal" and "high" priority messages,
40# with any priority above 4 being treated as "high" priority.
41HIGH_PRIORITY = "5"
42NORMAL_PRIORITY = "0"
45def _send_message(payload: BaseModel, topic: str, priority: str = NORMAL_PRIORITY):
46 """Publish a Pydantic payload to the configured Dapr pub/sub topic.
48 Args:
49 payload: Pydantic model to serialise and send.
50 topic: Pub/sub topic that should receive the event.
51 priority: Optional RabbitMQ priority to tag onto the message metadata.
52 """
53 with DaprClient() as dapr:
54 dapr.publish_event(
55 pubsub_name=DAPR_PUBSUB_NAME, # This should be an environment variable
56 topic_name=topic, # This should be an environment variable
57 data=payload.model_dump_json(),
58 data_content_type="application/json",
59 publish_metadata=_get_dapr_tracing_headers() | {"priority": priority},
60 )
63def _get_dapr_tracing_headers():
64 """Extract W3C trace headers for forwarding to Dapr if a span is active."""
65 current_span = trace.get_current_span()
66 w3c_trace_headers = {}
68 # If there is an active span, and it is recording, we can pass on
69 # trace metadata to Dapr for distributed tracing.
70 if current_span.is_recording():
71 TraceContextTextMapPropagator().inject(w3c_trace_headers)
72 w3c_trace_headers = {
73 "cloudevent.traceparent": w3c_trace_headers.get("traceparent", None),
74 "cloudevent.traceid": w3c_trace_headers.get("traceparent", None),
75 }
77 return w3c_trace_headers
80def send_idaes_solve_message(
81 payload: IdaesSolveRequestPayload, high_priority: bool = False
82):
83 """
84 Send a solve message to IDAES service asynchronously. A message can be marked as high priority,
85 meaning it will be processed at a two to one ratio to normal priority solves.
86 """
88 priority = HIGH_PRIORITY if high_priority else NORMAL_PRIORITY
89 _send_message(payload, IDAES_SOLVE_TOPIC, priority=priority)
92def send_idaes_solve_completion_message(payload: IdaesSolveCompletionPayload):
93 """Publish a flowsheet solve completion event emitted by the IDAES solver service."""
94 _send_message(payload, IDAES_SOLVE_COMPLETION_TOPIC)
97def send_dispatch_multi_solve_message(payload: MultiSolvePayload):
98 """Publish a request instructing a dispatcher to construct and send off multi-solve tasks."""
99 _send_message(payload, DISPATCH_MULTI_SOLVE_TOPIC)
102def send_ml_training_message(payload: MLTrainRequestPayload):
103 """Send a message to trigger asynchronous ML training for a flowsheet."""
104 _send_message(payload, ML_TRAINING_TOPIC)
107def send_ml_training_completion_message(payload: MLTrainingCompletionEvent):
108 """Announce that a remote ML training task has completed and provide the results."""
109 _send_message(payload, ML_TRAINING_COMPLETION_TOPIC)
112def send_task_running_message(task_id: int):
113 """Notify listeners that a task has transitioned into the running state."""
114 _send_message(TaskRunningPayload(task_id=task_id), TASK_RUNNING_TOPIC)
117def send_flowsheet_notification_message(
118 flowsheet_id: int, message_data: dict, message_type: NotificationServiceMessageType
119):
120 """Send a single notification message for a specific flowsheet to clients."""
121 message = NotificationServiceMessage(data=message_data, message_type=message_type)
122 envelope = NotificationServiceEnvelope(
123 flowsheet_id=flowsheet_id, messages=[message]
124 )
126 _send_message(envelope, USER_NOTIFICATION_TOPIC)
129def send_flowsheet_notification_messages(
130 flowsheet_id: int, messages: list[NotificationServiceMessage]
131):
132 """Send a batch of notification messages for a specific flowsheet."""
133 envelope = NotificationServiceEnvelope(flowsheet_id=flowsheet_id, messages=messages)
134 _send_message(envelope, USER_NOTIFICATION_TOPIC)