Coverage for backend/django/core/validation.py: 82%

179 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-05-13 02:47 +0000

1import contextvars 

2from contextlib import contextmanager 

3from django.apps import apps 

4from functools import wraps 

5from authentication.user.models import User 

6from core.viewset import ModelViewSet, ReadOnlyModelViewSet 

7from django.urls.resolvers import URLPattern, URLResolver 

8from rest_framework.exceptions import ValidationError 

9from rest_framework.request import Request 

10from rest_framework.viewsets import ModelViewSet as MDVS 

11from typing import TypedDict 

12 

13def sanitize_flowsheet_id(flowsheet_id): 

14 """ 

15 Ensure the client has provided a valid flowsheet id (positive integer) 

16 """ 

17 try: 

18 flowsheet_id = int(flowsheet_id) 

19 if not flowsheet_id: 

20 raise ValidationError("Invalid flowsheet id") 

21 if flowsheet_id < 1: 

22 raise ValidationError("Invalid flowsheet id") 

23 except: 

24 raise ValidationError("Invalid flowsheet id") 

25 

26 return flowsheet_id 

27 

28 

29def api_view_validate(view_func): 

30 """ 

31 Decorator for every api_view to enforce access control 

32 """ 

33 

34 @wraps(view_func) 

35 def _wrapped_view(request: Request, *args, **kwargs): 

36 flowsheet_id = request.GET.get("flowsheet") 

37 user = request.user 

38 flowsheet_id = sanitize_flowsheet_id(flowsheet_id) 

39 

40 with flowsheet_context(flowsheet_id, user): 

41 return view_func(request, *args, **kwargs) 

42 _wrapped_view._is_api_view_validated = True 

43 return _wrapped_view 

44 

45def api_view_ignore_access_control(view_func): 

46 """ 

47 Decorator for api_view to ignore access control. This is not for 

48 general use. Attach only to internal endpoints/handlers (such as 

49 for Dapr to invoke). 

50 """ 

51 

52 @wraps(view_func) 

53 def _wrapped_view(request: Request, *args, **kwargs): 

54 return view_func(request, *args, **kwargs) 

55 

56 _wrapped_view.ignore_access_control = True 

57 return _wrapped_view 

58 

59 

60 

61################## Context manager ################## 

62class FlowsheetContext(TypedDict): 

63 flowsheet: int 

64 user: User 

65 has_access: bool | None # Backward-compatible alias for has_read_access. 

66 has_read_access: bool | None 

67 has_write_access: bool | None 

68 write_intent: bool # Lets manager/queryset code distinguish read vs write views. 

69 bypass_write_checks: bool # Temporary escape hatch for explicitly validated internal flows. 

70 

71flowsheet_ctx = contextvars.ContextVar[FlowsheetContext | None]("flowsheet", default=None) 

72 

73@contextmanager 

74def flowsheet_context(flowsheet: int, user: User, write_intent: bool = False): 

75 """ 

76 Store the active flowsheet, request user, and write-intent flag in a 

77 request-local context so managers/querysets can enforce access control 

78 without every model method needing the request passed through explicitly. 

79 """ 

80 

81 data = FlowsheetContext({ 

82 "flowsheet": flowsheet, 

83 "user": user, 

84 "has_access": None, 

85 "has_read_access": None, 

86 "has_write_access": None, 

87 "write_intent": write_intent, 

88 "bypass_write_checks": False, 

89 }) 

90 token = flowsheet_ctx.set(data) 

91 

92 try: 

93 # Allow the view to execute with the context set 

94 yield 

95 finally: 

96 try: 

97 flowsheet_ctx.reset(token) 

98 except ValueError: 

99 # This is to reset the context if the server throws an error while processing the request 

100 # So that it doesn't leak to the next request 

101 flowsheet_ctx.set(None) 

102 

103def get_current_flowsheet(): 

104 """ 

105 Get the current flowsheet and user id from the context in format 

106 { 

107 "flowsheet": flowsheet_id, 

108 "user": User, 

109 "has_access": has_access 

110 } 

111 """ 

112 return flowsheet_ctx.get() 

113 

114 

115def cache_result(has_access: bool = False): 

116 """ 

117 Cache the result of the flowsheet context 

118 """ 

119 data = flowsheet_ctx.get() 

120 if not data: 

121 return 

122 

123 data["has_access"] = has_access 

124 data["has_read_access"] = has_access 

125 flowsheet_ctx.set(data) 

126 

127 

128def cache_access_result( 

129 *, 

130 has_read_access: bool | None = None, 

131 has_write_access: bool | None = None, 

132): 

133 """ 

134 Cache read/write flowsheet access in request context. 

135 """ 

136 data = flowsheet_ctx.get() 

137 if not data: 137 ↛ 138line 137 didn't jump to line 138 because the condition on line 137 was never true

138 return 

139 

140 if has_read_access is not None: 140 ↛ 144line 140 didn't jump to line 144 because the condition on line 140 was always true

141 data["has_read_access"] = has_read_access 

142 data["has_access"] = has_read_access 

143 

144 if has_write_access is not None: 144 ↛ 147line 144 didn't jump to line 147 because the condition on line 144 was always true

145 data["has_write_access"] = has_write_access 

146 

147 flowsheet_ctx.set(data) 

148 

149 

150@contextmanager 

151def bypass_write_access_checks(): 

152 """ 

153 Temporarily bypass manager/queryset write checks for special internal flows 

154 that already perform their own explicit access validation. 

155 """ 

156 data = flowsheet_ctx.get() 

157 if not data: 157 ↛ 158line 157 didn't jump to line 158 because the condition on line 157 was never true

158 yield 

159 return 

160 

161 previous_value = data.get("bypass_write_checks", False) 

162 data["bypass_write_checks"] = True 

163 flowsheet_ctx.set(data) 

164 

165 try: 

166 yield 

167 finally: 

168 data = flowsheet_ctx.get() or data 

169 data["bypass_write_checks"] = previous_value 

170 flowsheet_ctx.set(data) 

171 

172################## Router and urlpattern validation ################## 

173 

174 

175def validate_router(router): 

176 """ 

177 Validate that every registered viewset is a subclass of ModelViewSet 

178 """ 

179 exclude_list = ["flowsheets", "flowsheetTemplates", "compounds"] 

180 viewsets = [ModelViewSet, ReadOnlyModelViewSet] 

181 for (prefix, viewset, basename) in router.registry: 

182 if prefix in exclude_list: 

183 continue 

184 

185 if not any(issubclass(viewset, vs) for vs in viewsets): 185 ↛ 186line 185 didn't jump to line 186 because the condition on line 185 was never true

186 raise Exception(f"ModelViewSet (from core.viewset) is not being inherited at {prefix}!") 

187 

188 if getattr(viewset, "get_queryset", None) == getattr(MDVS, "get_queryset"): 188 ↛ 189line 188 didn't jump to line 189 because the condition on line 188 was never true

189 raise Exception(f"get_queryset is not being overridden at {prefix}! Please override get_queryset method to provide the queryset") 

190 

191 if getattr(viewset, "queryset", None) != None: 191 ↛ 192line 191 didn't jump to line 192 because the condition on line 191 was never true

192 raise Exception(f"Please remove queryset from {prefix} viewset and create a get_queryset method instead. \ 

193 Avoid creating the queryset attribute to enforce access control.") 

194 

195 

196def validate_urlpatterns(urlpatterns): 

197 """ 

198 Validate that every view in urlpatterns is decorated with api_view_validate 

199 """ 

200 all_views = extract_views_from_urlpatterns(urlpatterns) 

201 for path, view in all_views: 

202 if not hasattr(view, '_is_api_view_validated'): 202 ↛ 203line 202 didn't jump to line 203 because the condition on line 202 was never true

203 raise Exception(f"api_view_validate decorator (from core.validation) is not being used at {path}!!!") 

204 

205 

206def extract_views_from_urlpatterns(urlpatterns, base_path=''): 

207 """ 

208 Recursively extracts views from urlpatterns. 

209 """ 

210 views = [] 

211 for pattern in urlpatterns: 

212 if isinstance(pattern, URLPattern): 212 ↛ 215line 212 didn't jump to line 215 because the condition on line 212 was always true

213 path = base_path + str(pattern.pattern) 

214 views.append((path, pattern.callback)) 

215 elif isinstance(pattern, URLResolver): # nested patterns (like routers) 

216 nested_path = base_path + str(pattern.pattern) 

217 views.extend(extract_views_from_urlpatterns( 

218 pattern.url_patterns, nested_path)) 

219 return views 

220 

221 

222# Check that all models have flowsheet attribute and use AccessControlManager or SoftDeleteManager 

223def validate_models(): 

224 from core.managers import AccessControlManager, SoftDeleteManager 

225 exclude_models = [ 

226 'User', 

227 'Permission', 

228 'Group', 

229 'ContentType', 

230 'Flowsheet', 

231 'AccessTable', 

232 'Session', 

233 'TaskMeta' 

234 ] 

235 

236 models = apps.get_models() 

237 

238 

239 for model in models: 

240 model_name = model.__name__ 

241 module_name = model.__module__ 

242 # Skip silk models 

243 if module_name.startswith('silk'): 243 ↛ 244line 243 didn't jump to line 244 because the condition on line 243 was never true

244 continue 

245 objects = model.objects 

246 if model_name in exclude_models: 

247 continue 

248 

249 if 'flowsheetOwner' in [field.name for field in model._meta.get_fields()]: 249 ↛ 250line 249 didn't jump to line 250 because the condition on line 249 was never true

250 raise ValueError("To enforce access control, `flowsheetOwner` should not be used. Please rename to `flowsheet` instead") 

251 

252 has_flowsheet = False 

253 

254 assert ( 

255 isinstance(objects, (AccessControlManager, SoftDeleteManager)) 

256 ), f"Model {model_name} doesn't have AccessControlManager or SoftDeleteManager (from core.managers). Either is required to enforce access control." 

257 

258 # Check if the object has a flowsheet or flowsheet attribute 

259 if hasattr(model, 'flowsheet'): 259 ↛ 262line 259 didn't jump to line 262 because the condition on line 259 was always true

260 has_flowsheet = True 

261 

262 if hasattr(model, 'flowsheetOwner'): 262 ↛ 263line 262 didn't jump to line 263 because the condition on line 262 was never true

263 raise ValueError("To enforce access control, `flowsheetOwner` should not be used. Please rename to `flowsheet` instead") 

264 

265 assert has_flowsheet, f"Models should have flowsheet or flowsheet attribute for handling access control, model: {model_name} doesn't" 

266 

267def validate_routers(): 

268 import os 

269 import ast 

270 

271 IGNORED_DIRS = {"site-packages"} 

272 IGNORED_FILES = { 

273 os.path.normpath("./authentication/routers.py") 

274 } 

275 

276 def is_ignored(path: str) -> bool: 

277 norm = os.path.normpath(path) 

278 parts = norm.split(os.sep) 

279 

280 # Ignore specific directories 

281 if any(part in IGNORED_DIRS for part in parts): 

282 return True 

283 

284 # Ignore specific file paths 

285 if norm in IGNORED_FILES: 

286 return True 

287 

288 return False 

289 

290 issues = [] 

291 

292 for root, _, files in os.walk("./"): 

293 if is_ignored(root): 

294 continue 

295 

296 for file in files: 

297 if file == "routers.py": 

298 path = os.path.join(root, file) 

299 

300 if is_ignored(path): 

301 continue 

302 

303 with open(path, "r", encoding="utf-8") as f: 

304 source = f.read() 

305 

306 tree = ast.parse(source) 

307 

308 found = False 

309 

310 # inspect AST for validate_router(...) calls 

311 for node in ast.walk(tree): 311 ↛ 327line 311 didn't jump to line 327 because the loop on line 311 didn't complete

312 if isinstance(node, ast.Call): 

313 func = node.func 

314 

315 # Extract function name 

316 if isinstance(func, ast.Name): 

317 name = func.id 

318 elif isinstance(func, ast.Attribute): 318 ↛ 321line 318 didn't jump to line 321 because the condition on line 318 was always true

319 name = func.attr 

320 else: 

321 continue 

322 

323 if name == "validate_router": 

324 found = True 

325 break 

326 

327 if not found: 327 ↛ 328line 327 didn't jump to line 328 because the condition on line 327 was never true

328 issues.append(path) 

329 

330 

331 if issues: 331 ↛ 332line 331 didn't jump to line 332 because the condition on line 331 was never true

332 issue_list = "\n".join(issues) 

333 raise Exception(f"The following routers.py files are missing validate_router(...) calls:\n{issue_list}")