Coverage for backend/common/src/common/services/messaging.py: 94%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-26 20:57 +0000

1import os 

2 

3from dapr.clients import DaprClient 

4from opentelemetry import trace 

5from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 

6from pydantic import BaseModel 

7 

8from common.models.general import TaskRunningPayload 

9from ahuora_builder_types.payloads.solve_request_schema import ( 

10 IdaesSolveRequestPayload, 

11 IdaesSolveCompletionPayload, 

12 MultiSolvePayload, 

13) 

14from common.models.idaes.payloads.ml_request_schema import ( 

15 MLTrainRequestPayload, 

16 MLTrainingCompletionEvent, 

17) 

18from common.models.notifications import NotificationServiceEnvelope 

19from common.models.notifications.payloads import ( 

20 NotificationServiceMessage, 

21 NotificationServiceMessageType, 

22) 

23 

24DAPR_PUBSUB_NAME = os.getenv("DAPR_PUBSUB_NAME", "rabbitmq-pubsub") 

25IDAES_SOLVE_TOPIC = os.getenv("IDAES_SOLVE_TOPIC", "idaes-solve") 

26IDAES_SOLVE_COMPLETION_TOPIC = os.getenv( 

27 "IDAES_SOLVE_COMPLETION_TOPIC", "idaes-solve-completion" 

28) 

29DISPATCH_MULTI_SOLVE_TOPIC = os.getenv( 

30 "DISPATCH_MULTI_SOLVE_TOPIC", "dispatch-multi-solve" 

31) 

32TASK_RUNNING_TOPIC = os.getenv("TASK_RUNNING_TOPIC", "task-running") 

33ML_TRAINING_TOPIC = os.getenv("ML_TRAINING_TOPIC", "ml-training") 

34ML_TRAINING_COMPLETION_TOPIC = os.getenv( 

35 "ML_TRAINING_COMPLETION_TOPIC", "ml-training-completion" 

36) 

37USER_NOTIFICATION_TOPIC = os.getenv("USER_NOTIFICATION_TOPIC", "user-notification") 

38 

39# RabbitMQ quorum queues support only "normal" and "high" priority messages, 

40# with any priority above 4 being treated as "high" priority. 

41HIGH_PRIORITY = "5" 

42NORMAL_PRIORITY = "0" 

43 

44 

45def _send_message(payload: BaseModel, topic: str, priority: str = NORMAL_PRIORITY): 

46 """Publish a Pydantic payload to the configured Dapr pub/sub topic. 

47 

48 Args: 

49 payload: Pydantic model to serialise and send. 

50 topic: Pub/sub topic that should receive the event. 

51 priority: Optional RabbitMQ priority to tag onto the message metadata. 

52 """ 

53 with DaprClient() as dapr: 

54 dapr.publish_event( 

55 pubsub_name=DAPR_PUBSUB_NAME, # This should be an environment variable 

56 topic_name=topic, # This should be an environment variable 

57 data=payload.model_dump_json(), 

58 data_content_type="application/json", 

59 publish_metadata=_get_dapr_tracing_headers() | {"priority": priority}, 

60 ) 

61 

62 

63def _get_dapr_tracing_headers(): 

64 """Extract W3C trace headers for forwarding to Dapr if a span is active.""" 

65 current_span = trace.get_current_span() 

66 w3c_trace_headers = {} 

67 

68 # If there is an active span, and it is recording, we can pass on 

69 # trace metadata to Dapr for distributed tracing. 

70 if current_span.is_recording(): 

71 TraceContextTextMapPropagator().inject(w3c_trace_headers) 

72 w3c_trace_headers = { 

73 "cloudevent.traceparent": w3c_trace_headers.get("traceparent", None), 

74 "cloudevent.traceid": w3c_trace_headers.get("traceparent", None), 

75 } 

76 

77 return w3c_trace_headers 

78 

79 

80def send_idaes_solve_message( 

81 payload: IdaesSolveRequestPayload, high_priority: bool = False 

82): 

83 """ 

84 Send a solve message to IDAES service asynchronously. A message can be marked as high priority, 

85 meaning it will be processed at a two to one ratio to normal priority solves. 

86 """ 

87 

88 priority = HIGH_PRIORITY if high_priority else NORMAL_PRIORITY 

89 _send_message(payload, IDAES_SOLVE_TOPIC, priority=priority) 

90 

91 

92def send_idaes_solve_completion_message(payload: IdaesSolveCompletionPayload): 

93 """Publish a flowsheet solve completion event emitted by the IDAES solver service.""" 

94 _send_message(payload, IDAES_SOLVE_COMPLETION_TOPIC) 

95 

96 

97def send_dispatch_multi_solve_message(payload: MultiSolvePayload): 

98 """Publish a request instructing a dispatcher to construct and send off multi-solve tasks.""" 

99 _send_message(payload, DISPATCH_MULTI_SOLVE_TOPIC) 

100 

101 

102def send_ml_training_message(payload: MLTrainRequestPayload): 

103 """Send a message to trigger asynchronous ML training for a flowsheet.""" 

104 _send_message(payload, ML_TRAINING_TOPIC) 

105 

106 

107def send_ml_training_completion_message(payload: MLTrainingCompletionEvent): 

108 """Announce that a remote ML training task has completed and provide the results.""" 

109 _send_message(payload, ML_TRAINING_COMPLETION_TOPIC) 

110 

111 

112def send_task_running_message(task_id: int): 

113 """Notify listeners that a task has transitioned into the running state.""" 

114 _send_message(TaskRunningPayload(task_id=task_id), TASK_RUNNING_TOPIC) 

115 

116 

117def send_flowsheet_notification_message( 

118 flowsheet_id: int, message_data: dict, message_type: NotificationServiceMessageType 

119): 

120 """Send a single notification message for a specific flowsheet to clients.""" 

121 message = NotificationServiceMessage(data=message_data, message_type=message_type) 

122 envelope = NotificationServiceEnvelope( 

123 flowsheet_id=flowsheet_id, messages=[message] 

124 ) 

125 

126 _send_message(envelope, USER_NOTIFICATION_TOPIC) 

127 

128 

129def send_flowsheet_notification_messages( 

130 flowsheet_id: int, messages: list[NotificationServiceMessage] 

131): 

132 """Send a batch of notification messages for a specific flowsheet.""" 

133 envelope = NotificationServiceEnvelope(flowsheet_id=flowsheet_id, messages=messages) 

134 _send_message(envelope, USER_NOTIFICATION_TOPIC)