Coverage for backend/common/services/messaging.py: 81%

50 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1import os 

2 

3from dapr.clients import DaprClient 

4from opentelemetry import trace 

5from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator 

6from pydantic import BaseModel 

7 

8from common.models.general import TaskRunningPayload 

9from common.models.idaes.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \ 

10 MultiSolvePayload 

11from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionEvent 

12from common.models.notifications import NotificationServiceEnvelope 

13from common.models.notifications.payloads import NotificationServiceMessage, NotificationServiceMessageType 

14 

15DAPR_PUBSUB_NAME = os.getenv("DAPR_PUBSUB_NAME", "rabbitmq-pubsub") 

16IDAES_SOLVE_TOPIC = os.getenv("IDAES_SOLVE_TOPIC", "idaes-solve") 

17IDAES_SOLVE_COMPLETION_TOPIC = os.getenv("IDAES_SOLVE_COMPLETION_TOPIC", "idaes-solve-completion") 

18DISPATCH_MULTI_SOLVE_TOPIC = os.getenv("DISPATCH_MULTI_SOLVE_TOPIC", "dispatch-multi-solve") 

19TASK_RUNNING_TOPIC = os.getenv("TASK_RUNNING_TOPIC", "task-running") 

20ML_TRAINING_TOPIC = os.getenv("ML_TRAINING_TOPIC", "ml-training") 

21ML_TRAINING_COMPLETION_TOPIC = os.getenv("ML_TRAINING_COMPLETION_TOPIC", "ml-training-completion") 

22USER_NOTIFICATION_TOPIC = os.getenv("USER_NOTIFICATION_TOPIC", "user-notification") 

23 

24# RabbitMQ quorum queues support only "normal" and "high" priority messages, 

25# with any priority above 4 being treated as "high" priority. 

26HIGH_PRIORITY = "5" 

27NORMAL_PRIORITY = "0" 

28 

29def _send_message(payload: BaseModel, topic: str, priority: str = NORMAL_PRIORITY): 

30 """Publish a Pydantic payload to the configured Dapr pub/sub topic. 

31 

32 Args: 

33 payload: Pydantic model to serialise and send. 

34 topic: Pub/sub topic that should receive the event. 

35 priority: Optional RabbitMQ priority to tag onto the message metadata. 

36 """ 

37 with DaprClient() as dapr: 

38 dapr.publish_event( 

39 pubsub_name=DAPR_PUBSUB_NAME, # This should be an environment variable 

40 topic_name=topic, # This should be an environment variable 

41 data=payload.model_dump_json(), 

42 data_content_type='application/json', 

43 publish_metadata=_get_dapr_tracing_headers() | { 'priority': priority }, 

44 ) 

45 

46def _get_dapr_tracing_headers(): 

47 """Extract W3C trace headers for forwarding to Dapr if a span is active.""" 

48 current_span = trace.get_current_span() 

49 w3c_trace_headers = {} 

50 

51 # If there is an active span, and it is recording, we can pass on 

52 # trace metadata to Dapr for distributed tracing. 

53 if current_span.is_recording(): 53 ↛ 60line 53 didn't jump to line 60 because the condition on line 53 was always true

54 TraceContextTextMapPropagator().inject(w3c_trace_headers) 

55 w3c_trace_headers = { 

56 'cloudevent.traceparent': w3c_trace_headers.get('traceparent', None), 

57 'cloudevent.traceid': w3c_trace_headers.get('traceparent', None) 

58 } 

59 

60 return w3c_trace_headers 

61 

62def send_idaes_solve_message(payload: IdaesSolveRequestPayload, high_priority: bool = False): 

63 """ 

64 Send a solve message to IDAES service asynchronously. A message can be marked as high priority, 

65 meaning it will be processed at a two to one ratio to normal priority solves. 

66 """ 

67 

68 priority = HIGH_PRIORITY if high_priority else NORMAL_PRIORITY 

69 _send_message(payload, IDAES_SOLVE_TOPIC, priority=priority) 

70 

71def send_idaes_solve_completion_message(payload: IdaesSolveCompletionPayload): 

72 """Publish a flowsheet solve completion event emitted by the IDAES solver service.""" 

73 _send_message(payload, IDAES_SOLVE_COMPLETION_TOPIC) 

74 

75def send_dispatch_multi_solve_message(payload: MultiSolvePayload): 

76 """Publish a request instructing a dispatcher to construct and send off multi-solve tasks.""" 

77 _send_message(payload, DISPATCH_MULTI_SOLVE_TOPIC) 

78 

79def send_ml_training_message(payload: MLTrainRequestPayload): 

80 """Send a message to trigger asynchronous ML training for a flowsheet.""" 

81 _send_message(payload, ML_TRAINING_TOPIC) 

82 

83def send_ml_training_completion_message(payload: MLTrainingCompletionEvent): 

84 """Announce that a remote ML training task has completed and provide the results.""" 

85 _send_message(payload, ML_TRAINING_COMPLETION_TOPIC) 

86 

87def send_task_running_message(task_id: int): 

88 """Notify listeners that a task has transitioned into the running state.""" 

89 _send_message(TaskRunningPayload(task_id=task_id), TASK_RUNNING_TOPIC) 

90 

91def send_flowsheet_notification_message(flowsheet_id: int, message_data: dict, message_type: NotificationServiceMessageType): 

92 """Send a single notification message for a specific flowsheet to clients.""" 

93 message = NotificationServiceMessage(data=message_data, message_type=message_type) 

94 envelope = NotificationServiceEnvelope(flowsheet_id=flowsheet_id, messages=[message]) 

95 

96 _send_message(envelope, USER_NOTIFICATION_TOPIC) 

97 

98def send_flowsheet_notification_messages(flowsheet_id: int, messages: list[NotificationServiceMessage]): 

99 """Send a batch of notification messages for a specific flowsheet.""" 

100 envelope = NotificationServiceEnvelope(flowsheet_id=flowsheet_id, messages=messages) 

101 _send_message(envelope, USER_NOTIFICATION_TOPIC)