Coverage for backend/idaes_factory/endpoints.py: 70%
205 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
1import datetime
2import json
3import os
4from typing import Any, TypedDict
5import requests
6import traceback
8from django.db import transaction
9from django.utils import timezone
10from opentelemetry import trace
11from rest_framework.response import Response
13from CoreRoot import settings
14from authentication.user.models import User
15from common.models.idaes.payloads import BuildStateRequestSchema
16from common.models.idaes.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \
17 MultiSolvePayload
18from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload
19from core.auxiliary.models.MLModel import MLModel
20from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \
21 NotificationServiceMessage
22from core.auxiliary.enums.generalEnums import TaskStatus
23from core.auxiliary.models.Task import Task, TaskType
24from core.auxiliary.serializers import TaskSerializer
25from core.exceptions import DetailedException
26from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema
27from flowsheetInternals.unitops.models.SimulationObject import SimulationObject
28from .adapters.stream_properties import serialise_stream
29from .adapters.property_package_adapter import PropertyPackageAdapter
30from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams
31from core.auxiliary.models.Scenario import Scenario
32from common.services import messaging
34tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME)
37class IdaesServiceRequestException(Exception):
38 def __init__(self, message: str) -> None:
39 super().__init__(message)
40 self.message = message
42class SolveFlowsheetError(DetailedException):
43 pass
45class ResponseType(TypedDict, total=False):
46 status: str
47 error: dict | None
48 log: str | None
49 debug: dict | None
52def idaes_service_request(endpoint, data: Any) -> Any:
53 """Send a JSON payload to the configured IDAES service endpoint.
55 Args:
56 endpoint: Relative path of the IDAES service endpoint to call.
57 data: Serialised payload that conforms to the endpoint schema.
59 Returns:
60 Parsed JSON response returned by the IDAES service.
62 Raises:
63 IdaesServiceRequestException: If the service responds with a non-200 status.
64 """
65 url = (os.getenv('IDAES_SERVICE_URL')
66 or "http://localhost:8080") + "/" + endpoint
67 result = requests.post(url, json=data)
68 if result.status_code != 200:
69 raise IdaesServiceRequestException(result.json())
71 return result.json()
74@tracer.start_as_current_span("send_flowsheet_solve_request")
75def _solve_flowsheet_request(
76 task_id: int,
77 built_factory: IdaesFactory,
78 perform_diagnostics: bool = False,
79 high_priority: bool = False
80):
81 """Queue an IDAES solve request for the provided flowsheet build.
83 Args:
84 task_id: Identifier of the task tracking the solve request.
85 built_factory: Fully built factory containing flowsheet data to solve.
86 perform_diagnostics: Whether to request diagnostic output from IDAES.
87 high_priority: Whether the message should be prioritised over normal solves.
89 Raises:
90 SolveFlowsheetError: If the message cannot be dispatched to the queue.
91 """
93 try:
94 idaes_payload = IdaesSolveRequestPayload(
95 flowsheet=built_factory.flowsheet,
96 solve_index=built_factory.solve_index,
97 scenario_id=(built_factory.scenario.id if built_factory.scenario else None),
98 task_id=task_id,
99 perform_diagnostics=perform_diagnostics
100 )
102 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority)
103 except Exception as e:
104 raise SolveFlowsheetError(e,"idaes_factory_solve_message")
106def start_flowsheet_solve_event(
107 flowsheet_id: int,
108 user: User,
109 scenario: Scenario = None,
110 perform_diagnostics: bool = False
111) -> Response:
112 """Start a single solve for the given flowsheet and return the tracking task.
114 Args:
115 flowsheet_id: Identifier of the flowsheet that should be solved.
116 user: User initiating the solve request.
117 scenario: Optional scenario providing context for the solve.
118 perform_diagnostics: Whether the solve should run with diagnostic output enabled.
120 Returns:
121 REST response containing the serialised `Task` used to track the solve.
122 """
124 try:
125 # Remove all previously created DynamicResults for this scenario
126 scenario.solutions.all().delete()
127 except:
128 # there's no solution yet, that's fine
129 pass
131 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True)
133 try:
134 factory = IdaesFactory(
135 flowsheet_id=flowsheet_id,
136 scenario=scenario
137 )
138 factory.build()
140 # We send single solve requests as high priority to ensure
141 # they are not blocked by large multi-solve requests.
142 _solve_flowsheet_request(
143 solve_task.id,
144 factory,
145 perform_diagnostics=perform_diagnostics,
146 high_priority=True
147 )
148 except DetailedException as e:
149 solve_task.set_failure_with_exception(e, save=True)
151 task_serializer = TaskSerializer(solve_task)
153 return Response(task_serializer.data, status=200)
155def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response:
156 """Kick off a multi steady-state solve and return the parent tracking task.
158 Args:
159 flowsheet_id: Identifier of the flowsheet being solved.
160 user: User who requested the multi-solve.
161 scenario: Scenario containing the steady-state configurations to solve.
163 Returns:
164 REST response containing the parent `Task` that aggregates child solves.
165 """
167 # Remove all previously created DynamicResults for this scenario
168 scenario.solutions.all().delete()
169 solve_iterations = scenario.solveStates.count()
171 parent_task = Task.create_parent_task(
172 creator=user,
173 flowsheet_id=flowsheet_id,
174 scheduled_tasks=solve_iterations,
175 status=TaskStatus.Running
176 )
177 child_tasks: list[Task] = [
178 Task.create(
179 user,
180 flowsheet_id,
181 parent=parent_task,
182 status=TaskStatus.Pending
183 ) for _ in range(solve_iterations)
184 ]
186 Task.objects.bulk_create(child_tasks)
188 messaging.send_dispatch_multi_solve_message(MultiSolvePayload(
189 task_id=parent_task.id,
190 scenario_id=scenario.id)
191 )
193 return Response(TaskSerializer(parent_task).data, status=200)
195def dispatch_multi_solves(parent_task_id: int, scenario_id: int):
196 """Build and dispatch queued steady-state solves for each child task.
198 Args:
199 parent_task_id: Identifier of the parent multi-solve task.
200 scenario_id: Scenario containing the solve states to iterate through.
201 """
202 parent_task = Task.objects.get(id=parent_task_id)
203 scenario = Scenario.objects.get(id=scenario_id)
205 factory = IdaesFactory(
206 flowsheet_id=parent_task.flowsheet_id,
207 scenario=scenario,
208 )
210 child_tasks = list(parent_task.children.order_by('start_time'))
212 for solve_index, task in enumerate(child_tasks):
213 try:
214 factory.clear_flowsheet()
215 factory.use_with_solve_index(solve_index)
216 factory.build()
218 _solve_flowsheet_request(
219 task.id,
220 factory
221 )
222 except DetailedException as e:
223 task.set_failure_with_exception(exception=e, save=True)
224 parent_task.update_status_from_child(task)
226 flowsheet_messages = [
227 NotificationServiceMessage(
228 data=TaskSerializer(task).data,
229 message_type=NotificationServiceMessageType.TASK_UPDATED
230 ) for task in [task, parent_task]
231 ]
233 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages)
236def start_ml_training_event(
237 datapoints: list,
238 columns: list[str],
239 input_labels: list[str],
240 output_labels: list[str],
241 user: User,
242 flowsheet_id: int,
243):
244 """Queue an asynchronous machine-learning training job for the given dataset.
246 Args:
247 datapoints: Training rows to supply to the ML service.
248 columns: Column names describing the datapoint ordering.
249 input_labels: Names of the input features.
250 output_labels: Names of the predicted outputs.
251 user: User requesting the training run.
252 flowsheet_id: Flowsheet the training run is associated with.
254 Returns:
255 REST response containing the serialised `Task` for the training job.
256 """
257 training_task = Task.create(
258 user,
259 flowsheet_id,
260 task_type=TaskType.ML_TRAINING,
261 status=TaskStatus.Pending,
262 save=True
263 )
265 try:
266 payload = MLTrainRequestPayload(
267 datapoints=datapoints,
268 columns=columns,
269 input_labels=input_labels,
270 output_labels=output_labels,
271 task_id=training_task.id,
272 )
273 messaging.send_ml_training_message(payload)
275 except DetailedException as e:
276 training_task.set_failure_with_exception(e, save=True)
278 task_serializer = TaskSerializer(training_task)
279 return Response(task_serializer.data, status=200)
281def _send_task_notifications(task: Task):
282 """Broadcast task completion or status updates to interested flowsheet clients.
284 Args:
285 task: Task whose status change should be pushed to subscribers.
286 """
287 flowsheet_messages = []
289 # If this is a child task, update the parent task status
290 if task.parent:
291 task.parent.update_status_from_child(task)
293 message_type = (NotificationServiceMessageType.TASK_COMPLETED
294 if task.parent.status is TaskStatus.Completed else
295 NotificationServiceMessageType.TASK_UPDATED)
297 flowsheet_messages.append(NotificationServiceMessage(
298 data=TaskSerializer(task.parent).data,
299 message_type=message_type
300 ))
302 flowsheet_messages.append(NotificationServiceMessage(
303 data=TaskSerializer(task).data,
304 message_type=NotificationServiceMessageType.TASK_COMPLETED
305 ))
307 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages)
309def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
310 """Persist the outcome of a completed IDAES solve and notify listeners.
312 Args:
313 solve_response: Payload describing the finished solve result.
314 """
315 # Use a transaction to ensure that either everything succeeds or nothing does
316 with transaction.atomic():
318 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
320 # Silently ignore if the task has already been marked as completed.
321 # This allows us to simulate exactly-once delivery semantics (only process
322 # a finished task once).
323 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 323 ↛ 324line 323 didn't jump to line 324 because the condition on line 323 was never true
324 return
326 task.completed_time = timezone.now()
327 task.log = solve_response.log
329 if solve_response.status == "success": 329 ↛ 335line 329 didn't jump to line 335 because the condition on line 329 was always true
330 task.status = TaskStatus.Completed
331 task.debug = {
332 "timing": solve_response.timing
333 }
334 else:
335 task.status = TaskStatus.Failed
336 task.error = {
337 "message": solve_response.error["message"],
338 "cause": "idaes_service_request",
339 "traceback": solve_response.traceback
340 }
342 task.save(update_fields=["status", "completed_time", "log", "debug", "error"])
344 # Save the solved flowsheet values
345 if task.status == TaskStatus.Completed: 345 ↛ 353line 345 didn't jump to line 353 because the condition on line 345 was always true
346 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index)
348 # For now, only save initial values for single and dynamic solves, not MSS
349 # In future, we may need some more complex logic to handle MSS initial values
350 if solve_response.solve_index is None:
351 save_all_initial_values(solve_response.flowsheet.initial_values)
353 _send_task_notifications(task)
355def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
356 """Handle final failure notifications for solves that could not be processed.
358 Args:
359 solve_response: Completion payload received from the dead-letter queue.
360 """
361 # Use a transaction to ensure that either everything succeeds or nothing does
362 with transaction.atomic():
363 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
365 # Silently ignore if the task has already been marked as failed.
366 # This allows us to simulate exactly-once delivery semantics (only process
367 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode.
368 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled:
369 return
371 task.completed_time = timezone.now()
372 task.log = solve_response.log
373 task.status = TaskStatus.Failed
374 task.error = {
375 "message": "Internal server error: several attempts to process finished solve failed."
376 }
377 task.save()
379 _send_task_notifications(task)
381def process_ml_training_response(
382 ml_training_response: MLTrainingCompletionPayload
383):
384 """Persist the result of a machine-learning training job and send updates.
386 Args:
387 ml_training_response: Completion payload returned by the ML service.
388 """
389 with transaction.atomic():
390 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id)
392 # Silently ignore if the task has already been marked as completed.
393 # This allows us to simulate exactly-once delivery semantics (only process
394 # a finished task once).
395 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 395 ↛ 396line 395 didn't jump to line 396 because the condition on line 395 was never true
396 return
398 task.completed_time = timezone.now()
399 task.log = ml_training_response.log
401 if ml_training_response.status == "success": 401 ↛ 407line 401 didn't jump to line 407 because the condition on line 401 was always true
402 task.status = TaskStatus.Completed
403 task.debug = {
404 "timing": ml_training_response.json_response.get("timing", {})
405 }
406 else:
407 task.status = TaskStatus.Failed
408 task.error = {
409 "message": ml_training_response.error,
410 "traceback": ml_training_response.traceback
411 }
413 task.save()
414 result = ml_training_response.json_response
416 MLModel.objects.update(
417 surrogate_model=result.get("surrogate_model"),
418 charts=result.get("charts"),
419 metrics=result.get("metrics"),
420 test_inputs=result.get("test_inputs"),
421 test_outputs=result.get("test_outputs"),
422 progress=3
423 )
424 _send_task_notifications(task)
427def cancel_idaes_solve(task_id: int):
428 """Mark an in-flight solve task as cancelled and notify subscribers.
430 Args:
431 task_id: Identifier of the `Task` being cancelled.
432 """
433 with transaction.atomic():
434 task = Task.objects.select_related("parent").get(id=task_id)
436 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set
437 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending:
438 return
440 task.status = TaskStatus.Cancelled
441 task.save()
443 messaging.send_flowsheet_notification_message(
444 task.flowsheet_id,
445 TaskSerializer(task).data,
446 NotificationServiceMessageType.TASK_CANCELLED
447 )
449def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response:
450 """Generate Python IDAES code for a flowsheet by calling the IDAES service.
452 Args:
453 flowsheet_id: Identifier of the flowsheet to translate to Python code.
454 return_json: Whether to bypass the remote call and return raw JSON.
456 Returns:
457 REST response containing either the generated code or the flowsheet JSON.
458 """
459 scenario = None
460 factory = IdaesFactory(flowsheet_id, scenario=scenario, require_variables_fixed=False)
461 response_data = ResponseType(
462 status="success",
463 error=None,
464 log=None,
465 debug=None
466 )
467 try:
468 factory.build()
469 data = factory.flowsheet
470 if return_json:
471 return Response(data.model_dump_json(), status=200)
472 except Exception as e:
473 response_data["status"] = "error"
474 response_data["error"] = {
475 "message": str(e),
476 "traceback": traceback.format_exc()
477 }
478 return Response(response_data, status=400)
479 try:
480 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump())
481 return Response(response, status=200)
482 except IdaesServiceRequestException as e:
483 response = e.message
484 response_data["status"] = "error"
485 response_data["error"] = {
486 "message": response["error"],
487 "traceback": response["traceback"]
488 }
489 return Response(response_data, status=400)
491class BuildStateSolveError(Exception):
492 pass
494def build_state_request(stream: SimulationObject):
495 """Request a state build for the provided stream using the IDAES service.
497 Args:
498 stream: Stream object whose inlet properties should be used for the build.
500 Returns:
501 REST response returned by the IDAES service containing built properties.
503 Raises:
504 BuildStateSolveError: If the IDAES service rejects the build request.
505 Exception: If preparing the payload fails for any reason.
506 """
507 ctx = IdaesFactoryContext(stream.flowsheet_id)
509 try:
510 port = stream.connectedPorts.get(direction="inlet")
511 unitop: SimulationObject = port.unitOp
512 # find the property package key for this port (we have the value, the key of this port)
513 property_package_ports = unitop.schema.propertyPackagePorts
514 for key, port_list in property_package_ports.items():
515 if port.key in port_list: 515 ↛ 514line 515 didn't jump to line 514 because the condition on line 515 was always true
516 property_package_key = key
517 PropertyPackageAdapter(
518 property_package_key).serialise(ctx, unitop)
520 data = BuildStateRequestSchema(
521 property_package=ctx.property_packages[0],
522 properties=serialise_stream(ctx, stream, is_inlet=True)
523 )
524 except Exception as e:
525 raise Exception(e)
526 try:
527 response = idaes_service_request(endpoint="build_state", data=data.model_dump())
528 store_properties_schema(response["properties"], stream.flowsheet_id)
530 return Response(response, status=200)
531 except IdaesServiceRequestException as e:
532 raise BuildStateSolveError(e.message)
533 except Exception as e:
534 raise Exception(e)