Coverage for backend/idaes_service/endpoints.py: 87%

69 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1import io 

2import sys 

3import traceback 

4 

5from abc import ABC, abstractmethod 

6from pydantic import TypeAdapter 

7from typing import Callable, Any 

8from fastapi import FastAPI, Response 

9 

10from common.models.idaes import FlowsheetSchema 

11from common.models.idaes.payloads import BuildStateRequestSchema, BuildStateResponseSchema 

12from common.models.idaes.payloads.solve_request_schema import IdaesSolveCompletionPayload, IdaesSolveRequestPayload 

13from common.models.idaes.unit_model_schema import SolvedPropertyValueSchema 

14from .solver.solver import solve_model, SolveModelResult 

15from .solver.build_state import solve_state_block 

16from .solver.generate_python_file import generate_python_code 

17from .solver.ml_wizard import ml_generate, MLResult 

18from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingEvent 

19import idaes.logger as idaeslog 

20import logging 

21from common.models.idaes.payloads.ml_request_schema import MLTrainRequestPayload, MLTrainingCompletionPayload 

22from typing import Tuple 

23from functools import wraps 

24from common.models.idaes.payloads.solve_request_schema import IdaesSolveEvent, IdaesSolveRequestPayload 

25 

26class IOCapture: 

27 """ 

28 Class to capture the output of a function and return it along with the result. 

29 This is used to capture the output of the function and return it along with the result. 

30 """ 

31 def __init__(self): 

32 self._captured_output = io.StringIO() 

33 

34 def get_output(self) -> str: 

35 return self._captured_output.getvalue() 

36 

37 def start_capture(self): 

38 """ 

39 Start capturing the output of the function. 

40 """ 

41 sys.stdout = self._captured_output 

42 # Can't capture the output of the init logger, so we need to add a new handler 

43 stringio_handler = logging.StreamHandler(self._captured_output) 

44 formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(name)s: %(message)s') 

45 stringio_handler.setFormatter(formatter) 

46 inithandler = idaeslog.getInitLogger("init").logger.parent 

47 # Clearing existing handlers 

48 inithandler.handlers=[] 

49 # Add our custom handlers 

50 inithandler.addHandler(stringio_handler) 

51 

52 return self # allow chaining 

53 

54 def stop_capture(self): 

55 """ 

56 Stop capturing the output of the function. 

57 """ 

58 sys.stdout = sys.__stdout__ 

59 # Print the captured output to the terminal 

60 print(self._captured_output.getvalue(), end='') 

61 

62def solve_endpoint(solve_request: IdaesSolveRequestPayload, response: Response) -> IdaesSolveCompletionPayload: 

63 logs = IOCapture().start_capture() 

64 try: 

65 solve_result : SolveModelResult = solve_model(solve_request) 

66 return IdaesSolveCompletionPayload( 

67 flowsheet=solve_result.output_flowsheet, 

68 input_flowsheet=solve_result.input_flowsheet, 

69 timing=solve_result.timing, 

70 log=logs.get_output(), 

71 traceback=None, 

72 solve_index=solve_result.solve_index, 

73 scenario_id=solve_result.scenario_id, 

74 task_id=solve_result.task_id, 

75 status="success", 

76 error=None, 

77 ) 

78 except Exception as error: 

79 return IdaesSolveCompletionPayload( 

80 flowsheet=None, 

81 input_flowsheet=solve_request.flowsheet, 

82 timing={}, 

83 log=logs.get_output(), 

84 traceback=None, 

85 solve_index=solve_request.solve_index, 

86 scenario_id=solve_request.scenario_id, 

87 task_id=solve_request.task_id, 

88 status="error", 

89 # TODO: Why are we using the traceback twice? does this need to be a subobject? see also endpoints.py in ../idaes_factory, LogPanel.tsx, and LogsSlice.tsx. 

90 error={ 

91 "message": str(error), 

92 "cause": "idaes_service_request", 

93 "traceback": traceback.format_exc(), 

94 }, 

95 ) 

96 finally: 

97 logs.stop_capture() 

98 

99 

100 

101def build_state_endpoint(schema: BuildStateRequestSchema,response: Response) -> BuildStateResponseSchema: 

102 try: 

103 output: list[SolvedPropertyValueSchema] = solve_state_block(schema) 

104 response.status_code = 200 

105 return BuildStateResponseSchema( 

106 properties=output, 

107 error=None, 

108 traceback=None, 

109 log=None 

110 ) 

111 except Exception as error: 

112 response.status_code = 400 

113 return BuildStateResponseSchema( 

114 properties=None, 

115 error=str(error), 

116 traceback=traceback.format_exc(), 

117 log=None 

118 ) 

119 

120def generate_python_code_endpoint(flowsheet: FlowsheetSchema, response: Response) -> str: 

121 try: 

122 return generate_python_code(flowsheet) 

123 except Exception as error: 

124 response.status_code = 400 

125 return f"{type(error).__name__}: {error}\n{traceback.format_exc()}" 

126 

127 

128def ml_endpoint(ml_request: MLTrainRequestPayload, response: Response) -> MLTrainingCompletionPayload: 

129 logs = IOCapture().start_capture() 

130 try: 

131 result = ml_generate(ml_request) 

132 return MLTrainingCompletionPayload( 

133 json_response=result.model_dump(), 

134 error=None, 

135 traceback=None, 

136 log=logs.get_output(), 

137 task_id=ml_request.task_id, 

138 status="success" 

139 ) 

140 except Exception as error: 

141 return MLTrainingCompletionPayload( 

142 json_response=None, 

143 error=str(error), 

144 traceback=traceback.format_exc(), 

145 log=logs.get_output(), 

146 task_id=ml_request.task_id, 

147 status="error" 

148 ) 

149 finally: 

150 logs.stop_capture()