Coverage for backend/notifications/consumers/NotificationsConsumer.py: 70%

40 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1"""Websocket consumer that streams flowsheet notifications to authenticated users.""" 

2 

3from channels.generic.websocket import AsyncWebsocketConsumer 

4from CoreRoot import settings 

5from CoreRoot.helpers import get_asgi_header_value 

6from authentication.user.AccessTable import AccessTable 

7from authentication.user.models import User 

8 

9 

10async def _get_user(username): 

11 """Return the minimal user record needed for websocket authentication.""" 

12 user = await User.objects.only("id").aget(username=username) 

13 

14 return user 

15 

16 

17async def _get_flowsheet_access_entry_for_user(user_id: int, flowsheet_id: int) -> AccessTable | None: 

18 """Look up the user's AccessTable row for the requested flowsheet, if any.""" 

19 

20 return await (AccessTable.objects 

21 .filter(flowsheet_id=flowsheet_id, user_id=user_id) 

22 .afirst() 

23 ) 

24 

25 

26class NotificationsConsumer(AsyncWebsocketConsumer): 

27 """Manage a flowsheet-scoped websocket connection for notification delivery.""" 

28 

29 def __init__(self, *args, **kwargs): 

30 super().__init__(*args, **kwargs) 

31 

32 self.user_id = None 

33 self.flowsheet_id = None 

34 

35 async def connect(self): 

36 """Authenticate the user and join the flowsheet broadcast group.""" 

37 

38 username = get_asgi_header_value(self.scope["headers"], settings.ASGI_REMOTE_USER_HEADER) 

39 self.flowsheet_id = self.scope["query_params"].get("flowsheetId", [None])[0] 

40 

41 if not username: 41 ↛ 42line 41 didn't jump to line 42 because the condition on line 41 was never true

42 await self.close(reason="Unauthorized") 

43 return 

44 

45 if not self.flowsheet_id: 45 ↛ 46line 45 didn't jump to line 46 because the condition on line 45 was never true

46 await self.close(reason="Missing flowsheet_id query parameter") 

47 return 

48 

49 user = await _get_user(username) 

50 flowsheet_access_entry = await _get_flowsheet_access_entry_for_user(user.id, self.flowsheet_id) 

51 self.user_id = user.id 

52 

53 if flowsheet_access_entry is None: 

54 await self.close(reason="Unauthorized") 

55 return 

56 

57 await self.accept() 

58 

59 # Register this socket with the flowsheet-specific broadcast group so 

60 # `broadcast_view` can fan out to every active session for the user. 

61 await self.channel_layer.group_add(f"{self.flowsheet_id}", self.channel_name) 

62 

63 async def receive(self, text_data=None, bytes_data=None): 

64 """Respond to heartbeat pings from the client to keep the socket alive.""" 

65 if text_data == "ping": 

66 await self.send(text_data="pong") 

67 

68 async def disconnect(self, code): 

69 """Remove the socket from the flowsheet broadcast group on close.""" 

70 if self.user_id and self.flowsheet_id: 

71 await self.channel_layer.group_discard(f"{self.flowsheet_id}", self.channel_name) 

72 

73 async def flowsheet_message(self, event): 

74 """Forward a channel-layer broadcast event payload to the client.""" 

75 await self.send(text_data=event["data"])