Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 81%
214 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-06-23 21:51 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-06-23 21:51 +0000
1import json
2import logging
3from typing import Any
5from django.http import HttpResponse
6from django.shortcuts import get_object_or_404
7from django.views.decorators.csrf import csrf_exempt
8from django.db import transaction
9from drf_spectacular.utils import OpenApiParameter, OpenApiTypes, extend_schema
10from pydantic import ValidationError as PydanticValidationError
11from rest_framework import serializers
12from rest_framework.decorators import action, api_view, authentication_classes, permission_classes
13from authentication.custom_drf_authentication import DaprApiTokenAuthentication
14from rest_framework.exceptions import NotFound, ValidationError
15from rest_framework.parsers import JSONParser
16from rest_framework.permissions import IsAuthenticated
17from rest_framework.response import Response
19from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent
20from core.auxiliary.models.MLModel import MLModel
21from core.auxiliary.models.MLWizard import train
22from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer
23from core.auxiliary.services.object_storage.s3 import presign_download_url
24from core.auxiliary.services.uploads import attach_upload_to_ml_model, inspect_upload_session
25from core.viewset import ModelViewSet
26from flowsheetInternals.unitops.models import SimulationObject
27from idaes_factory.endpoints import process_ml_training_response
30logger = logging.getLogger(__name__)
33def _parse_ml_training_completion_event(data) -> MLTrainingCompletionEvent | None:
34 try:
35 return MLTrainingCompletionEvent.model_validate(data)
36 except PydanticValidationError:
37 logger.warning(
38 "Discarding malformed ML training completion event.",
39 exc_info=True,
40 )
41 return None
44class GetCsvHeaderSerializer(serializers.Serializer):
45 headers = serializers.ListField(child=serializers.CharField())
48class OperationMessageSerializer(serializers.Serializer):
49 message = serializers.CharField()
52class CreateSurrogateModelFromColumnSerializer(serializers.Serializer):
53 model = serializers.IntegerField()
56class DownloadTestResultsSerializer(serializers.Serializer):
57 url = serializers.URLField()
60class UploadModelSerializer(serializers.Serializer):
61 json_data = serializers.JSONField()
62 simulationObject = serializers.IntegerField()
65class CreateMLModelSerializer(serializers.Serializer):
66 simulationObject = serializers.IntegerField()
67 surrogate_model = serializers.JSONField(required=False, default=dict)
70class UpdateMLModelSerializer(serializers.Serializer):
71 displayName = serializers.CharField(required=False, max_length=64)
72 surrogate_model = serializers.JSONField(required=False)
73 progress = serializers.IntegerField(required=False)
74 csv_upload_session = serializers.IntegerField(required=False)
76 def validate(self, attrs):
77 if "csv_upload_session" in attrs and "surrogate_model" in attrs: 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true
78 raise serializers.ValidationError(
79 {
80 "non_field_errors": [
81 "csv_upload_session and surrogate_model cannot be updated in the same request."
82 ]
83 }
84 )
85 return attrs
88class MLViewSet(ModelViewSet):
89 """Manage ML CSV uploads, header lookup, model import, and surrogate training."""
91 serializer_class = MLModelSerializer
92 parser_classes = [JSONParser]
94 def get_queryset(self):
95 return MLModel.objects.all()
97 @staticmethod
98 def _validate_single_model_rule(simulation_object: SimulationObject) -> None:
99 if simulation_object.objectType != "machineLearningBlock":
100 return
101 if MLModel.objects.filter(simulationObject=simulation_object).exists(): 101 ↛ exitline 101 didn't return from function '_validate_single_model_rule' because the condition on line 101 was always true
102 raise ValidationError(
103 {"simulationObject": "machineLearningBlock can only have one ML model at a time."}
104 )
107 def _update_ml_model(self, request, *, partial: bool) -> Response:
108 """Handle ML-model edits while preserving upload-attach side effects."""
109 instance = self.get_object()
110 request_serializer = UpdateMLModelSerializer(data=request.data, partial=partial)
111 request_serializer.is_valid(raise_exception=True)
112 validated_data = dict(request_serializer.validated_data)
114 with transaction.atomic():
115 upload_session_id = validated_data.pop("csv_upload_session", None)
116 if upload_session_id is not None:
117 instance = attach_upload_to_ml_model(
118 instance=instance,
119 upload_session_id=upload_session_id,
120 simulation_object_id=instance.simulationObject_id,
121 user_id=request.user.id,
122 )
124 # CSV attachment always derives progress from the upload workflow.
125 validated_data.pop("progress", None)
127 if validated_data:
128 serializer = self.get_serializer(
129 instance,
130 data=validated_data,
131 partial=partial,
132 )
133 serializer.is_valid(raise_exception=True)
134 self.perform_update(serializer)
135 instance = serializer.instance
137 instance.refresh_from_db()
138 return Response(self.get_serializer(instance).data, status=200)
140 @extend_schema(
141 request=UpdateMLModelSerializer,
142 responses=MLModelSerializer,
143 )
144 def update(self, request, *args, **kwargs) -> Response:
145 return self._update_ml_model(request, partial=False)
147 @extend_schema(
148 request=UpdateMLModelSerializer,
149 responses=MLModelSerializer,
150 )
151 def partial_update(self, request, *args, **kwargs) -> Response:
152 return self._update_ml_model(request, partial=True)
154 @extend_schema(
155 parameters=[
156 OpenApiParameter(
157 name="simulationObject", required=True, type=OpenApiTypes.INT
158 ),
159 ]
160 )
161 def list(self, request, *args, **kwargs):
162 queryset = MLModel.objects.all().filter(simulationObject=self.request.query_params.get("simulationObject", None))
163 serializer = self.get_serializer(queryset, many=True)
164 return Response(serializer.data, status=200)
166 @extend_schema(
167 request=CreateMLModelSerializer,
168 responses=MLModelSerializer,
169 )
170 def create(self, request, *args, **kwargs):
171 """Create an ML model with simulationObject and optional surrogate_model."""
173 serializer = CreateMLModelSerializer(data=request.data)
174 serializer.is_valid(raise_exception=True)
175 validated_data = serializer.validated_data
177 simulation_object_id = validated_data["simulationObject"]
178 surrogate_model = validated_data.get("surrogate_model", {})
180 try:
181 simulation_object = SimulationObject.objects.get(id=simulation_object_id)
182 except SimulationObject.DoesNotExist:
183 raise NotFound({"error": "SimulationObject not found."})
185 self._validate_single_model_rule(simulation_object)
187 instance = MLModel.objects.create(
188 flowsheet_id=simulation_object.flowsheet_id,
189 simulationObject=simulation_object,
190 surrogate_model=surrogate_model,
191 progress=1 if surrogate_model else 0,
192 )
194 response_serializer = self.get_serializer(instance)
195 return Response(response_serializer.data, status=201)
197 @extend_schema(
198 parameters=[
199 OpenApiParameter(name="model", required=True,
200 type=OpenApiTypes.INT),
201 ],
202 responses=GetCsvHeaderSerializer,
203 )
204 @action(
205 detail=False,
206 methods=["get"],
207 url_path="get-csv-header",
208 url_name="get-csv-header",
209 )
210 def get_csv_header(self, request):
211 """Return the authoritative CSV headers for the selected ML model."""
212 model = self.request.query_params.get("model")
213 if not model: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true
214 raise ValidationError({"error": "model is required."})
216 try:
217 model = MLModel.objects.get(id=model)
218 except MLModel.DoesNotExist:
219 raise NotFound({"error": "MLModel not found."})
221 if model.surrogate_model != {}:
222 input_labels = model.surrogate_model.get("input_labels")
223 output_labels = model.surrogate_model.get("output_labels")
224 return Response(
225 GetCsvHeaderSerializer({"headers": input_labels + output_labels}).data,
226 status=200,
227 )
229 headers = model.csv_headers
230 if not headers and model.csv_upload_session_id: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true
231 inspection = inspect_upload_session(model.csv_upload_session)
232 headers = inspection.headers
233 model.csv_headers = headers
234 model.csv_delimiter = inspection.delimiter
235 model.save(update_fields=["csv_headers", "csv_delimiter"])
236 if not headers: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true
237 raise ValidationError({"error": "No CSV headers found for this ML model."})
238 return Response(GetCsvHeaderSerializer({"headers": headers}).data, status=200)
240 @extend_schema(request=UploadModelSerializer, responses=None)
241 @action(
242 detail=False,
243 methods=["post"],
244 url_path="upload-ml-model",
245 url_name="upload-ml-model",
246 )
247 def upload_model(self, request):
248 """Import a serialized surrogate model instead of training from CSV data."""
249 serializer = UploadModelSerializer(data=request.data)
250 serializer.is_valid(raise_exception=True)
251 validated_data = serializer.validated_data
252 json_data = validated_data.get("json_data")
253 simulationObject = validated_data.get("simulationObject")
255 try:
256 simulationObject = SimulationObject.objects.get(
257 id=simulationObject)
258 except SimulationObject.DoesNotExist:
259 raise NotFound({"error": "SimulationObject not found."})
261 self._validate_single_model_rule(simulationObject)
263 MLModel.objects.create(
264 flowsheet_id=simulationObject.flowsheet_id,
265 simulationObject=simulationObject,
266 surrogate_model=json.loads(json_data) if isinstance(json_data, str) else json_data,
267 progress=1,
268 )
270 return Response(
271 OperationMessageSerializer({"message": "success"}).data,
272 status=200,
273 )
275 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None)
276 @action(detail=False, methods=["post"], url_path="create-surrogate-model")
277 def create_surrogate_model(self, request):
278 """Start surrogate-model training once the column mappings are complete."""
279 serializer = CreateSurrogateModelFromColumnSerializer(
280 data=request.data)
281 serializer.is_valid(raise_exception=True)
282 validated_data = serializer.validated_data
283 model = validated_data.get("model")
284 model_instance = get_object_or_404(MLModel.objects, id=model)
286 # if already have one (imported model), skip training
287 if model_instance.surrogate_model != {}: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true
288 model_instance.progress = 3
289 model_instance.save(update_fields=["progress"])
290 return Response(
291 OperationMessageSerializer({"message": "successfully trained"}).data,
292 status=200,
293 )
295 return train(request.user, model_instance)
297 @extend_schema(
298 parameters=[
299 OpenApiParameter(name="model", required=True,
300 type=OpenApiTypes.INT),
301 ]
302 )
303 @action(
304 detail=False,
305 methods=["get"],
306 url_path="export-ml-model",
307 url_name="export-ml-model",
308 )
309 def export_flowsheet(self, request):
310 """Export the serialized surrogate model as a downloadable JSON file."""
311 model_id = request.query_params.get("model")
312 if not model_id: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true
313 raise ValidationError({"error": "model_id parameter is required."})
315 try:
316 ml_model = MLModel.objects.get(id=model_id)
317 except MLModel.DoesNotExist:
318 raise NotFound({"error": "MLModel not found."})
320 data = ml_model.surrogate_model
321 response = HttpResponse(
322 json.dumps(data, indent=4), content_type="application/json"
323 )
324 response["Content-Disposition"] = f'attachment; filename="model.json"'
326 return response
328 @extend_schema(
329 parameters=[
330 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT),
331 ],
332 responses=DownloadTestResultsSerializer,
333 )
334 @action(
335 detail=False,
336 methods=["get"],
337 url_path="download-test-results",
338 url_name="download-test-results",
339 )
340 def download_test_results(self, request):
341 """Return a presigned URL for downloading the full ML test-results CSV."""
342 raw_model_id = request.query_params.get("model")
343 if not raw_model_id: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true
344 raise ValidationError({"error": "model parameter is required."})
346 try:
347 model_id = int(raw_model_id)
348 except (TypeError, ValueError) as exc:
349 raise ValidationError({"error": "model parameter must be an integer."}) from exc
351 try:
352 ml_model = MLModel.objects.get(id=model_id)
353 except MLModel.DoesNotExist as exc:
354 raise NotFound({"error": "MLModel not found."}) from exc
356 if not ml_model.test_results_bucket or not ml_model.test_results_key:
357 raise ValidationError({"error": "No test results available for this model."})
359 url = presign_download_url(
360 bucket=ml_model.test_results_bucket,
361 key=ml_model.test_results_key,
362 filename=f"test-results-{ml_model.id}.csv",
363 expires_seconds=3600,
364 )
365 return Response(DownloadTestResultsSerializer({"url": url}).data, status=200)
368 def destroy(self,request, *args, **kwargs):
369 # Delete custom properties associated with the ML model
370 # This is necessary to avoid leaving orphaned custom properties that reference the deleted ML model
371 try:
372 ml_model: MLModel = self.get_object()
373 flowsheet_id = ml_model.flowsheet_id
374 for columnMapping in ml_model.MLColumnMappings.all():
375 if columnMapping.portIndex == -1: # custom property 375 ↛ 374line 375 didn't jump to line 374 because the condition on line 375 was always true
376 columnMapping.propertyInfo.delete() # delete the custom property
378 except MLModel.DoesNotExist:
379 raise NotFound({"error": "MLModel not found."})
380 return super().destroy(request, *args, **kwargs)
382@extend_schema(exclude=True)
383@api_view(["POST"])
384@authentication_classes([DaprApiTokenAuthentication])
385@permission_classes([IsAuthenticated])
386@csrf_exempt
387def process_ml_training_event(request) -> Response:
388 """Handle Dapr-delivered ML completion events and update the stored task/model."""
389 training_response = _parse_ml_training_completion_event(request.data)
390 if training_response is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true
391 return Response(status=200)
393 process_ml_training_response(training_response.data)
394 return Response(status=200)