Coverage for backend/django/Economics/costing/cost_curves/serializers.py: 80%

152 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-06-23 21:51 +0000

1from drf_spectacular.utils import extend_schema_field 

2from rest_framework import serializers 

3 

4from Economics.costing.models import CostCurve 

5from Economics.shared.choices import CostCurveEvaluationKind 

6from Economics.shared.serializer_base import FlowsheetScopedSerializer, _current_flowsheet_id 

7from Economics.shared.serializers import UnitOptionSerializer 

8from Economics.costing.cost_curves.catalog import cost_curve_category_requires_subtype 

9from Economics.costing.cost_curves.driver_specs import ( 

10 CostCurveDiscreteVariant, 

11 CostCurveDiscreteVariantPayload, 

12 CostCurveDriverSpec, 

13 CostCurveDriverSpecPayload, 

14 CostCurveDriverSpecRead, 

15 driver_spec_read_payload, 

16 normalize_discrete_variants, 

17 normalize_required_driver_specs, 

18) 

19from Economics.costing.cost_curves.unit_options import ( 

20 cost_curve_common_driver_unit_options, 

21 cost_curve_output_unit_options, 

22) 

23from Economics.formulas.engine.core import FormulaError 

24from Economics.formulas.engine.parsing import parse_cost_expression 

25 

26 

27@extend_schema_field( 

28 CostCurveDriverSpecRead.model_json_schema(), 

29 component_name="CostCurveDriverSpec", 

30) 

31class CostCurveDriverSpecField(serializers.JSONField): 

32 """JSON transport field whose OpenAPI component comes from Pydantic.""" 

33 

34 def to_representation(self, value): 

35 try: 

36 return driver_spec_read_payload(CostCurveDriverSpec.model_validate(value)) 

37 except ValueError: 

38 return super().to_representation(value) 

39 

40 

41@extend_schema_field( 

42 CostCurveDiscreteVariant.model_json_schema(), 

43 component_name="CostCurveDiscreteVariant", 

44) 

45class CostCurveDiscreteVariantField(serializers.JSONField): 

46 """JSON transport field whose OpenAPI component comes from Pydantic.""" 

47 

48 def to_representation(self, value): 

49 if isinstance(value, CostCurveDiscreteVariant): 49 ↛ 51line 49 didn't jump to line 51 because the condition on line 49 was always true

50 return value.model_dump(mode="json") 

51 return super().to_representation(value) 

52 

53 

54class CostCurveSerializer(FlowsheetScopedSerializer): 

55 evaluation_kind = serializers.ChoiceField( 

56 choices=CostCurveEvaluationKind.choices, 

57 required=True, 

58 ) 

59 output_unit_options = serializers.SerializerMethodField() 

60 required_driver_specs = serializers.ListField( 

61 child=CostCurveDriverSpecField(), 

62 required=True, 

63 ) 

64 discrete_variants = serializers.ListField( 

65 child=CostCurveDiscreteVariantField(), 

66 required=True, 

67 ) 

68 

69 class Meta: 

70 model = CostCurve 

71 fields = "__all__" 

72 read_only_fields = ("id", "flowsheet", "created_at", "updated_at") 

73 

74 @extend_schema_field(UnitOptionSerializer(many=True)) 

75 def get_output_unit_options(self, instance) -> list[dict[str, str]]: 

76 return cost_curve_output_unit_options(instance.output_unit, currency=instance.currency or "NZD") 

77 

78 

79class CostCurveAuthoringSerializer(FlowsheetScopedSerializer): 

80 """Expression-first write contract for user-authored cost curves.""" 

81 

82 evaluation_kind = serializers.ChoiceField( 

83 choices=CostCurveEvaluationKind.choices, 

84 required=True, 

85 ) 

86 required_driver_specs = serializers.ListField( 

87 child=CostCurveDriverSpecField(), 

88 required=True, 

89 ) 

90 discrete_variants = serializers.ListField( 

91 child=CostCurveDiscreteVariantField(), 

92 required=True, 

93 ) 

94 removed_formula_fields = frozenset( 

95 { 

96 "expression_type", 

97 "coefficient_a", 

98 "coefficient_b", 

99 "coefficient_c", 

100 "exponent", 

101 "manual_quote_amount", 

102 } 

103 ) 

104 

105 class Meta: 

106 model = CostCurve 

107 fields = ( 

108 "id", 

109 "flowsheet", 

110 "curve_key", 

111 "name", 

112 "equipment_category", 

113 "equipment_subtype", 

114 "cost_basis", 

115 "evaluation_kind", 

116 "output_unit", 

117 "expression_text", 

118 "required_driver_specs", 

119 "discrete_variants", 

120 "valid_min", 

121 "valid_max", 

122 "valid_range_note", 

123 "currency", 

124 "basis_date", 

125 "basis_index_name", 

126 "basis_index_value", 

127 "source_document_title", 

128 "source_page", 

129 "source_figure", 

130 "source_data_origin", 

131 "source_range_precision", 

132 "source_license_status", 

133 "source_reference", 

134 "applicability_warning", 

135 "notes", 

136 "active", 

137 "created_at", 

138 "updated_at", 

139 ) 

140 read_only_fields = ("id", "flowsheet", "created_at", "updated_at") 

141 

142 def to_internal_value(self, data): 

143 if isinstance(data, dict): 143 ↛ 152line 143 didn't jump to line 152 because the condition on line 143 was always true

144 removed_fields = sorted(self.removed_formula_fields.intersection(data.keys())) 

145 if removed_fields: 

146 raise serializers.ValidationError( 

147 { 

148 field: "Cost curves are expression-only. Use expression_text." 

149 for field in removed_fields 

150 } 

151 ) 

152 return super().to_internal_value(data) 

153 

154 def validate(self, attrs): 

155 attrs = super().validate(attrs) 

156 evaluation_kind = attrs.get( 

157 "evaluation_kind", 

158 getattr(self.instance, "evaluation_kind", CostCurveEvaluationKind.EXPRESSION), 

159 ) 

160 expression_text = attrs.get("expression_text", getattr(self.instance, "expression_text", "")) 

161 if evaluation_kind == CostCurveEvaluationKind.EXPRESSION and not expression_text: 

162 raise serializers.ValidationError({"expression_text": "Expression cost curves require expression_text."}) 

163 if "required_driver_specs" in attrs: 

164 attrs["required_driver_specs"] = _normalized_required_driver_specs(attrs["required_driver_specs"]) 

165 if "discrete_variants" in attrs: 

166 attrs["discrete_variants"] = _normalized_discrete_variants(attrs["discrete_variants"]) 

167 self._validate_formula_contract(attrs, evaluation_kind=evaluation_kind, expression_text=expression_text) 

168 self._validate_equipment_subtype_requirement(attrs) 

169 return attrs 

170 

171 def _validate_formula_contract(self, attrs, *, evaluation_kind: str, expression_text: str) -> None: 

172 specs_payload = attrs.get( 

173 "required_driver_specs", 

174 getattr(self.instance, "required_driver_specs", []), 

175 ) 

176 formula_specs = [spec for spec in specs_payload if spec["role"] == "formula_input"] 

177 selector_specs = [spec for spec in specs_payload if spec["role"] == "discrete_selector"] 

178 variable_symbols = [spec["variable_symbol"] for spec in formula_specs] 

179 try: 

180 if evaluation_kind == CostCurveEvaluationKind.EXPRESSION: 180 ↛ 182line 180 didn't jump to line 182 because the condition on line 180 was always true

181 parse_cost_expression(expression_text, variable_symbols=variable_symbols) 

182 elif evaluation_kind == CostCurveEvaluationKind.DISCRETE_FAMILY: 

183 expected_selector_keys = {spec["key"] for spec in selector_specs} 

184 if not expected_selector_keys: 

185 raise FormulaError( 

186 "missing_discrete_selectors", 

187 "Discrete-family curves require at least one selector input.", 

188 ) 

189 if len(expected_selector_keys) > 1: 

190 raise FormulaError( 

191 "unsupported_multi_selector_discrete_family", 

192 "Discrete-family curves currently support exactly one capacity selector.", 

193 context={"selector_keys": sorted(expected_selector_keys)}, 

194 ) 

195 variants = attrs.get("discrete_variants", getattr(self.instance, "discrete_variants", [])) 

196 if not variants: 

197 raise FormulaError( 

198 "missing_discrete_variants", 

199 "Discrete-family curves require at least one variant.", 

200 ) 

201 for variant in variants: 

202 selector_keys = set(variant["selector_values"]) 

203 if selector_keys != expected_selector_keys: 

204 raise FormulaError( 

205 "invalid_discrete_variant_selectors", 

206 "Discrete variant selector values must exactly match selector inputs.", 

207 context={ 

208 "variant_key": variant["key"], 

209 "missing_selectors": sorted(expected_selector_keys - selector_keys), 

210 "extra_selectors": sorted(selector_keys - expected_selector_keys), 

211 }, 

212 ) 

213 parse_cost_expression(variant["expression_text"], variable_symbols=variable_symbols) 

214 except FormulaError as exc: 

215 raise serializers.ValidationError({"expression_text": exc.message, "context": exc.context}) from exc 

216 

217 def _validate_equipment_subtype_requirement(self, attrs) -> None: 

218 category = attrs.get("equipment_category", getattr(self.instance, "equipment_category", "")) 

219 subtype = attrs.get("equipment_subtype", getattr(self.instance, "equipment_subtype", "")) 

220 if not category or str(subtype).strip(): 

221 return 

222 

223 if cost_curve_category_requires_subtype(str(category), self._peer_cost_curves()): 

224 raise serializers.ValidationError( 

225 { 

226 "equipment_subtype": ( 

227 "Choose an equipment subtype before saving a cost curve for this category." 

228 ) 

229 } 

230 ) 

231 

232 def _peer_cost_curves(self): 

233 flowsheet_id = _current_flowsheet_id() 

234 if flowsheet_id is None and self.instance is not None: 234 ↛ 235line 234 didn't jump to line 235 because the condition on line 234 was never true

235 flowsheet_id = self.instance.flowsheet_id 

236 

237 queryset = CostCurve.objects.all() 

238 return queryset.filter(flowsheet_id=flowsheet_id) if flowsheet_id is not None else queryset 

239 

240 

241class CostCurveTemplateSerializer(serializers.Serializer): 

242 output_unit_options = serializers.SerializerMethodField() 

243 required_driver_specs = serializers.ListField(child=CostCurveDriverSpecField()) 

244 discrete_variants = serializers.ListField(child=CostCurveDiscreteVariantField()) 

245 value = serializers.CharField() 

246 name = serializers.CharField() 

247 equipment_category = serializers.CharField() 

248 equipment_subtype = serializers.CharField(allow_blank=True) 

249 cost_basis = serializers.CharField() 

250 evaluation_kind = serializers.CharField() 

251 output_unit = serializers.CharField() 

252 expression_text = serializers.CharField(allow_blank=True) 

253 valid_min = serializers.CharField(allow_blank=True) 

254 valid_max = serializers.CharField(allow_blank=True) 

255 valid_range_note = serializers.CharField(allow_blank=True) 

256 currency = serializers.CharField() 

257 basis_date = serializers.CharField(allow_blank=True) 

258 basis_index_name = serializers.CharField(allow_blank=True) 

259 basis_index_value = serializers.CharField(allow_blank=True) 

260 source_document_title = serializers.CharField(allow_blank=True) 

261 source_page = serializers.CharField(allow_blank=True) 

262 source_figure = serializers.CharField(allow_blank=True) 

263 source_data_origin = serializers.CharField(allow_blank=True) 

264 source_range_precision = serializers.CharField(allow_blank=True) 

265 source_license_status = serializers.CharField(allow_blank=True) 

266 source_reference = serializers.CharField(allow_blank=True) 

267 notes = serializers.CharField(allow_blank=True) 

268 applicability_warning = serializers.CharField(allow_blank=True) 

269 active = serializers.BooleanField() 

270 

271 @extend_schema_field(UnitOptionSerializer(many=True)) 

272 def get_output_unit_options(self, instance) -> list[dict[str, str]]: 

273 return cost_curve_output_unit_options(instance.output_unit, currency=instance.currency or "NZD") 

274 

275 

276def _normalized_required_driver_specs(specs) -> list[CostCurveDriverSpecPayload]: 

277 """Route DRF writes through the Pydantic driver-spec contract.""" 

278 try: 

279 return normalize_required_driver_specs(specs) 

280 except ValueError as exc: 

281 raise serializers.ValidationError({"required_driver_specs": str(exc)}) from exc 

282 

283 

284 

285def _normalized_discrete_variants(variants) -> list[CostCurveDiscreteVariantPayload]: 

286 """Route DRF writes through the Pydantic discrete-variant contract.""" 

287 try: 

288 return normalize_discrete_variants(variants) 

289 except ValueError as exc: 

290 raise serializers.ValidationError({"discrete_variants": str(exc)}) from exc 

291 

292 

293class CostCurveEquipmentCategorySerializer(serializers.Serializer): 

294 value = serializers.CharField() 

295 label = serializers.CharField() 

296 subtypes = serializers.ListField(child=serializers.CharField()) 

297 driver_unit_options = serializers.SerializerMethodField() 

298 templates = CostCurveTemplateSerializer(many=True) 

299 

300 @extend_schema_field(UnitOptionSerializer(many=True)) 

301 def get_driver_unit_options(self, _instance) -> list[dict[str, str]]: 

302 return cost_curve_common_driver_unit_options()