Coverage for backend/idaes_factory/endpoints.py: 70%

205 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1import datetime 

2import json 

3import os 

4from typing import Any, TypedDict 

5import requests 

6import traceback 

7 

8from django.db import transaction 

9from django.utils import timezone 

10from opentelemetry import trace 

11from rest_framework.response import Response 

12 

13from CoreRoot import settings 

14from authentication.user.models import User 

15from common.models.idaes.payloads import BuildStateRequestSchema 

16from common.models.idaes.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \ 

17 MultiSolvePayload 

18from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload 

19from core.auxiliary.models.MLModel import MLModel 

20from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \ 

21 NotificationServiceMessage 

22from core.auxiliary.enums.generalEnums import TaskStatus 

23from core.auxiliary.models.Task import Task, TaskType 

24from core.auxiliary.serializers import TaskSerializer 

25from core.exceptions import DetailedException 

26from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema 

27from flowsheetInternals.unitops.models.SimulationObject import SimulationObject 

28from .adapters.stream_properties import serialise_stream 

29from .adapters.property_package_adapter import PropertyPackageAdapter 

30from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams 

31from core.auxiliary.models.Scenario import Scenario 

32from common.services import messaging 

33 

34tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME) 

35 

36 

37class IdaesServiceRequestException(Exception): 

38 def __init__(self, message: str) -> None: 

39 super().__init__(message) 

40 self.message = message 

41 

42class SolveFlowsheetError(DetailedException): 

43 pass 

44 

45class ResponseType(TypedDict, total=False): 

46 status: str 

47 error: dict | None 

48 log: str | None 

49 debug: dict | None 

50 

51 

52def idaes_service_request(endpoint, data: Any) -> Any: 

53 """Send a JSON payload to the configured IDAES service endpoint. 

54 

55 Args: 

56 endpoint: Relative path of the IDAES service endpoint to call. 

57 data: Serialised payload that conforms to the endpoint schema. 

58 

59 Returns: 

60 Parsed JSON response returned by the IDAES service. 

61 

62 Raises: 

63 IdaesServiceRequestException: If the service responds with a non-200 status. 

64 """ 

65 url = (os.getenv('IDAES_SERVICE_URL') 

66 or "http://localhost:8080") + "/" + endpoint 

67 result = requests.post(url, json=data) 

68 if result.status_code != 200: 

69 raise IdaesServiceRequestException(result.json()) 

70 

71 return result.json() 

72 

73 

74@tracer.start_as_current_span("send_flowsheet_solve_request") 

75def _solve_flowsheet_request( 

76 task_id: int, 

77 built_factory: IdaesFactory, 

78 perform_diagnostics: bool = False, 

79 high_priority: bool = False 

80): 

81 """Queue an IDAES solve request for the provided flowsheet build. 

82 

83 Args: 

84 task_id: Identifier of the task tracking the solve request. 

85 built_factory: Fully built factory containing flowsheet data to solve. 

86 perform_diagnostics: Whether to request diagnostic output from IDAES. 

87 high_priority: Whether the message should be prioritised over normal solves. 

88 

89 Raises: 

90 SolveFlowsheetError: If the message cannot be dispatched to the queue. 

91 """ 

92 

93 try: 

94 idaes_payload = IdaesSolveRequestPayload( 

95 flowsheet=built_factory.flowsheet, 

96 solve_index=built_factory.solve_index, 

97 scenario_id=(built_factory.scenario.id if built_factory.scenario else None), 

98 task_id=task_id, 

99 perform_diagnostics=perform_diagnostics 

100 ) 

101 

102 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority) 

103 except Exception as e: 

104 raise SolveFlowsheetError(e,"idaes_factory_solve_message") 

105 

106def start_flowsheet_solve_event( 

107 flowsheet_id: int, 

108 user: User, 

109 scenario: Scenario = None, 

110 perform_diagnostics: bool = False 

111) -> Response: 

112 """Start a single solve for the given flowsheet and return the tracking task. 

113 

114 Args: 

115 flowsheet_id: Identifier of the flowsheet that should be solved. 

116 user: User initiating the solve request. 

117 scenario: Optional scenario providing context for the solve. 

118 perform_diagnostics: Whether the solve should run with diagnostic output enabled. 

119 

120 Returns: 

121 REST response containing the serialised `Task` used to track the solve. 

122 """ 

123 

124 try: 

125 # Remove all previously created DynamicResults for this scenario 

126 scenario.solutions.all().delete() 

127 except: 

128 # there's no solution yet, that's fine 

129 pass 

130 

131 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True) 

132 

133 try: 

134 factory = IdaesFactory( 

135 flowsheet_id=flowsheet_id, 

136 scenario=scenario 

137 ) 

138 factory.build() 

139 

140 # We send single solve requests as high priority to ensure 

141 # they are not blocked by large multi-solve requests. 

142 _solve_flowsheet_request( 

143 solve_task.id, 

144 factory, 

145 perform_diagnostics=perform_diagnostics, 

146 high_priority=True 

147 ) 

148 except DetailedException as e: 

149 solve_task.set_failure_with_exception(e, save=True) 

150 

151 task_serializer = TaskSerializer(solve_task) 

152 

153 return Response(task_serializer.data, status=200) 

154 

155def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response: 

156 """Kick off a multi steady-state solve and return the parent tracking task. 

157 

158 Args: 

159 flowsheet_id: Identifier of the flowsheet being solved. 

160 user: User who requested the multi-solve. 

161 scenario: Scenario containing the steady-state configurations to solve. 

162 

163 Returns: 

164 REST response containing the parent `Task` that aggregates child solves. 

165 """ 

166 

167 # Remove all previously created DynamicResults for this scenario 

168 scenario.solutions.all().delete() 

169 solve_iterations = scenario.solveStates.count() 

170 

171 parent_task = Task.create_parent_task( 

172 creator=user, 

173 flowsheet_id=flowsheet_id, 

174 scheduled_tasks=solve_iterations, 

175 status=TaskStatus.Running 

176 ) 

177 child_tasks: list[Task] = [ 

178 Task.create( 

179 user, 

180 flowsheet_id, 

181 parent=parent_task, 

182 status=TaskStatus.Pending 

183 ) for _ in range(solve_iterations) 

184 ] 

185 

186 Task.objects.bulk_create(child_tasks) 

187 

188 messaging.send_dispatch_multi_solve_message(MultiSolvePayload( 

189 task_id=parent_task.id, 

190 scenario_id=scenario.id) 

191 ) 

192 

193 return Response(TaskSerializer(parent_task).data, status=200) 

194 

195def dispatch_multi_solves(parent_task_id: int, scenario_id: int): 

196 """Build and dispatch queued steady-state solves for each child task. 

197 

198 Args: 

199 parent_task_id: Identifier of the parent multi-solve task. 

200 scenario_id: Scenario containing the solve states to iterate through. 

201 """ 

202 parent_task = Task.objects.get(id=parent_task_id) 

203 scenario = Scenario.objects.get(id=scenario_id) 

204 

205 factory = IdaesFactory( 

206 flowsheet_id=parent_task.flowsheet_id, 

207 scenario=scenario, 

208 ) 

209 

210 child_tasks = list(parent_task.children.order_by('start_time')) 

211 

212 for solve_index, task in enumerate(child_tasks): 

213 try: 

214 factory.clear_flowsheet() 

215 factory.use_with_solve_index(solve_index) 

216 factory.build() 

217 

218 _solve_flowsheet_request( 

219 task.id, 

220 factory 

221 ) 

222 except DetailedException as e: 

223 task.set_failure_with_exception(exception=e, save=True) 

224 parent_task.update_status_from_child(task) 

225 

226 flowsheet_messages = [ 

227 NotificationServiceMessage( 

228 data=TaskSerializer(task).data, 

229 message_type=NotificationServiceMessageType.TASK_UPDATED 

230 ) for task in [task, parent_task] 

231 ] 

232 

233 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages) 

234 

235 

236def start_ml_training_event( 

237 datapoints: list, 

238 columns: list[str], 

239 input_labels: list[str], 

240 output_labels: list[str], 

241 user: User, 

242 flowsheet_id: int, 

243): 

244 """Queue an asynchronous machine-learning training job for the given dataset. 

245 

246 Args: 

247 datapoints: Training rows to supply to the ML service. 

248 columns: Column names describing the datapoint ordering. 

249 input_labels: Names of the input features. 

250 output_labels: Names of the predicted outputs. 

251 user: User requesting the training run. 

252 flowsheet_id: Flowsheet the training run is associated with. 

253 

254 Returns: 

255 REST response containing the serialised `Task` for the training job. 

256 """ 

257 training_task = Task.create( 

258 user, 

259 flowsheet_id, 

260 task_type=TaskType.ML_TRAINING, 

261 status=TaskStatus.Pending, 

262 save=True 

263 ) 

264 

265 try: 

266 payload = MLTrainRequestPayload( 

267 datapoints=datapoints, 

268 columns=columns, 

269 input_labels=input_labels, 

270 output_labels=output_labels, 

271 task_id=training_task.id, 

272 ) 

273 messaging.send_ml_training_message(payload) 

274 

275 except DetailedException as e: 

276 training_task.set_failure_with_exception(e, save=True) 

277 

278 task_serializer = TaskSerializer(training_task) 

279 return Response(task_serializer.data, status=200) 

280 

281def _send_task_notifications(task: Task): 

282 """Broadcast task completion or status updates to interested flowsheet clients. 

283 

284 Args: 

285 task: Task whose status change should be pushed to subscribers. 

286 """ 

287 flowsheet_messages = [] 

288 

289 # If this is a child task, update the parent task status 

290 if task.parent: 

291 task.parent.update_status_from_child(task) 

292 

293 message_type = (NotificationServiceMessageType.TASK_COMPLETED 

294 if task.parent.status is TaskStatus.Completed else 

295 NotificationServiceMessageType.TASK_UPDATED) 

296 

297 flowsheet_messages.append(NotificationServiceMessage( 

298 data=TaskSerializer(task.parent).data, 

299 message_type=message_type 

300 )) 

301 

302 flowsheet_messages.append(NotificationServiceMessage( 

303 data=TaskSerializer(task).data, 

304 message_type=NotificationServiceMessageType.TASK_COMPLETED 

305 )) 

306 

307 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages) 

308 

309def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

310 """Persist the outcome of a completed IDAES solve and notify listeners. 

311 

312 Args: 

313 solve_response: Payload describing the finished solve result. 

314 """ 

315 # Use a transaction to ensure that either everything succeeds or nothing does 

316 with transaction.atomic(): 

317 

318 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

319 

320 # Silently ignore if the task has already been marked as completed. 

321 # This allows us to simulate exactly-once delivery semantics (only process 

322 # a finished task once). 

323 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true

324 return 

325 

326 task.completed_time = timezone.now() 

327 task.log = solve_response.log 

328 

329 if solve_response.status == "success": 329 ↛ 335line 329 didn't jump to line 335 because the condition on line 329 was always true

330 task.status = TaskStatus.Completed 

331 task.debug = { 

332 "timing": solve_response.timing 

333 } 

334 else: 

335 task.status = TaskStatus.Failed 

336 task.error = { 

337 "message": solve_response.error["message"], 

338 "cause": "idaes_service_request", 

339 "traceback": solve_response.traceback 

340 } 

341 

342 task.save(update_fields=["status", "completed_time", "log", "debug", "error"]) 

343 

344 # Save the solved flowsheet values 

345 if task.status == TaskStatus.Completed: 345 ↛ 353line 345 didn't jump to line 353 because the condition on line 345 was always true

346 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index) 

347 

348 # For now, only save initial values for single and dynamic solves, not MSS 

349 # In future, we may need some more complex logic to handle MSS initial values 

350 if solve_response.solve_index is None: 

351 save_all_initial_values(solve_response.flowsheet.initial_values) 

352 

353 _send_task_notifications(task) 

354 

355def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

356 """Handle final failure notifications for solves that could not be processed. 

357 

358 Args: 

359 solve_response: Completion payload received from the dead-letter queue. 

360 """ 

361 # Use a transaction to ensure that either everything succeeds or nothing does 

362 with transaction.atomic(): 

363 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

364 

365 # Silently ignore if the task has already been marked as failed. 

366 # This allows us to simulate exactly-once delivery semantics (only process 

367 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode. 

368 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled: 

369 return 

370 

371 task.completed_time = timezone.now() 

372 task.log = solve_response.log 

373 task.status = TaskStatus.Failed 

374 task.error = { 

375 "message": "Internal server error: several attempts to process finished solve failed." 

376 } 

377 task.save() 

378 

379 _send_task_notifications(task) 

380 

381def process_ml_training_response( 

382 ml_training_response: MLTrainingCompletionPayload 

383): 

384 """Persist the result of a machine-learning training job and send updates. 

385 

386 Args: 

387 ml_training_response: Completion payload returned by the ML service. 

388 """ 

389 with transaction.atomic(): 

390 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id) 

391 

392 # Silently ignore if the task has already been marked as completed. 

393 # This allows us to simulate exactly-once delivery semantics (only process 

394 # a finished task once). 

395 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true

396 return 

397 

398 task.completed_time = timezone.now() 

399 task.log = ml_training_response.log 

400 

401 if ml_training_response.status == "success": 401 ↛ 407line 401 didn't jump to line 407 because the condition on line 401 was always true

402 task.status = TaskStatus.Completed 

403 task.debug = { 

404 "timing": ml_training_response.json_response.get("timing", {}) 

405 } 

406 else: 

407 task.status = TaskStatus.Failed 

408 task.error = { 

409 "message": ml_training_response.error, 

410 "traceback": ml_training_response.traceback 

411 } 

412 

413 task.save() 

414 result = ml_training_response.json_response 

415 

416 MLModel.objects.update( 

417 surrogate_model=result.get("surrogate_model"), 

418 charts=result.get("charts"), 

419 metrics=result.get("metrics"), 

420 test_inputs=result.get("test_inputs"), 

421 test_outputs=result.get("test_outputs"), 

422 progress=3 

423 ) 

424 _send_task_notifications(task) 

425 

426 

427def cancel_idaes_solve(task_id: int): 

428 """Mark an in-flight solve task as cancelled and notify subscribers. 

429 

430 Args: 

431 task_id: Identifier of the `Task` being cancelled. 

432 """ 

433 with transaction.atomic(): 

434 task = Task.objects.select_related("parent").get(id=task_id) 

435 

436 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set 

437 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending: 

438 return 

439 

440 task.status = TaskStatus.Cancelled 

441 task.save() 

442 

443 messaging.send_flowsheet_notification_message( 

444 task.flowsheet_id, 

445 TaskSerializer(task).data, 

446 NotificationServiceMessageType.TASK_CANCELLED 

447 ) 

448 

449def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response: 

450 """Generate Python IDAES code for a flowsheet by calling the IDAES service. 

451 

452 Args: 

453 flowsheet_id: Identifier of the flowsheet to translate to Python code. 

454 return_json: Whether to bypass the remote call and return raw JSON. 

455 

456 Returns: 

457 REST response containing either the generated code or the flowsheet JSON. 

458 """ 

459 scenario = None 

460 factory = IdaesFactory(flowsheet_id, scenario=scenario, require_variables_fixed=False) 

461 response_data = ResponseType( 

462 status="success", 

463 error=None, 

464 log=None, 

465 debug=None 

466 ) 

467 try: 

468 factory.build() 

469 data = factory.flowsheet 

470 if return_json: 

471 return Response(data.model_dump_json(), status=200) 

472 except Exception as e: 

473 response_data["status"] = "error" 

474 response_data["error"] = { 

475 "message": str(e), 

476 "traceback": traceback.format_exc() 

477 } 

478 return Response(response_data, status=400) 

479 try: 

480 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump()) 

481 return Response(response, status=200) 

482 except IdaesServiceRequestException as e: 

483 response = e.message 

484 response_data["status"] = "error" 

485 response_data["error"] = { 

486 "message": response["error"], 

487 "traceback": response["traceback"] 

488 } 

489 return Response(response_data, status=400) 

490 

491class BuildStateSolveError(Exception): 

492 pass 

493 

494def build_state_request(stream: SimulationObject): 

495 """Request a state build for the provided stream using the IDAES service. 

496 

497 Args: 

498 stream: Stream object whose inlet properties should be used for the build. 

499 

500 Returns: 

501 REST response returned by the IDAES service containing built properties. 

502 

503 Raises: 

504 BuildStateSolveError: If the IDAES service rejects the build request. 

505 Exception: If preparing the payload fails for any reason. 

506 """ 

507 ctx = IdaesFactoryContext(stream.flowsheet_id) 

508 

509 try: 

510 port = stream.connectedPorts.get(direction="inlet") 

511 unitop: SimulationObject = port.unitOp 

512 # find the property package key for this port (we have the value, the key of this port) 

513 property_package_ports = unitop.schema.propertyPackagePorts 

514 for key, port_list in property_package_ports.items(): 

515 if port.key in port_list: 515 ↛ 514line 515 didn't jump to line 514 because the condition on line 515 was always true

516 property_package_key = key 

517 PropertyPackageAdapter( 

518 property_package_key).serialise(ctx, unitop) 

519 

520 data = BuildStateRequestSchema( 

521 property_package=ctx.property_packages[0], 

522 properties=serialise_stream(ctx, stream, is_inlet=True) 

523 ) 

524 except Exception as e: 

525 raise Exception(e) 

526 try: 

527 response = idaes_service_request(endpoint="build_state", data=data.model_dump()) 

528 store_properties_schema(response["properties"], stream.flowsheet_id) 

529 

530 return Response(response, status=200) 

531 except IdaesServiceRequestException as e: 

532 raise BuildStateSolveError(e.message) 

533 except Exception as e: 

534 raise Exception(e)