Coverage for backend/notifications/consumers/NotificationsConsumer.py: 70%
40 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
1"""Websocket consumer that streams flowsheet notifications to authenticated users."""
3from channels.generic.websocket import AsyncWebsocketConsumer
4from CoreRoot import settings
5from CoreRoot.helpers import get_asgi_header_value
6from authentication.user.AccessTable import AccessTable
7from authentication.user.models import User
10async def _get_user(username):
11 """Return the minimal user record needed for websocket authentication."""
12 user = await User.objects.only("id").aget(username=username)
14 return user
17async def _get_flowsheet_access_entry_for_user(user_id: int, flowsheet_id: int) -> AccessTable | None:
18 """Look up the user's AccessTable row for the requested flowsheet, if any."""
20 return await (AccessTable.objects
21 .filter(flowsheet_id=flowsheet_id, user_id=user_id)
22 .afirst()
23 )
26class NotificationsConsumer(AsyncWebsocketConsumer):
27 """Manage a flowsheet-scoped websocket connection for notification delivery."""
29 def __init__(self, *args, **kwargs):
30 super().__init__(*args, **kwargs)
32 self.user_id = None
33 self.flowsheet_id = None
35 async def connect(self):
36 """Authenticate the user and join the flowsheet broadcast group."""
38 username = get_asgi_header_value(self.scope["headers"], settings.ASGI_REMOTE_USER_HEADER)
39 self.flowsheet_id = self.scope["query_params"].get("flowsheetId", [None])[0]
41 if not username: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true
42 await self.close(reason="Unauthorized")
43 return
45 if not self.flowsheet_id: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true
46 await self.close(reason="Missing flowsheet_id query parameter")
47 return
49 user = await _get_user(username)
50 flowsheet_access_entry = await _get_flowsheet_access_entry_for_user(user.id, self.flowsheet_id)
51 self.user_id = user.id
53 if flowsheet_access_entry is None:
54 await self.close(reason="Unauthorized")
55 return
57 await self.accept()
59 # Register this socket with the flowsheet-specific broadcast group so
60 # `broadcast_view` can fan out to every active session for the user.
61 await self.channel_layer.group_add(f"{self.flowsheet_id}", self.channel_name)
63 async def receive(self, text_data=None, bytes_data=None):
64 """Respond to heartbeat pings from the client to keep the socket alive."""
65 if text_data == "ping":
66 await self.send(text_data="pong")
68 async def disconnect(self, code):
69 """Remove the socket from the flowsheet broadcast group on close."""
70 if self.user_id and self.flowsheet_id:
71 await self.channel_layer.group_discard(f"{self.flowsheet_id}", self.channel_name)
73 async def flowsheet_message(self, event):
74 """Forward a channel-layer broadcast event payload to the client."""
75 await self.send(text_data=event["data"])