Coverage for backend/django/idaes_factory/endpoints.py: 78%

219 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-26 20:57 +0000

1import datetime 

2import json 

3import logging 

4import os 

5from typing import TypedDict 

6import requests 

7import traceback 

8 

9from django.db import transaction 

10from django.utils import timezone 

11from opentelemetry import trace 

12from rest_framework.response import Response 

13from pydantic import JsonValue 

14 

15from CoreRoot import settings 

16from authentication.user.models import User 

17from ahuora_builder_types.payloads.build_state_request_schema import BuildStateRequestSchema 

18from ahuora_builder_types.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \ 

19 MultiSolvePayload 

20from ahuora_builder_types.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload 

21from core.auxiliary.models.MLModel import MLModel 

22from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \ 

23 NotificationServiceMessage 

24from core.auxiliary.enums.generalEnums import TaskStatus 

25from core.auxiliary.models.Task import Task, TaskType 

26from core.auxiliary.models.Flowsheet import Flowsheet 

27from core.auxiliary.serializers import TaskSerializer 

28from core.exceptions import DetailedException 

29from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema 

30from flowsheetInternals.unitops.models.SimulationObject import SimulationObject 

31from .adapters.stream_properties import serialise_stream 

32from .adapters.property_package_adapter import PropertyPackageAdapter 

33from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams 

34from core.auxiliary.models.Scenario import Scenario, ScenarioTabTypeEnum 

35from common.services import messaging 

36 

37logger = logging.getLogger(__name__) 

38tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME) 

39 

40 

41class IdaesServiceRequestException(Exception): 

42 def __init__(self, message: str) -> None: 

43 super().__init__(message) 

44 self.message = message 

45 

46class SolveFlowsheetError(DetailedException): 

47 pass 

48 

49class ResponseType(TypedDict, total=False): 

50 status: str 

51 error: dict | None 

52 log: str | None 

53 debug: dict | None 

54 

55 

56def idaes_service_request(endpoint: str, data: JsonValue) -> JsonValue: 

57 """Send a JSON payload to the configured IDAES service endpoint. 

58 

59 Args: 

60 endpoint: Relative path of the IDAES service endpoint to call. 

61 data: Serialised payload that conforms to the endpoint schema. 

62 

63 Returns: 

64 Parsed JSON response returned by the IDAES service. 

65 

66 Raises: 

67 IdaesServiceRequestException: If the service responds with a non-200 status. 

68 """ 

69 url = (os.getenv('IDAES_SERVICE_URL') 

70 or "http://localhost:8080") + "/" + endpoint 

71 result = requests.post(url, json=data) 

72 if result.status_code != 200: 

73 raise IdaesServiceRequestException(result.json()) 

74 

75 return result.json() 

76 

77 

78@tracer.start_as_current_span("send_flowsheet_solve_request") 

79def _solve_flowsheet_request( 

80 task_id: int, 

81 built_factory: IdaesFactory, 

82 perform_diagnostics: bool = False, 

83 high_priority: bool = False 

84): 

85 """Queue an IDAES solve request for the provided flowsheet build. 

86 

87 Args: 

88 task_id: Identifier of the task tracking the solve request. 

89 built_factory: Fully built factory containing flowsheet data to solve. 

90 perform_diagnostics: Whether to request diagnostic output from IDAES. 

91 high_priority: Whether the message should be prioritised over normal solves. 

92 

93 Raises: 

94 SolveFlowsheetError: If the message cannot be dispatched to the queue. 

95 """ 

96 

97 try: 

98 idaes_payload = IdaesSolveRequestPayload( 

99 flowsheet=built_factory.flowsheet, 

100 solve_index=built_factory.solve_index, 

101 scenario_id=(built_factory.scenario.id if built_factory.scenario else None), 

102 task_id=task_id, 

103 perform_diagnostics=perform_diagnostics 

104 ) 

105 

106 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority) 

107 except Exception as e: 

108 raise SolveFlowsheetError(e,"idaes_factory_solve_message") 

109 

110def start_flowsheet_solve_event( 

111 flowsheet_id: int, 

112 group_id: int, 

113 user: User, 

114 scenario: Scenario = None, 

115 perform_diagnostics: bool = False 

116) -> Response: 

117 """Start a single solve for the given flowsheet and return the tracking task. 

118 

119 Args: 

120 flowsheet_id: Identifier of the flowsheet that should be solved. 

121 user: User initiating the solve request. 

122 scenario: Optional scenario providing context for the solve. 

123 perform_diagnostics: Whether the solve should run with diagnostic output enabled. 

124 

125 Returns: 

126 REST response containing the serialised `Task` used to track the solve. 

127 """ 

128 if scenario and scenario.state_name != ScenarioTabTypeEnum.SteadyState and not scenario.dataRows.exists(): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was never true

129 # if not steady state solves and no data rows exist, cannot proceed 

130 return Response(status=400, data=f"No data was provided for {scenario.state_name} scenario.") 

131 

132 try: 

133 # Remove all previously created DynamicResults for this scenario 

134 scenario.solutions.all().delete() 

135 except: 

136 # there's no solution yet, that's fine 

137 pass 

138 

139 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True) 

140 # Persist the user's intent alongside the task so downstream consumers 

141 # can tell whether this solve was launched with diagnostics enabled. 

142 debug: dict[str, JsonValue] = { 

143 **(solve_task.debug or {}), 

144 "perform_diagnostics": bool(perform_diagnostics), 

145 } 

146 if scenario is not None: 

147 debug["scenario_id"] = scenario.id 

148 solve_task.debug = debug 

149 solve_task.save(update_fields=["debug"]) 

150 

151 try: 

152 factory = IdaesFactory( 

153 group_id=group_id, 

154 scenario=scenario 

155 ) 

156 factory.build() 

157 

158 # We send single solve requests as high priority to ensure 

159 # they are not blocked by large multi-solve requests. 

160 _solve_flowsheet_request( 

161 solve_task.id, 

162 factory, 

163 perform_diagnostics=perform_diagnostics, 

164 high_priority=True 

165 ) 

166 except DetailedException as e: 

167 solve_task.set_failure_with_exception(e, save=True) 

168 

169 task_serializer = TaskSerializer(solve_task) 

170 

171 return Response(task_serializer.data, status=200) 

172 

173def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response: 

174 """Kick off a multi steady-state solve and return the parent tracking task. 

175 

176 Args: 

177 flowsheet_id: Identifier of the flowsheet being solved. 

178 user: User who requested the multi-solve. 

179 scenario: Scenario containing the steady-state configurations to solve. 

180 

181 Returns: 

182 REST response containing the parent `Task` that aggregates child solves. 

183 """ 

184 if not scenario.dataRows.exists(): # empty rows 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 return Response(status=400, data="No data was provided for multi steady-state scenario.") 

186 

187 # Remove all previously created DynamicResults for this scenario 

188 scenario.solutions.all().delete() 

189 solve_iterations = scenario.dataRows.count() 

190 

191 parent_task = Task.create_parent_task( 

192 creator=user, 

193 flowsheet_id=flowsheet_id, 

194 scheduled_tasks=solve_iterations, 

195 status=TaskStatus.Running 

196 ) 

197 child_tasks: list[Task] = [ 

198 Task.create( 

199 user, 

200 flowsheet_id, 

201 parent=parent_task, 

202 status=TaskStatus.Pending 

203 ) for _ in range(solve_iterations) 

204 ] 

205 

206 Task.objects.bulk_create(child_tasks) 

207 

208 messaging.send_dispatch_multi_solve_message(MultiSolvePayload( 

209 task_id=parent_task.id, 

210 scenario_id=scenario.id) 

211 ) 

212 

213 return Response(TaskSerializer(parent_task).data, status=200) 

214 

215def dispatch_multi_solves(parent_task_id: int, scenario_id: int): 

216 """Build and dispatch queued steady-state solves for each child task. 

217 

218 Args: 

219 parent_task_id: Identifier of the parent multi-solve task. 

220 scenario_id: Scenario containing the data rows to iterate through. 

221 """ 

222 parent_task = Task.objects.get(id=parent_task_id) 

223 scenario = Scenario.objects.get(id=scenario_id) 

224 

225 rootGroup = scenario.flowsheet.rootGrouping 

226 factory = IdaesFactory( 

227 group_id=rootGroup.id, 

228 scenario=scenario, 

229 ) 

230 

231 child_tasks = list(parent_task.children.order_by('start_time')) 

232 

233 for solve_index, task in enumerate(child_tasks): 

234 try: 

235 factory.clear_flowsheet() 

236 factory.use_with_solve_index(solve_index) 

237 factory.build() 

238 

239 _solve_flowsheet_request( 

240 task.id, 

241 factory 

242 ) 

243 except DetailedException as e: 

244 task.set_failure_with_exception(exception=e, save=True) 

245 parent_task.update_status_from_child(task) 

246 

247 flowsheet_messages = [ 

248 NotificationServiceMessage( 

249 data=TaskSerializer(task).data, 

250 message_type=NotificationServiceMessageType.TASK_UPDATED 

251 ) for task in [task, parent_task] 

252 ] 

253 

254 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages) 

255 

256 

257def start_ml_training_event( 

258 datapoints: list, 

259 columns: list[str], 

260 input_labels: list[str], 

261 output_labels: list[str], 

262 user: User, 

263 flowsheet_id: int, 

264): 

265 """Queue an asynchronous machine-learning training job for the given dataset. 

266 

267 Args: 

268 datapoints: Training rows to supply to the ML service. 

269 columns: Column names describing the datapoint ordering. 

270 input_labels: Names of the input features. 

271 output_labels: Names of the predicted outputs. 

272 user: User requesting the training run. 

273 flowsheet_id: Flowsheet the training run is associated with. 

274 

275 Returns: 

276 REST response containing the serialised `Task` for the training job. 

277 """ 

278 training_task = Task.create( 

279 user, 

280 flowsheet_id, 

281 task_type=TaskType.ML_TRAINING, 

282 status=TaskStatus.Pending, 

283 save=True 

284 ) 

285 

286 try: 

287 payload = MLTrainRequestPayload( 

288 datapoints=datapoints, 

289 columns=columns, 

290 input_labels=input_labels, 

291 output_labels=output_labels, 

292 task_id=training_task.id, 

293 ) 

294 messaging.send_ml_training_message(payload) 

295 

296 except DetailedException as e: 

297 training_task.set_failure_with_exception(e, save=True) 

298 

299 task_serializer = TaskSerializer(training_task) 

300 return Response(task_serializer.data, status=200) 

301 

302def _send_task_notifications(task: Task): 

303 """Broadcast task completion or status updates to interested flowsheet clients. 

304 

305 Args: 

306 task: Task whose status change should be pushed to subscribers. 

307 """ 

308 flowsheet_messages = [] 

309 

310 # If this is a child task, update the parent task status 

311 if task.parent: 

312 task.parent.update_status_from_child(task) 

313 

314 message_type = (NotificationServiceMessageType.TASK_COMPLETED 

315 if task.parent.status is TaskStatus.Completed else 

316 NotificationServiceMessageType.TASK_UPDATED) 

317 

318 flowsheet_messages.append(NotificationServiceMessage( 

319 data=TaskSerializer(task.parent).data, 

320 message_type=message_type 

321 )) 

322 

323 flowsheet_messages.append(NotificationServiceMessage( 

324 data=TaskSerializer(task).data, 

325 message_type=NotificationServiceMessageType.TASK_COMPLETED 

326 )) 

327 

328 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages) 

329 

330def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

331 """Persist the outcome of a completed IDAES solve and notify listeners. 

332 

333 Args: 

334 solve_response: Payload describing the finished solve result. 

335 """ 

336 # Use a transaction to ensure that either everything succeeds or nothing does 

337 with transaction.atomic(): 

338 

339 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

340 

341 # Silently ignore if the task has already been marked as completed. 

342 # This allows us to simulate exactly-once delivery semantics (only process 

343 # a finished task once). 

344 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 

345 return 

346 

347 task.completed_time = timezone.now() 

348 task.log = solve_response.log 

349 task.debug = { 

350 **(task.debug or {}), 

351 "timing": solve_response.timing or {}, 

352 } 

353 

354 if solve_response.status == "success": 

355 task.status = TaskStatus.Completed 

356 else: 

357 task.status = TaskStatus.Failed 

358 task.error = { 

359 "message": solve_response.error["message"], 

360 "cause": "idaes_service_request", 

361 "traceback": solve_response.traceback 

362 } 

363 

364 task.save(update_fields=["status", "completed_time", "log", "debug", "error"]) 

365 

366 # Save the solved flowsheet values 

367 if task.status == TaskStatus.Completed: 

368 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index) 

369 

370 # For now, only save initial values for single and dynamic solves, not MSS 

371 # In future, we may need some more complex logic to handle MSS initial values 

372 if solve_response.solve_index is None: 

373 save_all_initial_values(solve_response.flowsheet.initial_values) 

374 

375 _send_task_notifications(task) 

376 

377def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload): 

378 """Handle final failure notifications for solves that could not be processed. 

379 

380 Args: 

381 solve_response: Completion payload received from the dead-letter queue. 

382 """ 

383 # Use a transaction to ensure that either everything succeeds or nothing does 

384 with transaction.atomic(): 

385 task = Task.objects.select_related("parent").get(id=solve_response.task_id) 

386 

387 # Silently ignore if the task has already been marked as failed. 

388 # This allows us to simulate exactly-once delivery semantics (only process 

389 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode. 

390 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled: 

391 return 

392 

393 task.completed_time = timezone.now() 

394 task.log = solve_response.log 

395 task.status = TaskStatus.Failed 

396 task.error = { 

397 "message": "Internal server error: several attempts to process finished solve failed." 

398 } 

399 task.save() 

400 

401 _send_task_notifications(task) 

402 

403def process_ml_training_response( 

404 ml_training_response: MLTrainingCompletionPayload 

405): 

406 """Persist the result of a machine-learning training job and send updates. 

407 

408 Args: 

409 ml_training_response: Completion payload returned by the ML service. 

410 """ 

411 with transaction.atomic(): 

412 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id) 

413 

414 # Silently ignore if the task has already been marked as completed. 

415 # This allows us to simulate exactly-once delivery semantics (only process 

416 # a finished task once). 

417 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 417 ↛ 418line 417 didn't jump to line 418 because the condition on line 417 was never true

418 return 

419 

420 task.completed_time = timezone.now() 

421 task.log = ml_training_response.log 

422 

423 if ml_training_response.status == "success": 423 ↛ 429line 423 didn't jump to line 429 because the condition on line 423 was always true

424 task.status = TaskStatus.Completed 

425 task.debug = { 

426 "timing": ml_training_response.json_response.get("timing", {}) 

427 } 

428 else: 

429 task.status = TaskStatus.Failed 

430 task.error = { 

431 "message": ml_training_response.error, 

432 "traceback": ml_training_response.traceback 

433 } 

434 

435 task.save() 

436 result = ml_training_response.json_response 

437 

438 MLModel.objects.update( 

439 surrogate_model=result.get("surrogate_model"), 

440 charts=result.get("charts"), 

441 metrics=result.get("metrics"), 

442 test_inputs=result.get("test_inputs"), 

443 test_outputs=result.get("test_outputs"), 

444 progress=3 

445 ) 

446 _send_task_notifications(task) 

447 

448 

449def cancel_idaes_solve(task_id: int): 

450 """Mark an in-flight solve task as cancelled and notify subscribers. 

451 

452 Args: 

453 task_id: Identifier of the `Task` being cancelled. 

454 """ 

455 with transaction.atomic(): 

456 task = Task.objects.select_related("parent").get(id=task_id) 

457 

458 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set 

459 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending: 459 ↛ 460line 459 didn't jump to line 460 because the condition on line 459 was never true

460 return 

461 

462 task.status = TaskStatus.Cancelled 

463 task.save() 

464 

465 messaging.send_flowsheet_notification_message( 

466 task.flowsheet_id, 

467 TaskSerializer(task).data, 

468 NotificationServiceMessageType.TASK_CANCELLED 

469 ) 

470 

471def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response: 

472 """Generate Python IDAES code for a flowsheet by calling the IDAES service. 

473 

474 Args: 

475 flowsheet_id: Identifier of the flowsheet to translate to Python code. 

476 return_json: Whether to bypass the remote call and return raw JSON. 

477 

478 Returns: 

479 REST response containing either the generated code or the flowsheet JSON. 

480 """ 

481 scenario = None 

482 flowsheet = Flowsheet.objects.get(id=flowsheet_id) 

483 factory = IdaesFactory(group_id=flowsheet.rootGrouping.id, scenario=scenario, require_variables_fixed=False) 

484 response_data = ResponseType( 

485 status="success", 

486 error=None, 

487 log=None, 

488 debug=None 

489 ) 

490 try: 

491 factory.build() 

492 data = factory.flowsheet 

493 if return_json: 

494 return Response(data.model_dump_json(), status=200) 

495 except Exception as e: 

496 response_data["status"] = "error" 

497 response_data["error"] = { 

498 "message": str(e), 

499 "traceback": traceback.format_exc() 

500 } 

501 return Response(response_data, status=400) 

502 try: 

503 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump()) 

504 return Response(response, status=200) 

505 except IdaesServiceRequestException as e: 

506 response = e.message 

507 response_data["status"] = "error" 

508 response_data["error"] = { 

509 "message": response["error"], 

510 "traceback": response["traceback"] 

511 } 

512 return Response(response_data, status=400) 

513 

514class BuildStateSolveError(Exception): 

515 pass 

516 

517def state_request_build(stream: SimulationObject) -> BuildStateRequestSchema: 

518 ctx = IdaesFactoryContext(stream.flowsheet.rootGrouping.id) 

519 

520 port = stream.connectedPorts.get(direction="inlet") 

521 unitop: SimulationObject = port.unitOp 

522 # find the property package key for this port (we have the value, the key of this port) 

523 property_package_ports = unitop.schema.propertyPackagePorts 

524 for key, port_list in property_package_ports.items(): 

525 if port.key in port_list: 

526 property_package_key = key 

527 PropertyPackageAdapter( 

528 property_package_key).serialise(ctx, unitop) 

529 

530 return BuildStateRequestSchema( 

531 property_package=ctx.property_packages[0], 

532 properties=serialise_stream(ctx, stream, is_inlet=True) 

533 ) 

534 

535def build_state_request(stream: SimulationObject): 

536 """Request a state build for the provided stream using the IDAES service. 

537 

538 Args: 

539 stream: Stream object whose inlet properties should be used for the build. 

540 

541 Returns: 

542 REST response returned by the IDAES service containing built properties. 

543 

544 Raises: 

545 BuildStateSolveError: If the IDAES service rejects the build request. 

546 Exception: If preparing the payload fails for any reason. 

547 """ 

548 

549 try: 

550 data = state_request_build(stream) 

551 response = idaes_service_request(endpoint="build_state", data=data.model_dump()) 

552 store_properties_schema(response["properties"], stream.flowsheet_id) 

553 

554 return Response(response, status=200) 

555 except IdaesServiceRequestException as e: 

556 raise BuildStateSolveError(e.message) 

557 except Exception as e: 

558 raise Exception(e)