Coverage for backend/idaes_service/endpoints.py: 72%

81 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-02-11 21:43 +0000

1import io 

2import sys 

3import traceback 

4import logging 

5from abc import ABC, abstractmethod 

6from typing import TYPE_CHECKING 

7 

8from pydantic import TypeAdapter 

9import idaes.logger as idaeslog 

10 

11if TYPE_CHECKING: 

12 from fastapi import Response 

13 

14from common.models.idaes import FlowsheetSchema 

15from common.models.idaes.payloads.build_state_request_schema import BuildStateRequestSchema, BuildStateResponseSchema 

16from common.models.idaes.payloads.solve_request_schema import ( 

17 IdaesSolveCompletionPayload, 

18 IdaesSolveRequestPayload, 

19 IdaesSolveEvent, 

20 MultiSolvePayload, 

21 DispatchMultiSolveEvent 

22) 

23from common.models.idaes.payloads.ml_request_schema import ( 

24 MLTrainRequestPayload, 

25 MLTrainingCompletionPayload 

26) 

27from common.models.idaes.unit_model_schema import SolvedPropertyValueSchema 

28from .solver.solver import solve_model, SolveModelResult 

29from .solver.build_state import solve_state_block 

30from .solver.generate_python_file import generate_python_code 

31from .solver.ml_wizard import ml_generate, MLResult 

32 

33class IOCapture: 

34 """ 

35 Class to capture the output of a function and return it along with the result. 

36 This is used to capture the output of the function and return it along with the result. 

37 """ 

38 

39 # These markers are printed by `DiagnosticsToolbox` and we use them to split the 

40 # diagnostics block out of the full solver log. 

41 DIAGNOSTICS_START_MARKER = "=== DIAGNOSTICS ===" 

42 DIAGNOSTICS_END_MARKER = "=== END DIAGNOSTICS ===" 

43 

44 def __init__(self): 

45 self._captured_output = io.StringIO() 

46 

47 @staticmethod 

48 def _extract_marked_section(text: str, start_marker: str, end_marker: str) -> str | None: 

49 """ 

50 Extract a section of stdout between two markers. 

51 

52 This pulls the `DiagnosticsToolbox` block out of the full solver log so we can 

53 store it separately (and parse it later) without changing normal solve logs. 

54 """ 

55 if not text: 

56 return None 

57 

58 start_idx = text.find(start_marker) 

59 if start_idx == -1: 

60 return None 

61 

62 end_idx = text.find(end_marker, start_idx + len(start_marker)) 

63 if end_idx == -1: 

64 return None 

65 

66 section = text[start_idx + len(start_marker):end_idx] 

67 section = section.strip("\n") 

68 return section or None 

69 

70 def get_output(self) -> str: 

71 return self._captured_output.getvalue() 

72 

73 def start_capture(self): 

74 """ 

75 Start capturing the output of the function. 

76 """ 

77 sys.stdout = self._captured_output 

78 # Can't capture the output of the init logger, so we need to add a new handler 

79 stringio_handler = logging.StreamHandler(self._captured_output) 

80 formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') 

81 stringio_handler.setFormatter(formatter) 

82 inithandler = idaeslog.getInitLogger("init").logger.parent 

83 # Clearing existing handlers 

84 inithandler.handlers=[] 

85 # Add our custom handlers 

86 inithandler.addHandler(stringio_handler) 

87 

88 return self # allow chaining 

89 

90 def stop_capture(self): 

91 """ 

92 Stop capturing the output of the function. 

93 """ 

94 sys.stdout = sys.__stdout__ 

95 # Print the captured output to the terminal 

96 print(self._captured_output.getvalue(), end='') 

97 

98def solve_endpoint(solve_request: IdaesSolveRequestPayload, response: "Response") -> IdaesSolveCompletionPayload: 

99 logs = IOCapture().start_capture() 

100 try: 

101 solve_result : SolveModelResult = solve_model(solve_request) 

102 full_log = logs.get_output() 

103 diagnostics_raw_text = ( 

104 logs._extract_marked_section( 

105 full_log, IOCapture.DIAGNOSTICS_START_MARKER, IOCapture.DIAGNOSTICS_END_MARKER 

106 ) 

107 if solve_request.perform_diagnostics 

108 else None 

109 ) 

110 return IdaesSolveCompletionPayload( 

111 flowsheet=solve_result.output_flowsheet, 

112 input_flowsheet=solve_result.input_flowsheet, 

113 timing=solve_result.timing, 

114 log=full_log, 

115 traceback=None, 

116 solve_index=solve_result.solve_index, 

117 scenario_id=solve_result.scenario_id, 

118 task_id=solve_result.task_id, 

119 status="success", 

120 error=None, 

121 diagnostics_raw_text=diagnostics_raw_text, 

122 ) 

123 except Exception as error: 

124 return IdaesSolveCompletionPayload( 

125 flowsheet=None, 

126 input_flowsheet=solve_request.flowsheet, 

127 timing={}, 

128 log=logs.get_output(), 

129 traceback=None, 

130 solve_index=solve_request.solve_index, 

131 scenario_id=solve_request.scenario_id, 

132 task_id=solve_request.task_id, 

133 status="error", 

134 # TODO: Why are we using the traceback twice? does this need to be a subobject? see also endpoints.py in ../idaes_factory, LogPanel.tsx, and LogsSlice.tsx. 

135 error={ 

136 "message": str(error), 

137 "cause": "idaes_service_request", 

138 "traceback": traceback.format_exc(), 

139 }, 

140 diagnostics_raw_text=None, 

141 ) 

142 finally: 

143 logs.stop_capture() 

144 

145 

146 

147def build_state_endpoint(schema: BuildStateRequestSchema, response: "Response") -> BuildStateResponseSchema: 

148 try: 

149 output: list[SolvedPropertyValueSchema] = solve_state_block(schema) 

150 response.status_code = 200 

151 return BuildStateResponseSchema( 

152 properties=output, 

153 error=None, 

154 traceback=None, 

155 log=None 

156 ) 

157 except Exception as error: 

158 response.status_code = 400 

159 return BuildStateResponseSchema( 

160 properties=None, 

161 error=str(error), 

162 traceback=traceback.format_exc(), 

163 log=None 

164 ) 

165 

166def generate_python_code_endpoint(flowsheet: FlowsheetSchema, response: "Response") -> str: 

167 try: 

168 return generate_python_code(flowsheet) 

169 except Exception as error: 

170 response.status_code = 400 

171 return f"{type(error).__name__}: {error}\n{traceback.format_exc()}" 

172 

173 

174def ml_endpoint(ml_request: MLTrainRequestPayload, response: "Response") -> MLTrainingCompletionPayload: 

175 logs = IOCapture().start_capture() 

176 try: 

177 result = ml_generate(ml_request) 

178 return MLTrainingCompletionPayload( 

179 json_response=result.model_dump(), 

180 error=None, 

181 traceback=None, 

182 log=logs.get_output(), 

183 task_id=ml_request.task_id, 

184 status="success" 

185 ) 

186 except Exception as error: 

187 return MLTrainingCompletionPayload( 

188 json_response=None, 

189 error=str(error), 

190 traceback=traceback.format_exc(), 

191 log=logs.get_output(), 

192 task_id=ml_request.task_id, 

193 status="error" 

194 ) 

195 finally: 

196 logs.stop_capture()