Coverage for backend/core/auxiliary/viewsets/MLViewSet.py: 84%
102 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
1from core.viewset import ModelViewSet
2from core.auxiliary.models.MLModel import MLModel
3from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer
4from rest_framework.response import Response
5from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes
6from rest_framework.exceptions import NotFound, ValidationError
7from rest_framework.decorators import action
8import json
9from rest_framework import serializers
10from core.auxiliary.models.MLWizard import train
11from flowsheetInternals.unitops.models import SimulationObject
12from django.http import HttpResponse
14from django.views.decorators.csrf import csrf_exempt
15from authentication.custom_drf_authentication import DaprApiTokenAuthentication
16from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent
17from rest_framework.decorators import api_view, authentication_classes
18from idaes_factory.endpoints import process_ml_training_response
20class GetCsvHeaderSerializer(serializers.Serializer):
21 headers = serializers.ListField(child=serializers.CharField())
24class CreateSurrogateModelFromColumnSerializer(serializers.Serializer):
25 model = serializers.IntegerField()
28class UploadModelSerializer(serializers.Serializer):
29 json_data = serializers.JSONField()
30 simulationObject = serializers.IntegerField()
33class MLViewSet(ModelViewSet):
34 serializer_class = MLModelSerializer
36 def get_queryset(self):
37 simulationObjectId = self.request.query_params.get(
38 'simulationObject', None)
39 return MLModel.objects.all().filter(simulationObject=simulationObjectId)
41 @extend_schema(
42 parameters=[
43 OpenApiParameter(name="simulationObject",
44 required=True, type=OpenApiTypes.INT),
45 ]
46 )
47 def list(self, request, *args, **kwargs):
48 return super().list(request, *args, **kwargs)
50 def create(self, request, *args, **kwargs):
51 request.data["csv_data"] = json.dumps(request.data.get("csv_data"))
52 return super().create(request, *args, **kwargs)
55 @extend_schema(
56 parameters=[
57 OpenApiParameter(name="model", required=True,
58 type=OpenApiTypes.INT),
59 ],
60 responses=GetCsvHeaderSerializer
61 )
62 @action(detail=False, methods=['get'], url_path='get-csv-header', url_name='get-csv-header')
63 def get_csv_header(self, request):
64 model = self.request.query_params.get("model")
65 if not model: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true
66 raise ValidationError({"error": "model is required."})
68 try:
69 model = MLModel.objects.get(id=model)
70 except MLModel.DoesNotExist:
71 raise NotFound({"error": "MLModel not found."})
73 if model.surrogate_model != {}:
74 input_labels = model.surrogate_model.get("input_labels")
75 output_labels = model.surrogate_model.get("output_labels")
76 return Response({"headers": input_labels + output_labels}, status=200)
78 csv = model.csv_data
79 data = json.loads(csv)
80 first_row = [key for key in data[0].keys() if key]
81 return Response({"headers": first_row}, status=200)
84 @extend_schema(
85 request=UploadModelSerializer,
86 responses=None
87 )
88 @action(detail=False, methods=['post'], url_path='upload-ml-model', url_name='upload-ml-model')
89 def upload_model(self, request):
90 serializer = UploadModelSerializer(data=request.data)
91 serializer.is_valid(raise_exception=True)
92 validated_data = serializer.validated_data
93 json_data = validated_data.get("json_data")
94 simulationObject = validated_data.get("simulationObject")
95 flowsheet = request.query_params.get("flowsheet")
97 try:
98 simulationObject = SimulationObject.objects.get(
99 id=simulationObject)
100 except SimulationObject.DoesNotExist:
101 raise NotFound({"error": "SimulationObject not found."})
103 MLModel.objects.create(
104 simulationObject=simulationObject,
105 surrogate_model=json.loads(json_data),
106 progress=1,
107 flowsheet_id=flowsheet
108 )
110 return Response({"message": "success"}, status=200)
113 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None)
114 @action(detail=False, methods=['post'], url_path='create-surrogate-model')
115 def create_surrogate_model(self, request):
116 serializer = CreateSurrogateModelFromColumnSerializer(
117 data=request.data)
118 serializer.is_valid(raise_exception=True)
119 validated_data = serializer.validated_data
120 model = validated_data.get("model")
121 model_instance = MLModel.objects.get(id=model)
123 # if already have one (imported model), skip training
124 if model_instance.surrogate_model != {}: 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true
125 MLModel.objects.update(
126 progress=3
127 )
128 return Response({"message": "successfully trained"}, status=200)
130 return train(request.user, model_instance)
132 @extend_schema(parameters=[
133 OpenApiParameter(name="model", required=True,
134 type=OpenApiTypes.INT),
135 ])
136 @action(detail=False, methods=['get'], url_path='export-ml-model', url_name='export-ml-model')
137 def export_flowsheet(self, request):
138 model_id = request.query_params.get("model")
139 if not model_id: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true
140 return Response(
141 {"error": "model_id parameter is required."},
142 status=400
143 )
145 try:
146 ml_model = MLModel.objects.get(id=model_id)
147 except MLModel.DoesNotExist:
148 return Response(
149 {"error": "MLModel not found."},
150 status=404
151 )
153 data = ml_model.surrogate_model
154 response = HttpResponse(json.dumps(
155 data, indent=4), content_type='application/json')
156 response['Content-Disposition'] = f'attachment; filename="model.json"'
158 return response
160@extend_schema(exclude=True)
161@api_view(['POST'])
162@authentication_classes([DaprApiTokenAuthentication])
163@csrf_exempt
164def process_ml_training_event(request) -> Response:
165 training_response = MLTrainingCompletionEvent.model_validate(request.data)
166 process_ml_training_response(training_response.data)
167 return Response(status=200)