Coverage for backend/core/validation.py: 84%

78 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1import contextvars 

2from contextlib import contextmanager 

3from functools import wraps 

4 

5from authentication.user.models import User 

6from core.viewset import ModelViewSet, ReadOnlyModelViewSet 

7from django.urls.resolvers import URLPattern, URLResolver 

8from rest_framework.exceptions import ValidationError 

9from rest_framework.request import Request 

10from rest_framework.viewsets import ModelViewSet as MDVS 

11 

12 

13def sanitize_flowsheet_id(flowsheet_id): 

14 """ 

15 Ensure the client has provided a valid flowsheet id (positive integer) 

16 """ 

17 try: 

18 flowsheet_id = int(flowsheet_id) 

19 if not flowsheet_id: 

20 raise ValidationError("Invalid flowsheet id") 

21 if flowsheet_id < 1: 

22 raise ValidationError("Invalid flowsheet id") 

23 except: 

24 raise ValidationError("Invalid flowsheet id") 

25 

26 

27def api_view_validate(view_func): 

28 """ 

29 Decorator for every api_view to enforce access control 

30 """ 

31 

32 @wraps(view_func) 

33 def _wrapped_view(request: Request, *args, **kwargs): 

34 flowsheet_id = request.GET.get("flowsheet") 

35 user = request.user 

36 sanitize_flowsheet_id(flowsheet_id) 

37 

38 with flowsheet_context(flowsheet_id, user): 

39 return view_func(request, *args, **kwargs) 

40 _wrapped_view._is_api_view_validated = True 

41 return _wrapped_view 

42 

43def api_view_ignore_access_control(view_func): 

44 """ 

45 Decorator for api_view to ignore access control. This is not for 

46 general use. Attach only to internal endpoints/handlers (such as 

47 for Dapr to invoke). 

48 """ 

49 

50 @wraps(view_func) 

51 def _wrapped_view(request: Request, *args, **kwargs): 

52 return view_func(request, *args, **kwargs) 

53 

54 _wrapped_view.ignore_access_control = True 

55 return _wrapped_view 

56 

57################## Context manager ################## 

58flowsheet_ctx = contextvars.ContextVar("flowsheet", default=None) 

59 

60@contextmanager 

61def flowsheet_context(flowsheet: int, user: User): 

62 """ 

63 Sanitize and inject flowsheet and user id as context for each view 

64 """ 

65 

66 data = { 

67 "flowsheet": flowsheet, 

68 "user": user, 

69 "has_access": None 

70 } 

71 token = flowsheet_ctx.set(data) 

72 

73 try: 

74 # Allow the view to execute with the context set 

75 yield 

76 finally: 

77 try: 

78 flowsheet_ctx.reset(token) 

79 except ValueError: 

80 # This is to reset the context if the server throws an error while processing the request 

81 # So that it doesn't leak to the next request 

82 flowsheet_ctx.set(None) 

83 

84def get_current_flowsheet(): 

85 """ 

86 Get the current flowsheet and user id from the context in format 

87 { 

88 "flowsheet": flowsheet_id, 

89 "user": User, 

90 "has_access": has_access 

91 } 

92 """ 

93 return flowsheet_ctx.get() 

94 

95 

96def cache_result(has_access: bool = False): 

97 """ 

98 Cache the result of the flowsheet context 

99 """ 

100 data = flowsheet_ctx.get() 

101 data["has_access"] = has_access 

102 flowsheet_ctx.set(data) 

103 

104################## Router and urlpattern validation ################## 

105 

106 

107def validate_router(router): 

108 """ 

109 Validate that every registered viewset is a subclass of ModelViewSet 

110 """ 

111 exclude_list = ["flowsheets", "flowsheetTemplates", "compounds"] 

112 viewsets = [ModelViewSet, ReadOnlyModelViewSet] 

113 for (prefix, viewset, basename) in router.registry: 

114 if prefix in exclude_list: 

115 continue 

116 

117 if not any(issubclass(viewset, vs) for vs in viewsets): 117 ↛ 118line 117 didn't jump to line 118 because the condition on line 117 was never true

118 raise Exception(f"ModelViewSet (from core.viewset) is not being inherited at {prefix}!") 

119 

120 if getattr(viewset, "get_queryset", None) == getattr(MDVS, "get_queryset"): 120 ↛ 121line 120 didn't jump to line 121 because the condition on line 120 was never true

121 raise Exception(f"get_queryset is not being overridden at {prefix}! Please override get_queryset method to provide the queryset") 

122 

123 if getattr(viewset, "queryset", None) != None: 123 ↛ 124line 123 didn't jump to line 124 because the condition on line 123 was never true

124 raise Exception(f"Please remove queryset from {prefix} viewset and create a get_queryset method instead. \ 

125 Avoid creating the queryset attribute to enforce access control.") 

126 

127 

128def validate_urlpatterns(urlpatterns): 

129 """ 

130 Validate that every view in urlpatterns is decorated with api_view_validate 

131 """ 

132 all_views = extract_views_from_urlpatterns(urlpatterns) 

133 for path, view in all_views: 

134 if not hasattr(view, '_is_api_view_validated'): 134 ↛ 135line 134 didn't jump to line 135 because the condition on line 134 was never true

135 raise Exception(f"api_view_validate decorator (from core.validation) is not being used at {path}!!!") 

136 

137 

138def extract_views_from_urlpatterns(urlpatterns, base_path=''): 

139 """ 

140 Recursively extracts views from urlpatterns. 

141 """ 

142 views = [] 

143 for pattern in urlpatterns: 

144 if isinstance(pattern, URLPattern): 144 ↛ 147line 144 didn't jump to line 147 because the condition on line 144 was always true

145 path = base_path + str(pattern.pattern) 

146 views.append((path, pattern.callback)) 

147 elif isinstance(pattern, URLResolver): # nested patterns (like routers) 

148 nested_path = base_path + str(pattern.pattern) 

149 views.extend(extract_views_from_urlpatterns( 

150 pattern.url_patterns, nested_path)) 

151 return views