Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 84%

156 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-05-13 02:47 +0000

1import json 

2import logging 

3from typing import Any 

4 

5from django.http import HttpResponse 

6from django.shortcuts import get_object_or_404 

7from django.views.decorators.csrf import csrf_exempt 

8from drf_spectacular.utils import OpenApiParameter, OpenApiTypes, extend_schema 

9from pydantic import ValidationError as PydanticValidationError 

10from rest_framework import serializers 

11from rest_framework.decorators import action, api_view, authentication_classes, permission_classes 

12from rest_framework.permissions import IsAuthenticated 

13from authentication.custom_drf_authentication import DaprApiTokenAuthentication 

14from rest_framework.exceptions import NotFound, ValidationError 

15from rest_framework.parsers import JSONParser 

16from rest_framework.response import Response 

17 

18from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent 

19from core.auxiliary.models.MLModel import MLModel 

20from core.auxiliary.models.MLWizard import train 

21from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer 

22from core.auxiliary.services.object_storage.s3 import presign_download_url 

23from core.auxiliary.services.uploads import attach_upload_to_ml_model, inspect_upload_session 

24from core.viewset import ModelViewSet 

25from flowsheetInternals.unitops.models import SimulationObject 

26from idaes_factory.endpoints import process_ml_training_response 

27 

28 

29logger = logging.getLogger(__name__) 

30 

31 

32def _parse_ml_training_completion_event(data) -> MLTrainingCompletionEvent | None: 

33 try: 

34 return MLTrainingCompletionEvent.model_validate(data) 

35 except PydanticValidationError: 

36 logger.warning( 

37 "Discarding malformed ML training completion event.", 

38 exc_info=True, 

39 ) 

40 return None 

41 

42 

43class GetCsvHeaderSerializer(serializers.Serializer): 

44 headers = serializers.ListField(child=serializers.CharField()) 

45 

46 

47class OperationMessageSerializer(serializers.Serializer): 

48 message = serializers.CharField() 

49 

50 

51class CreateSurrogateModelFromColumnSerializer(serializers.Serializer): 

52 model = serializers.IntegerField() 

53 

54 

55class DownloadTestResultsSerializer(serializers.Serializer): 

56 url = serializers.URLField() 

57 

58 

59class UploadModelSerializer(serializers.Serializer): 

60 json_data = serializers.JSONField() 

61 simulationObject = serializers.IntegerField() 

62 

63 

64class UploadSessionSerializer(serializers.Serializer): 

65 simulationObject = serializers.IntegerField() 

66 upload_session_id = serializers.IntegerField() 

67 

68 

69class MLViewSet(ModelViewSet): 

70 """Manage ML CSV uploads, header lookup, model import, and surrogate training.""" 

71 

72 serializer_class = MLModelSerializer 

73 parser_classes = [JSONParser] 

74 

75 def get_queryset(self): 

76 simulationObjectId = self.request.query_params.get( 

77 "simulationObject", None) 

78 return MLModel.objects.all().filter(simulationObject=simulationObjectId) 

79 

80 @extend_schema( 

81 parameters=[ 

82 OpenApiParameter( 

83 name="simulationObject", required=True, type=OpenApiTypes.INT 

84 ), 

85 ] 

86 ) 

87 def list(self, request, *args, **kwargs): 

88 return super().list(request, *args, **kwargs) 

89 

90 @extend_schema( 

91 request=UploadSessionSerializer, 

92 responses=MLModelSerializer, 

93 ) 

94 def create(self, request, *args, **kwargs): 

95 """Create an ML model from a completed object-storage upload session.""" 

96 serializer = UploadSessionSerializer(data=request.data) 

97 serializer.is_valid(raise_exception=True) 

98 validated_data = serializer.validated_data 

99 instance = attach_upload_to_ml_model( 

100 upload_session_id=validated_data["upload_session_id"], 

101 simulation_object_id=validated_data["simulationObject"], 

102 user_id=request.user.id, 

103 ) 

104 

105 response_serializer = self.get_serializer(instance) 

106 return Response(response_serializer.data, status=201) 

107 

108 @extend_schema( 

109 parameters=[ 

110 OpenApiParameter(name="model", required=True, 

111 type=OpenApiTypes.INT), 

112 ], 

113 responses=GetCsvHeaderSerializer, 

114 ) 

115 @action( 

116 detail=False, 

117 methods=["get"], 

118 url_path="get-csv-header", 

119 url_name="get-csv-header", 

120 ) 

121 def get_csv_header(self, request): 

122 """Return the authoritative CSV headers for the selected ML model.""" 

123 model = self.request.query_params.get("model") 

124 if not model: 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true

125 raise ValidationError({"error": "model is required."}) 

126 

127 try: 

128 model = MLModel.objects.get(id=model) 

129 except MLModel.DoesNotExist: 

130 raise NotFound({"error": "MLModel not found."}) 

131 

132 if model.surrogate_model != {}: 

133 input_labels = model.surrogate_model.get("input_labels") 

134 output_labels = model.surrogate_model.get("output_labels") 

135 return Response( 

136 GetCsvHeaderSerializer({"headers": input_labels + output_labels}).data, 

137 status=200, 

138 ) 

139 

140 headers = model.csv_headers 

141 if not headers and model.csv_upload_session_id: 141 ↛ 142line 141 didn't jump to line 142 because the condition on line 141 was never true

142 inspection = inspect_upload_session(model.csv_upload_session) 

143 headers = inspection.headers 

144 model.csv_headers = headers 

145 model.csv_delimiter = inspection.delimiter 

146 model.save(update_fields=["csv_headers", "csv_delimiter"]) 

147 if not headers: 147 ↛ 148line 147 didn't jump to line 148 because the condition on line 147 was never true

148 raise ValidationError({"error": "No CSV headers found for this ML model."}) 

149 return Response(GetCsvHeaderSerializer({"headers": headers}).data, status=200) 

150 

151 @extend_schema(request=UploadModelSerializer, responses=None) 

152 @action( 

153 detail=False, 

154 methods=["post"], 

155 url_path="upload-ml-model", 

156 url_name="upload-ml-model", 

157 ) 

158 def upload_model(self, request): 

159 """Import a serialized surrogate model instead of training from CSV data.""" 

160 serializer = UploadModelSerializer(data=request.data) 

161 serializer.is_valid(raise_exception=True) 

162 validated_data = serializer.validated_data 

163 json_data = validated_data.get("json_data") 

164 simulationObject = validated_data.get("simulationObject") 

165 

166 try: 

167 simulationObject = SimulationObject.objects.get( 

168 id=simulationObject) 

169 except SimulationObject.DoesNotExist: 

170 raise NotFound({"error": "SimulationObject not found."}) 

171 

172 MLModel.objects.create( 

173 flowsheet_id=simulationObject.flowsheet_id, 

174 simulationObject=simulationObject, 

175 surrogate_model=json.loads(json_data) if isinstance(json_data, str) else json_data, 

176 progress=1, 

177 ) 

178 

179 return Response( 

180 OperationMessageSerializer({"message": "success"}).data, 

181 status=200, 

182 ) 

183 

184 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None) 

185 @action(detail=False, methods=["post"], url_path="create-surrogate-model") 

186 def create_surrogate_model(self, request): 

187 """Start surrogate-model training once the column mappings are complete.""" 

188 serializer = CreateSurrogateModelFromColumnSerializer( 

189 data=request.data) 

190 serializer.is_valid(raise_exception=True) 

191 validated_data = serializer.validated_data 

192 model = validated_data.get("model") 

193 model_instance = get_object_or_404(MLModel.objects, id=model) 

194 

195 # if already have one (imported model), skip training 

196 if model_instance.surrogate_model != {}: 196 ↛ 197line 196 didn't jump to line 197 because the condition on line 196 was never true

197 model_instance.progress = 3 

198 model_instance.save(update_fields=["progress"]) 

199 return Response( 

200 OperationMessageSerializer({"message": "successfully trained"}).data, 

201 status=200, 

202 ) 

203 

204 return train(request.user, model_instance) 

205 

206 @extend_schema( 

207 parameters=[ 

208 OpenApiParameter(name="model", required=True, 

209 type=OpenApiTypes.INT), 

210 ] 

211 ) 

212 @action( 

213 detail=False, 

214 methods=["get"], 

215 url_path="export-ml-model", 

216 url_name="export-ml-model", 

217 ) 

218 def export_flowsheet(self, request): 

219 """Export the serialized surrogate model as a downloadable JSON file.""" 

220 model_id = request.query_params.get("model") 

221 if not model_id: 221 ↛ 222line 221 didn't jump to line 222 because the condition on line 221 was never true

222 raise ValidationError({"error": "model_id parameter is required."}) 

223 

224 try: 

225 ml_model = MLModel.objects.get(id=model_id) 

226 except MLModel.DoesNotExist: 

227 raise NotFound({"error": "MLModel not found."}) 

228 

229 data = ml_model.surrogate_model 

230 response = HttpResponse( 

231 json.dumps(data, indent=4), content_type="application/json" 

232 ) 

233 response["Content-Disposition"] = f'attachment; filename="model.json"' 

234 

235 return response 

236 

237 @extend_schema( 

238 parameters=[ 

239 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT), 

240 ], 

241 responses=DownloadTestResultsSerializer, 

242 ) 

243 @action( 

244 detail=False, 

245 methods=["get"], 

246 url_path="download-test-results", 

247 url_name="download-test-results", 

248 ) 

249 def download_test_results(self, request): 

250 """Return a presigned URL for downloading the full ML test-results CSV.""" 

251 raw_model_id = request.query_params.get("model") 

252 if not raw_model_id: 252 ↛ 253line 252 didn't jump to line 253 because the condition on line 252 was never true

253 raise ValidationError({"error": "model parameter is required."}) 

254 

255 try: 

256 model_id = int(raw_model_id) 

257 except (TypeError, ValueError) as exc: 

258 raise ValidationError({"error": "model parameter must be an integer."}) from exc 

259 

260 try: 

261 ml_model = MLModel.objects.get(id=model_id) 

262 except MLModel.DoesNotExist as exc: 

263 raise NotFound({"error": "MLModel not found."}) from exc 

264 

265 if not ml_model.test_results_bucket or not ml_model.test_results_key: 

266 raise ValidationError({"error": "No test results available for this model."}) 

267 

268 url = presign_download_url( 

269 bucket=ml_model.test_results_bucket, 

270 key=ml_model.test_results_key, 

271 filename=f"test-results-{ml_model.id}.csv", 

272 expires_seconds=3600, 

273 ) 

274 return Response(DownloadTestResultsSerializer({"url": url}).data, status=200) 

275 

276 

277@extend_schema(exclude=True) 

278@api_view(["POST"]) 

279@authentication_classes([DaprApiTokenAuthentication]) 

280@permission_classes([IsAuthenticated]) 

281@csrf_exempt 

282def process_ml_training_event(request) -> Response: 

283 """Handle Dapr-delivered ML completion events and update the stored task/model.""" 

284 training_response = _parse_ml_training_completion_event(request.data) 

285 if training_response is None: 285 ↛ 286line 285 didn't jump to line 286 because the condition on line 285 was never true

286 return Response(status=200) 

287 

288 process_ml_training_response(training_response.data) 

289 return Response(status=200)