Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 81%

214 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-06-23 21:51 +0000

1import json 

2import logging 

3from typing import Any 

4 

5from django.http import HttpResponse 

6from django.shortcuts import get_object_or_404 

7from django.views.decorators.csrf import csrf_exempt 

8from django.db import transaction 

9from drf_spectacular.utils import OpenApiParameter, OpenApiTypes, extend_schema 

10from pydantic import ValidationError as PydanticValidationError 

11from rest_framework import serializers 

12from rest_framework.decorators import action, api_view, authentication_classes, permission_classes 

13from authentication.custom_drf_authentication import DaprApiTokenAuthentication 

14from rest_framework.exceptions import NotFound, ValidationError 

15from rest_framework.parsers import JSONParser 

16from rest_framework.permissions import IsAuthenticated 

17from rest_framework.response import Response 

18 

19from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent 

20from core.auxiliary.models.MLModel import MLModel 

21from core.auxiliary.models.MLWizard import train 

22from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer 

23from core.auxiliary.services.object_storage.s3 import presign_download_url 

24from core.auxiliary.services.uploads import attach_upload_to_ml_model, inspect_upload_session 

25from core.viewset import ModelViewSet 

26from flowsheetInternals.unitops.models import SimulationObject 

27from idaes_factory.endpoints import process_ml_training_response 

28 

29 

30logger = logging.getLogger(__name__) 

31 

32 

33def _parse_ml_training_completion_event(data) -> MLTrainingCompletionEvent | None: 

34 try: 

35 return MLTrainingCompletionEvent.model_validate(data) 

36 except PydanticValidationError: 

37 logger.warning( 

38 "Discarding malformed ML training completion event.", 

39 exc_info=True, 

40 ) 

41 return None 

42 

43 

44class GetCsvHeaderSerializer(serializers.Serializer): 

45 headers = serializers.ListField(child=serializers.CharField()) 

46 

47 

48class OperationMessageSerializer(serializers.Serializer): 

49 message = serializers.CharField() 

50 

51 

52class CreateSurrogateModelFromColumnSerializer(serializers.Serializer): 

53 model = serializers.IntegerField() 

54 

55 

56class DownloadTestResultsSerializer(serializers.Serializer): 

57 url = serializers.URLField() 

58 

59 

60class UploadModelSerializer(serializers.Serializer): 

61 json_data = serializers.JSONField() 

62 simulationObject = serializers.IntegerField() 

63 

64 

65class CreateMLModelSerializer(serializers.Serializer): 

66 simulationObject = serializers.IntegerField() 

67 surrogate_model = serializers.JSONField(required=False, default=dict) 

68 

69 

70class UpdateMLModelSerializer(serializers.Serializer): 

71 displayName = serializers.CharField(required=False, max_length=64) 

72 surrogate_model = serializers.JSONField(required=False) 

73 progress = serializers.IntegerField(required=False) 

74 csv_upload_session = serializers.IntegerField(required=False) 

75 

76 def validate(self, attrs): 

77 if "csv_upload_session" in attrs and "surrogate_model" in attrs: 77 ↛ 78line 77 didn't jump to line 78 because the condition on line 77 was never true

78 raise serializers.ValidationError( 

79 { 

80 "non_field_errors": [ 

81 "csv_upload_session and surrogate_model cannot be updated in the same request." 

82 ] 

83 } 

84 ) 

85 return attrs 

86 

87 

88class MLViewSet(ModelViewSet): 

89 """Manage ML CSV uploads, header lookup, model import, and surrogate training.""" 

90 

91 serializer_class = MLModelSerializer 

92 parser_classes = [JSONParser] 

93 

94 def get_queryset(self): 

95 return MLModel.objects.all() 

96 

97 @staticmethod 

98 def _validate_single_model_rule(simulation_object: SimulationObject) -> None: 

99 if simulation_object.objectType != "machineLearningBlock": 

100 return 

101 if MLModel.objects.filter(simulationObject=simulation_object).exists(): 101 ↛ exitline 101 didn't return from function '_validate_single_model_rule' because the condition on line 101 was always true

102 raise ValidationError( 

103 {"simulationObject": "machineLearningBlock can only have one ML model at a time."} 

104 ) 

105 

106 

107 def _update_ml_model(self, request, *, partial: bool) -> Response: 

108 """Handle ML-model edits while preserving upload-attach side effects.""" 

109 instance = self.get_object() 

110 request_serializer = UpdateMLModelSerializer(data=request.data, partial=partial) 

111 request_serializer.is_valid(raise_exception=True) 

112 validated_data = dict(request_serializer.validated_data) 

113 

114 with transaction.atomic(): 

115 upload_session_id = validated_data.pop("csv_upload_session", None) 

116 if upload_session_id is not None: 

117 instance = attach_upload_to_ml_model( 

118 instance=instance, 

119 upload_session_id=upload_session_id, 

120 simulation_object_id=instance.simulationObject_id, 

121 user_id=request.user.id, 

122 ) 

123 

124 # CSV attachment always derives progress from the upload workflow. 

125 validated_data.pop("progress", None) 

126 

127 if validated_data: 

128 serializer = self.get_serializer( 

129 instance, 

130 data=validated_data, 

131 partial=partial, 

132 ) 

133 serializer.is_valid(raise_exception=True) 

134 self.perform_update(serializer) 

135 instance = serializer.instance 

136 

137 instance.refresh_from_db() 

138 return Response(self.get_serializer(instance).data, status=200) 

139 

140 @extend_schema( 

141 request=UpdateMLModelSerializer, 

142 responses=MLModelSerializer, 

143 ) 

144 def update(self, request, *args, **kwargs) -> Response: 

145 return self._update_ml_model(request, partial=False) 

146 

147 @extend_schema( 

148 request=UpdateMLModelSerializer, 

149 responses=MLModelSerializer, 

150 ) 

151 def partial_update(self, request, *args, **kwargs) -> Response: 

152 return self._update_ml_model(request, partial=True) 

153 

154 @extend_schema( 

155 parameters=[ 

156 OpenApiParameter( 

157 name="simulationObject", required=True, type=OpenApiTypes.INT 

158 ), 

159 ] 

160 ) 

161 def list(self, request, *args, **kwargs): 

162 queryset = MLModel.objects.all().filter(simulationObject=self.request.query_params.get("simulationObject", None)) 

163 serializer = self.get_serializer(queryset, many=True) 

164 return Response(serializer.data, status=200) 

165 

166 @extend_schema( 

167 request=CreateMLModelSerializer, 

168 responses=MLModelSerializer, 

169 ) 

170 def create(self, request, *args, **kwargs): 

171 """Create an ML model with simulationObject and optional surrogate_model.""" 

172 

173 serializer = CreateMLModelSerializer(data=request.data) 

174 serializer.is_valid(raise_exception=True) 

175 validated_data = serializer.validated_data 

176 

177 simulation_object_id = validated_data["simulationObject"] 

178 surrogate_model = validated_data.get("surrogate_model", {}) 

179 

180 try: 

181 simulation_object = SimulationObject.objects.get(id=simulation_object_id) 

182 except SimulationObject.DoesNotExist: 

183 raise NotFound({"error": "SimulationObject not found."}) 

184 

185 self._validate_single_model_rule(simulation_object) 

186 

187 instance = MLModel.objects.create( 

188 flowsheet_id=simulation_object.flowsheet_id, 

189 simulationObject=simulation_object, 

190 surrogate_model=surrogate_model, 

191 progress=1 if surrogate_model else 0, 

192 ) 

193 

194 response_serializer = self.get_serializer(instance) 

195 return Response(response_serializer.data, status=201) 

196 

197 @extend_schema( 

198 parameters=[ 

199 OpenApiParameter(name="model", required=True, 

200 type=OpenApiTypes.INT), 

201 ], 

202 responses=GetCsvHeaderSerializer, 

203 ) 

204 @action( 

205 detail=False, 

206 methods=["get"], 

207 url_path="get-csv-header", 

208 url_name="get-csv-header", 

209 ) 

210 def get_csv_header(self, request): 

211 """Return the authoritative CSV headers for the selected ML model.""" 

212 model = self.request.query_params.get("model") 

213 if not model: 213 ↛ 214line 213 didn't jump to line 214 because the condition on line 213 was never true

214 raise ValidationError({"error": "model is required."}) 

215 

216 try: 

217 model = MLModel.objects.get(id=model) 

218 except MLModel.DoesNotExist: 

219 raise NotFound({"error": "MLModel not found."}) 

220 

221 if model.surrogate_model != {}: 

222 input_labels = model.surrogate_model.get("input_labels") 

223 output_labels = model.surrogate_model.get("output_labels") 

224 return Response( 

225 GetCsvHeaderSerializer({"headers": input_labels + output_labels}).data, 

226 status=200, 

227 ) 

228 

229 headers = model.csv_headers 

230 if not headers and model.csv_upload_session_id: 230 ↛ 231line 230 didn't jump to line 231 because the condition on line 230 was never true

231 inspection = inspect_upload_session(model.csv_upload_session) 

232 headers = inspection.headers 

233 model.csv_headers = headers 

234 model.csv_delimiter = inspection.delimiter 

235 model.save(update_fields=["csv_headers", "csv_delimiter"]) 

236 if not headers: 236 ↛ 237line 236 didn't jump to line 237 because the condition on line 236 was never true

237 raise ValidationError({"error": "No CSV headers found for this ML model."}) 

238 return Response(GetCsvHeaderSerializer({"headers": headers}).data, status=200) 

239 

240 @extend_schema(request=UploadModelSerializer, responses=None) 

241 @action( 

242 detail=False, 

243 methods=["post"], 

244 url_path="upload-ml-model", 

245 url_name="upload-ml-model", 

246 ) 

247 def upload_model(self, request): 

248 """Import a serialized surrogate model instead of training from CSV data.""" 

249 serializer = UploadModelSerializer(data=request.data) 

250 serializer.is_valid(raise_exception=True) 

251 validated_data = serializer.validated_data 

252 json_data = validated_data.get("json_data") 

253 simulationObject = validated_data.get("simulationObject") 

254 

255 try: 

256 simulationObject = SimulationObject.objects.get( 

257 id=simulationObject) 

258 except SimulationObject.DoesNotExist: 

259 raise NotFound({"error": "SimulationObject not found."}) 

260 

261 self._validate_single_model_rule(simulationObject) 

262 

263 MLModel.objects.create( 

264 flowsheet_id=simulationObject.flowsheet_id, 

265 simulationObject=simulationObject, 

266 surrogate_model=json.loads(json_data) if isinstance(json_data, str) else json_data, 

267 progress=1, 

268 ) 

269 

270 return Response( 

271 OperationMessageSerializer({"message": "success"}).data, 

272 status=200, 

273 ) 

274 

275 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None) 

276 @action(detail=False, methods=["post"], url_path="create-surrogate-model") 

277 def create_surrogate_model(self, request): 

278 """Start surrogate-model training once the column mappings are complete.""" 

279 serializer = CreateSurrogateModelFromColumnSerializer( 

280 data=request.data) 

281 serializer.is_valid(raise_exception=True) 

282 validated_data = serializer.validated_data 

283 model = validated_data.get("model") 

284 model_instance = get_object_or_404(MLModel.objects, id=model) 

285 

286 # if already have one (imported model), skip training 

287 if model_instance.surrogate_model != {}: 287 ↛ 288line 287 didn't jump to line 288 because the condition on line 287 was never true

288 model_instance.progress = 3 

289 model_instance.save(update_fields=["progress"]) 

290 return Response( 

291 OperationMessageSerializer({"message": "successfully trained"}).data, 

292 status=200, 

293 ) 

294 

295 return train(request.user, model_instance) 

296 

297 @extend_schema( 

298 parameters=[ 

299 OpenApiParameter(name="model", required=True, 

300 type=OpenApiTypes.INT), 

301 ] 

302 ) 

303 @action( 

304 detail=False, 

305 methods=["get"], 

306 url_path="export-ml-model", 

307 url_name="export-ml-model", 

308 ) 

309 def export_flowsheet(self, request): 

310 """Export the serialized surrogate model as a downloadable JSON file.""" 

311 model_id = request.query_params.get("model") 

312 if not model_id: 312 ↛ 313line 312 didn't jump to line 313 because the condition on line 312 was never true

313 raise ValidationError({"error": "model_id parameter is required."}) 

314 

315 try: 

316 ml_model = MLModel.objects.get(id=model_id) 

317 except MLModel.DoesNotExist: 

318 raise NotFound({"error": "MLModel not found."}) 

319 

320 data = ml_model.surrogate_model 

321 response = HttpResponse( 

322 json.dumps(data, indent=4), content_type="application/json" 

323 ) 

324 response["Content-Disposition"] = f'attachment; filename="model.json"' 

325 

326 return response 

327 

328 @extend_schema( 

329 parameters=[ 

330 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT), 

331 ], 

332 responses=DownloadTestResultsSerializer, 

333 ) 

334 @action( 

335 detail=False, 

336 methods=["get"], 

337 url_path="download-test-results", 

338 url_name="download-test-results", 

339 ) 

340 def download_test_results(self, request): 

341 """Return a presigned URL for downloading the full ML test-results CSV.""" 

342 raw_model_id = request.query_params.get("model") 

343 if not raw_model_id: 343 ↛ 344line 343 didn't jump to line 344 because the condition on line 343 was never true

344 raise ValidationError({"error": "model parameter is required."}) 

345 

346 try: 

347 model_id = int(raw_model_id) 

348 except (TypeError, ValueError) as exc: 

349 raise ValidationError({"error": "model parameter must be an integer."}) from exc 

350 

351 try: 

352 ml_model = MLModel.objects.get(id=model_id) 

353 except MLModel.DoesNotExist as exc: 

354 raise NotFound({"error": "MLModel not found."}) from exc 

355 

356 if not ml_model.test_results_bucket or not ml_model.test_results_key: 

357 raise ValidationError({"error": "No test results available for this model."}) 

358 

359 url = presign_download_url( 

360 bucket=ml_model.test_results_bucket, 

361 key=ml_model.test_results_key, 

362 filename=f"test-results-{ml_model.id}.csv", 

363 expires_seconds=3600, 

364 ) 

365 return Response(DownloadTestResultsSerializer({"url": url}).data, status=200) 

366 

367 

368 def destroy(self,request, *args, **kwargs): 

369 # Delete custom properties associated with the ML model 

370 # This is necessary to avoid leaving orphaned custom properties that reference the deleted ML model  

371 try: 

372 ml_model: MLModel = self.get_object() 

373 flowsheet_id = ml_model.flowsheet_id 

374 for columnMapping in ml_model.MLColumnMappings.all(): 

375 if columnMapping.portIndex == -1: # custom property 375 ↛ 374line 375 didn't jump to line 374 because the condition on line 375 was always true

376 columnMapping.propertyInfo.delete() # delete the custom property 

377 

378 except MLModel.DoesNotExist: 

379 raise NotFound({"error": "MLModel not found."}) 

380 return super().destroy(request, *args, **kwargs) 

381 

382@extend_schema(exclude=True) 

383@api_view(["POST"]) 

384@authentication_classes([DaprApiTokenAuthentication]) 

385@permission_classes([IsAuthenticated]) 

386@csrf_exempt 

387def process_ml_training_event(request) -> Response: 

388 """Handle Dapr-delivered ML completion events and update the stored task/model.""" 

389 training_response = _parse_ml_training_completion_event(request.data) 

390 if training_response is None: 390 ↛ 391line 390 didn't jump to line 391 because the condition on line 390 was never true

391 return Response(status=200) 

392 

393 process_ml_training_response(training_response.data) 

394 return Response(status=200)