Coverage for backend/pinch_service/OpenPinch/src/classes/stream_collection.py: 59%

82 statements  

« prev     ^ index     » next       coverage.py v7.10.7, created at 2025-11-06 23:27 +0000

1from typing import Callable, Dict, List, Union, TYPE_CHECKING 

2 

3if TYPE_CHECKING: 

4 from ..classes import Stream, Stream 

5 

6 

7class StreamCollection: 

8 """A dynamic, ordered collection of streams. 

9 

10 Features: 

11 - Add and remove streams by name. 

12 - Prevent overwriting existing streams by auto-renaming. 

13 - Set custom sort keys (attribute name, list of attributes, or callable). 

14 - Supports efficient iteration with lazy sorting. 

15 - Allows ascending or descending sorting. 

16 

17 Typical usage: 

18 - Store and manage process streams or utility streams. 

19 - Sort streams dynamically by attributes like temperature or flow. 

20 - Avoid duplicate names automatically. 

21 

22 Example: 

23 zone = StreamCollection() 

24 zone.add(Stream("H1", 300, 400)) 

25 zone.set_sort_key(["t_target", "t_supply"], reverse=True) 

26 for stream in zone: 

27 print(stream.name, stream.t_target) 

28 """ 

29 

30 def __init__(self): 

31 self._streams: Dict[str, object] = {} 

32 self._sort_key: Callable = lambda s: s.t_supply # default: sort by name 

33 self._sort_reverse: bool = True 

34 self._sorted_cache: List[object] = [] 

35 self._needs_sort: bool = True 

36 

37 def add(self, stream, key: str = None, prevent_overwrite: bool = True): 

38 if key is None: 

39 key = stream.name 

40 original_key = key 

41 counter = 1 

42 while prevent_overwrite and key in self._streams: 

43 key = f"{original_key}_{counter}" 

44 counter += 1 

45 # stream.name = key 

46 self._streams[key] = stream 

47 self._needs_sort = True 

48 

49 def add_many(self, streams, keys = None, prevent_overwrite: bool = True): 

50 if keys == None: 50 ↛ 54line 50 didn't jump to line 54 because the condition on line 50 was always true

51 for stream in streams: 

52 self.add(stream, prevent_overwrite=prevent_overwrite) 

53 else: 

54 if len(streams) != len(keys): 

55 raise ValueError("Length of streams and keys must match.") 

56 for stream, key in zip(streams, keys): 

57 self.add(stream, key, prevent_overwrite) 

58 

59 def replace(self, stream_dict: Dict[str, Union["Stream", "Stream"]]): 

60 self._streams = {} 

61 for stream in stream_dict.values(): 

62 self._streams[stream.name] = stream 

63 self._needs_sort = True 

64 

65 def remove(self, stream_name: str): 

66 if stream_name in self._streams: 

67 del self._streams[stream_name] 

68 self._needs_sort = True 

69 else: 

70 raise KeyError(f"Stream '{stream_name}' not found.") 

71 

72 def set_sort_key(self, key: Union[str, List[str], Callable], reverse: bool = False): 

73 """Set the sorting key. Supports attribute names or custom lambdas.""" 

74 self._sort_reverse = reverse 

75 if isinstance(key, str): 

76 self._sort_key = lambda s: getattr(s, key) 

77 elif isinstance(key, list): 

78 self._sort_key = lambda s: tuple( 

79 -getattr(s, attr) if reverse else getattr(s, attr) for attr in key 

80 ) 

81 else: 

82 self._sort_key = key 

83 self._needs_sort = True 

84 

85 def get_index(self, stream) -> int: 

86 """Return the position (index) of a stream object in the sorted stream list.""" 

87 self._ensure_sorted() 

88 for idx, s in enumerate(self._sorted_cache): 88 ↛ 91line 88 didn't jump to line 91 because the loop on line 88 didn't complete

89 if s == stream: 

90 return idx 

91 raise ValueError("Stream not found in collection.") 

92 

93 def _ensure_sorted(self): 

94 """(Internal) Sort streams if needed.""" 

95 if self._needs_sort: 

96 self._sorted_cache = sorted( 

97 self._streams.values(), key=self._sort_key, reverse=self._sort_reverse 

98 ) 

99 self._needs_sort = False 

100 

101 def __iter__(self): 

102 self._ensure_sorted() 

103 return iter(self._sorted_cache) 

104 

105 def __add__(self, other): 

106 if not isinstance(other, StreamCollection): 106 ↛ 107line 106 didn't jump to line 107 because the condition on line 106 was never true

107 return NotImplemented 

108 combined = StreamCollection() 

109 # Add all streams from self 

110 for stream in self._streams.values(): 

111 combined.add(stream) 

112 # Add all streams from other 

113 for stream in other._streams.values(): 

114 combined.add(stream) 

115 return combined 

116 

117 def __len__(self): 

118 return len(self._streams) 

119 

120 def __getitem__(self, key): 

121 if isinstance(key, int): 121 ↛ 124line 121 didn't jump to line 124 because the condition on line 121 was always true

122 # Allow indexing by integer 

123 return list(self._streams.values())[key] 

124 elif isinstance(key, str): 

125 # Allow accessing by stream name 

126 return self._streams[key] 

127 else: 

128 raise TypeError(f"Invalid key type {type(key)}. Must be str (name) or int (index).") 

129 

130 def __contains__(self, stream_name: str): 

131 return stream_name in self._streams 

132 

133 def __repr__(self): 

134 return f"StreamCollection({list(self._streams.keys())})" 

135 

136 def __eq__(self, other): 

137 if not isinstance(other, StreamCollection): 

138 return NotImplemented 

139 return self._streams == other._streams