Coverage for backend/pinch_service/OpenPinch/src/classes/problem_table.py: 51%
203 statements
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
« prev ^ index » next coverage.py v7.10.7, created at 2025-11-06 23:27 +0000
1import pandas as pd
2import numpy as np
3from copy import deepcopy
4from ..lib.enums import ProblemTableLabel
6class ProblemTable:
7 def __init__(self, data_input: dict | list = None, add_default_labels: bool = True):
8 if add_default_labels:
9 self.columns = list([index.value for index in ProblemTableLabel])
10 else:
11 self.columns = list([key for key in data_input.keys()])
13 for key in self.columns:
14 if np.isnan(data_input[key]).all():
15 data_input.pop(key)
16 self.col_index = {col: idx for idx, col in enumerate(self.columns)}
18 if isinstance(data_input, dict): 18 ↛ 24line 18 didn't jump to line 24 because the condition on line 18 was always true
19 # Align data from dict into array using columns order
20 self.data = np.array([
21 data_input.get(col, [np.nan] * len(next(iter(data_input.values()))))
22 for col in self.columns
23 ]).T
24 elif isinstance(data_input, list):
25 data_input = self._pad_data_input(data_input, len(self.columns))
26 self.data = np.array(data_input).T
27 else:
28 self.data = None
30 class ColumnViewByIndex:
31 def __init__(self, parent: "ProblemTable"):
32 self.parent = parent
34 def __getitem__(self, idx):
35 return self.parent.data[:, idx]
37 def __setitem__(self, idx, values):
38 self.parent.data[:, idx] = values
40 @property
41 def icol(self):
42 return self.ColumnViewByIndex(self)
44 class ColumnViewByName:
45 def __init__(self, parent: "ProblemTable"):
46 self.parent = parent
48 def __getitem__(self, col_name):
49 idx = self.parent.col_index[col_name]
50 return self.parent.data[:, idx]
52 def __setitem__(self, col_name, values):
53 idx = self.parent.col_index[col_name]
54 if self.parent.data is not None: 54 ↛ 57line 54 didn't jump to line 57 because the condition on line 54 was always true
55 self.parent.data[:, idx] = values
56 else:
57 data_input = {col_name: values}
58 self.data = np.array([
59 data_input.get(col, [np.nan] * len(next(iter(data_input.values()))))
60 for col in self.parent.columns
61 ]).T
63 @property
64 def col(self):
65 return self.ColumnViewByName(self)
67 class ColumnsViewByName:
68 def __init__(self, parent: "ProblemTable"):
69 self.parent = parent
71 def __getitem__(self, col_names):
72 idxs = []
73 for col_name in col_names:
74 idxs.append(self.parent.col_index[col_name])
75 return self.parent.data[:, idxs]
77 def __setitem__(self, col_name, values):
78 idx = self.parent.col_index[col_name]
79 if self.parent.data is not None:
80 self.parent.data[:, idx] = values
81 else:
82 data_input = {col_name: values}
83 self.data = np.array([
84 data_input.get(col, [np.nan] * len(next(iter(data_input.values()))))
85 for col in self.parent.columns
86 ]).T
88 @property
89 def cols(self):
90 return self.ColumnsViewByName(self)
92 class LocationByRowByColName:
93 def __init__(self, parent: "ProblemTable"):
94 self.parent = parent
96 def __getitem__(self, key):
97 row_idx, col_key = key
98 col_idx = self.parent.col_index[col_key]
99 return self.parent.data[row_idx, col_idx]
101 def __setitem__(self, key, value):
102 row_idx, col_key = key
103 col_idx = self.parent.col_index[col_key]
104 self.parent.data[row_idx, col_idx] = value
106 @property
107 def loc(self):
108 return self.LocationByRowByColName(self)
110 class LocationByRowByCol:
111 def __init__(self, parent: "ProblemTable"):
112 self.parent = parent
114 def __getitem__(self, key):
115 row_idx, col_key = key
116 col_idx = self.parent.col_index[col_key]
117 return self.parent.data[row_idx, col_idx]
119 def __setitem__(self, key, value):
120 row_idx, col_key = key
121 col_idx = self.parent.col_index[col_key]
122 self.parent.data[row_idx, col_idx] = value
124 @property
125 def iloc(self):
126 return self.LocationByRowByCol(self)
128 def __len__(self):
129 if isinstance(self.data, np.ndarray): 129 ↛ 132line 129 didn't jump to line 132 because the condition on line 129 was always true
130 return self.data.shape[0]
131 else:
132 return 0
134 def __getitem__(self, keys):
135 data_input = {}
136 if isinstance(keys, str):
137 keys = [keys]
138 for key in keys:
139 data_input[key] = self.col[key]
140 return ProblemTable(
141 data_input, add_default_labels=False
142 )
144 def __eq__(self, other):
145 if not isinstance(other, ProblemTable):
146 return False
147 if self.columns != other.columns:
148 return False
149 if self.data.shape != other.data.shape:
150 return False
152 # NaN-safe elementwise comparison
153 a = self.data
154 b = other.data
155 nan_mask = np.isnan(a) & np.isnan(b)
156 close_mask = np.isclose(a, b, rtol=1e-5, atol=1e-8, equal_nan=False)
157 return np.all(nan_mask | close_mask)
159 def __ne__(self, other):
160 return not self.__eq__(other)
162 @property
163 def shape(self):
164 return self.data.shape
166 @property
167 def to_dataframe(self) -> pd.DataFrame:
168 """Convert the buffer into a pandas DataFrame."""
169 return pd.DataFrame(self.data.copy, columns=self.columns)
171 @property
172 def copy(self):
173 return deepcopy(self)
175 def _pad_data_input(self, data_input, n_cols):
176 current_cols = len(data_input)
177 if current_cols < n_cols:
178 n_rows = len(data_input[0]) # assume all rows are same length
179 padding = [[np.nan] * n_rows for _ in range(n_cols - current_cols)]
180 data_input += padding
181 return data_input
183 def to_list(self, col: str = None):
184 if isinstance(col, str): 184 ↛ 185line 184 didn't jump to line 185 because the condition on line 184 was never true
185 ls = self.col[col].T.tolist()
186 elif col == None: 186 ↛ 188line 186 didn't jump to line 188 because the condition on line 186 was always true
187 ls = self.data.T.tolist()
188 return ls[0] if len(ls) == 1 else ls
190 def delta_col(self, key, shift: int =1) -> np.ndarray:
191 idx = self.col_index[key]
192 col_values = self.data[:, idx]
193 delta = np.roll(col_values, shift) - col_values
194 delta[0] = 0.0
195 return delta
197 def shift(self, key, shift: int =1, filler_value: float = 0.0) -> np.ndarray:
198 idx = self.col_index[key]
199 col_values = self.data[:, idx]
200 values = np.roll(col_values, shift)
201 if len(values) > 0: 201 ↛ 208line 201 didn't jump to line 208 because the condition on line 201 was always true
202 if shift > 0:
203 for i in range(shift):
204 values[i] = filler_value
205 elif shift < 0: 205 ↛ 208line 205 didn't jump to line 208 because the condition on line 205 was always true
206 for i in range(shift, 0):
207 values[i] = filler_value
208 return values
210 def round(self, decimals):
211 self.data = np.round(self.data, decimals)
213 def insert(self, row_dict: dict, index: int):
214 """Insert a single row (dict of column: value) at the specified index."""
215 new_row = np.full(self.data.shape[1], np.nan)
216 for key, value in row_dict.items():
217 # if key in self.col_index:
218 new_row[self.col_index[key]] = value
219 self.data = np.insert(self.data, index, new_row, axis=0)
221 def update_row(self, index: int, row_dict: dict):
222 for key, value in row_dict.items():
223 if key in self.col_index:
224 self.data[index, self.col_index[key]] = value
226 def delete_row(self, index: int):
227 self.data = np.delete(self.data, index, axis=0)
229 def sort_by_column(self, column: str, ascending: bool = True):
230 if column not in self.col_index:
231 raise KeyError(f"Column {column} not found")
232 col_data = self.data[:, self.col_index[column]]
233 order = np.argsort(col_data)
234 if not ascending:
235 order = order[::-1]
236 self.data = self.data[order]
239def compare_problem_tables(pt1: ProblemTable, pt2: ProblemTable, atol: float = 1e-6) -> bool:
240 """Compares two DataFrames element-wise and reports differences within an absolute tolerance."""
241 if pt1.shape != pt2.shape:
242 print(f"❌ Shape mismatch: {pt1.shape} vs {pt2.shape}")
243 return False
245 if list(pt1.columns) != list(pt2.columns):
246 print("❌ Column mismatch:")
247 print(f"pt1 columns: {pt1.columns}")
248 print(f"pt2 columns: {pt2.columns}")
249 return False
251 mismatches = []
252 for i in range(len(pt1)):
253 for col in pt1.columns:
254 v1, v2 = pt1.iloc[i][col], pt2.iloc[i][col]
255 try:
256 if pd.isna(v1) and pd.isna(v2):
257 continue
258 if not np.isclose(v1, v2, atol=atol):
259 mismatches.append((i, col, v1, v2))
260 except TypeError:
261 if v1 != v2:
262 mismatches.append((i, col, v1, v2))
264 if mismatches:
265 print(f"⚠️ {len(mismatches)} mismatches found:")
266 for i, col, v1, v2 in mismatches:
267 print(f"Row {i}, Column '{col}': pt1={v1}, pt2={v2}")
268 return False
270 print("✅ All values match within tolerance.")
271 return True