Coverage for backend/django/idaes_factory/endpoints.py: 78%
241 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-12 01:47 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-02-12 01:47 +0000
1import datetime
2import json
3import logging
4import os
5from typing import TypedDict
6import requests
7import traceback
9from django.db import transaction
10from django.utils import timezone
11from opentelemetry import trace
12from rest_framework.response import Response
13from pydantic import JsonValue
15from CoreRoot import settings
16from authentication.user.models import User
17from common.models.idaes.payloads.build_state_request_schema import BuildStateRequestSchema
18from common.models.idaes.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \
19 MultiSolvePayload
20from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload
21from core.auxiliary.models.MLModel import MLModel
22from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \
23 NotificationServiceMessage
24from core.auxiliary.enums.generalEnums import TaskStatus
25from core.auxiliary.models.Task import Task, TaskType
26from core.auxiliary.serializers import TaskSerializer
27from core.exceptions import DetailedException
28from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema
29from flowsheetInternals.unitops.models.SimulationObject import SimulationObject
30from .adapters.stream_properties import serialise_stream
31from .adapters.property_package_adapter import PropertyPackageAdapter
32from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams
33from core.auxiliary.models.Scenario import Scenario, ScenarioTabTypeEnum
34from common.services import messaging
35from diagnostics.constants import DiagnosticTrigger
37logger = logging.getLogger(__name__)
38tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME)
41class IdaesServiceRequestException(Exception):
42 def __init__(self, message: str) -> None:
43 super().__init__(message)
44 self.message = message
46class SolveFlowsheetError(DetailedException):
47 pass
49class ResponseType(TypedDict, total=False):
50 status: str
51 error: dict | None
52 log: str | None
53 debug: dict | None
56def idaes_service_request(endpoint: str, data: JsonValue) -> JsonValue:
57 """Send a JSON payload to the configured IDAES service endpoint.
59 Args:
60 endpoint: Relative path of the IDAES service endpoint to call.
61 data: Serialised payload that conforms to the endpoint schema.
63 Returns:
64 Parsed JSON response returned by the IDAES service.
66 Raises:
67 IdaesServiceRequestException: If the service responds with a non-200 status.
68 """
69 url = (os.getenv('IDAES_SERVICE_URL')
70 or "http://localhost:8080") + "/" + endpoint
71 result = requests.post(url, json=data)
72 if result.status_code != 200:
73 raise IdaesServiceRequestException(result.json())
75 return result.json()
78@tracer.start_as_current_span("send_flowsheet_solve_request")
79def _solve_flowsheet_request(
80 task_id: int,
81 built_factory: IdaesFactory,
82 perform_diagnostics: bool = False,
83 high_priority: bool = False
84):
85 """Queue an IDAES solve request for the provided flowsheet build.
87 Args:
88 task_id: Identifier of the task tracking the solve request.
89 built_factory: Fully built factory containing flowsheet data to solve.
90 perform_diagnostics: Whether to request diagnostic output from IDAES.
91 high_priority: Whether the message should be prioritised over normal solves.
93 Raises:
94 SolveFlowsheetError: If the message cannot be dispatched to the queue.
95 """
97 try:
98 idaes_payload = IdaesSolveRequestPayload(
99 flowsheet=built_factory.flowsheet,
100 solve_index=built_factory.solve_index,
101 scenario_id=(built_factory.scenario.id if built_factory.scenario else None),
102 task_id=task_id,
103 perform_diagnostics=perform_diagnostics
104 )
106 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority)
107 except Exception as e:
108 raise SolveFlowsheetError(e,"idaes_factory_solve_message")
110def start_flowsheet_solve_event(
111 flowsheet_id: int,
112 user: User,
113 scenario: Scenario = None,
114 perform_diagnostics: bool = False
115) -> Response:
116 """Start a single solve for the given flowsheet and return the tracking task.
118 Args:
119 flowsheet_id: Identifier of the flowsheet that should be solved.
120 user: User initiating the solve request.
121 scenario: Optional scenario providing context for the solve.
122 perform_diagnostics: Whether the solve should run with diagnostic output enabled.
124 Returns:
125 REST response containing the serialised `Task` used to track the solve.
126 """
127 if scenario and scenario.state_name != ScenarioTabTypeEnum.SteadyState and not scenario.dataRows.exists():
128 # if not steady state solves and no data rows exist, cannot proceed
129 return Response(status=400, data=f"No data was provided for {scenario.state_name} scenario.")
131 try:
132 # Remove all previously created DynamicResults for this scenario
133 scenario.solutions.all().delete()
134 except:
135 # there's no solution yet, that's fine
136 pass
138 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True)
139 # Persist the user's intent alongside the task so downstream consumers
140 # (diagnostics orchestrator + UI) can tell whether this solve was launched
141 # with diagnostics enabled.
142 debug: dict[str, JsonValue] = {
143 **(solve_task.debug or {}),
144 "perform_diagnostics": bool(perform_diagnostics),
145 }
146 if scenario is not None:
147 debug["scenario_id"] = scenario.id
148 solve_task.debug = debug
149 solve_task.save(update_fields=["debug"])
151 try:
152 factory = IdaesFactory(
153 flowsheet_id=flowsheet_id,
154 scenario=scenario
155 )
156 factory.build()
158 # We send single solve requests as high priority to ensure
159 # they are not blocked by large multi-solve requests.
160 _solve_flowsheet_request(
161 solve_task.id,
162 factory,
163 perform_diagnostics=perform_diagnostics,
164 high_priority=True
165 )
166 except DetailedException as e:
167 solve_task.set_failure_with_exception(e, save=True)
169 task_serializer = TaskSerializer(solve_task)
171 return Response(task_serializer.data, status=200)
173def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response:
174 """Kick off a multi steady-state solve and return the parent tracking task.
176 Args:
177 flowsheet_id: Identifier of the flowsheet being solved.
178 user: User who requested the multi-solve.
179 scenario: Scenario containing the steady-state configurations to solve.
181 Returns:
182 REST response containing the parent `Task` that aggregates child solves.
183 """
184 if not scenario.dataRows.exists(): # empty rows
185 return Response(status=400, data="No data was provided for multi steady-state scenario.")
187 # Remove all previously created DynamicResults for this scenario
188 scenario.solutions.all().delete()
189 solve_iterations = scenario.dataRows.count()
191 parent_task = Task.create_parent_task(
192 creator=user,
193 flowsheet_id=flowsheet_id,
194 scheduled_tasks=solve_iterations,
195 status=TaskStatus.Running
196 )
197 child_tasks: list[Task] = [
198 Task.create(
199 user,
200 flowsheet_id,
201 parent=parent_task,
202 status=TaskStatus.Pending
203 ) for _ in range(solve_iterations)
204 ]
206 Task.objects.bulk_create(child_tasks)
208 messaging.send_dispatch_multi_solve_message(MultiSolvePayload(
209 task_id=parent_task.id,
210 scenario_id=scenario.id)
211 )
213 return Response(TaskSerializer(parent_task).data, status=200)
215def dispatch_multi_solves(parent_task_id: int, scenario_id: int):
216 """Build and dispatch queued steady-state solves for each child task.
218 Args:
219 parent_task_id: Identifier of the parent multi-solve task.
220 scenario_id: Scenario containing the data rows to iterate through.
221 """
222 parent_task = Task.objects.get(id=parent_task_id)
223 scenario = Scenario.objects.get(id=scenario_id)
225 factory = IdaesFactory(
226 flowsheet_id=parent_task.flowsheet_id,
227 scenario=scenario,
228 )
230 child_tasks = list(parent_task.children.order_by('start_time'))
232 for solve_index, task in enumerate(child_tasks):
233 try:
234 factory.clear_flowsheet()
235 factory.use_with_solve_index(solve_index)
236 factory.build()
238 _solve_flowsheet_request(
239 task.id,
240 factory
241 )
242 except DetailedException as e:
243 task.set_failure_with_exception(exception=e, save=True)
244 parent_task.update_status_from_child(task)
246 flowsheet_messages = [
247 NotificationServiceMessage(
248 data=TaskSerializer(task).data,
249 message_type=NotificationServiceMessageType.TASK_UPDATED
250 ) for task in [task, parent_task]
251 ]
253 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages)
256def start_ml_training_event(
257 datapoints: list,
258 columns: list[str],
259 input_labels: list[str],
260 output_labels: list[str],
261 user: User,
262 flowsheet_id: int,
263):
264 """Queue an asynchronous machine-learning training job for the given dataset.
266 Args:
267 datapoints: Training rows to supply to the ML service.
268 columns: Column names describing the datapoint ordering.
269 input_labels: Names of the input features.
270 output_labels: Names of the predicted outputs.
271 user: User requesting the training run.
272 flowsheet_id: Flowsheet the training run is associated with.
274 Returns:
275 REST response containing the serialised `Task` for the training job.
276 """
277 training_task = Task.create(
278 user,
279 flowsheet_id,
280 task_type=TaskType.ML_TRAINING,
281 status=TaskStatus.Pending,
282 save=True
283 )
285 try:
286 payload = MLTrainRequestPayload(
287 datapoints=datapoints,
288 columns=columns,
289 input_labels=input_labels,
290 output_labels=output_labels,
291 task_id=training_task.id,
292 )
293 messaging.send_ml_training_message(payload)
295 except DetailedException as e:
296 training_task.set_failure_with_exception(e, save=True)
298 task_serializer = TaskSerializer(training_task)
299 return Response(task_serializer.data, status=200)
301def _send_task_notifications(task: Task):
302 """Broadcast task completion or status updates to interested flowsheet clients.
304 Args:
305 task: Task whose status change should be pushed to subscribers.
306 """
307 flowsheet_messages = []
309 # If this is a child task, update the parent task status
310 if task.parent:
311 task.parent.update_status_from_child(task)
313 message_type = (NotificationServiceMessageType.TASK_COMPLETED
314 if task.parent.status is TaskStatus.Completed else
315 NotificationServiceMessageType.TASK_UPDATED)
317 flowsheet_messages.append(NotificationServiceMessage(
318 data=TaskSerializer(task.parent).data,
319 message_type=message_type
320 ))
322 flowsheet_messages.append(NotificationServiceMessage(
323 data=TaskSerializer(task).data,
324 message_type=NotificationServiceMessageType.TASK_COMPLETED
325 ))
327 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages)
329def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
330 """Persist the outcome of a completed IDAES solve and notify listeners.
332 Args:
333 solve_response: Payload describing the finished solve result.
334 """
335 # Use a transaction to ensure that either everything succeeds or nothing does
336 with transaction.atomic():
338 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
340 # Silently ignore if the task has already been marked as completed.
341 # This allows us to simulate exactly-once delivery semantics (only process
342 # a finished task once).
343 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled:
344 return
346 task.completed_time = timezone.now()
347 task.log = solve_response.log
349 # Store timing and diagnostics_raw_text from idaes_service in task.debug.
350 # NOTE: diagnostics_blob was removed; IDAES DiagnosticsToolbox only outputs
351 # plain text which we capture as diagnostics_raw_text.
352 debug: dict[str, JsonValue] = {
353 **(task.debug or {}),
354 "timing": solve_response.timing,
355 }
356 if getattr(solve_response, "diagnostics_raw_text", None):
357 debug["diagnostics_raw_text"] = solve_response.diagnostics_raw_text
358 task.debug = debug
360 if solve_response.status == "success":
361 task.status = TaskStatus.Completed
362 else:
363 task.status = TaskStatus.Failed
364 task.error = {
365 "message": solve_response.error["message"],
366 "cause": "idaes_service_request",
367 "traceback": solve_response.traceback
368 }
370 task.save(update_fields=["status", "completed_time", "log", "debug", "error"])
372 # If a solve fails, I automatically run the deterministic diagnostics rulesets
373 # so the Diagnostics tab has something actionable without the user needing to
374 # manually re-run anything.
375 if task.status == TaskStatus.Failed:
376 try:
377 from diagnostics.orchestrator import build_failure_bundle_from_payload, run_diagnostics_for_task
379 failure_bundle = build_failure_bundle_from_payload(solve_response)
380 run_diagnostics_for_task(task, failure_bundle, trigger=DiagnosticTrigger.SOLVE_FAILURE)
381 except Exception:
382 # Diagnostics should never prevent the original task result from being persisted.
383 logger.exception("Diagnostics run failed for task %s (SOLVE_FAILURE trigger)", task.id)
384 pass
385 # If the user explicitly requested diagnostics for this solve, also persist
386 # a DiagnosticRun on success so the Diagnostics tab can show findings/evidence.
387 #
388 # NOTE: We intentionally do NOT create a second "manual" run for failures
389 # to avoid duplicates; failures are handled by the SOLVE_FAILURE trigger above.
390 elif task.status == TaskStatus.Completed: 390 ↛ 409line 390 didn't jump to line 409 because the condition on line 390 was always true
391 try:
392 requested = False
393 if isinstance(task.debug, dict): 393 ↛ 395line 393 didn't jump to line 395 because the condition on line 393 was always true
394 requested = bool(task.debug.get("perform_diagnostics"))
395 if requested:
396 from diagnostics.orchestrator import build_failure_bundle_from_payload, run_diagnostics_for_task
398 failure_bundle = build_failure_bundle_from_payload(
399 solve_response,
400 trigger=DiagnosticTrigger.MANUAL,
401 )
402 run_diagnostics_for_task(task, failure_bundle, trigger=DiagnosticTrigger.MANUAL)
403 except Exception:
404 # Diagnostics should never prevent the original task result from being persisted.
405 logger.exception("Diagnostics run failed for task %s (MANUAL trigger)", task.id)
406 pass
408 # Save the solved flowsheet values
409 if task.status == TaskStatus.Completed:
410 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index)
412 # For now, only save initial values for single and dynamic solves, not MSS
413 # In future, we may need some more complex logic to handle MSS initial values
414 if solve_response.solve_index is None:
415 save_all_initial_values(solve_response.flowsheet.initial_values)
417 _send_task_notifications(task)
419def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
420 """Handle final failure notifications for solves that could not be processed.
422 Args:
423 solve_response: Completion payload received from the dead-letter queue.
424 """
425 # Use a transaction to ensure that either everything succeeds or nothing does
426 with transaction.atomic():
427 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
429 # Silently ignore if the task has already been marked as failed.
430 # This allows us to simulate exactly-once delivery semantics (only process
431 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode.
432 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled:
433 return
435 task.completed_time = timezone.now()
436 task.log = solve_response.log
437 task.status = TaskStatus.Failed
438 task.error = {
439 "message": "Internal server error: several attempts to process finished solve failed."
440 }
441 task.save()
443 _send_task_notifications(task)
445def process_ml_training_response(
446 ml_training_response: MLTrainingCompletionPayload
447):
448 """Persist the result of a machine-learning training job and send updates.
450 Args:
451 ml_training_response: Completion payload returned by the ML service.
452 """
453 with transaction.atomic():
454 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id)
456 # Silently ignore if the task has already been marked as completed.
457 # This allows us to simulate exactly-once delivery semantics (only process
458 # a finished task once).
459 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 459 ↛ 460line 459 didn't jump to line 460 because the condition on line 459 was never true
460 return
462 task.completed_time = timezone.now()
463 task.log = ml_training_response.log
465 if ml_training_response.status == "success": 465 ↛ 471line 465 didn't jump to line 471 because the condition on line 465 was always true
466 task.status = TaskStatus.Completed
467 task.debug = {
468 "timing": ml_training_response.json_response.get("timing", {})
469 }
470 else:
471 task.status = TaskStatus.Failed
472 task.error = {
473 "message": ml_training_response.error,
474 "traceback": ml_training_response.traceback
475 }
477 task.save()
478 result = ml_training_response.json_response
480 MLModel.objects.update(
481 surrogate_model=result.get("surrogate_model"),
482 charts=result.get("charts"),
483 metrics=result.get("metrics"),
484 test_inputs=result.get("test_inputs"),
485 test_outputs=result.get("test_outputs"),
486 progress=3
487 )
488 _send_task_notifications(task)
491def cancel_idaes_solve(task_id: int):
492 """Mark an in-flight solve task as cancelled and notify subscribers.
494 Args:
495 task_id: Identifier of the `Task` being cancelled.
496 """
497 with transaction.atomic():
498 task = Task.objects.select_related("parent").get(id=task_id)
500 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set
501 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending: 501 ↛ 502line 501 didn't jump to line 502 because the condition on line 501 was never true
502 return
504 task.status = TaskStatus.Cancelled
505 task.save()
507 messaging.send_flowsheet_notification_message(
508 task.flowsheet_id,
509 TaskSerializer(task).data,
510 NotificationServiceMessageType.TASK_CANCELLED
511 )
513def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response:
514 """Generate Python IDAES code for a flowsheet by calling the IDAES service.
516 Args:
517 flowsheet_id: Identifier of the flowsheet to translate to Python code.
518 return_json: Whether to bypass the remote call and return raw JSON.
520 Returns:
521 REST response containing either the generated code or the flowsheet JSON.
522 """
523 scenario = None
524 factory = IdaesFactory(flowsheet_id, scenario=scenario, require_variables_fixed=False)
525 response_data = ResponseType(
526 status="success",
527 error=None,
528 log=None,
529 debug=None
530 )
531 try:
532 factory.build()
533 data = factory.flowsheet
534 if return_json:
535 return Response(data.model_dump_json(), status=200)
536 except Exception as e:
537 response_data["status"] = "error"
538 response_data["error"] = {
539 "message": str(e),
540 "traceback": traceback.format_exc()
541 }
542 return Response(response_data, status=400)
543 try:
544 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump())
545 return Response(response, status=200)
546 except IdaesServiceRequestException as e:
547 response = e.message
548 response_data["status"] = "error"
549 response_data["error"] = {
550 "message": response["error"],
551 "traceback": response["traceback"]
552 }
553 return Response(response_data, status=400)
555class BuildStateSolveError(Exception):
556 pass
558def build_state_request(stream: SimulationObject):
559 """Request a state build for the provided stream using the IDAES service.
561 Args:
562 stream: Stream object whose inlet properties should be used for the build.
564 Returns:
565 REST response returned by the IDAES service containing built properties.
567 Raises:
568 BuildStateSolveError: If the IDAES service rejects the build request.
569 Exception: If preparing the payload fails for any reason.
570 """
571 ctx = IdaesFactoryContext(stream.flowsheet_id)
573 try:
574 port = stream.connectedPorts.get(direction="inlet")
575 unitop: SimulationObject = port.unitOp
576 # find the property package key for this port (we have the value, the key of this port)
577 property_package_ports = unitop.schema.propertyPackagePorts
578 for key, port_list in property_package_ports.items():
579 if port.key in port_list:
580 property_package_key = key
581 PropertyPackageAdapter(
582 property_package_key).serialise(ctx, unitop)
584 data = BuildStateRequestSchema(
585 property_package=ctx.property_packages[0],
586 properties=serialise_stream(ctx, stream, is_inlet=True)
587 )
588 except Exception as e:
589 raise Exception(e)
590 try:
591 response = idaes_service_request(endpoint="build_state", data=data.model_dump())
592 store_properties_schema(response["properties"], stream.flowsheet_id)
594 return Response(response, status=200)
595 except IdaesServiceRequestException as e:
596 raise BuildStateSolveError(e.message)
597 except Exception as e:
598 raise Exception(e)