Coverage for backend/django/core/auxiliary/services/parameter_sweep.py: 90%
303 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-06-23 21:51 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-06-23 21:51 +0000
1"""Parameter sweep domain models and row-generation service functions."""
3from __future__ import annotations
5from decimal import Decimal, ROUND_FLOOR
6from enum import StrEnum
7from itertools import product
8import random
9from typing import Any, Mapping
11from django.db.models import Q
12from django.db import transaction
13from pydantic import BaseModel, ConfigDict, Field, RootModel, StrictInt
14from pydantic import ValidationError as PydanticValidationError
15from pydantic import field_validator, model_validator
16from rest_framework.exceptions import ValidationError
18from core.auxiliary.enums.uiEnums import DisplayType
19from core.auxiliary.models.DataCell import DataCell
20from core.auxiliary.models.DataColumn import DataColumn
21from core.auxiliary.models.DataRow import DataRow
22from core.auxiliary.models.PropertyValue import PropertyValue
23from core.auxiliary.models.Scenario import (
24 ParameterSweepDefinition,
25 ParameterSweepParameter,
26 Scenario,
27 ScenarioInputModeEnum,
28)
31ROW_WARNING_THRESHOLD = 1000
32ROW_HARD_LIMIT = 50000
33CELL_HARD_LIMIT = 1000000
34CELL_BULK_BATCH_SIZE = 5000
35DECIMAL_QUANT = Decimal("0.000000000001")
38class ParameterSweepRequestMethodEnum(StrEnum):
39 Grid = "grid"
40 MonteCarlo = "monte_carlo"
41 Hammersley = "hammersley"
42 HaltonZaremba = "halton_zaremba"
45class ParameterSweepParameterRequest(BaseModel):
46 """One free parameter and the numeric interval to explore."""
48 model_config = ConfigDict(extra="forbid")
50 property_value: StrictInt
51 lower_bound: Decimal
52 upper_bound: Decimal
53 step: Decimal | None = None
55 @field_validator("lower_bound", "upper_bound", "step")
56 @classmethod
57 def validate_decimal(cls, value: Decimal | None) -> Decimal | None:
58 if value is not None and not value.is_finite(): 58 ↛ 59line 58 didn't jump to line 59 because the condition on line 58 was never true
59 raise ValueError("Numeric values must be finite.")
60 return value
63class ParameterSweepRequest(BaseModel):
64 """Validated parameter sweep request shared by preview and generation."""
66 model_config = ConfigDict(extra="forbid")
68 method: ParameterSweepRequestMethodEnum = ParameterSweepRequestMethodEnum.Grid
69 parameters: list[ParameterSweepParameterRequest] = Field(min_length=1)
70 sample_count: StrictInt | None = None
71 monte_carlo_seed: StrictInt | None = None
72 confirm_replace: bool = False
73 confirm_large_generation: bool = False
75 @model_validator(mode="after")
76 def validate_method_requirements(self) -> "ParameterSweepRequest":
77 if self.method == ParameterSweepRequestMethodEnum.Grid:
78 self.sample_count = None
79 self.monte_carlo_seed = None
80 for parameter in self.parameters:
81 _validate_grid_parameter(parameter)
82 return self
84 if self.sample_count is None: 84 ↛ 85line 84 didn't jump to line 85 because the condition on line 84 was never true
85 raise ValueError("Sample count is required for sampling methods.")
86 if self.sample_count <= 0: 86 ↛ 87line 86 didn't jump to line 87 because the condition on line 86 was never true
87 raise ValueError("Sample count must be greater than zero.")
89 if self.monte_carlo_seed is not None and self.monte_carlo_seed < 0: 89 ↛ 90line 89 didn't jump to line 90 because the condition on line 89 was never true
90 raise ValueError("Seed must be a non-negative integer.")
91 if self.method == ParameterSweepRequestMethodEnum.MonteCarlo:
92 if self.monte_carlo_seed is None:
93 self.monte_carlo_seed = random.SystemRandom().randrange(1, 2**31)
94 else:
95 self.monte_carlo_seed = None
96 for parameter in self.parameters:
97 parameter.step = None
98 return self
101class EligibleParameterSweepTarget(BaseModel):
102 property_value: int
103 simulation_object: int
104 simulation_object_name: str
105 property_name: str
106 indexed_set_names: list[str]
107 unit: str
108 label: str
111class EligibleParameterSweepTargetsResponse(RootModel[list[EligibleParameterSweepTarget]]):
112 """List response wrapper so OpenAPI keeps the target endpoint typed as an array."""
115class ParameterSweepPreviewParameterResponse(BaseModel):
116 property_value: int
117 label: str
118 unit: str
119 value_count: int
122class ParameterSweepPreviewResponse(BaseModel):
123 method: ParameterSweepRequestMethodEnum
124 row_count: int
125 warns_above_threshold: bool
126 hard_limit: int
127 warning_threshold: int
128 monte_carlo_seed: int | None = None
129 parameters: list[ParameterSweepPreviewParameterResponse]
132class ParameterSweepGenerateResponse(BaseModel):
133 method: ParameterSweepRequestMethodEnum
134 row_count: int
135 definition: int
136 monte_carlo_seed: int | None = None
139def eligible_parameter_sweep_targets(flowsheet_id: int) -> list[EligibleParameterSweepTarget]:
140 """Return directly settable numeric property values for a flowsheet sweep."""
142 values = (
143 PropertyValue.objects.filter(
144 flowsheet_id=flowsheet_id,
145 property__type=DisplayType.numeric,
146 property__set__simulationObject__is_deleted=False,
147 )
148 .filter(Q(formula__isnull=True) | Q(formula=""))
149 .filter(Q(enabled=True) | Q(controlSetPoint__isnull=False))
150 .filter(controlManipulated__isnull=True)
151 .select_related("property", "property__set", "property__set__simulationObject")
152 .prefetch_related("indexedItems")
153 .order_by(
154 "property__set__simulationObject__componentName",
155 "property__displayName",
156 "id",
157 )
158 )
160 return [_serialize_target(value) for value in values]
163def preview_parameter_sweep(
164 scenario: Scenario,
165 payload: ParameterSweepRequest | Mapping[str, Any],
166) -> ParameterSweepPreviewResponse:
167 """Validate a sweep request and report its size without persisting rows."""
169 spec = _coerce_sweep_request(payload)
170 targets = _validate_targets(scenario, spec)
171 row_count = _calculate_row_count(spec)
172 _validate_total_cell_count(row_count, len(targets))
173 return ParameterSweepPreviewResponse(
174 method=spec.method,
175 row_count=row_count,
176 warns_above_threshold=row_count > ROW_WARNING_THRESHOLD,
177 hard_limit=ROW_HARD_LIMIT,
178 warning_threshold=ROW_WARNING_THRESHOLD,
179 monte_carlo_seed=spec.monte_carlo_seed,
180 parameters=[
181 ParameterSweepPreviewParameterResponse(
182 property_value=target.id,
183 label=_target_label(target),
184 unit=target.property.unit,
185 value_count=_parameter_value_count(spec, index),
186 )
187 for index, target in enumerate(targets)
188 ],
189 )
192@transaction.atomic
193def generate_parameter_sweep(
194 scenario: Scenario,
195 payload: ParameterSweepRequest | Mapping[str, Any],
196) -> ParameterSweepGenerateResponse:
197 """Replace a scenario's MSS input with generated sweep rows.
199 The operation is deliberately all-or-nothing because the scenario mode,
200 saved sweep definition, data columns, rows, and cells must remain in sync.
201 """
203 spec = _coerce_sweep_request(payload)
204 targets = _validate_targets(scenario, spec)
205 row_count = _calculate_row_count(spec)
206 _validate_total_cell_count(row_count, len(targets))
207 existing_rows = scenario.dataRows.exists()
208 existing_columns = scenario.dataColumns.exists()
209 requires_replace = existing_rows or existing_columns
211 if requires_replace and not spec.confirm_replace:
212 raise ValidationError(
213 {
214 "confirm_replace": (
215 "Existing scenario input data will be replaced. "
216 "Set confirm_replace to true to continue."
217 )
218 }
219 )
220 if row_count > ROW_WARNING_THRESHOLD and not spec.confirm_large_generation:
221 raise ValidationError(
222 {
223 "confirm_large_generation": (
224 f"This sweep will generate {row_count} rows. "
225 "Set confirm_large_generation to true to continue."
226 )
227 }
228 )
230 rows = _generate_rows(spec)
232 DataRow.objects.filter(scenario=scenario).delete()
233 DataColumn.objects.filter(scenario=scenario).delete()
234 ParameterSweepDefinition.objects.filter(scenario=scenario).delete()
236 definition = ParameterSweepDefinition.objects.create(
237 flowsheet_id=scenario.flowsheet_id,
238 scenario=scenario,
239 method=spec.method,
240 sample_count=spec.sample_count,
241 monte_carlo_seed=spec.monte_carlo_seed,
242 )
243 ParameterSweepParameter.objects.bulk_create(
244 [
245 ParameterSweepParameter(
246 flowsheet_id=scenario.flowsheet_id,
247 definition=definition,
248 property_value=target,
249 order=index,
250 lower_bound=param.lower_bound,
251 upper_bound=param.upper_bound,
252 step=param.step,
253 unit=target.property.unit,
254 target_label=_target_label(target),
255 )
256 for index, (param, target) in enumerate(zip(spec.parameters, targets))
257 ]
258 )
260 columns = [
261 DataColumn(
262 flowsheet_id=scenario.flowsheet_id,
263 scenario=scenario,
264 name=_unique_column_name(target, index, targets),
265 property_value=target,
266 )
267 for index, target in enumerate(targets)
268 ]
269 DataColumn.objects.bulk_create(columns)
270 columns = list(DataColumn.objects.filter(scenario=scenario).order_by("id"))
272 data_rows = [
273 DataRow(index=index, flowsheet_id=scenario.flowsheet_id, scenario=scenario)
274 for index in range(row_count)
275 ]
276 DataRow.objects.bulk_create(data_rows, batch_size=CELL_BULK_BATCH_SIZE)
277 data_rows = list(DataRow.objects.filter(scenario=scenario).order_by("index"))
279 cells: list[DataCell] = []
280 for data_row, values in zip(data_rows, rows):
281 for column, value in zip(columns, values):
282 cells.append(
283 DataCell(
284 flowsheet_id=scenario.flowsheet_id,
285 data_column=column,
286 data_row=data_row,
287 value=float(value),
288 )
289 )
290 if len(cells) >= CELL_BULK_BATCH_SIZE:
291 DataCell.objects.bulk_create(cells, batch_size=CELL_BULK_BATCH_SIZE)
292 cells.clear()
293 if cells: 293 ↛ 296line 293 didn't jump to line 296 because the condition on line 293 was always true
294 DataCell.objects.bulk_create(cells, batch_size=CELL_BULK_BATCH_SIZE)
296 scenario.mss_input_mode = ScenarioInputModeEnum.ParameterSweep
297 scenario.Uploaded_fileName = ""
298 scenario.save(update_fields=["mss_input_mode", "Uploaded_fileName"])
300 return ParameterSweepGenerateResponse(
301 method=spec.method,
302 row_count=row_count,
303 definition=definition.id,
304 monte_carlo_seed=spec.monte_carlo_seed,
305 )
308def clear_parameter_sweep_definition(scenario: Scenario) -> None:
309 ParameterSweepDefinition.objects.filter(scenario=scenario).delete()
312def clear_mss_input_data(
313 scenario: Scenario,
314 *,
315 clear_uploaded_filename: bool = True,
316) -> None:
317 """Remove all generated/uploaded MSS table data for a scenario."""
319 DataRow.objects.filter(scenario=scenario).delete()
320 DataColumn.objects.filter(scenario=scenario).delete()
321 clear_parameter_sweep_definition(scenario)
322 if clear_uploaded_filename: 322 ↛ exitline 322 didn't return from function 'clear_mss_input_data' because the condition on line 322 was always true
323 scenario.Uploaded_fileName = ""
324 scenario.save(update_fields=["Uploaded_fileName"])
327def clear_mss_input_data_after_mode_switch(
328 scenario: Scenario,
329 *,
330 previous_mode: str,
331 requested_mode: str,
332 had_parameter_sweep_definition: bool,
333) -> None:
334 """Clear stale MSS data after the scenario's persisted input mode changes."""
336 if requested_mode not in { 336 ↛ 340line 336 didn't jump to line 340 because the condition on line 336 was never true
337 ScenarioInputModeEnum.Csv,
338 ScenarioInputModeEnum.ParameterSweep,
339 }:
340 return
342 if previous_mode == requested_mode and not ( 342 ↛ 345line 342 didn't jump to line 345 because the condition on line 342 was never true
343 requested_mode == ScenarioInputModeEnum.Csv and had_parameter_sweep_definition
344 ):
345 return
347 clear_mss_input_data(scenario)
350def validate_parameter_sweep_solve_ready(scenario: Scenario) -> None:
351 if scenario.mss_input_mode != ScenarioInputModeEnum.ParameterSweep:
352 return
354 try:
355 definition = scenario.parameterSweepDefinition
356 except ParameterSweepDefinition.DoesNotExist as exc:
357 raise ValidationError("Parameter sweep scenarios require a saved sweep definition.") from exc
359 params = list(definition.parameters.select_related("property_value"))
360 if not params: 360 ↛ 361line 360 didn't jump to line 361 because the condition on line 360 was never true
361 raise ValidationError("Parameter sweep scenarios require at least one parameter.")
362 if not scenario.dataRows.exists(): 362 ↛ 363line 362 didn't jump to line 363 because the condition on line 362 was never true
363 raise ValidationError("Parameter sweep scenarios require generated rows.")
365 spec = ParameterSweepRequest(
366 method=definition.method,
367 sample_count=definition.sample_count,
368 monte_carlo_seed=definition.monte_carlo_seed,
369 parameters=[
370 ParameterSweepParameterRequest(
371 property_value=param.property_value_id,
372 lower_bound=param.lower_bound,
373 upper_bound=param.upper_bound,
374 step=param.step,
375 )
376 for param in params
377 ],
378 )
379 _validate_targets(scenario, spec)
382def _generate_rows(spec: ParameterSweepRequest) -> list[list[Decimal]]:
383 """Materialize the bounded samples in the column order chosen by the user."""
385 if spec.method == ParameterSweepRequestMethodEnum.Grid:
386 value_lists = [_grid_values(parameter) for parameter in spec.parameters]
387 rows = [list(values) for values in product(*value_lists)]
388 elif spec.method == ParameterSweepRequestMethodEnum.MonteCarlo:
389 rows = _monte_carlo_rows(spec)
390 elif spec.method == ParameterSweepRequestMethodEnum.Hammersley:
391 rows = _hammersley_rows(spec)
392 elif spec.method == ParameterSweepRequestMethodEnum.HaltonZaremba: 392 ↛ 395line 392 didn't jump to line 395 because the condition on line 392 was always true
393 rows = _halton_rows(spec)
394 else:
395 raise ValidationError({"method": f"Unsupported parameter sweep method: {spec.method}"})
397 return rows
400def _calculate_row_count(spec: ParameterSweepRequest) -> int:
401 """Calculate sweep size before materializing rows so hard limits stay cheap."""
403 if spec.method != ParameterSweepRequestMethodEnum.Grid:
404 sample_count = spec.sample_count or 0
405 if sample_count > ROW_HARD_LIMIT:
406 raise ValidationError(
407 {
408 "row_count": (
409 f"Parameter sweep cannot generate more than {ROW_HARD_LIMIT} rows."
410 )
411 }
412 )
413 return sample_count
415 total = 1
416 for parameter in spec.parameters:
417 total *= _grid_value_count(parameter)
418 if total > ROW_HARD_LIMIT:
419 raise ValidationError(
420 {
421 "row_count": (
422 f"Parameter sweep cannot generate more than {ROW_HARD_LIMIT} rows."
423 )
424 }
425 )
426 return total
429def _validate_total_cell_count(row_count: int, parameter_count: int) -> None:
430 total_cells = row_count * parameter_count
431 if total_cells > CELL_HARD_LIMIT:
432 raise ValidationError(
433 {
434 "cell_count": (
435 f"Parameter sweep cannot generate more than {CELL_HARD_LIMIT} cells."
436 )
437 }
438 )
441def _grid_values(parameter: ParameterSweepParameterRequest) -> list[Decimal]:
442 count = _grid_value_count(parameter)
443 current = parameter.lower_bound
444 values: list[Decimal] = []
445 for _ in range(count):
446 values.append(current.quantize(DECIMAL_QUANT))
447 current += parameter.step or 0
448 return values
451def _grid_value_count(parameter: ParameterSweepParameterRequest) -> int:
452 step = parameter.step
453 if step is None: 453 ↛ 454line 453 didn't jump to line 454 because the condition on line 453 was never true
454 raise ValidationError({"step": "Grid sweeps require a step for every parameter."})
455 if step == 0: 455 ↛ 456line 455 didn't jump to line 456 because the condition on line 455 was never true
456 raise ValidationError({"step": "Step cannot be zero."})
458 start = parameter.lower_bound
459 end = parameter.upper_bound
460 if start == end:
461 return 1
463 if (end > start and step < 0) or (end < start and step > 0): 463 ↛ 464line 463 didn't jump to line 464 because the condition on line 463 was never true
464 raise ValidationError({"step": "Step must move from start toward end."})
466 span = abs(end - start)
467 step_size = abs(step)
468 return int((span / step_size).to_integral_value(rounding=ROUND_FLOOR)) + 1
471def _monte_carlo_rows(spec: ParameterSweepRequest) -> list[list[Decimal]]:
472 rng = random.Random(spec.monte_carlo_seed)
473 return [
474 [
475 _scale_unit_interval(Decimal(str(rng.random())), parameter)
476 for parameter in spec.parameters
477 ]
478 for _ in range(spec.sample_count or 0)
479 ]
482def _hammersley_rows(spec: ParameterSweepRequest) -> list[list[Decimal]]:
483 from skopt.sampler import Hammersly
484 from skopt.space import Real
486 dimensions = [Real(0.0, 1.0) for _ in spec.parameters]
487 points = Hammersly().generate(dimensions, spec.sample_count or 0)
488 return _scale_points(points, spec.parameters)
491def _halton_rows(spec: ParameterSweepRequest) -> list[list[Decimal]]:
492 from scipy.stats import qmc
494 sampler = qmc.Halton(d=len(spec.parameters), scramble=False)
495 points = sampler.random(spec.sample_count or 0)
496 return _scale_points(points, spec.parameters)
499def _scale_points(
500 points,
501 parameters: list[ParameterSweepParameterRequest],
502) -> list[list[Decimal]]:
503 return [
504 [
505 _scale_unit_interval(Decimal(str(point[index])), parameter)
506 for index, parameter in enumerate(parameters)
507 ]
508 for point in points
509 ]
512def _scale_unit_interval(
513 value: Decimal,
514 parameter: ParameterSweepParameterRequest,
515) -> Decimal:
516 lower = parameter.lower_bound
517 upper = parameter.upper_bound
518 return (lower + (upper - lower) * value).quantize(DECIMAL_QUANT)
521def _validate_targets(
522 scenario: Scenario,
523 spec: ParameterSweepRequest,
524) -> list[PropertyValue]:
525 """Ensure every requested target is still a settable numeric variable."""
527 target_ids = [param.property_value for param in spec.parameters]
528 if len(set(target_ids)) != len(target_ids): 528 ↛ 529line 528 didn't jump to line 529 because the condition on line 528 was never true
529 raise ValidationError({"parameters": "Each parameter can only be selected once."})
531 targets = list(
532 PropertyValue.objects.filter(id__in=target_ids)
533 .select_related("property", "property__set", "property__set__simulationObject")
534 .prefetch_related("indexedItems")
535 )
536 targets_by_id = {target.id: target for target in targets}
537 ordered_targets: list[PropertyValue] = []
539 for target_id in target_ids:
540 target = targets_by_id.get(target_id)
541 if target is None: 541 ↛ 542line 541 didn't jump to line 542 because the condition on line 541 was never true
542 raise ValidationError({"parameters": f"Property value {target_id} was not found."})
543 if target.flowsheet_id != scenario.flowsheet_id: 543 ↛ 544line 543 didn't jump to line 544 because the condition on line 543 was never true
544 raise ValidationError({"parameters": "Sweep targets must belong to the scenario flowsheet."})
545 if not _is_eligible_target(target): 545 ↛ 546line 545 didn't jump to line 546 because the condition on line 545 was never true
546 raise ValidationError(
547 {
548 "parameters": (
549 f"{_target_label(target)} is no longer eligible for parameter sweep."
550 )
551 }
552 )
553 ordered_targets.append(target)
555 return ordered_targets
558def _is_eligible_target(value: PropertyValue) -> bool:
559 return (
560 value.is_enabled()
561 and not value.formula
562 and value.property is not None
563 and value.property.type == DisplayType.numeric
564 and not hasattr(value, "controlManipulated")
565 and value.property.set is not None
566 and value.property.set.simulationObject is not None
567 and not value.property.set.simulationObject.is_deleted
568 )
571def _parameter_value_count(spec: ParameterSweepRequest, index: int) -> int:
572 if spec.method == ParameterSweepRequestMethodEnum.Grid:
573 return _grid_value_count(spec.parameters[index])
574 return spec.sample_count or 0
577def _serialize_target(value: PropertyValue) -> EligibleParameterSweepTarget:
578 return EligibleParameterSweepTarget(
579 property_value=value.id,
580 simulation_object=value.property.set.simulationObject.id,
581 simulation_object_name=value.property.set.simulationObject.componentName,
582 property_name=value.property.displayName,
583 indexed_set_names=value.get_index_names(),
584 unit=value.property.unit,
585 label=_target_label(value),
586 )
589def _target_label(value: PropertyValue) -> str:
590 object_name = value.property.set.simulationObject.componentName or "Object"
591 index_names = value.get_index_names()
592 suffix = f" {' '.join(index_names)}" if index_names else ""
593 return f"{object_name} / {value.property.displayName}{suffix}"
596def _unique_column_name(target: PropertyValue, index: int, targets: list[PropertyValue]) -> str:
597 label = _target_label(target)
598 if sum(1 for candidate in targets if _target_label(candidate) == label) == 1: 598 ↛ 600line 598 didn't jump to line 600 because the condition on line 598 was always true
599 return label
600 return f"{label} ({target.id})"
603def _coerce_sweep_request(
604 payload: ParameterSweepRequest | Mapping[str, Any],
605) -> ParameterSweepRequest:
606 if isinstance(payload, ParameterSweepRequest):
607 return payload
609 # The frontend base query injects the active flowsheet id into mutation
610 # bodies as transport context. It is not part of the sweep contract, so
611 # remove it before applying the strict Pydantic request model.
612 payload_without_transport_context = dict(payload)
613 payload_without_transport_context.pop("flowsheet", None)
615 try:
616 return ParameterSweepRequest.model_validate(payload_without_transport_context)
617 except PydanticValidationError as exc:
618 raise _drf_validation_error_from_pydantic(exc) from exc
621def _validate_grid_parameter(parameter: ParameterSweepParameterRequest) -> None:
622 if parameter.step is None: 622 ↛ 623line 622 didn't jump to line 623 because the condition on line 622 was never true
623 raise ValueError("Grid sweeps require a step for every parameter.")
624 if parameter.step == 0: 624 ↛ 625line 624 didn't jump to line 625 because the condition on line 624 was never true
625 raise ValueError("Step cannot be zero.")
627 if parameter.lower_bound == parameter.upper_bound:
628 return
629 if (
630 parameter.upper_bound > parameter.lower_bound
631 and parameter.step < 0
632 ) or (
633 parameter.upper_bound < parameter.lower_bound
634 and parameter.step > 0
635 ):
636 raise ValueError("Step must move from start toward end.")
639def _drf_validation_error_from_pydantic(exc: PydanticValidationError) -> ValidationError:
640 details: dict[str, list[str]] = {}
641 for error in exc.errors():
642 location = error.get("loc") or ("non_field_errors",)
643 field = str(location[0])
644 details.setdefault(field, []).append(str(error.get("msg", "Invalid value.")))
645 return ValidationError(details)