Coverage for backend/django/core/auxiliary/viewsets/MLViewSet.py: 88%

102 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2026-03-26 20:57 +0000

1from core.viewset import ModelViewSet 

2from core.auxiliary.models.MLModel import MLModel 

3from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer 

4from rest_framework.response import Response 

5from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes 

6from rest_framework.exceptions import NotFound, ValidationError 

7from rest_framework.decorators import action 

8import json 

9from rest_framework import serializers 

10from core.auxiliary.models.MLWizard import train 

11from flowsheetInternals.unitops.models import SimulationObject 

12from django.http import HttpResponse 

13 

14from django.views.decorators.csrf import csrf_exempt 

15from authentication.custom_drf_authentication import DaprApiTokenAuthentication 

16from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent 

17from rest_framework.decorators import api_view, authentication_classes 

18from idaes_factory.endpoints import process_ml_training_response 

19 

20 

21class GetCsvHeaderSerializer(serializers.Serializer): 

22 headers = serializers.ListField(child=serializers.CharField()) 

23 

24 

25class CreateSurrogateModelFromColumnSerializer(serializers.Serializer): 

26 model = serializers.IntegerField() 

27 

28 

29class UploadModelSerializer(serializers.Serializer): 

30 json_data = serializers.JSONField() 

31 simulationObject = serializers.IntegerField() 

32 

33 

34class MLViewSet(ModelViewSet): 

35 serializer_class = MLModelSerializer 

36 

37 def get_queryset(self): 

38 simulationObjectId = self.request.query_params.get("simulationObject", None) 

39 return MLModel.objects.all().filter(simulationObject=simulationObjectId) 

40 

41 @extend_schema( 

42 parameters=[ 

43 OpenApiParameter( 

44 name="simulationObject", required=True, type=OpenApiTypes.INT 

45 ), 

46 ] 

47 ) 

48 def list(self, request, *args, **kwargs): 

49 return super().list(request, *args, **kwargs) 

50 

51 def create(self, request, *args, **kwargs): 

52 request.data["csv_data"] = json.dumps(request.data.get("csv_data")) 

53 return super().create(request, *args, **kwargs) 

54 

55 @extend_schema( 

56 parameters=[ 

57 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT), 

58 ], 

59 responses=GetCsvHeaderSerializer, 

60 ) 

61 @action( 

62 detail=False, 

63 methods=["get"], 

64 url_path="get-csv-header", 

65 url_name="get-csv-header", 

66 ) 

67 def get_csv_header(self, request): 

68 model = self.request.query_params.get("model") 

69 if not model: 69 ↛ 70line 69 didn't jump to line 70 because the condition on line 69 was never true

70 raise ValidationError({"error": "model is required."}) 

71 

72 try: 

73 model = MLModel.objects.get(id=model) 

74 except MLModel.DoesNotExist: 

75 raise NotFound({"error": "MLModel not found."}) 

76 

77 if model.surrogate_model != {}: 

78 input_labels = model.surrogate_model.get("input_labels") 

79 output_labels = model.surrogate_model.get("output_labels") 

80 return Response({"headers": input_labels + output_labels}, status=200) 

81 

82 csv = model.csv_data 

83 data = json.loads(csv) 

84 first_row = [key for key in data[0].keys() if key] 

85 return Response({"headers": first_row}, status=200) 

86 

87 @extend_schema(request=UploadModelSerializer, responses=None) 

88 @action( 

89 detail=False, 

90 methods=["post"], 

91 url_path="upload-ml-model", 

92 url_name="upload-ml-model", 

93 ) 

94 def upload_model(self, request): 

95 serializer = UploadModelSerializer(data=request.data) 

96 serializer.is_valid(raise_exception=True) 

97 validated_data = serializer.validated_data 

98 json_data = validated_data.get("json_data") 

99 simulationObject = validated_data.get("simulationObject") 

100 flowsheet = request.query_params.get("flowsheet") 

101 

102 try: 

103 simulationObject = SimulationObject.objects.get(id=simulationObject) 

104 except SimulationObject.DoesNotExist: 

105 raise NotFound({"error": "SimulationObject not found."}) 

106 

107 MLModel.objects.create( 

108 simulationObject=simulationObject, 

109 surrogate_model=json.loads(json_data), 

110 progress=1, 

111 flowsheet_id=flowsheet, 

112 ) 

113 

114 return Response({"message": "success"}, status=200) 

115 

116 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None) 

117 @action(detail=False, methods=["post"], url_path="create-surrogate-model") 

118 def create_surrogate_model(self, request): 

119 serializer = CreateSurrogateModelFromColumnSerializer(data=request.data) 

120 serializer.is_valid(raise_exception=True) 

121 validated_data = serializer.validated_data 

122 model = validated_data.get("model") 

123 model_instance = MLModel.objects.get(id=model) 

124 

125 # if already have one (imported model), skip training 

126 if model_instance.surrogate_model != {}: 126 ↛ 127line 126 didn't jump to line 127 because the condition on line 126 was never true

127 MLModel.objects.update(progress=3) 

128 return Response({"message": "successfully trained"}, status=200) 

129 

130 return train(request.user, model_instance) 

131 

132 @extend_schema( 

133 parameters=[ 

134 OpenApiParameter(name="model", required=True, type=OpenApiTypes.INT), 

135 ] 

136 ) 

137 @action( 

138 detail=False, 

139 methods=["get"], 

140 url_path="export-ml-model", 

141 url_name="export-ml-model", 

142 ) 

143 def export_flowsheet(self, request): 

144 model_id = request.query_params.get("model") 

145 if not model_id: 145 ↛ 146line 145 didn't jump to line 146 because the condition on line 145 was never true

146 return Response({"error": "model_id parameter is required."}, status=400) 

147 

148 try: 

149 ml_model = MLModel.objects.get(id=model_id) 

150 except MLModel.DoesNotExist: 

151 return Response({"error": "MLModel not found."}, status=404) 

152 

153 data = ml_model.surrogate_model 

154 response = HttpResponse( 

155 json.dumps(data, indent=4), content_type="application/json" 

156 ) 

157 response["Content-Disposition"] = f'attachment; filename="model.json"' 

158 

159 return response 

160 

161 

162@extend_schema(exclude=True) 

163@api_view(["POST"]) 

164@authentication_classes([DaprApiTokenAuthentication]) 

165@csrf_exempt 

166def process_ml_training_event(request) -> Response: 

167 training_response = MLTrainingCompletionEvent.model_validate(request.data) 

168 process_ml_training_response(training_response.data) 

169 return Response(status=200)