Coverage for backend/idaes_service/solver/generate_python_file.py: 89%

255 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1from typing import Any 

2from enum import Enum 

3from pyomo.core.base.units_container import units 

4from pyomo.environ import value as get_value 

5from common.models.idaes import FlowsheetSchema, UnitModelSchema, PropertiesSchema, PortsSchema 

6from common.models.idaes.flowsheet_schema import PropertyPackageType 

7from .methods.units_handler import idaes_specific_convert, attach_unit, get_attached_unit, get_attached_unit_str 

8from .methods.adapter_library import AdapterLibrary, UnitModelConstructor 

9from .methods import adapter_methods 

10from .build_state import get_state_vars 

11import json 

12from idaes_service.solver.custom.energy.power_property_package import PowerParameterBlock 

13 

14 

15 

16class Section: 

17 """ 

18 Represents a section of a Python file 

19 Includes a name and a list of lines 

20 """ 

21 def __init__(self, name: str, header: bool = True, new_line: bool = True, optional: bool = False) -> None: 

22 self._name: str = name 

23 self._lines: list[str] = [] 

24 self._header: bool = header 

25 self._new_line: bool = new_line 

26 self._optional: bool = optional 

27 

28 def extend(self, lines: list[str]) -> None: 

29 self._lines.extend(lines) 

30 

31 def header(self) -> str: 

32 return f"### {self._name}" 

33 

34 def lines(self) -> list[str]: 

35 return self._lines 

36 

37 

38class PythonFileGenerator: 

39 """ 

40 Generate a Python file from the given model data 

41 """ 

42 

43 def __init__(self, schema: FlowsheetSchema) -> None: 

44 self._schema = schema 

45 

46 # set some global constants 

47 self._model = "m" 

48 self._flowsheet = "fs" 

49 self._solver = "ipopt" 

50 # store the unit models, property packages, and ports 

51 self._property_packages: dict[int, dict[str, Any]] = {} # {id: {name, vars: {vars}}} 

52 self._ports: dict[int, dict[str, Any]] = {} # {id: {name, arc_id}} 

53 self._arcs: dict[int, dict[str, Any]] = {} # {id: {name, source_id, destination_id}} 

54 

55 # create the sections 

56 self._sections = { 

57 "imports": Section("Imports"), 

58 "property_package_imports": Section("Property Package Imports", header=False, new_line=False), 

59 "unit_model_imports": Section("Unit Model Imports", header=False, new_line=False), 

60 "utility methods": Section("Utility Methods"), 

61 "build_model": Section("Build Model"), 

62 "create_property_packages": Section("Create Property Packages", header=False), 

63 "create_unit_models": Section("Create Unit Models", header=False), 

64 "create_arcs": Section("Connect Unit Models", optional=True), 

65 "check_model": Section("Check Model Status"), 

66 "initialize": Section("Initialize Model"), 

67 "solve": Section("Solve"), 

68 "report": Section("Report"), 

69 } 

70 

71 

72 def sections(self) -> dict: 

73 return self._sections 

74 

75 

76 def setup_sections(self) -> None: 

77 """ 

78 Set up the sections with the initial (constant) lines 

79 """ 

80 self.extend("imports", [ 

81 "from pyomo.environ import ConcreteModel, SolverFactory, SolverStatus, TerminationCondition, Block, TransformationFactory, assert_optimal_termination", 

82 "from pyomo.network import SequentialDecomposition, Port, Arc", 

83 "from pyomo.core.base.units_container import _PyomoUnit, units as pyomo_units", 

84 "from idaes.core import FlowsheetBlock", 

85 "from idaes.core.util.model_statistics import report_statistics, degrees_of_freedom", 

86 "from idaes.core.util.tables import _get_state_from_port", 

87 "import idaes.logger as idaeslog", 

88 ]) 

89 self.extend("property_package_imports", [ 

90 "from property_packages.build_package import build_package", 

91 ]) 

92 self.extend("utility methods", [ 

93 "def units(item: str) -> _PyomoUnit:", 

94 " ureg = pyomo_units._pint_registry", 

95 " pint_unit = getattr(ureg, item)", 

96 " return _PyomoUnit(pint_unit, ureg)", 

97 ]) 

98 self.extend("build_model", [ 

99 f"{self._model} = ConcreteModel()", 

100 f"{self._model}.{self._flowsheet} = FlowsheetBlock(dynamic=False)", 

101 ]) 

102 self.extend("create_property_packages", [ 

103 "# Set up property packages", 

104 ]) 

105 self.extend("create_unit_models", [ 

106 "# Create unit models", 

107 ]) 

108 self.extend("check_model", [ 

109 f"report_statistics({self._model})", 

110 f"print(\"Degrees of freedom:\", degrees_of_freedom({self._model}))", 

111 ]) 

112 self.extend("solve", [ 

113 f"opt = SolverFactory(\"{self._solver}\")", 

114 f"res = opt.solve({self._model}, tee=True)", 

115 f"assert_optimal_termination(res)", 

116 ]) 

117 

118 

119 def extend(self, section_name: str, lines: list[str] | str) -> None: 

120 # add lines to a section 

121 if isinstance(lines, str): 

122 lines = [lines] 

123 self._sections[section_name].extend(lines) 

124 

125 

126 def add_section_excl(self, section_name: str, line: str) -> None: 

127 # add a line to a section if it is not already present 

128 if line not in self._sections[section_name].lines(): 128 ↛ exitline 128 didn't return from function 'add_section_excl' because the condition on line 128 was always true

129 self.extend(section_name, line) 

130 

131 

132 def resolve_import(self, obj: type | Enum) -> tuple[str, str]: 

133 # get the class name and import statement for a given class 

134 module_name = obj.__module__ 

135 if isinstance(obj, Enum): 135 ↛ 136line 135 didn't jump to line 136 because the condition on line 135 was never true

136 import_name = obj.__class__.__name__ # name to be used in the import statement 

137 class_name = f"{obj.__class__.__name__}.{obj.name}" # name to be used in the code 

138 else: 

139 import_name = obj.__name__ 

140 class_name = import_name 

141 return class_name, f"from {module_name} import {import_name}" 

142 

143 

144 def get_name(self, name: str) -> str: 

145 # clean the name so that it can be used as a Python variable 

146 name = name.strip().replace("-", "_").replace(" ", "_") 

147 # remove any spaces and special characters 

148 name = "".join([char for char in name if char.isalnum() or char == "_"]) 

149 # if the name is empty, use a default name 

150 if len(name) == 0: 

151 name = f"_unnamed_unit" 

152 # if the name starts with a number, add an underscore 

153 if name[0].isnumeric(): 

154 name = "_" + name 

155 name = f"{self._model}.{self._flowsheet}.{name}" 

156 return name 

157 

158 

159 def create_property_packages(self) -> None: 

160 """ 

161 Create property packages 

162 """ 

163 for schema in self._schema.property_packages: 163 ↛ 164line 163 didn't jump to line 164 because the loop on line 163 never started

164 self.create_property_package(schema) 

165 

166 

167 def create_property_package(self, schema: PropertyPackageType) -> None: 

168 name = self.get_name("PP_" + str(schema.id)) 

169 compounds = schema.compounds 

170 phases = schema.phases 

171 type = schema.type 

172 self.extend("create_property_packages", [ 

173 f"{name} = build_package(", 

174 f" \"{type}\",", 

175 f" {json.dumps(compounds)},", 

176 ")", 

177 ]) 

178 self._property_packages[schema.id] = { "name": name, "vars": get_state_vars(schema) } 

179 

180 

181 def get_property_package(self, id: int) -> dict: 

182 """Get the name of a property package by ID""" 

183 if id == -1 and id not in self._property_packages: 

184 # add a default Helmholtz property package (for testing purposes) 

185 self.create_property_package(PropertyPackageType(id=id, type="helmholtz", compounds=["h2o"], phases=["Liq"])) 

186 return self._property_packages[id] 

187 

188 def get_power_property_package(self, id: str): 

189 power_package = PowerParameterBlock 

190 return self._property_packages[-1] 

191 

192 

193 def get_property_package_at_port(self, model_schema: UnitModelSchema, port: str) -> dict: 

194 """Get the property package that is used at a port""" 

195 # generally, the unitop has one property package that is used for all ports 

196 if model_schema.args.get("property_package") is not None: 

197 return self.get_property_package(model_schema.args["property_package"]) 

198 # this is a special case (hard-coded for now) for heat exchangers, which have two property packages 

199 # TODO: make this dynamic once we have parent stream inheritance 

200 package_arg = port.removesuffix("_inlet").removesuffix("_outlet") # eg. "hot_side_inlet" -> "hot_side" 

201 return self.get_property_package(model_schema.args[package_arg]["property_package"]) 

202 

203 

204 def serialise_dict(self, d: dict, indent: bool = False, indent_level: int = 1, nested_indent: bool = True) -> str: 

205 if len(d) == 0: 205 ↛ 206line 205 didn't jump to line 206 because the condition on line 205 was never true

206 return "{}" 

207 result = "{" 

208 for k, v in d.items(): 

209 if indent: 209 ↛ 210line 209 didn't jump to line 210 because the condition on line 209 was never true

210 result += "\n" + " " * indent_level 

211 adj_k = f"\"{k}\"" if isinstance(k, str) else k 

212 adj_v = f"\"{v}\"" if isinstance(v, str) and not v.startswith(f"{self._model}.{self._flowsheet}.") else v 

213 if isinstance(v, dict): 213 ↛ 215line 213 didn't jump to line 215 because the condition on line 213 was never true

214 # allow nested dictionaries 

215 adj_v = self.serialise_dict(v, indent=indent and nested_indent, indent_level=indent_level + 1) 

216 result += f"{adj_k}: {adj_v}," 

217 result = result[:-1] + ("\n" + " " * (indent_level - 1)) * indent + "}" 

218 return result 

219 

220 

221 def serialise_list(self, l: list) -> str: 

222 if len(l) == 0: 222 ↛ 224line 222 didn't jump to line 224 because the condition on line 222 was always true

223 return "[]" 

224 result = "[" 

225 for v in l: 

226 adj_v = f"\"{v}\"" if isinstance(v, str) and not v.startswith(f"{self._model}.{self._flowsheet}.") else v 

227 result += f"{adj_v}," 

228 result = result[:-1] + "]" 

229 return result 

230 

231 

232 def setup_args(self, args: dict, arg_parsers: dict) -> dict: 

233 """Setup the arguments for a unit model""" 

234 result: dict[str, Any] = {} 

235 print("args: " + str(args)) 

236 for arg_name, method in arg_parsers.items(): 

237 def match_method() -> Any: 

238 match method.__class__: 

239 case adapter_methods.Constant: 

240 # constant, defined in the method 

241 constant = method.run(None,None) 

242 # constant can be a function, in which case we need to resolve the import 

243 if callable(constant) or isinstance(constant, Enum): 

244 constant_class_name, constant_import = self.resolve_import(constant) 

245 self.add_section_excl("imports", constant_import) 

246 return constant_class_name 

247 return constant 

248 case adapter_methods.Value: 

249 # value, keep as is 

250 return args.get(arg_name, None) 

251 case adapter_methods.PropertyPackage: 

252 # property package 

253 property_package_id = args["property_package"] 

254 print(property_package_id) 

255 return self.get_property_package(property_package_id)["name"] 

256 case adapter_methods.PowerPropertyPackage: 

257 #power property package 

258 return "m.fs.power_property_package" 

259 

260 case adapter_methods.Dictionary: 260 ↛ 263line 260 didn't jump to line 263 because the pattern on line 260 always matched

261 # another dictionary of arg parsers, recursively setup the args 

262 return self.setup_args(args[arg_name], method._schema) 

263 case _: 

264 raise Exception(f"Method {method} not supported") 

265 result[arg_name] = match_method() 

266 return result 

267 

268 

269 def write_args(self, args: dict) -> str: 

270 args_str = "" 

271 if len(args) == 0: 271 ↛ 272line 271 didn't jump to line 272 because the condition on line 271 was never true

272 return args_str 

273 args_str += "\n" 

274 for key, value in args.items(): 

275 args_str += f" {key}=" 

276 if isinstance(value, dict): 

277 args_str += self.serialise_dict(value) 

278 else: 

279 args_str += f"{value}" 

280 args_str += ",\n" 

281 args_str = args_str[:-2] + "\n" # remove the last comma 

282 return args_str 

283 

284 

285 def create_unit_models(self) -> None: 

286 """ 

287 Create the unit models 

288 """ 

289 for unit_model in self._schema.unit_models: 

290 # add to imports 

291 adapter: Adapter = AdapterLibrary[unit_model.type] 

292 class_name, class_import = self.resolve_import(adapter.model_constructor) 

293 self.add_section_excl("unit_model_imports", class_import) 

294 # setup args 

295 args = self.setup_args(unit_model.args, adapter.arg_parsers) 

296 args_str = self.write_args(args) 

297 print("args_str: " + args_str) 

298 

299 # create the unit model 

300 name = self.get_name(unit_model.name) 

301 self.extend("create_unit_models", f"\n# {unit_model.name}") # comment 

302 self.extend("create_unit_models", f"{name} = {class_name}({args_str})") # constructor 

303 self.extend("create_unit_models", self.fix_properties(name, unit_model.properties)) # fix properties 

304 for port_name, port_data in unit_model.ports.items(): 

305 # save the port 

306 global_name = f"{name}.{port_name}" 

307 self._ports[port_data.id] = { "name": global_name, "arc": None } 

308 # available_vars = self.get_property_package_at_port(unit_model, port_name)["vars"] 

309 # fix the properties of the port 

310 self.extend("create_unit_models", self.fix_state_block(global_name, port_data.properties)) 

311 self.extend("report", f"{name}.report()") 

312 

313 

314 def fix_properties(self, prefix: str, properties_schema: PropertiesSchema) -> list[str]: 

315 lines = [] 

316 for key, property_info in properties_schema.items(): 

317 for property_value in property_info.data: 

318 if property_value.value is None: 318 ↛ 319line 318 didn't jump to line 319 because the condition on line 318 was never true

319 continue 

320 if property_value.discrete_indexes is not None: 320 ↛ 321line 320 didn't jump to line 321 because the condition on line 320 was never true

321 indexes_tuple = tuple(property_value.discrete_indexes) 

322 indexes_string = f"[{indexes_tuple}]" if len(property_value.discrete_indexes) > 0 else "" 

323 else: 

324 indexes_string = "" 

325 val = get_value(property_value.value) 

326 unit = property_info.unit 

327 # TODO: Handle dynamic indexes etc. 

328 lines.append(f"{prefix}.{key}{indexes_string}.fix({val} * units(\"{unit}\"))") 

329 

330 return lines 

331 

332 

333 def fix_state_block(self, prefix: str, properties: PropertiesSchema) -> list[str]: 

334 """ 

335 Fix the properties of a unit model 

336 """ 

337 lines = [] 

338 lines.append(f"sb = _get_state_from_port({prefix}, 0)") 

339 for key, property_info in properties.items(): 

340 for property_value in property_info.data: 

341 if property_value.value is None: 

342 continue 

343 if property_value.discrete_indexes is None: 

344 indexes_str = "" 

345 else: 

346 # We aren't worrying about time yet, but we will need to do in the future. 

347 indexes_tuple = tuple(property_value.discrete_indexes) 

348 indexes_str = f"[{indexes_tuple}]" if len(property_value.discrete_indexes) > 0 else "" 

349 val = get_value(property_value.value) 

350 unit = property_info.unit 

351 lines.append(f"sb.constrain_component(sb.{key}{indexes_str}, {val} * units(\"{unit}\"))") 

352 

353 if len(lines) == 1: 

354 return [] 

355 return lines 

356 

357 

358 def create_arcs(self) -> None: 

359 """ 

360 Create the arcs 

361 """ 

362 if len(self._schema.arcs) == 0: 

363 return 

364 for i, arc in enumerate(self._schema.arcs): 

365 source = self._ports[arc.source] 

366 destination = self._ports[arc.destination] 

367 name = f"{self._model}.{self._flowsheet}.arc_{i + 1}" 

368 self.extend("create_arcs", f"{name} = Arc(source={source['name']}, destination={destination['name']})") 

369 self._arcs[i] = { "name": name, "source": arc.source, "destination": arc.destination } 

370 source["arc"] = i 

371 destination["arc"] = i 

372 # add to initialization: expand the arcs 

373 self.extend("initialize", f"TransformationFactory(\"network.expand_arcs\").apply_to({self._model})") 

374 

375 

376 def initialize(self) -> None: 

377 """ 

378 Initialize the model 

379 """ 

380 def is_connected(ports: PortsSchema) -> bool: 

381 for _, port_data in ports.items(): 

382 port = self._ports[port_data.id] 

383 if port["arc"] is not None: 

384 return True 

385 return False 

386 # initialize everything that is not connected 

387 for unit_model in self._schema.unit_models: 

388 if not is_connected(unit_model.ports): 

389 name = self.get_name(unit_model.name) 

390 self.extend("initialize", f"{name}.initialize(outlvl=idaeslog.INFO)") 

391 if len(self._schema.arcs) == 0: 

392 return 

393 # setup sequential decomposition 

394 self.extend("utility methods", [ 

395 "\ndef init_unit(unit: Block) -> None:", 

396 " unit.initialize(outlvl=idaeslog.INFO)" 

397 ]) 

398 self.extend("initialize", "seq = SequentialDecomposition()") 

399 # set tear guesses 

400 tear_set = [] 

401 

402 # Need to rewrite the logic for dealing with tear sets from recycle 

403 

404 self.extend("initialize", f"seq.set_tear_set({self.serialise_list(tear_set)})") 

405 self.extend("initialize", f"seq.run({self._model}, init_unit)") 

406 

407 

408def generate_python_code(model_data: FlowsheetSchema) -> str: 

409 """ 

410 Generate a Python file from the given model data 

411 """ 

412 generator = PythonFileGenerator(model_data) 

413 generator.setup_sections() 

414 generator.create_property_packages() 

415 generator.create_unit_models() 

416 generator.create_arcs() 

417 generator.initialize() 

418 sections = generator.sections() 

419 

420 result = "" 

421 for key, section in sections.items(): 

422 if section._optional and len(section.lines()) == 0: 

423 # skip empty sections 

424 continue 

425 # add extra newline characters between sections 

426 if section._new_line and key != list(sections.keys())[0]: 

427 if section._header: 

428 result += "\n\n" 

429 else: 

430 result += "\n" 

431 # add section header 

432 if section._header: 

433 result += section.header() 

434 if section._new_line: 434 ↛ 438line 434 didn't jump to line 438 because the condition on line 434 was always true

435 result += "\n" 

436 # write each line to the result string 

437 # separated by a newline character 

438 result += "\n".join(section.lines()) + "\n" 

439 return result 

440