Coverage for backend/django/idaes_factory/endpoints.py: 78%
219 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
1import datetime
2import json
3import logging
4import os
5from typing import TypedDict
6import requests
7import traceback
9from django.db import transaction
10from django.utils import timezone
11from opentelemetry import trace
12from rest_framework.response import Response
13from pydantic import JsonValue
15from CoreRoot import settings
16from authentication.user.models import User
17from ahuora_builder_types.payloads.build_state_request_schema import BuildStateRequestSchema
18from ahuora_builder_types.payloads.solve_request_schema import IdaesSolveRequestPayload, IdaesSolveCompletionPayload, \
19 MultiSolvePayload
20from ahuora_builder_types.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload
21from core.auxiliary.models.MLModel import MLModel
22from common.models.notifications.payloads import TaskCompletedPayload, NotificationServiceMessageType, \
23 NotificationServiceMessage
24from core.auxiliary.enums.generalEnums import TaskStatus
25from core.auxiliary.models.Task import Task, TaskType
26from core.auxiliary.models.Flowsheet import Flowsheet
27from core.auxiliary.serializers import TaskSerializer
28from core.exceptions import DetailedException
29from .idaes_factory import IdaesFactory, save_all_initial_values, store_properties_schema
30from flowsheetInternals.unitops.models.SimulationObject import SimulationObject
31from .adapters.stream_properties import serialise_stream
32from .adapters.property_package_adapter import PropertyPackageAdapter
33from .idaes_factory_context import IdaesFactoryContext, LiveSolveParams
34from core.auxiliary.models.Scenario import Scenario, ScenarioTabTypeEnum
35from common.services import messaging
37logger = logging.getLogger(__name__)
38tracer = trace.get_tracer(settings.OPEN_TELEMETRY_TRACER_NAME)
41class IdaesServiceRequestException(Exception):
42 def __init__(self, message: str) -> None:
43 super().__init__(message)
44 self.message = message
46class SolveFlowsheetError(DetailedException):
47 pass
49class ResponseType(TypedDict, total=False):
50 status: str
51 error: dict | None
52 log: str | None
53 debug: dict | None
56def idaes_service_request(endpoint: str, data: JsonValue) -> JsonValue:
57 """Send a JSON payload to the configured IDAES service endpoint.
59 Args:
60 endpoint: Relative path of the IDAES service endpoint to call.
61 data: Serialised payload that conforms to the endpoint schema.
63 Returns:
64 Parsed JSON response returned by the IDAES service.
66 Raises:
67 IdaesServiceRequestException: If the service responds with a non-200 status.
68 """
69 url = (os.getenv('IDAES_SERVICE_URL')
70 or "http://localhost:8080") + "/" + endpoint
71 result = requests.post(url, json=data)
72 if result.status_code != 200:
73 raise IdaesServiceRequestException(result.json())
75 return result.json()
78@tracer.start_as_current_span("send_flowsheet_solve_request")
79def _solve_flowsheet_request(
80 task_id: int,
81 built_factory: IdaesFactory,
82 perform_diagnostics: bool = False,
83 high_priority: bool = False
84):
85 """Queue an IDAES solve request for the provided flowsheet build.
87 Args:
88 task_id: Identifier of the task tracking the solve request.
89 built_factory: Fully built factory containing flowsheet data to solve.
90 perform_diagnostics: Whether to request diagnostic output from IDAES.
91 high_priority: Whether the message should be prioritised over normal solves.
93 Raises:
94 SolveFlowsheetError: If the message cannot be dispatched to the queue.
95 """
97 try:
98 idaes_payload = IdaesSolveRequestPayload(
99 flowsheet=built_factory.flowsheet,
100 solve_index=built_factory.solve_index,
101 scenario_id=(built_factory.scenario.id if built_factory.scenario else None),
102 task_id=task_id,
103 perform_diagnostics=perform_diagnostics
104 )
106 messaging.send_idaes_solve_message(idaes_payload, high_priority=high_priority)
107 except Exception as e:
108 raise SolveFlowsheetError(e,"idaes_factory_solve_message")
110def start_flowsheet_solve_event(
111 flowsheet_id: int,
112 group_id: int,
113 user: User,
114 scenario: Scenario = None,
115 perform_diagnostics: bool = False
116) -> Response:
117 """Start a single solve for the given flowsheet and return the tracking task.
119 Args:
120 flowsheet_id: Identifier of the flowsheet that should be solved.
121 user: User initiating the solve request.
122 scenario: Optional scenario providing context for the solve.
123 perform_diagnostics: Whether the solve should run with diagnostic output enabled.
125 Returns:
126 REST response containing the serialised `Task` used to track the solve.
127 """
128 if scenario and scenario.state_name != ScenarioTabTypeEnum.SteadyState and not scenario.dataRows.exists(): 128 ↛ 130line 128 didn't jump to line 130 because the condition on line 128 was never true
129 # if not steady state solves and no data rows exist, cannot proceed
130 return Response(status=400, data=f"No data was provided for {scenario.state_name} scenario.")
132 try:
133 # Remove all previously created DynamicResults for this scenario
134 scenario.solutions.all().delete()
135 except:
136 # there's no solution yet, that's fine
137 pass
139 solve_task = Task.create(user, flowsheet_id, status=TaskStatus.Pending, save=True)
140 # Persist the user's intent alongside the task so downstream consumers
141 # can tell whether this solve was launched with diagnostics enabled.
142 debug: dict[str, JsonValue] = {
143 **(solve_task.debug or {}),
144 "perform_diagnostics": bool(perform_diagnostics),
145 }
146 if scenario is not None:
147 debug["scenario_id"] = scenario.id
148 solve_task.debug = debug
149 solve_task.save(update_fields=["debug"])
151 try:
152 factory = IdaesFactory(
153 group_id=group_id,
154 scenario=scenario
155 )
156 factory.build()
158 # We send single solve requests as high priority to ensure
159 # they are not blocked by large multi-solve requests.
160 _solve_flowsheet_request(
161 solve_task.id,
162 factory,
163 perform_diagnostics=perform_diagnostics,
164 high_priority=True
165 )
166 except DetailedException as e:
167 solve_task.set_failure_with_exception(e, save=True)
169 task_serializer = TaskSerializer(solve_task)
171 return Response(task_serializer.data, status=200)
173def start_multi_steady_state_solve_event(flowsheet_id: int, user: User, scenario: Scenario) -> Response:
174 """Kick off a multi steady-state solve and return the parent tracking task.
176 Args:
177 flowsheet_id: Identifier of the flowsheet being solved.
178 user: User who requested the multi-solve.
179 scenario: Scenario containing the steady-state configurations to solve.
181 Returns:
182 REST response containing the parent `Task` that aggregates child solves.
183 """
184 if not scenario.dataRows.exists(): # empty rows 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 return Response(status=400, data="No data was provided for multi steady-state scenario.")
187 # Remove all previously created DynamicResults for this scenario
188 scenario.solutions.all().delete()
189 solve_iterations = scenario.dataRows.count()
191 parent_task = Task.create_parent_task(
192 creator=user,
193 flowsheet_id=flowsheet_id,
194 scheduled_tasks=solve_iterations,
195 status=TaskStatus.Running
196 )
197 child_tasks: list[Task] = [
198 Task.create(
199 user,
200 flowsheet_id,
201 parent=parent_task,
202 status=TaskStatus.Pending
203 ) for _ in range(solve_iterations)
204 ]
206 Task.objects.bulk_create(child_tasks)
208 messaging.send_dispatch_multi_solve_message(MultiSolvePayload(
209 task_id=parent_task.id,
210 scenario_id=scenario.id)
211 )
213 return Response(TaskSerializer(parent_task).data, status=200)
215def dispatch_multi_solves(parent_task_id: int, scenario_id: int):
216 """Build and dispatch queued steady-state solves for each child task.
218 Args:
219 parent_task_id: Identifier of the parent multi-solve task.
220 scenario_id: Scenario containing the data rows to iterate through.
221 """
222 parent_task = Task.objects.get(id=parent_task_id)
223 scenario = Scenario.objects.get(id=scenario_id)
225 rootGroup = scenario.flowsheet.rootGrouping
226 factory = IdaesFactory(
227 group_id=rootGroup.id,
228 scenario=scenario,
229 )
231 child_tasks = list(parent_task.children.order_by('start_time'))
233 for solve_index, task in enumerate(child_tasks):
234 try:
235 factory.clear_flowsheet()
236 factory.use_with_solve_index(solve_index)
237 factory.build()
239 _solve_flowsheet_request(
240 task.id,
241 factory
242 )
243 except DetailedException as e:
244 task.set_failure_with_exception(exception=e, save=True)
245 parent_task.update_status_from_child(task)
247 flowsheet_messages = [
248 NotificationServiceMessage(
249 data=TaskSerializer(task).data,
250 message_type=NotificationServiceMessageType.TASK_UPDATED
251 ) for task in [task, parent_task]
252 ]
254 messaging.send_flowsheet_notification_messages(parent_task.flowsheet_id, flowsheet_messages)
257def start_ml_training_event(
258 datapoints: list,
259 columns: list[str],
260 input_labels: list[str],
261 output_labels: list[str],
262 user: User,
263 flowsheet_id: int,
264):
265 """Queue an asynchronous machine-learning training job for the given dataset.
267 Args:
268 datapoints: Training rows to supply to the ML service.
269 columns: Column names describing the datapoint ordering.
270 input_labels: Names of the input features.
271 output_labels: Names of the predicted outputs.
272 user: User requesting the training run.
273 flowsheet_id: Flowsheet the training run is associated with.
275 Returns:
276 REST response containing the serialised `Task` for the training job.
277 """
278 training_task = Task.create(
279 user,
280 flowsheet_id,
281 task_type=TaskType.ML_TRAINING,
282 status=TaskStatus.Pending,
283 save=True
284 )
286 try:
287 payload = MLTrainRequestPayload(
288 datapoints=datapoints,
289 columns=columns,
290 input_labels=input_labels,
291 output_labels=output_labels,
292 task_id=training_task.id,
293 )
294 messaging.send_ml_training_message(payload)
296 except DetailedException as e:
297 training_task.set_failure_with_exception(e, save=True)
299 task_serializer = TaskSerializer(training_task)
300 return Response(task_serializer.data, status=200)
302def _send_task_notifications(task: Task):
303 """Broadcast task completion or status updates to interested flowsheet clients.
305 Args:
306 task: Task whose status change should be pushed to subscribers.
307 """
308 flowsheet_messages = []
310 # If this is a child task, update the parent task status
311 if task.parent:
312 task.parent.update_status_from_child(task)
314 message_type = (NotificationServiceMessageType.TASK_COMPLETED
315 if task.parent.status is TaskStatus.Completed else
316 NotificationServiceMessageType.TASK_UPDATED)
318 flowsheet_messages.append(NotificationServiceMessage(
319 data=TaskSerializer(task.parent).data,
320 message_type=message_type
321 ))
323 flowsheet_messages.append(NotificationServiceMessage(
324 data=TaskSerializer(task).data,
325 message_type=NotificationServiceMessageType.TASK_COMPLETED
326 ))
328 messaging.send_flowsheet_notification_messages(task.flowsheet_id, flowsheet_messages)
330def process_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
331 """Persist the outcome of a completed IDAES solve and notify listeners.
333 Args:
334 solve_response: Payload describing the finished solve result.
335 """
336 # Use a transaction to ensure that either everything succeeds or nothing does
337 with transaction.atomic():
339 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
341 # Silently ignore if the task has already been marked as completed.
342 # This allows us to simulate exactly-once delivery semantics (only process
343 # a finished task once).
344 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled:
345 return
347 task.completed_time = timezone.now()
348 task.log = solve_response.log
349 task.debug = {
350 **(task.debug or {}),
351 "timing": solve_response.timing or {},
352 }
354 if solve_response.status == "success":
355 task.status = TaskStatus.Completed
356 else:
357 task.status = TaskStatus.Failed
358 task.error = {
359 "message": solve_response.error["message"],
360 "cause": "idaes_service_request",
361 "traceback": solve_response.traceback
362 }
364 task.save(update_fields=["status", "completed_time", "log", "debug", "error"])
366 # Save the solved flowsheet values
367 if task.status == TaskStatus.Completed:
368 store_properties_schema(solve_response.flowsheet.properties, task.flowsheet_id, solve_response.scenario_id, solve_response.solve_index)
370 # For now, only save initial values for single and dynamic solves, not MSS
371 # In future, we may need some more complex logic to handle MSS initial values
372 if solve_response.solve_index is None:
373 save_all_initial_values(solve_response.flowsheet.initial_values)
375 _send_task_notifications(task)
377def process_failed_idaes_solve_response(solve_response: IdaesSolveCompletionPayload):
378 """Handle final failure notifications for solves that could not be processed.
380 Args:
381 solve_response: Completion payload received from the dead-letter queue.
382 """
383 # Use a transaction to ensure that either everything succeeds or nothing does
384 with transaction.atomic():
385 task = Task.objects.select_related("parent").get(id=solve_response.task_id)
387 # Silently ignore if the task has already been marked as failed.
388 # This allows us to simulate exactly-once delivery semantics (only process
389 # a failed task once). Our dead letter queue is configured in "at least once" delivery mode.
390 if task.status == TaskStatus.Failed or task.status == TaskStatus.Cancelled:
391 return
393 task.completed_time = timezone.now()
394 task.log = solve_response.log
395 task.status = TaskStatus.Failed
396 task.error = {
397 "message": "Internal server error: several attempts to process finished solve failed."
398 }
399 task.save()
401 _send_task_notifications(task)
403def process_ml_training_response(
404 ml_training_response: MLTrainingCompletionPayload
405):
406 """Persist the result of a machine-learning training job and send updates.
408 Args:
409 ml_training_response: Completion payload returned by the ML service.
410 """
411 with transaction.atomic():
412 task = Task.objects.select_related("parent").get(id=ml_training_response.task_id)
414 # Silently ignore if the task has already been marked as completed.
415 # This allows us to simulate exactly-once delivery semantics (only process
416 # a finished task once).
417 if task.status == TaskStatus.Completed or task.status == TaskStatus.Cancelled: 417 ↛ 418line 417 didn't jump to line 418 because the condition on line 417 was never true
418 return
420 task.completed_time = timezone.now()
421 task.log = ml_training_response.log
423 if ml_training_response.status == "success": 423 ↛ 429line 423 didn't jump to line 429 because the condition on line 423 was always true
424 task.status = TaskStatus.Completed
425 task.debug = {
426 "timing": ml_training_response.json_response.get("timing", {})
427 }
428 else:
429 task.status = TaskStatus.Failed
430 task.error = {
431 "message": ml_training_response.error,
432 "traceback": ml_training_response.traceback
433 }
435 task.save()
436 result = ml_training_response.json_response
438 MLModel.objects.update(
439 surrogate_model=result.get("surrogate_model"),
440 charts=result.get("charts"),
441 metrics=result.get("metrics"),
442 test_inputs=result.get("test_inputs"),
443 test_outputs=result.get("test_outputs"),
444 progress=3
445 )
446 _send_task_notifications(task)
449def cancel_idaes_solve(task_id: int):
450 """Mark an in-flight solve task as cancelled and notify subscribers.
452 Args:
453 task_id: Identifier of the `Task` being cancelled.
454 """
455 with transaction.atomic():
456 task = Task.objects.select_related("parent").get(id=task_id)
458 # Ignore cancellation request if a final status (e.g. completed or failed) has already been set
459 if task.status != TaskStatus.Running and task.status != TaskStatus.Pending: 459 ↛ 460line 459 didn't jump to line 460 because the condition on line 459 was never true
460 return
462 task.status = TaskStatus.Cancelled
463 task.save()
465 messaging.send_flowsheet_notification_message(
466 task.flowsheet_id,
467 TaskSerializer(task).data,
468 NotificationServiceMessageType.TASK_CANCELLED
469 )
471def generate_IDAES_python_request(flowsheet_id: int, return_json: bool = False) -> Response:
472 """Generate Python IDAES code for a flowsheet by calling the IDAES service.
474 Args:
475 flowsheet_id: Identifier of the flowsheet to translate to Python code.
476 return_json: Whether to bypass the remote call and return raw JSON.
478 Returns:
479 REST response containing either the generated code or the flowsheet JSON.
480 """
481 scenario = None
482 flowsheet = Flowsheet.objects.get(id=flowsheet_id)
483 factory = IdaesFactory(group_id=flowsheet.rootGrouping.id, scenario=scenario, require_variables_fixed=False)
484 response_data = ResponseType(
485 status="success",
486 error=None,
487 log=None,
488 debug=None
489 )
490 try:
491 factory.build()
492 data = factory.flowsheet
493 if return_json:
494 return Response(data.model_dump_json(), status=200)
495 except Exception as e:
496 response_data["status"] = "error"
497 response_data["error"] = {
498 "message": str(e),
499 "traceback": traceback.format_exc()
500 }
501 return Response(response_data, status=400)
502 try:
503 response = idaes_service_request(endpoint="generate_python_code", data=data.model_dump())
504 return Response(response, status=200)
505 except IdaesServiceRequestException as e:
506 response = e.message
507 response_data["status"] = "error"
508 response_data["error"] = {
509 "message": response["error"],
510 "traceback": response["traceback"]
511 }
512 return Response(response_data, status=400)
514class BuildStateSolveError(Exception):
515 pass
517def state_request_build(stream: SimulationObject) -> BuildStateRequestSchema:
518 ctx = IdaesFactoryContext(stream.flowsheet.rootGrouping.id)
520 port = stream.connectedPorts.get(direction="inlet")
521 unitop: SimulationObject = port.unitOp
522 # find the property package key for this port (we have the value, the key of this port)
523 property_package_ports = unitop.schema.propertyPackagePorts
524 for key, port_list in property_package_ports.items():
525 if port.key in port_list:
526 property_package_key = key
527 PropertyPackageAdapter(
528 property_package_key).serialise(ctx, unitop)
530 return BuildStateRequestSchema(
531 property_package=ctx.property_packages[0],
532 properties=serialise_stream(ctx, stream, is_inlet=True)
533 )
535def build_state_request(stream: SimulationObject):
536 """Request a state build for the provided stream using the IDAES service.
538 Args:
539 stream: Stream object whose inlet properties should be used for the build.
541 Returns:
542 REST response returned by the IDAES service containing built properties.
544 Raises:
545 BuildStateSolveError: If the IDAES service rejects the build request.
546 Exception: If preparing the payload fails for any reason.
547 """
549 try:
550 data = state_request_build(stream)
551 response = idaes_service_request(endpoint="build_state", data=data.model_dump())
552 store_properties_schema(response["properties"], stream.flowsheet_id)
554 return Response(response, status=200)
555 except IdaesServiceRequestException as e:
556 raise BuildStateSolveError(e.message)
557 except Exception as e:
558 raise Exception(e)