Coverage for backend/core/auxiliary/viewsets/MLViewSet.py: 84%

102 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1from core.viewset import ModelViewSet 

2from core.auxiliary.models.MLModel import MLModel 

3from core.auxiliary.serializers.MLModelSerializer import MLModelSerializer 

4from rest_framework.response import Response 

5from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiTypes 

6from rest_framework.exceptions import NotFound, ValidationError 

7from rest_framework.decorators import action 

8import json 

9from rest_framework import serializers 

10from core.auxiliary.models.MLWizard import train 

11from flowsheetInternals.unitops.models import SimulationObject 

12from django.http import HttpResponse 

13 

14from django.views.decorators.csrf import csrf_exempt 

15from authentication.custom_drf_authentication import DaprApiTokenAuthentication 

16from common.models.idaes.payloads.ml_request_schema import MLTrainingCompletionEvent 

17from rest_framework.decorators import api_view, authentication_classes 

18from idaes_factory.endpoints import process_ml_training_response 

19 

20class GetCsvHeaderSerializer(serializers.Serializer): 

21 headers = serializers.ListField(child=serializers.CharField()) 

22 

23 

24class CreateSurrogateModelFromColumnSerializer(serializers.Serializer): 

25 model = serializers.IntegerField() 

26 

27 

28class UploadModelSerializer(serializers.Serializer): 

29 json_data = serializers.JSONField() 

30 simulationObject = serializers.IntegerField() 

31 

32 

33class MLViewSet(ModelViewSet): 

34 serializer_class = MLModelSerializer 

35 

36 def get_queryset(self): 

37 simulationObjectId = self.request.query_params.get( 

38 'simulationObject', None) 

39 return MLModel.objects.all().filter(simulationObject=simulationObjectId) 

40 

41 @extend_schema( 

42 parameters=[ 

43 OpenApiParameter(name="simulationObject", 

44 required=True, type=OpenApiTypes.INT), 

45 ] 

46 ) 

47 def list(self, request, *args, **kwargs): 

48 return super().list(request, *args, **kwargs) 

49 

50 def create(self, request, *args, **kwargs): 

51 request.data["csv_data"] = json.dumps(request.data.get("csv_data")) 

52 return super().create(request, *args, **kwargs) 

53 

54 

55 @extend_schema( 

56 parameters=[ 

57 OpenApiParameter(name="model", required=True, 

58 type=OpenApiTypes.INT), 

59 ], 

60 responses=GetCsvHeaderSerializer 

61 ) 

62 @action(detail=False, methods=['get'], url_path='get-csv-header', url_name='get-csv-header') 

63 def get_csv_header(self, request): 

64 model = self.request.query_params.get("model") 

65 if not model: 65 ↛ 66line 65 didn't jump to line 66 because the condition on line 65 was never true

66 raise ValidationError({"error": "model is required."}) 

67 

68 try: 

69 model = MLModel.objects.get(id=model) 

70 except MLModel.DoesNotExist: 

71 raise NotFound({"error": "MLModel not found."}) 

72 

73 if model.surrogate_model != {}: 

74 input_labels = model.surrogate_model.get("input_labels") 

75 output_labels = model.surrogate_model.get("output_labels") 

76 return Response({"headers": input_labels + output_labels}, status=200) 

77 

78 csv = model.csv_data 

79 data = json.loads(csv) 

80 first_row = [key for key in data[0].keys() if key] 

81 return Response({"headers": first_row}, status=200) 

82 

83 

84 @extend_schema( 

85 request=UploadModelSerializer, 

86 responses=None 

87 ) 

88 @action(detail=False, methods=['post'], url_path='upload-ml-model', url_name='upload-ml-model') 

89 def upload_model(self, request): 

90 serializer = UploadModelSerializer(data=request.data) 

91 serializer.is_valid(raise_exception=True) 

92 validated_data = serializer.validated_data 

93 json_data = validated_data.get("json_data") 

94 simulationObject = validated_data.get("simulationObject") 

95 flowsheet = request.query_params.get("flowsheet") 

96 

97 try: 

98 simulationObject = SimulationObject.objects.get( 

99 id=simulationObject) 

100 except SimulationObject.DoesNotExist: 

101 raise NotFound({"error": "SimulationObject not found."}) 

102 

103 MLModel.objects.create( 

104 simulationObject=simulationObject, 

105 surrogate_model=json.loads(json_data), 

106 progress=1, 

107 flowsheet_id=flowsheet 

108 ) 

109 

110 return Response({"message": "success"}, status=200) 

111 

112 

113 @extend_schema(request=CreateSurrogateModelFromColumnSerializer, responses=None) 

114 @action(detail=False, methods=['post'], url_path='create-surrogate-model') 

115 def create_surrogate_model(self, request): 

116 serializer = CreateSurrogateModelFromColumnSerializer( 

117 data=request.data) 

118 serializer.is_valid(raise_exception=True) 

119 validated_data = serializer.validated_data 

120 model = validated_data.get("model") 

121 model_instance = MLModel.objects.get(id=model) 

122 

123 # if already have one (imported model), skip training 

124 if model_instance.surrogate_model != {}: 124 ↛ 125line 124 didn't jump to line 125 because the condition on line 124 was never true

125 MLModel.objects.update( 

126 progress=3 

127 ) 

128 return Response({"message": "successfully trained"}, status=200) 

129 

130 return train(request.user, model_instance) 

131 

132 @extend_schema(parameters=[ 

133 OpenApiParameter(name="model", required=True, 

134 type=OpenApiTypes.INT), 

135 ]) 

136 @action(detail=False, methods=['get'], url_path='export-ml-model', url_name='export-ml-model') 

137 def export_flowsheet(self, request): 

138 model_id = request.query_params.get("model") 

139 if not model_id: 139 ↛ 140line 139 didn't jump to line 140 because the condition on line 139 was never true

140 return Response( 

141 {"error": "model_id parameter is required."}, 

142 status=400 

143 ) 

144 

145 try: 

146 ml_model = MLModel.objects.get(id=model_id) 

147 except MLModel.DoesNotExist: 

148 return Response( 

149 {"error": "MLModel not found."}, 

150 status=404 

151 ) 

152 

153 data = ml_model.surrogate_model 

154 response = HttpResponse(json.dumps( 

155 data, indent=4), content_type='application/json') 

156 response['Content-Disposition'] = f'attachment; filename="model.json"' 

157 

158 return response 

159 

160@extend_schema(exclude=True) 

161@api_view(['POST']) 

162@authentication_classes([DaprApiTokenAuthentication]) 

163@csrf_exempt 

164def process_ml_training_event(request) -> Response: 

165 training_response = MLTrainingCompletionEvent.model_validate(request.data) 

166 process_ml_training_response(training_response.data) 

167 return Response(status=200)