Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 88%
102 statements
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2026-03-26 20:57 +0000
1from core.viewset import ModelViewSet
2from core.auxiliary.models.MLModel import MLModel
3from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer
4from rest_framework.response import Response
5from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes
6from rest_framework.exceptions import NotFound, ValidationError
7from rest_framework.decorators import action
8import json
9from rest_framework import serializers
10from core.auxiliary.models.MLWizard import train
11from flowsheetInternals.unitops.models import SimulationObject
12from django.http import HttpResponse
14from django.views.decorators.csrf import csrf_exempt
15from authentication.custom_drf_authentication import DaprApiTokenAuthentication
16from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent
17from rest_framework.decorators import api_view, authentication_classes
18from idaes_factory.endpoints import process_ml_training_response
21class GetCsvHeaderSerializer(serializers.Serializer):
22 headers = serializers.ListField(child=serializers.CharField())
25class CreateSurrogateModelFromColumnSerializer(serializers.Serializer):
26 model = serializers.IntegerField()
29class UploadModelSerializer(serializers.Serializer):
30 json_data = serializers.JSONField()
31 simulationObject = serializers.IntegerField()
34class MLViewSet(ModelViewSet):
35 serializer_class = MLModelSerializer
37 def get_queryset(self):
38 simulationObjectId = self.request.query_params.get("simulationObject", None)
39 return MLModel.objects.all().filter(simulationObject=simulationObjectId)
41 @extend_schema(
42 parameters=[
43 OpenApiParameter(
44 name="simulationObject", required=True, type=OpenApiTypes.INT
45 ),
46 ]
47 )
48 def list(self, request, *args, **kwargs):
49 return super().list(request, *args, **kwargs)
51 def create(self, request, *args, **kwargs):
52 request.data["csv_data"] = json.dumps(request.data.get("csv_data"))
53 return super().create(request, *args, **kwargs)
55 @extend_schema(
56 parameters=[
57 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT),
58 ],
59 responses=GetCsvHeaderSerializer,
60 )
61 @action(
62 detail=False,
63 methods=["get"],
64 url_path="get-csv-header",
65 url_name="get-csv-header",
66 )
67 def get_csv_header(self, request):
68 model = self.request.query_params.get("model")
69 if not model: 69 ↛ 70line 69 didn't jump to line 70 because the condition on line 69 was never true
70 raise ValidationError({"error": "model is required."})
72 try:
73 model = MLModel.objects.get(id=model)
74 except MLModel.DoesNotExist:
75 raise NotFound({"error": "MLModel not found."})
77 if model.surrogate_model != {}:
78 input_labels = model.surrogate_model.get("input_labels")
79 output_labels = model.surrogate_model.get("output_labels")
80 return Response({"headers": input_labels + output_labels}, status=200)
82 csv = model.csv_data
83 data = json.loads(csv)
84 first_row = [key for key in data[0].keys() if key]
85 return Response({"headers": first_row}, status=200)
87 @extend_schema(request=UploadModelSerializer, responses=None)
88 @action(
89 detail=False,
90 methods=["post"],
91 url_path="upload-ml-model",
92 url_name="upload-ml-model",
93 )
94 def upload_model(self, request):
95 serializer = UploadModelSerializer(data=request.data)
96 serializer.is_valid(raise_exception=True)
97 validated_data = serializer.validated_data
98 json_data = validated_data.get("json_data")
99 simulationObject = validated_data.get("simulationObject")
100 flowsheet = request.query_params.get("flowsheet")
102 try:
103 simulationObject = SimulationObject.objects.get(id=simulationObject)
104 except SimulationObject.DoesNotExist:
105 raise NotFound({"error": "SimulationObject not found."})
107 MLModel.objects.create(
108 simulationObject=simulationObject,
109 surrogate_model=json.loads(json_data),
110 progress=1,
111 flowsheet_id=flowsheet,
112 )
114 return Response({"message": "success"}, status=200)
116 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None)
117 @action(detail=False, methods=["post"], url_path="create-surrogate-model")
118 def create_surrogate_model(self, request):
119 serializer = CreateSurrogateModelFromColumnSerializer(data=request.data)
120 serializer.is_valid(raise_exception=True)
121 validated_data = serializer.validated_data
122 model = validated_data.get("model")
123 model_instance = MLModel.objects.get(id=model)
125 # if already have one (imported model), skip training
126 if model_instance.surrogate_model != {}: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true
127 MLModel.objects.update(progress=3)
128 return Response({"message": "successfully trained"}, status=200)
130 return train(request.user, model_instance)
132 @extend_schema(
133 parameters=[
134 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT),
135 ]
136 )
137 @action(
138 detail=False,
139 methods=["get"],
140 url_path="export-ml-model",
141 url_name="export-ml-model",
142 )
143 def export_flowsheet(self, request):
144 model_id = request.query_params.get("model")
145 if not model_id: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true
146 return Response({"error": "model_id parameter is required."}, status=400)
148 try:
149 ml_model = MLModel.objects.get(id=model_id)
150 except MLModel.DoesNotExist:
151 return Response({"error": "MLModel not found."}, status=404)
153 data = ml_model.surrogate_model
154 response = HttpResponse(
155 json.dumps(data, indent=4), content_type="application/json"
156 )
157 response["Content-Disposition"] = f'attachment; filename="model.json"'
159 return response
162@extend_schema(exclude=True)
163@api_view(["POST"])
164@authentication_classes([DaprApiTokenAuthentication])
165@csrf_exempt
166def process_ml_training_event(request) -> Response:
167 training_response = MLTrainingCompletionEvent.model_validate(request.data)
168 process_ml_training_response(training_response.data)
169 return Response(status=200)