Coverage for backend/idaes_service/endpoints.py: 87%
69 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
1import io
2import sys
3import traceback
5from abc import ABC, abstractmethod
6from pydantic import TypeAdapter
7from typing import Callable, Any
8from fastapi import FastAPI, Response
10from common.models.idaes import FlowsheetSchema
11from common.models.idaes.payloads import BuildStateRequestSchema, BuildStateResponseSchema
12from common.models.idaes.payloads.solve_request_schema import IdaesSolveCompletionPayload, IdaesSolveRequestPayload
13from common.models.idaes.unit_model_schema import SolvedPropertyValueSchema
14from .solver.solver import solve_model, SolveModelResult
15from .solver.build_state import solve_state_block
16from .solver.generate_python_file import generate_python_code
17from .solver.ml_wizard import ml_generate, MLResult
18from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingEvent
19import idaes.logger as idaeslog
20import logging
21from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload
22from typing import Tuple
23from functools import wraps
24from common.models.idaes.payloads.solve_request_schema import IdaesSolveEvent, IdaesSolveRequestPayload
26class IOCapture:
27 """
28 Class to capture the output of a function and return it along with the result.
29 This is used to capture the output of the function and return it along with the result.
30 """
31 def __init__(self):
32 self._captured_output = io.StringIO()
34 def get_output(self) -> str:
35 return self._captured_output.getvalue()
37 def start_capture(self):
38 """
39 Start capturing the output of the function.
40 """
41 sys.stdout = self._captured_output
42 # Can't capture the output of the init logger, so we need to add a new handler
43 stringio_handler = logging.StreamHandler(self._captured_output)
44 formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s')
45 stringio_handler.setFormatter(formatter)
46 inithandler = idaeslog.getInitLogger("init").logger.parent
47 # Clearing existing handlers
48 inithandler.handlers=[]
49 # Add our custom handlers
50 inithandler.addHandler(stringio_handler)
52 return self # allow chaining
54 def stop_capture(self):
55 """
56 Stop capturing the output of the function.
57 """
58 sys.stdout = sys.__stdout__
59 # Print the captured output to the terminal
60 print(self._captured_output.getvalue(), end='')
62def solve_endpoint(solve_request: IdaesSolveRequestPayload, response: Response) -> IdaesSolveCompletionPayload:
63 logs = IOCapture().start_capture()
64 try:
65 solve_result : SolveModelResult = solve_model(solve_request)
66 return IdaesSolveCompletionPayload(
67 flowsheet=solve_result.output_flowsheet,
68 input_flowsheet=solve_result.input_flowsheet,
69 timing=solve_result.timing,
70 log=logs.get_output(),
71 traceback=None,
72 solve_index=solve_result.solve_index,
73 scenario_id=solve_result.scenario_id,
74 task_id=solve_result.task_id,
75 status="success",
76 error=None,
77 )
78 except Exception as error:
79 return IdaesSolveCompletionPayload(
80 flowsheet=None,
81 input_flowsheet=solve_request.flowsheet,
82 timing={},
83 log=logs.get_output(),
84 traceback=None,
85 solve_index=solve_request.solve_index,
86 scenario_id=solve_request.scenario_id,
87 task_id=solve_request.task_id,
88 status="error",
89 # TODO: Why are we using the traceback twice? does this need to be a subobject? see also endpoints.py in ../idaes_factory, LogPanel.tsx, and LogsSlice.tsx.
90 error={
91 "message": str(error),
92 "cause": "idaes_service_request",
93 "traceback": traceback.format_exc(),
94 },
95 )
96 finally:
97 logs.stop_capture()
101def build_state_endpoint(schema: BuildStateRequestSchema,response: Response) -> BuildStateResponseSchema:
102 try:
103 output: list[SolvedPropertyValueSchema] = solve_state_block(schema)
104 response.status_code = 200
105 return BuildStateResponseSchema(
106 properties=output,
107 error=None,
108 traceback=None,
109 log=None
110 )
111 except Exception as error:
112 response.status_code = 400
113 return BuildStateResponseSchema(
114 properties=None,
115 error=str(error),
116 traceback=traceback.format_exc(),
117 log=None
118 )
120def generate_python_code_endpoint(flowsheet: FlowsheetSchema, response: Response) -> str:
121 try:
122 return generate_python_code(flowsheet)
123 except Exception as error:
124 response.status_code = 400
125 return f"{type(error).__name__}: {error}\n{traceback.format_exc()}"
128def ml_endpoint(ml_request: MLTrainRequestPayload, response: Response) -> MLTrainingCompletionPayload:
129 logs = IOCapture().start_capture()
130 try:
131 result = ml_generate(ml_request)
132 return MLTrainingCompletionPayload(
133 json_response=result.model_dump(),
134 error=None,
135 traceback=None,
136 log=logs.get_output(),
137 task_id=ml_request.task_id,
138 status="success"
139 )
140 except Exception as error:
141 return MLTrainingCompletionPayload(
142 json_response=None,
143 error=str(error),
144 traceback=traceback.format_exc(),
145 log=logs.get_output(),
146 task_id=ml_request.task_id,
147 status="error"
148 )
149 finally:
150 logs.stop_capture()