Coverage for src\cognitivefactory\interactive_clustering\sampling\clusters_based.py: 100.00%

57 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-17 13:31 +0100

1# -*- coding: utf-8 -*- 

2 

3""" 

4* Name: cognitivefactory.interactive_clustering.sampling.clusters_based 

5* Description: Implementation of constraints sampling based on clusters information. 

6* Author: Erwan SCHILD 

7* Created: 04/10/2021 

8* Licence: CeCILL (https://cecill.info/licences.fr.html) 

9""" 

10 

11# ============================================================================== 

12# IMPORT PYTHON DEPENDENCIES 

13# ============================================================================== 

14 

15import random 

16from typing import Dict, List, Optional, Tuple 

17 

18from scipy.sparse import csr_matrix, vstack 

19from sklearn.metrics import pairwise_distances 

20 

21from cognitivefactory.interactive_clustering.constraints.abstract import AbstractConstraintsManager 

22from cognitivefactory.interactive_clustering.sampling.abstract import AbstractConstraintsSampling 

23 

24 

25# ============================================================================== 

26# CLUSTERS BASED CONSTRAINTS SAMPLING 

27# ============================================================================== 

28class ClustersBasedConstraintsSampling(AbstractConstraintsSampling): 

29 """ 

30 This class implements the sampling of data IDs based on clusters information in order to annotate constraints. 

31 It inherits from `AbstractConstraintsSampling`. 

32 

33 Example: 

34 ```python 

35 # Import. 

36 from cognitivefactory.interactive_clustering.constraints.binary import BinaryConstraintsManager 

37 from cognitivefactory.interactive_clustering.sampling.clusters_based import ClustersBasedConstraintsSampling 

38 

39 # Create an instance of random sampling. 

40 sampler = ClustersBasedConstraintsSampling(random_seed=1) 

41 

42 # Define list of data IDs. 

43 list_of_data_IDs = ["bonjour", "salut", "coucou", "au revoir", "a bientôt",] 

44 

45 # Define constraints manager. 

46 constraints_manager = BinaryConstraintsManager( 

47 list_of_data_IDs=list_of_data_IDs, 

48 ) 

49 constraints_manager.add_constraint(data_ID1="bonjour", data_ID2="salut", constraint_type="MUST_LINK") 

50 constraints_manager.add_constraint(data_ID1="au revoir", data_ID2="a bientôt", constraint_type="MUST_LINK") 

51 

52 # Run sampling. 

53 selection = sampler.sample( 

54 constraints_manager=constraints_manager, 

55 nb_to_select=3, 

56 ) 

57 

58 # Print results. 

59 print("Expected results", ";", [("au revoir", "bonjour"), ("bonjour", "coucou"), ("a bientôt", "coucou"),]) 

60 print("Computed results", ":", selection) 

61 ``` 

62 """ 

63 

64 # ============================================================================== 

65 # INITIALIZATION 

66 # ============================================================================== 

67 def __init__( 

68 self, 

69 random_seed: Optional[int] = None, 

70 clusters_restriction: Optional[str] = None, 

71 distance_restriction: Optional[str] = None, 

72 without_added_constraints: bool = True, 

73 without_inferred_constraints: bool = True, 

74 **kargs, 

75 ) -> None: 

76 """ 

77 The constructor for Clusters Based Constraints Sampling class. 

78 

79 Args: 

80 random_seed (Optional[int]): The random seed to use to redo the same sampling. Defaults to `None`. 

81 clusters_restriction (Optional[str]): Restrict the sampling with a cluster constraints. Can impose data IDs to be in `"same_cluster"` or `"different_clusters"`. Defaults to `None`. # TODO: `"specific_clusters"` 

82 distance_restriction (Optional[str]): Restrict the sampling with a distance constraints. Can impose data IDs to be `"closest_neighbors"` or `"farthest_neighbors"`. Defaults to `None`. 

83 without_added_constraints (bool): Option to not sample the already added constraints. Defaults to `True`. 

84 without_inferred_constraints (bool): Option to not sample the deduced constraints from already added one. Defaults to `True`. 

85 **kargs (dict): Other parameters that can be used in the instantiation. 

86 

87 Raises: 

88 ValueError: if some parameters are incorrectly set. 

89 """ 

90 

91 # Store `self.random_seed`. 

92 self.random_seed: Optional[int] = random_seed 

93 

94 # Store clusters restriction. 

95 if clusters_restriction not in {None, "same_cluster", "different_clusters"}: 

96 raise ValueError("The `clusters_restriction` '" + str(clusters_restriction) + "' is not implemented.") 

97 self.clusters_restriction: Optional[str] = clusters_restriction 

98 

99 # Store distance restriction. 

100 if distance_restriction not in {None, "closest_neighbors", "farthest_neighbors"}: 

101 raise ValueError("The `distance_restriction` '" + str(distance_restriction) + "' is not implemented.") 

102 self.distance_restriction: Optional[str] = distance_restriction 

103 

104 # Store constraints restrictions. 

105 if not isinstance(without_added_constraints, bool): 

106 raise ValueError("The `without_added_constraints` must be boolean") 

107 self.without_added_constraints: bool = without_added_constraints 

108 if not isinstance(without_inferred_constraints, bool): 

109 raise ValueError("The `without_inferred_constraints` must be boolean") 

110 self.without_inferred_constraints: bool = without_inferred_constraints 

111 

112 # ============================================================================== 

113 # MAIN - SAMPLE 

114 # ============================================================================== 

115 def sample( 

116 self, 

117 constraints_manager: AbstractConstraintsManager, 

118 nb_to_select: int, 

119 clustering_result: Optional[Dict[str, int]] = None, 

120 vectors: Optional[Dict[str, csr_matrix]] = None, 

121 **kargs, 

122 ) -> List[Tuple[str, str]]: 

123 """ 

124 The main method used to sample pairs of data IDs for constraints annotation. 

125 

126 Args: 

127 constraints_manager (AbstractConstraintsManager): A constraints manager over data IDs. 

128 nb_to_select (int): The number of pairs of data IDs to sample. 

129 clustering_result (Optional[Dict[str,int]], optional): A dictionary that represents the predicted cluster for each data ID. The keys of the dictionary represents the data IDs. If `None`, no clustering result are used during the sampling. Defaults to `None`. 

130 vectors (Optional[Dict[str, csr_matrix]], optional): vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager`. The value of the dictionary represent the vector of each data. If `None`, no vectors are used during the sampling. Defaults to `None` 

131 **kargs (dict): Other parameters that can be used in the sampling. 

132 

133 Raises: 

134 ValueError: if some parameters are incorrectly set or incompatible. 

135 

136 Returns: 

137 List[Tuple[str,str]]: A list of couple of data IDs. 

138 """ 

139 

140 ### 

141 ### GET PARAMETERS 

142 ### 

143 

144 # Check `constraints_manager`. 

145 if not isinstance(constraints_manager, AbstractConstraintsManager): 

146 raise ValueError("The `constraints_manager` parameter has to be a `AbstractConstraintsManager` type.") 

147 self.constraints_manager: AbstractConstraintsManager = constraints_manager 

148 

149 # Check `nb_to_select`. 

150 if not isinstance(nb_to_select, int) or (nb_to_select < 0): 

151 raise ValueError("The `nb_to_select` '" + str(nb_to_select) + "' must be greater than or equal to 0.") 

152 elif nb_to_select == 0: 

153 return [] 

154 

155 # If `self.cluster_restriction` is set, check `clustering_result` parameters. 

156 if self.clusters_restriction is not None: 

157 if not isinstance(clustering_result, dict): 

158 raise ValueError("The `clustering_result` parameter has to be a `Dict[str, int]` type.") 

159 self.clustering_result: Dict[str, int] = clustering_result 

160 

161 # If `self.distance_restriction` is set, check `vectors` parameters. 

162 if self.distance_restriction is not None: 

163 if not isinstance(vectors, dict): 

164 raise ValueError("The `vectors` parameter has to be a `Dict[str, csr_matrix]` type.") 

165 self.vectors: Dict[str, csr_matrix] = vectors 

166 

167 ### 

168 ### DEFINE POSSIBLE PAIRS OF DATA IDS 

169 ### 

170 

171 # Initialize possible pairs of data IDs 

172 list_of_possible_pairs_of_data_IDs: List[Tuple[str, str]] = [] 

173 

174 # Loop over pairs of data IDs. 

175 for data_ID1 in self.constraints_manager.get_list_of_managed_data_IDs(): 

176 for data_ID2 in self.constraints_manager.get_list_of_managed_data_IDs(): 

177 # Select ordered pairs. 

178 if data_ID1 >= data_ID2: 

179 continue 

180 

181 # Check clusters restriction. 

182 if ( 

183 self.clusters_restriction == "same_cluster" 

184 and self.clustering_result[data_ID1] != self.clustering_result[data_ID2] 

185 ) or ( 

186 self.clusters_restriction == "different_clusters" 

187 and self.clustering_result[data_ID1] == self.clustering_result[data_ID2] 

188 ): 

189 continue 

190 

191 # Check known constraints. 

192 if ( 

193 self.without_added_constraints is True 

194 and self.constraints_manager.get_added_constraint(data_ID1=data_ID1, data_ID2=data_ID2) is not None 

195 ) or ( 

196 self.without_inferred_constraints is True 

197 and self.constraints_manager.get_inferred_constraint(data_ID1=data_ID1, data_ID2=data_ID2) 

198 is not None 

199 ): 

200 continue 

201 

202 # Add the pair of data IDs. 

203 list_of_possible_pairs_of_data_IDs.append((data_ID1, data_ID2)) 

204 

205 ### 

206 ### SAMPLING 

207 ### 

208 

209 # Precompute pairwise distances. 

210 if self.distance_restriction is not None: 

211 # Compute pairwise distances. 

212 matrix_of_pairwise_distances: csr_matrix = pairwise_distances( 

213 X=vstack(self.vectors[data_ID] for data_ID in self.constraints_manager.get_list_of_managed_data_IDs()), 

214 metric="euclidean", # TODO get different pairwise_distances config in **kargs 

215 ) 

216 

217 # Format pairwise distances in a dictionary. 

218 self.dict_of_pairwise_distances: Dict[str, Dict[str, float]] = { 

219 vector_ID1: { 

220 vector_ID2: float(matrix_of_pairwise_distances[i1, i2]) 

221 for i2, vector_ID2 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs()) 

222 } 

223 for i1, vector_ID1 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs()) 

224 } 

225 

226 # Set random seed. 

227 random.seed(self.random_seed) 

228 

229 # Case of closest neightbors selection. 

230 if self.distance_restriction == "closest_neighbors": 

231 return sorted( 

232 list_of_possible_pairs_of_data_IDs, 

233 key=lambda combination: self.dict_of_pairwise_distances[combination[0]][combination[1]], 

234 )[:nb_to_select] 

235 

236 # Case of farthest neightbors selection. 

237 if self.distance_restriction == "farthest_neighbors": 

238 return sorted( 

239 list_of_possible_pairs_of_data_IDs, 

240 key=lambda combination: self.dict_of_pairwise_distances[combination[0]][combination[1]], 

241 reverse=True, 

242 )[:nb_to_select] 

243 

244 # (default) Case of random selection. 

245 return random.sample( 

246 list_of_possible_pairs_of_data_IDs, k=min(nb_to_select, len(list_of_possible_pairs_of_data_IDs)) 

247 )