Coverage for backend/pinch_service/OpenPinch/src/classes/problem_table.py: 51%

203 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1import pandas as pd 

2import numpy as np 

3from copy import deepcopy 

4from ..lib.enums import ProblemTableLabel 

5 

6class ProblemTable: 

7 def __init__(self, data_input: dict | list = None, add_default_labels: bool = True): 

8 if add_default_labels: 

9 self.columns = list([index.value for index in ProblemTableLabel]) 

10 else: 

11 self.columns = list([key for key in data_input.keys()]) 

12 

13 for key in self.columns: 

14 if np.isnan(data_input[key]).all(): 

15 data_input.pop(key) 

16 self.col_index = {col: idx for idx, col in enumerate(self.columns)} 

17 

18 if isinstance(data_input, dict): 18 ↛ 24line 18 didn't jump to line 24 because the condition on line 18 was always true

19 # Align data from dict into array using columns order 

20 self.data = np.array([ 

21 data_input.get(col, [np.nan] * len(next(iter(data_input.values())))) 

22 for col in self.columns 

23 ]).T 

24 elif isinstance(data_input, list): 

25 data_input = self._pad_data_input(data_input, len(self.columns)) 

26 self.data = np.array(data_input).T 

27 else: 

28 self.data = None 

29 

30 class ColumnViewByIndex: 

31 def __init__(self, parent: "ProblemTable"): 

32 self.parent = parent 

33 

34 def __getitem__(self, idx): 

35 return self.parent.data[:, idx] 

36 

37 def __setitem__(self, idx, values): 

38 self.parent.data[:, idx] = values 

39 

40 @property 

41 def icol(self): 

42 return self.ColumnViewByIndex(self) 

43 

44 class ColumnViewByName: 

45 def __init__(self, parent: "ProblemTable"): 

46 self.parent = parent 

47 

48 def __getitem__(self, col_name): 

49 idx = self.parent.col_index[col_name] 

50 return self.parent.data[:, idx] 

51 

52 def __setitem__(self, col_name, values): 

53 idx = self.parent.col_index[col_name] 

54 if self.parent.data is not None: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true

55 self.parent.data[:, idx] = values 

56 else: 

57 data_input = {col_name: values} 

58 self.data = np.array([ 

59 data_input.get(col, [np.nan] * len(next(iter(data_input.values())))) 

60 for col in self.parent.columns 

61 ]).T 

62 

63 @property 

64 def col(self): 

65 return self.ColumnViewByName(self) 

66 

67 class ColumnsViewByName: 

68 def __init__(self, parent: "ProblemTable"): 

69 self.parent = parent 

70 

71 def __getitem__(self, col_names): 

72 idxs = [] 

73 for col_name in col_names: 

74 idxs.append(self.parent.col_index[col_name]) 

75 return self.parent.data[:, idxs] 

76 

77 def __setitem__(self, col_name, values): 

78 idx = self.parent.col_index[col_name] 

79 if self.parent.data is not None: 

80 self.parent.data[:, idx] = values 

81 else: 

82 data_input = {col_name: values} 

83 self.data = np.array([ 

84 data_input.get(col, [np.nan] * len(next(iter(data_input.values())))) 

85 for col in self.parent.columns 

86 ]).T 

87 

88 @property 

89 def cols(self): 

90 return self.ColumnsViewByName(self) 

91 

92 class LocationByRowByColName: 

93 def __init__(self, parent: "ProblemTable"): 

94 self.parent = parent 

95 

96 def __getitem__(self, key): 

97 row_idx, col_key = key 

98 col_idx = self.parent.col_index[col_key] 

99 return self.parent.data[row_idx, col_idx] 

100 

101 def __setitem__(self, key, value): 

102 row_idx, col_key = key 

103 col_idx = self.parent.col_index[col_key] 

104 self.parent.data[row_idx, col_idx] = value 

105 

106 @property 

107 def loc(self): 

108 return self.LocationByRowByColName(self) 

109 

110 class LocationByRowByCol: 

111 def __init__(self, parent: "ProblemTable"): 

112 self.parent = parent 

113 

114 def __getitem__(self, key): 

115 row_idx, col_key = key 

116 col_idx = self.parent.col_index[col_key] 

117 return self.parent.data[row_idx, col_idx] 

118 

119 def __setitem__(self, key, value): 

120 row_idx, col_key = key 

121 col_idx = self.parent.col_index[col_key] 

122 self.parent.data[row_idx, col_idx] = value 

123 

124 @property 

125 def iloc(self): 

126 return self.LocationByRowByCol(self) 

127 

128 def __len__(self): 

129 if isinstance(self.data, np.ndarray): 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true

130 return self.data.shape[0] 

131 else: 

132 return 0 

133 

134 def __getitem__(self, keys): 

135 data_input = {} 

136 if isinstance(keys, str): 

137 keys = [keys] 

138 for key in keys: 

139 data_input[key] = self.col[key] 

140 return ProblemTable( 

141 data_input, add_default_labels=False 

142 ) 

143 

144 def __eq__(self, other): 

145 if not isinstance(other, ProblemTable): 

146 return False 

147 if self.columns != other.columns: 

148 return False 

149 if self.data.shape != other.data.shape: 

150 return False 

151 

152 # NaN-safe elementwise comparison 

153 a = self.data 

154 b = other.data 

155 nan_mask = np.isnan(a) & np.isnan(b) 

156 close_mask = np.isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=False) 

157 return np.all(nan_mask | close_mask) 

158 

159 def __ne__(self, other): 

160 return not self.__eq__(other) 

161 

162 @property 

163 def shape(self): 

164 return self.data.shape 

165 

166 @property 

167 def to_dataframe(self) -> pd.DataFrame: 

168 """Convert the buffer into a pandas DataFrame.""" 

169 return pd.DataFrame(self.data.copy, columns=self.columns) 

170 

171 @property 

172 def copy(self): 

173 return deepcopy(self) 

174 

175 def _pad_data_input(self, data_input, n_cols): 

176 current_cols = len(data_input) 

177 if current_cols < n_cols: 

178 n_rows = len(data_input[0]) # assume all rows are same length 

179 padding = [[np.nan] * n_rows for _ in range(n_cols - current_cols)] 

180 data_input += padding 

181 return data_input 

182 

183 def to_list(self, col: str = None): 

184 if isinstance(col, str): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true

185 ls = self.col[col].T.tolist() 

186 elif col == None: 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true

187 ls = self.data.T.tolist() 

188 return ls[0] if len(ls) == 1 else ls 

189 

190 def delta_col(self, key, shift: int =1) -> np.ndarray: 

191 idx = self.col_index[key] 

192 col_values = self.data[:, idx] 

193 delta = np.roll(col_values, shift) - col_values 

194 delta[0] = 0.0 

195 return delta 

196 

197 def shift(self, key, shift: int =1, filler_value: float = 0.0) -> np.ndarray: 

198 idx = self.col_index[key] 

199 col_values = self.data[:, idx] 

200 values = np.roll(col_values, shift) 

201 if len(values) > 0: 201 ↛ 208line 201 didn't jump to line 208 because the condition on line 201 was always true

202 if shift > 0: 

203 for i in range(shift): 

204 values[i] = filler_value 

205 elif shift < 0: 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true

206 for i in range(shift, 0): 

207 values[i] = filler_value 

208 return values 

209 

210 def round(self, decimals): 

211 self.data = np.round(self.data, decimals) 

212 

213 def insert(self, row_dict: dict, index: int): 

214 """Insert a single row (dict of column: value) at the specified index.""" 

215 new_row = np.full(self.data.shape[1], np.nan) 

216 for key, value in row_dict.items(): 

217 # if key in self.col_index: 

218 new_row[self.col_index[key]] = value 

219 self.data = np.insert(self.data, index, new_row, axis=0) 

220 

221 def update_row(self, index: int, row_dict: dict): 

222 for key, value in row_dict.items(): 

223 if key in self.col_index: 

224 self.data[index, self.col_index[key]] = value 

225 

226 def delete_row(self, index: int): 

227 self.data = np.delete(self.data, index, axis=0) 

228 

229 def sort_by_column(self, column: str, ascending: bool = True): 

230 if column not in self.col_index: 

231 raise KeyError(f"Column {column} not found") 

232 col_data = self.data[:, self.col_index[column]] 

233 order = np.argsort(col_data) 

234 if not ascending: 

235 order = order[::-1] 

236 self.data = self.data[order] 

237 

238 

239def compare_problem_tables(pt1: ProblemTable, pt2: ProblemTable, atol: float = 1e-6) -> bool: 

240 """Compares two DataFrames element-wise and reports differences within an absolute tolerance.""" 

241 if pt1.shape != pt2.shape: 

242 print(f"❌ Shape mismatch: {pt1.shape} vs {pt2.shape}") 

243 return False 

244 

245 if list(pt1.columns) != list(pt2.columns): 

246 print("❌ Column mismatch:") 

247 print(f"pt1 columns: {pt1.columns}") 

248 print(f"pt2 columns: {pt2.columns}") 

249 return False 

250 

251 mismatches = [] 

252 for i in range(len(pt1)): 

253 for col in pt1.columns: 

254 v1, v2 = pt1.iloc[i][col], pt2.iloc[i][col] 

255 try: 

256 if pd.isna(v1) and pd.isna(v2): 

257 continue 

258 if not np.isclose(v1, v2, atol=atol): 

259 mismatches.append((i, col, v1, v2)) 

260 except TypeError: 

261 if v1 != v2: 

262 mismatches.append((i, col, v1, v2)) 

263 

264 if mismatches: 

265 print(f"⚠️ {len(mismatches)} mismatches found:") 

266 for i, col, v1, v2 in mismatches: 

267 print(f"Row {i}, Column '{col}': pt1={v1}, pt2={v2}") 

268 return False 

269 

270 print("✅ All values match within tolerance.") 

271 return True