Coverage for backend/django/idaes_factory/endpoints.py: 69%

209 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-12-18 04:00 +0000

1import datetime 

2import json 

3import os 

4from typing import Any, TypedDict 

5import requests 

6import traceback 

7 

8from django.db import transaction 

9from django.utils import timezone 

10from opentelemetry import trace 

11from rest_framework.response import Response 

12 

13from CoreRoot import settings 

14from authentication.user.models import User 

15from common.models.idaes.payloads import BuildStateRequestSchema 

16from common.models.idaes.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \ 

17 MultiSolvePayload 

18from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload 

19from core.auxiliary.models.MLModel import MLModel 

20from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \ 

21 NotificationServiceMessage 

22from core.auxiliary.enums.generalEnums import TaskStatus 

23from core.auxiliary.models.Task import Task, TaskType 

24from core.auxiliary.serializers import TaskSerializer 

25from core.exceptions import DetailedException 

26from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema 

27from flowsheetInternals.unitops.models.SimulationObject import SimulationObject 

28from .adapters.stream_properties import serialise_stream 

29from .adapters.property_package_adapter import PropertyPackageAdapter 

30from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams 

31from core.auxiliary.models.Scenario import Scenario, ScenarioTabTypeEnum 

32from common.services import messaging 

33 

34tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME) 

35 

36 

37class IdaesServiceRequestException(Exception): 

38 def __init__(self, message: str) -> None: 

39 super().__init__(message) 

40 self.message = message 

41 

42class SolveFlowsheetError(DetailedException): 

43 pass 

44 

45class ResponseType(TypedDict, total=False): 

46 status: str 

47 error: dict | None 

48 log: str | None 

49 debug: dict | None 

50 

51 

52def idaes_service_request(endpoint, data: Any) -> Any: 

53 """Send a JSON payload to the configured IDAES service endpoint. 

54 

55 Args: 

56 endpoint: Relative path of the IDAES service endpoint to call. 

57 data: Serialised payload that conforms to the endpoint schema. 

58 

59 Returns: 

60 Parsed JSON response returned by the IDAES service. 

61 

62 Raises: 

63 IdaesServiceRequestException: If the service responds with a non-200 status. 

64 """ 

65 url = (os.getenv('IDAES_SERVICE_URL') 

66 or "http://localhost:8080") + "/" + endpoint 

67 result = requests.post(url, json=data) 

68 if result.status_code != 200: 

69 raise IdaesServiceRequestException(result.json()) 

70 

71 return result.json() 

72 

73 

74@tracer.start_as_current_span("send_flowsheet_solve_request") 

75def _solve_flowsheet_request( 

76 task_id: int, 

77 built_factory: IdaesFactory, 

78 perform_diagnostics: bool = False, 

79 high_priority: bool = False 

80): 

81 """Queue an IDAES solve request for the provided flowsheet build. 

82 

83 Args: 

84 task_id: Identifier of the task tracking the solve request. 

85 built_factory: Fully built factory containing flowsheet data to solve. 

86 perform_diagnostics: Whether to request diagnostic output from IDAES. 

87 high_priority: Whether the message should be prioritised over normal solves. 

88 

89 Raises: 

90 SolveFlowsheetError: If the message cannot be dispatched to the queue. 

91 """ 

92 

93 try: 

94 idaes_payload = IdaesSolveRequestPayload( 

95 flowsheet=built_factory.flowsheet, 

96 solve_index=built_factory.solve_index, 

97 scenario_id=(built_factory.scenario.id if built_factory.scenario else None), 

98 task_id=task_id, 

99 perform_diagnostics=perform_diagnostics 

100 ) 

101 

102 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority) 

103 except Exception as e: 

104 raise SolveFlowsheetError(e,"idaes_factory_solve_message") 

105 

106def start_flowsheet_solve_event( 

107 flowsheet_id: int, 

108 user: User, 

109 scenario: Scenario = None, 

110 perform_diagnostics: bool = False 

111) -> Response: 

112 """Start a single solve for the given flowsheet and return the tracking task. 

113 

114 Args: 

115 flowsheet_id: Identifier of the flowsheet that should be solved. 

116 user: User initiating the solve request. 

117 scenario: Optional scenario providing context for the solve. 

118 perform_diagnostics: Whether the solve should run with diagnostic output enabled. 

119 

120 Returns: 

121 REST response containing the serialised `Task` used to track the solve. 

122 """ 

123 if scenario and scenario.state_name != ScenarioTabTypeEnum.SteadyState and not scenario.solveStates.exists(): 123 ↛ 125line 123 didn't jump to line 125 because the condition on line 123 was never true

124 # if not steady state solves and no solve states exist, cannot proceed 

125 return Response(status=400, data=f"No data was provided for {scenario.state_name} scenario.") 

126 

127 try: 

128 # Remove all previously created DynamicResults for this scenario 

129 scenario.solutions.all().delete() 

130 except: 

131 # there's no solution yet, that's fine 

132 pass 

133 

134 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True) 

135 

136 try: 

137 factory = IdaesFactory( 

138 flowsheet_id=flowsheet_id, 

139 scenario=scenario 

140 ) 

141 factory.build() 

142 

143 # We send single solve requests as high priority to ensure 

144 # they are not blocked by large multi-solve requests. 

145 _solve_flowsheet_request( 

146 solve_task.id, 

147 factory, 

148 perform_diagnostics=perform_diagnostics, 

149 high_priority=True 

150 ) 

151 except DetailedException as e: 

152 solve_task.set_failure_with_exception(e, save=True) 

153 

154 task_serializer = TaskSerializer(solve_task) 

155 

156 return Response(task_serializer.data, status=200) 

157 

158def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response: 

159 """Kick off a multi steady-state solve and return the parent tracking task. 

160 

161 Args: 

162 flowsheet_id: Identifier of the flowsheet being solved. 

163 user: User who requested the multi-solve. 

164 scenario: Scenario containing the steady-state configurations to solve. 

165 

166 Returns: 

167 REST response containing the parent `Task` that aggregates child solves. 

168 """ 

169 if not scenario.solveStates.exists(): # empty rows 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 return Response(status=400, data="No data was provided for multi steady-state scenario.") 

171 

172 # Remove all previously created DynamicResults for this scenario 

173 scenario.solutions.all().delete() 

174 solve_iterations = scenario.solveStates.count() 

175 

176 parent_task = Task.create_parent_task( 

177 creator=user, 

178 flowsheet_id=flowsheet_id, 

179 scheduled_tasks=solve_iterations, 

180 status=TaskStatus.Running 

181 ) 

182 child_tasks: list[Task] = [ 

183 Task.create( 

184 user, 

185 flowsheet_id, 

186 parent=parent_task, 

187 status=TaskStatus.Pending 

188 ) for _ in range(solve_iterations) 

189 ] 

190 

191 Task.objects.bulk_create(child_tasks) 

192 

193 messaging.send_dispatch_multi_solve_message(MultiSolvePayload( 

194 task_id=parent_task.id, 

195 scenario_id=scenario.id) 

196 ) 

197 

198 return Response(TaskSerializer(parent_task).data, status=200) 

199 

200def dispatch_multi_solves(parent_task_id: int, scenario_id: int): 

201 """Build and dispatch queued steady-state solves for each child task. 

202 

203 Args: 

204 parent_task_id: Identifier of the parent multi-solve task. 

205 scenario_id: Scenario containing the solve states to iterate through. 

206 """ 

207 parent_task = Task.objects.get(id=parent_task_id) 

208 scenario = Scenario.objects.get(id=scenario_id) 

209 

210 factory = IdaesFactory( 

211 flowsheet_id=parent_task.flowsheet_id, 

212 scenario=scenario, 

213 ) 

214 

215 child_tasks = list(parent_task.children.order_by('start_time')) 

216 

217 for solve_index, task in enumerate(child_tasks): 

218 try: 

219 factory.clear_flowsheet() 

220 factory.use_with_solve_index(solve_index) 

221 factory.build() 

222 

223 _solve_flowsheet_request( 

224 task.id, 

225 factory 

226 ) 

227 except DetailedException as e: 

228 task.set_failure_with_exception(exception=e, save=True) 

229 parent_task.update_status_from_child(task) 

230 

231 flowsheet_messages = [ 

232 NotificationServiceMessage( 

233 data=TaskSerializer(task).data, 

234 message_type=NotificationServiceMessageType.TASK_UPDATED 

235 ) for task in [task, parent_task] 

236 ] 

237 

238 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages) 

239 

240 

241def start_ml_training_event( 

242 datapoints: list, 

243 columns: list[str], 

244 input_labels: list[str], 

245 output_labels: list[str], 

246 user: User, 

247 flowsheet_id: int, 

248): 

249 """Queue an asynchronous machine-learning training job for the given dataset. 

250 

251 Args: 

252 datapoints: Training rows to supply to the ML service. 

253 columns: Column names describing the datapoint ordering. 

254 input_labels: Names of the input features. 

255 output_labels: Names of the predicted outputs. 

256 user: User requesting the training run. 

257 flowsheet_id: Flowsheet the training run is associated with. 

258 

259 Returns: 

260 REST response containing the serialised `Task` for the training job. 

261 """ 

262 training_task = Task.create( 

263 user, 

264 flowsheet_id, 

265 task_type=TaskType.ML_TRAINING, 

266 status=TaskStatus.Pending, 

267 save=True 

268 ) 

269 

270 try: 

271 payload = MLTrainRequestPayload( 

272 datapoints=datapoints, 

273 columns=columns, 

274 input_labels=input_labels, 

275 output_labels=output_labels, 

276 task_id=training_task.id, 

277 ) 

278 messaging.send_ml_training_message(payload) 

279 

280 except DetailedException as e: 

281 training_task.set_failure_with_exception(e, save=True) 

282 

283 task_serializer = TaskSerializer(training_task) 

284 return Response(task_serializer.data, status=200) 

285 

286def _send_task_notifications(task: Task): 

287 """Broadcast task completion or status updates to interested flowsheet clients. 

288 

289 Args: 

290 task: Task whose status change should be pushed to subscribers. 

291 """ 

292 flowsheet_messages = [] 

293 

294 # If this is a child task, update the parent task status 

295 if task.parent: 

296 task.parent.update_status_from_child(task) 

297 

298 message_type = (NotificationServiceMessageType.TASK_COMPLETED 

299 if task.parent.status is TaskStatus.Completed else 

300 NotificationServiceMessageType.TASK_UPDATED) 

301 

302 flowsheet_messages.append(NotificationServiceMessage( 

303 data=TaskSerializer(task.parent).data, 

304 message_type=message_type 

305 )) 

306 

307 flowsheet_messages.append(NotificationServiceMessage( 

308 data=TaskSerializer(task).data, 

309 message_type=NotificationServiceMessageType.TASK_COMPLETED 

310 )) 

311 

312 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages) 

313 

314def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

315 """Persist the outcome of a completed IDAES solve and notify listeners. 

316 

317 Args: 

318 solve_response: Payload describing the finished solve result. 

319 """ 

320 # Use a transaction to ensure that either everything succeeds or nothing does 

321 with transaction.atomic(): 

322 

323 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

324 

325 # Silently ignore if the task has already been marked as completed. 

326 # This allows us to simulate exactly-once delivery semantics (only process 

327 # a finished task once). 

328 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 328 ↛ 329line 328 didn't jump to line 329 because the condition on line 328 was never true

329 return 

330 

331 task.completed_time = timezone.now() 

332 task.log = solve_response.log 

333 

334 if solve_response.status == "success": 334 ↛ 340line 334 didn't jump to line 340 because the condition on line 334 was always true

335 task.status = TaskStatus.Completed 

336 task.debug = { 

337 "timing": solve_response.timing 

338 } 

339 else: 

340 task.status = TaskStatus.Failed 

341 task.error = { 

342 "message": solve_response.error["message"], 

343 "cause": "idaes_service_request", 

344 "traceback": solve_response.traceback 

345 } 

346 

347 task.save(update_fields=["status", "completed_time", "log", "debug", "error"]) 

348 

349 # Save the solved flowsheet values 

350 if task.status == TaskStatus.Completed: 350 ↛ 358line 350 didn't jump to line 358 because the condition on line 350 was always true

351 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index) 

352 

353 # For now, only save initial values for single and dynamic solves, not MSS 

354 # In future, we may need some more complex logic to handle MSS initial values 

355 if solve_response.solve_index is None: 

356 save_all_initial_values(solve_response.flowsheet.initial_values) 

357 

358 _send_task_notifications(task) 

359 

360def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

361 """Handle final failure notifications for solves that could not be processed. 

362 

363 Args: 

364 solve_response: Completion payload received from the dead-letter queue. 

365 """ 

366 # Use a transaction to ensure that either everything succeeds or nothing does 

367 with transaction.atomic(): 

368 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

369 

370 # Silently ignore if the task has already been marked as failed. 

371 # This allows us to simulate exactly-once delivery semantics (only process 

372 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode. 

373 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled: 

374 return 

375 

376 task.completed_time = timezone.now() 

377 task.log = solve_response.log 

378 task.status = TaskStatus.Failed 

379 task.error = { 

380 "message": "Internal server error: several attempts to process finished solve failed." 

381 } 

382 task.save() 

383 

384 _send_task_notifications(task) 

385 

386def process_ml_training_response( 

387 ml_training_response: MLTrainingCompletionPayload 

388): 

389 """Persist the result of a machine-learning training job and send updates. 

390 

391 Args: 

392 ml_training_response: Completion payload returned by the ML service. 

393 """ 

394 with transaction.atomic(): 

395 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id) 

396 

397 # Silently ignore if the task has already been marked as completed. 

398 # This allows us to simulate exactly-once delivery semantics (only process 

399 # a finished task once). 

400 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 400 ↛ 401line 400 didn't jump to line 401 because the condition on line 400 was never true

401 return 

402 

403 task.completed_time = timezone.now() 

404 task.log = ml_training_response.log 

405 

406 if ml_training_response.status == "success": 406 ↛ 412line 406 didn't jump to line 412 because the condition on line 406 was always true

407 task.status = TaskStatus.Completed 

408 task.debug = { 

409 "timing": ml_training_response.json_response.get("timing", {}) 

410 } 

411 else: 

412 task.status = TaskStatus.Failed 

413 task.error = { 

414 "message": ml_training_response.error, 

415 "traceback": ml_training_response.traceback 

416 } 

417 

418 task.save() 

419 result = ml_training_response.json_response 

420 

421 MLModel.objects.update( 

422 surrogate_model=result.get("surrogate_model"), 

423 charts=result.get("charts"), 

424 metrics=result.get("metrics"), 

425 test_inputs=result.get("test_inputs"), 

426 test_outputs=result.get("test_outputs"), 

427 progress=3 

428 ) 

429 _send_task_notifications(task) 

430 

431 

432def cancel_idaes_solve(task_id: int): 

433 """Mark an in-flight solve task as cancelled and notify subscribers. 

434 

435 Args: 

436 task_id: Identifier of the `Task` being cancelled. 

437 """ 

438 with transaction.atomic(): 

439 task = Task.objects.select_related("parent").get(id=task_id) 

440 

441 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set 

442 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending: 

443 return 

444 

445 task.status = TaskStatus.Cancelled 

446 task.save() 

447 

448 messaging.send_flowsheet_notification_message( 

449 task.flowsheet_id, 

450 TaskSerializer(task).data, 

451 NotificationServiceMessageType.TASK_CANCELLED 

452 ) 

453 

454def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response: 

455 """Generate Python IDAES code for a flowsheet by calling the IDAES service. 

456 

457 Args: 

458 flowsheet_id: Identifier of the flowsheet to translate to Python code. 

459 return_json: Whether to bypass the remote call and return raw JSON. 

460 

461 Returns: 

462 REST response containing either the generated code or the flowsheet JSON. 

463 """ 

464 scenario = None 

465 factory = IdaesFactory(flowsheet_id, scenario=scenario, require_variables_fixed=False) 

466 response_data = ResponseType( 

467 status="success", 

468 error=None, 

469 log=None, 

470 debug=None 

471 ) 

472 try: 

473 factory.build() 

474 data = factory.flowsheet 

475 if return_json: 

476 return Response(data.model_dump_json(), status=200) 

477 except Exception as e: 

478 response_data["status"] = "error" 

479 response_data["error"] = { 

480 "message": str(e), 

481 "traceback": traceback.format_exc() 

482 } 

483 return Response(response_data, status=400) 

484 try: 

485 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump()) 

486 return Response(response, status=200) 

487 except IdaesServiceRequestException as e: 

488 response = e.message 

489 response_data["status"] = "error" 

490 response_data["error"] = { 

491 "message": response["error"], 

492 "traceback": response["traceback"] 

493 } 

494 return Response(response_data, status=400) 

495 

496class BuildStateSolveError(Exception): 

497 pass 

498 

499def build_state_request(stream: SimulationObject): 

500 """Request a state build for the provided stream using the IDAES service. 

501 

502 Args: 

503 stream: Stream object whose inlet properties should be used for the build. 

504 

505 Returns: 

506 REST response returned by the IDAES service containing built properties. 

507 

508 Raises: 

509 BuildStateSolveError: If the IDAES service rejects the build request. 

510 Exception: If preparing the payload fails for any reason. 

511 """ 

512 ctx = IdaesFactoryContext(stream.flowsheet_id) 

513 

514 try: 

515 port = stream.connectedPorts.get(direction="inlet") 

516 unitop: SimulationObject = port.unitOp 

517 # find the property package key for this port (we have the value, the key of this port) 

518 property_package_ports = unitop.schema.propertyPackagePorts 

519 for key, port_list in property_package_ports.items(): 

520 if port.key in port_list: 520 ↛ 519line 520 didn't jump to line 519 because the condition on line 520 was always true

521 property_package_key = key 

522 PropertyPackageAdapter( 

523 property_package_key).serialise(ctx, unitop) 

524 

525 data = BuildStateRequestSchema( 

526 property_package=ctx.property_packages[0], 

527 properties=serialise_stream(ctx, stream, is_inlet=True) 

528 ) 

529 except Exception as e: 

530 raise Exception(e) 

531 try: 

532 response = idaes_service_request(endpoint="build_state", data=data.model_dump()) 

533 store_properties_schema(response["properties"], stream.flowsheet_id) 

534 

535 return Response(response, status=200) 

536 except IdaesServiceRequestException as e: 

537 raise BuildStateSolveError(e.message) 

538 except Exception as e: 

539 raise Exception(e)