Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 84%
156 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-05-13 02:47 +0000
1import json
2import logging
3from typing import Any
5from django.http import HttpResponse
6from django.shortcuts import get_object_or_404
7from django.views.decorators.csrf import csrf_exempt
8from drf_spectacular.utils import OpenApiParameter, OpenApiTypes, extend_schema
9from pydantic import ValidationError as PydanticValidationError
10from rest_framework import serializers
11from rest_framework.decorators import action, api_view, authentication_classes, permission_classes
12from rest_framework.permissions import IsAuthenticated
13from authentication.custom_drf_authentication import DaprApiTokenAuthentication
14from rest_framework.exceptions import NotFound, ValidationError
15from rest_framework.parsers import JSONParser
16from rest_framework.response import Response
18from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent
19from core.auxiliary.models.MLModel import MLModel
20from core.auxiliary.models.MLWizard import train
21from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer
22from core.auxiliary.services.object_storage.s3 import presign_download_url
23from core.auxiliary.services.uploads import attach_upload_to_ml_model, inspect_upload_session
24from core.viewset import ModelViewSet
25from flowsheetInternals.unitops.models import SimulationObject
26from idaes_factory.endpoints import process_ml_training_response
29logger = logging.getLogger(__name__)
32def _parse_ml_training_completion_event(data) -> MLTrainingCompletionEvent | None:
33 try:
34 return MLTrainingCompletionEvent.model_validate(data)
35 except PydanticValidationError:
36 logger.warning(
37 "Discarding malformed ML training completion event.",
38 exc_info=True,
39 )
40 return None
43class GetCsvHeaderSerializer(serializers.Serializer):
44 headers = serializers.ListField(child=serializers.CharField())
47class OperationMessageSerializer(serializers.Serializer):
48 message = serializers.CharField()
51class CreateSurrogateModelFromColumnSerializer(serializers.Serializer):
52 model = serializers.IntegerField()
55class DownloadTestResultsSerializer(serializers.Serializer):
56 url = serializers.URLField()
59class UploadModelSerializer(serializers.Serializer):
60 json_data = serializers.JSONField()
61 simulationObject = serializers.IntegerField()
64class UploadSessionSerializer(serializers.Serializer):
65 simulationObject = serializers.IntegerField()
66 upload_session_id = serializers.IntegerField()
69class MLViewSet(ModelViewSet):
70 """Manage ML CSV uploads, header lookup, model import, and surrogate training."""
72 serializer_class = MLModelSerializer
73 parser_classes = [JSONParser]
75 def get_queryset(self):
76 simulationObjectId = self.request.query_params.get(
77 "simulationObject", None)
78 return MLModel.objects.all().filter(simulationObject=simulationObjectId)
80 @extend_schema(
81 parameters=[
82 OpenApiParameter(
83 name="simulationObject", required=True, type=OpenApiTypes.INT
84 ),
85 ]
86 )
87 def list(self, request, *args, **kwargs):
88 return super().list(request, *args, **kwargs)
90 @extend_schema(
91 request=UploadSessionSerializer,
92 responses=MLModelSerializer,
93 )
94 def create(self, request, *args, **kwargs):
95 """Create an ML model from a completed object-storage upload session."""
96 serializer = UploadSessionSerializer(data=request.data)
97 serializer.is_valid(raise_exception=True)
98 validated_data = serializer.validated_data
99 instance = attach_upload_to_ml_model(
100 upload_session_id=validated_data["upload_session_id"],
101 simulation_object_id=validated_data["simulationObject"],
102 user_id=request.user.id,
103 )
105 response_serializer = self.get_serializer(instance)
106 return Response(response_serializer.data, status=201)
108 @extend_schema(
109 parameters=[
110 OpenApiParameter(name="model", required=True,
111 type=OpenApiTypes.INT),
112 ],
113 responses=GetCsvHeaderSerializer,
114 )
115 @action(
116 detail=False,
117 methods=["get"],
118 url_path="get-csv-header",
119 url_name="get-csv-header",
120 )
121 def get_csv_header(self, request):
122 """Return the authoritative CSV headers for the selected ML model."""
123 model = self.request.query_params.get("model")
124 if not model: 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true
125 raise ValidationError({"error": "model is required."})
127 try:
128 model = MLModel.objects.get(id=model)
129 except MLModel.DoesNotExist:
130 raise NotFound({"error": "MLModel not found."})
132 if model.surrogate_model != {}:
133 input_labels = model.surrogate_model.get("input_labels")
134 output_labels = model.surrogate_model.get("output_labels")
135 return Response(
136 GetCsvHeaderSerializer({"headers": input_labels + output_labels}).data,
137 status=200,
138 )
140 headers = model.csv_headers
141 if not headers and model.csv_upload_session_id: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true
142 inspection = inspect_upload_session(model.csv_upload_session)
143 headers = inspection.headers
144 model.csv_headers = headers
145 model.csv_delimiter = inspection.delimiter
146 model.save(update_fields=["csv_headers", "csv_delimiter"])
147 if not headers: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true
148 raise ValidationError({"error": "No CSV headers found for this ML model."})
149 return Response(GetCsvHeaderSerializer({"headers": headers}).data, status=200)
151 @extend_schema(request=UploadModelSerializer, responses=None)
152 @action(
153 detail=False,
154 methods=["post"],
155 url_path="upload-ml-model",
156 url_name="upload-ml-model",
157 )
158 def upload_model(self, request):
159 """Import a serialized surrogate model instead of training from CSV data."""
160 serializer = UploadModelSerializer(data=request.data)
161 serializer.is_valid(raise_exception=True)
162 validated_data = serializer.validated_data
163 json_data = validated_data.get("json_data")
164 simulationObject = validated_data.get("simulationObject")
166 try:
167 simulationObject = SimulationObject.objects.get(
168 id=simulationObject)
169 except SimulationObject.DoesNotExist:
170 raise NotFound({"error": "SimulationObject not found."})
172 MLModel.objects.create(
173 flowsheet_id=simulationObject.flowsheet_id,
174 simulationObject=simulationObject,
175 surrogate_model=json.loads(json_data) if isinstance(json_data, str) else json_data,
176 progress=1,
177 )
179 return Response(
180 OperationMessageSerializer({"message": "success"}).data,
181 status=200,
182 )
184 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None)
185 @action(detail=False, methods=["post"], url_path="create-surrogate-model")
186 def create_surrogate_model(self, request):
187 """Start surrogate-model training once the column mappings are complete."""
188 serializer = CreateSurrogateModelFromColumnSerializer(
189 data=request.data)
190 serializer.is_valid(raise_exception=True)
191 validated_data = serializer.validated_data
192 model = validated_data.get("model")
193 model_instance = get_object_or_404(MLModel.objects, id=model)
195 # if already have one (imported model), skip training
196 if model_instance.surrogate_model != {}: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true
197 model_instance.progress = 3
198 model_instance.save(update_fields=["progress"])
199 return Response(
200 OperationMessageSerializer({"message": "successfully trained"}).data,
201 status=200,
202 )
204 return train(request.user, model_instance)
206 @extend_schema(
207 parameters=[
208 OpenApiParameter(name="model", required=True,
209 type=OpenApiTypes.INT),
210 ]
211 )
212 @action(
213 detail=False,
214 methods=["get"],
215 url_path="export-ml-model",
216 url_name="export-ml-model",
217 )
218 def export_flowsheet(self, request):
219 """Export the serialized surrogate model as a downloadable JSON file."""
220 model_id = request.query_params.get("model")
221 if not model_id: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true
222 raise ValidationError({"error": "model_id parameter is required."})
224 try:
225 ml_model = MLModel.objects.get(id=model_id)
226 except MLModel.DoesNotExist:
227 raise NotFound({"error": "MLModel not found."})
229 data = ml_model.surrogate_model
230 response = HttpResponse(
231 json.dumps(data, indent=4), content_type="application/json"
232 )
233 response["Content-Disposition"] = f'attachment; filename="model.json"'
235 return response
237 @extend_schema(
238 parameters=[
239 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT),
240 ],
241 responses=DownloadTestResultsSerializer,
242 )
243 @action(
244 detail=False,
245 methods=["get"],
246 url_path="download-test-results",
247 url_name="download-test-results",
248 )
249 def download_test_results(self, request):
250 """Return a presigned URL for downloading the full ML test-results CSV."""
251 raw_model_id = request.query_params.get("model")
252 if not raw_model_id: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true
253 raise ValidationError({"error": "model parameter is required."})
255 try:
256 model_id = int(raw_model_id)
257 except (TypeError, ValueError) as exc:
258 raise ValidationError({"error": "model parameter must be an integer."}) from exc
260 try:
261 ml_model = MLModel.objects.get(id=model_id)
262 except MLModel.DoesNotExist as exc:
263 raise NotFound({"error": "MLModel not found."}) from exc
265 if not ml_model.test_results_bucket or not ml_model.test_results_key:
266 raise ValidationError({"error": "No test results available for this model."})
268 url = presign_download_url(
269 bucket=ml_model.test_results_bucket,
270 key=ml_model.test_results_key,
271 filename=f"test-results-{ml_model.id}.csv",
272 expires_seconds=3600,
273 )
274 return Response(DownloadTestResultsSerializer({"url": url}).data, status=200)
277@extend_schema(exclude=True)
278@api_view(["POST"])
279@authentication_classes([DaprApiTokenAuthentication])
280@permission_classes([IsAuthenticated])
281@csrf_exempt
282def process_ml_training_event(request) -> Response:
283 """Handle Dapr-delivered ML completion events and update the stored task/model."""
284 training_response = _parse_ml_training_completion_event(request.data)
285 if training_response is None: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true
286 return Response(status=200)
288 process_ml_training_response(training_response.data)
289 return Response(status=200)