Coverage for src\cognitivefactory\interactive_clustering\clustering\hierarchical.py: 100.00%

154 statements  

« prev     ^ index     » next       coverage.py v7.3.2, created at 2023-11-17 13:31 +0100

1# -*- coding: utf-8 -*- 

2 

3""" 

4* Name: cognitivefactory.interactive_clustering.clustering.hierarchical 

5* Description: Implementation of constrained hierarchical clustering algorithms. 

6* Author: Erwan SCHILD 

7* Created: 17/03/2021 

8* Licence: CeCILL-C License v1.0 (https://cecill.info/licences.fr.html) 

9""" 

10 

11# ============================================================================== 

12# IMPORT PYTHON DEPENDENCIES 

13# ============================================================================== 

14 

15from datetime import datetime 

16from typing import Any, Dict, List, Optional, Tuple 

17 

18from scipy.sparse import csr_matrix, vstack 

19from sklearn.metrics import pairwise_distances 

20 

21from cognitivefactory.interactive_clustering.clustering.abstract import ( 

22 AbstractConstrainedClustering, 

23 rename_clusters_by_order, 

24) 

25from cognitivefactory.interactive_clustering.constraints.abstract import AbstractConstraintsManager 

26 

27 

28# ============================================================================== 

29# HIERARCHICAL CONSTRAINED CLUSTERING 

30# ============================================================================== 

31class HierarchicalConstrainedClustering(AbstractConstrainedClustering): 

32 """ 

33 This class implements the hierarchical constrained clustering. 

34 It inherits from `AbstractConstrainedClustering`. 

35 

36 References: 

37 - Hierarchical Clustering: `Murtagh, F. et P. Contreras (2012). Algorithms for hierarchical clustering : An overview. Wiley Interdisc. Rew.: Data Mining and Knowledge Discovery 2, 86–97.` 

38 - Constrained Hierarchical Clustering: `Davidson, I. et S. S. Ravi (2005). Agglomerative Hierarchical Clustering with Constraints : Theoretical and Empirical Results. Springer, Berlin, Heidelberg 3721, 12.` 

39 

40 Example: 

41 ```python 

42 # Import. 

43 from scipy.sparse import csr_matrix 

44 from cognitivefactory.interactive_clustering.clustering.hierarchical import HierarchicalConstrainedClustering 

45 

46 # Create an instance of hierarchical clustering. 

47 clustering_model = HierarchicalConstrainedClustering( 

48 linkage="ward", 

49 random_seed=2, 

50 ) 

51 

52 # Define vectors. 

53 # NB : use cognitivefactory.interactive_clustering.utils to preprocess and vectorize texts. 

54 vectors = { 

55 "0": csr_matrix([1.00, 0.00, 0.00]), 

56 "1": csr_matrix([0.95, 0.02, 0.01]), 

57 "2": csr_matrix([0.98, 0.00, 0.00]), 

58 "3": csr_matrix([0.99, 0.00, 0.00]), 

59 "4": csr_matrix([0.01, 0.99, 0.07]), 

60 "5": csr_matrix([0.02, 0.99, 0.07]), 

61 "6": csr_matrix([0.01, 0.99, 0.02]), 

62 "7": csr_matrix([0.01, 0.01, 0.97]), 

63 "8": csr_matrix([0.00, 0.01, 0.99]), 

64 "9": csr_matrix([0.00, 0.00, 1.00]), 

65 } 

66 

67 # Define constraints manager. 

68 constraints_manager = BinaryConstraintsManager(list_of_data_IDs=list(vectors.keys())) 

69 

70 # Run clustering. 

71 dict_of_predicted_clusters = clustering_model.cluster( 

72 constraints_manager=constraints_manager, 

73 vectors=vectors, 

74 nb_clusters=3, 

75 ) 

76 

77 # Print results. 

78 print("Expected results", ";", {"0": 0, "1": 0, "2": 0, "3": 0, "4": 1, "5": 1, "6": 1, "7": 2, "8": 2, "9": 2,}) 

79 print("Computed results", ":", dict_of_predicted_clusters) 

80 ``` 

81 """ 

82 

83 # ============================================================================== 

84 # INITIALIZATION 

85 # ============================================================================== 

86 def __init__(self, linkage: str = "ward", random_seed: Optional[int] = None, **kargs) -> None: 

87 """ 

88 The constructor for Hierarchical Constrainted Clustering class. 

89 

90 Args: 

91 linkage (str, optional): The metric used to merge clusters. Several type are implemented : 

92 - `"ward"`: Merge the two clusters for which the merged cluster from these clusters have the lowest intra-class distance. 

93 - `"average"`: Merge the two clusters that have the closest barycenters. 

94 - `"complete"`: Merge the two clusters for which the maximum distance between two data of these clusters is the lowest. 

95 - `"single"`: Merge the two clusters for which the minimum distance between two data of these clusters is the lowest. 

96 Defaults to `"ward"`. 

97 random_seed (Optional[int], optional): The random seed to use to redo the same clustering. Defaults to `None`. 

98 **kargs (dict): Other parameters that can be used in the instantiation. 

99 

100 Raises: 

101 ValueError: if some parameters are incorrectly set. 

102 """ 

103 

104 # Store `self.linkage`. 

105 if linkage not in {"ward", "average", "complete", "single"}: 

106 raise ValueError("The `linkage` '" + str(linkage) + "' is not implemented.") 

107 self.linkage: str = linkage 

108 

109 # Store `self.random_seed` 

110 self.random_seed: Optional[int] = random_seed 

111 

112 # Store `self.kargs` for hierarchical clustering. 

113 self.kargs = kargs 

114 

115 # Initialize `self.clustering_root` and `self.dict_of_predicted_clusters`. 

116 self.clustering_root: Optional[Cluster] = None 

117 self.dict_of_predicted_clusters: Optional[Dict[str, int]] = None 

118 

119 # ============================================================================== 

120 # MAIN - CLUSTER DATA 

121 # ============================================================================== 

122 def cluster( 

123 self, 

124 constraints_manager: AbstractConstraintsManager, 

125 vectors: Dict[str, csr_matrix], 

126 nb_clusters: Optional[int], 

127 verbose: bool = False, 

128 **kargs, 

129 ) -> Dict[str, int]: 

130 """ 

131 The main method used to cluster data with the Hierarchical model. 

132 

133 Args: 

134 constraints_manager (AbstractConstraintsManager): A constraints manager over data IDs that will force clustering to respect some conditions during computation. 

135 vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager`. The value of the dictionary represent the vector of each data. 

136 nb_clusters (Optional[int]): The number of clusters to compute. 

137 verbose (bool, optional): Enable verbose output. Defaults to `False`. 

138 **kargs (dict): Other parameters that can be used in the clustering. 

139 

140 Raises: 

141 ValueError: If some parameters are incorrectly set. 

142 

143 Returns: 

144 Dict[str,int]: A dictionary that contains the predicted cluster for each data ID. 

145 """ 

146 

147 ### 

148 ### GET PARAMETERS 

149 ### 

150 

151 # Store `self.constraints_manager` and `self.list_of_data_IDs`. 

152 if not isinstance(constraints_manager, AbstractConstraintsManager): 

153 raise ValueError("The `constraints_manager` parameter has to be a `AbstractConstraintsManager` type.") 

154 self.constraints_manager: AbstractConstraintsManager = constraints_manager 

155 self.list_of_data_IDs: List[str] = self.constraints_manager.get_list_of_managed_data_IDs() 

156 

157 # Store `self.vectors`. 

158 if not isinstance(vectors, dict): 

159 raise ValueError("The `vectors` parameter has to be a `dict` type.") 

160 self.vectors: Dict[str, csr_matrix] = vectors 

161 

162 # Store `self.nb_clusters`. 

163 if (nb_clusters is None) or (nb_clusters < 2): 

164 raise ValueError("The `nb_clusters` '" + str(nb_clusters) + "' must be greater than or equal to 2.") 

165 self.nb_clusters: int = min(nb_clusters, len(self.list_of_data_IDs)) 

166 

167 # Compute pairwise distances. 

168 matrix_of_pairwise_distances: csr_matrix = pairwise_distances( 

169 X=vstack(self.vectors[data_ID] for data_ID in self.constraints_manager.get_list_of_managed_data_IDs()), 

170 metric="euclidean", # TODO get different pairwise_distances config in **kargs 

171 ) 

172 

173 # Format pairwise distances in a dictionary and store `self.dict_of_pairwise_distances`. 

174 self.dict_of_pairwise_distances: Dict[str, Dict[str, float]] = { 

175 vector_ID1: { 

176 vector_ID2: float(matrix_of_pairwise_distances[i1, i2]) 

177 for i2, vector_ID2 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs()) 

178 } 

179 for i1, vector_ID1 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs()) 

180 } 

181 

182 ### 

183 ### INITIALIZE HIERARCHICAL CONSTRAINED CLUSTERING 

184 ### 

185 

186 # Verbose 

187 if verbose: # pragma: no cover 

188 # Verbose - Print progression status. 

189 TIME_start: datetime = datetime.now() 

190 print( 

191 " ", 

192 "CLUSTERING_ITERATION=" + "INITIALIZATION", 

193 "(current_time = " + str(TIME_start - TIME_start).split(".")[0] + ")", 

194 ) 

195 

196 # Initialize `self.clustering_root` and `self.dict_of_predicted_clusters`. 

197 self.clustering_root = None 

198 self.dict_of_predicted_clusters = None 

199 

200 # Initialize iteration counter. 

201 self.clustering_iteration: int = 0 

202 

203 # Initialize `current_clusters` and `self.clusters_storage`. 

204 self.current_clusters: List[int] = [] 

205 self.clusters_storage: Dict[int, Cluster] = {} 

206 

207 # Get the list of possibles lists of MUST_LINK data for initialization. 

208 list_of_possible_lists_of_MUST_LINK_data: List[List[str]] = self.constraints_manager.get_connected_components() 

209 

210 # Estimation of max number of iteration. 

211 max_clustering_iteration: int = len(list_of_possible_lists_of_MUST_LINK_data) - 1 

212 

213 # For each list of same data (MUST_LINK constraints). 

214 for MUST_LINK_data in list_of_possible_lists_of_MUST_LINK_data: 

215 # Create a initial cluster with data that MUST be LINKed. 

216 self._add_new_cluster_by_setting_members( 

217 members=MUST_LINK_data, 

218 ) 

219 

220 # Initialize distance between clusters. 

221 self.clusters_distance: Dict[int, Dict[int, float]] = {} 

222 for cluster_IDi in self.current_clusters: 

223 for cluster_IDj in self.current_clusters: 

224 if cluster_IDi < cluster_IDj: 

225 # Compute distance between cluster i and cluster j. 

226 distance: float = self._compute_distance(cluster_IDi=cluster_IDi, cluster_IDj=cluster_IDj) 

227 # Store distance between cluster i and cluster j. 

228 self._set_distance(cluster_IDi=cluster_IDi, cluster_IDj=cluster_IDj, distance=distance) 

229 

230 # Initialize iterations at first iteration. 

231 self.clustering_iteration = 1 

232 

233 ### 

234 ### RUN ITERATIONS OF HIERARCHICAL CONSTRAINED CLUSTERING UNTIL CONVERGENCE 

235 ### 

236 

237 # Iter until convergence of clustering. 

238 while len(self.current_clusters) > 1: 

239 # Verbose 

240 if verbose: # pragma: no cover 

241 # Verbose - Print progression status. 

242 TIME_current: datetime = datetime.now() 

243 print( 

244 " ", 

245 "CLUSTERING_ITERATION=" 

246 + str(self.clustering_iteration).zfill(6) 

247 + "/" 

248 + str(max_clustering_iteration).zfill(6), 

249 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")", 

250 end="\r", 

251 ) 

252 

253 # Get clostest clusters to merge 

254 clostest_clusters: Optional[Tuple[int, int]] = self._get_the_two_clostest_clusters() 

255 

256 # If no clusters to merge, then stop iterations. 

257 if clostest_clusters is None: 

258 break 

259 

260 # Merge clusters the two closest clusters and add the merged cluster to the storage. 

261 # If merge one cluster "node" with a cluster "leaf" : add the cluster "leaf" to the children of the cluster "node". 

262 # If merge two clusters "nodes" or two clusters "leaves" : create a new cluster "node". 

263 merged_cluster_ID: int = self._add_new_cluster_by_merging_clusters( 

264 children=[ 

265 clostest_clusters[0], 

266 clostest_clusters[1], 

267 ] 

268 ) 

269 

270 # Update distances 

271 for cluster_ID in self.current_clusters: 

272 if cluster_ID != merged_cluster_ID: 

273 # Compute distance between cluster and merged cluster. 

274 distance = self._compute_distance(cluster_IDi=cluster_ID, cluster_IDj=merged_cluster_ID) 

275 # Store distance between cluster and merged cluster. 

276 self._set_distance(cluster_IDi=cluster_ID, cluster_IDj=merged_cluster_ID, distance=distance) 

277 

278 # Update self.clustering_iteration. 

279 self.clustering_iteration += 1 

280 

281 ### 

282 ### END HIERARCHICAL CONSTRAINED CLUSTERING 

283 ### 

284 

285 # Verbose 

286 if verbose: # pragma: no cover 

287 # Verbose - Print progression status. 

288 TIME_current = datetime.now() 

289 

290 # Case of clustering not completed. 

291 if len(self.current_clusters) > 1: 

292 print( 

293 " ", 

294 "CLUSTERING_ITERATION=" + str(self.clustering_iteration).zfill(5), 

295 "-", 

296 "End : No more cluster to merge", 

297 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")", 

298 ) 

299 else: 

300 print( 

301 " ", 

302 "CLUSTERING_ITERATION=" + str(self.clustering_iteration).zfill(5), 

303 "-", 

304 "End : Full clustering done", 

305 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")", 

306 ) 

307 

308 # If several clusters remains, then merge them in a cluster root. 

309 if len(self.current_clusters) > 1: 

310 # Merge all remaining clusters. 

311 # If merge one cluster "node" with many cluster "leaves" : add clusters "leaves" to the children of the cluster "node". 

312 # If merge many clusters "nodes" and/or many clusters "leaves" : create a new cluster "node". 

313 self._add_new_cluster_by_merging_clusters(children=self.current_clusters.copy()) 

314 

315 # Get clustering root. 

316 root_ID: int = self.current_clusters[0] 

317 self.clustering_root = self.clusters_storage[root_ID] 

318 

319 ### 

320 ### GET PREDICTED CLUSTERS 

321 ### 

322 

323 # Compute predicted clusters. 

324 self.dict_of_predicted_clusters = self.compute_predicted_clusters( 

325 nb_clusters=self.nb_clusters, 

326 ) 

327 

328 return self.dict_of_predicted_clusters 

329 

330 # ============================================================================== 

331 # ADD CLUSTER BY SETTING MEMBERS : 

332 # ============================================================================== 

333 def _add_new_cluster_by_setting_members( 

334 self, 

335 members: List[str], 

336 ) -> int: 

337 """ 

338 Create or Update a cluster by setting its members, and add it to the storage and current clusters. 

339 

340 Args: 

341 members (List[str]): A list of data IDs to define the new cluster by the data it contains. 

342 

343 Returns: 

344 int : ID of the merged cluster. 

345 """ 

346 # Get the ID of the new cluster. 

347 new_cluster_ID: int = max(self.clusters_storage.keys()) + 1 if (self.clusters_storage) else 0 

348 

349 # Create the cluster. 

350 new_cluster = Cluster( 

351 vectors=self.vectors, 

352 cluster_ID=new_cluster_ID, 

353 clustering_iteration=self.clustering_iteration, 

354 members=members, 

355 ) 

356 

357 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`. 

358 self.current_clusters.append(new_cluster_ID) 

359 self.clusters_storage[new_cluster_ID] = new_cluster 

360 

361 return new_cluster_ID 

362 

363 # ============================================================================== 

364 # ADD CLUSTER BY MERGING CLUSTERS : 

365 # ============================================================================== 

366 def _add_new_cluster_by_merging_clusters( 

367 self, 

368 children: List[int], 

369 ) -> int: 

370 """ 

371 Create or Update a cluster by setting its children, and add it to the storage and current clusters. 

372 

373 Args: 

374 children (List[int]): A list of cluster IDs to define the new cluster by its children. 

375 

376 Returns: 

377 int : ID of the merged cluster. 

378 """ 

379 

380 # Remove all leaves children clusters from `self.current_clusters`. 

381 for child_ID_to_remove in children: 

382 self.current_clusters.remove(child_ID_to_remove) 

383 

384 """ 

385 ### 

386 ### Tree optimization : if only one node, then update this node as parent of all leaves. 

387 ### TODO : test of check if relevant to use. pros = smarter tree visualisation ; cons = cluster number more difficult to choose. 

388 ### 

389 

390 # List of children nodes. 

391 list_of_children_nodes: List[int] = [ 

392 child_ID 

393 for child_ID in children 

394 if len(self.clusters_storage[child_ID].children) > 0 

395 ] 

396 

397 if len(list_of_children_nodes) == 1: 

398 

399 # Get the ID of the cluster to update 

400 parent_cluster_ID: int = list_of_children_nodes[0] 

401 parent_cluster: Cluster = self.clusters_storage[parent_cluster_ID] 

402 

403 # Add all leaves 

404 parent_cluster.add_new_children( 

405 new_children=[ 

406 self.clusters_storage[child_ID] 

407 for child_ID in children 

408 if child_ID != parent_cluster_ID 

409 ], 

410 new_clustering_iteration=self.clustering_iteration 

411 ) 

412 

413 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`. 

414 self.current_clusters.append(parent_cluster_ID) 

415 self.clusters_storage[parent_cluster_ID] = parent_cluster 

416 

417 # Return the cluster_ID of the created cluster. 

418 return parent_cluster_ID 

419 

420 

421 """ 

422 

423 ### 

424 ### Default case : Create a new node as parent of all children to merge. 

425 ### 

426 

427 # Get the ID of the new cluster. 

428 parent_cluster_ID: int = max(self.clusters_storage) + 1 

429 

430 # Create the cluster 

431 parent_cluster = Cluster( 

432 vectors=self.vectors, 

433 cluster_ID=parent_cluster_ID, 

434 clustering_iteration=self.clustering_iteration, 

435 children=[self.clusters_storage[child_ID] for child_ID in children], 

436 ) 

437 

438 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`. 

439 self.current_clusters.append(parent_cluster_ID) 

440 self.clusters_storage[parent_cluster_ID] = parent_cluster 

441 

442 # Return the cluster_ID of the created cluster. 

443 return parent_cluster_ID 

444 

445 # ============================================================================== 

446 # COMPUTE DISTANCE BETWEEN CLUSTERING NEW ITERATION OF CLUSTERING : 

447 # ============================================================================== 

448 def _compute_distance(self, cluster_IDi: int, cluster_IDj: int) -> float: 

449 """ 

450 Compute distance between two clusters. 

451 

452 Args: 

453 cluster_IDi (int): ID of the first cluster. 

454 cluster_IDj (int): ID of the second cluster. 

455 

456 Returns: 

457 float : Distance between the two clusters. 

458 """ 

459 

460 # Check `"CANNOT_LINK"` constraints. 

461 for data_ID1 in self.clusters_storage[cluster_IDi].members: 

462 for data_ID2 in self.clusters_storage[cluster_IDj].members: 

463 if ( 

464 self.constraints_manager.get_inferred_constraint( 

465 data_ID1=data_ID1, 

466 data_ID2=data_ID2, 

467 ) 

468 == "CANNOT_LINK" 

469 ): 

470 return float("Inf") 

471 

472 # Case 1 : `self.linkage` is "complete". 

473 if self.linkage == "complete": 

474 return max( 

475 [ 

476 self.dict_of_pairwise_distances[data_ID_in_cluster_IDi][data_ID_in_cluster_IDj] 

477 for data_ID_in_cluster_IDi in self.clusters_storage[cluster_IDi].members 

478 for data_ID_in_cluster_IDj in self.clusters_storage[cluster_IDj].members 

479 ] 

480 ) 

481 

482 # Case 2 : `self.linkage` is "average". 

483 if self.linkage == "average": 

484 return pairwise_distances( 

485 X=self.clusters_storage[cluster_IDi].centroid, 

486 Y=self.clusters_storage[cluster_IDj].centroid, 

487 metric="euclidean", # TODO: Load different parameters for distance computing ? 

488 )[0][0] 

489 

490 # Case 3 : `self.linkage` is "single". 

491 if self.linkage == "single": 

492 return min( 

493 [ 

494 self.dict_of_pairwise_distances[data_ID_in_cluster_IDi][data_ID_in_cluster_IDj] 

495 for data_ID_in_cluster_IDi in self.clusters_storage[cluster_IDi].members 

496 for data_ID_in_cluster_IDj in self.clusters_storage[cluster_IDj].members 

497 ] 

498 ) 

499 

500 # Case 4 : `self.linkage` is "ward". 

501 ##if self.linkage == "ward": ## DEFAULTS 

502 # Compute distance 

503 merged_members: List[str] = ( 

504 self.clusters_storage[cluster_IDi].members + self.clusters_storage[cluster_IDj].members 

505 ) 

506 return sum( 

507 [ 

508 self.dict_of_pairwise_distances[data_IDi][data_IDj] 

509 for i, data_IDi in enumerate(merged_members) 

510 for j, data_IDj in enumerate(merged_members) 

511 if i < j 

512 ] 

513 ) / (len(self.clusters_storage[cluster_IDi].members) * len(self.clusters_storage[cluster_IDj].members)) 

514 

515 # ============================================================================== 

516 # DISTANCE : GETTER 

517 # ============================================================================== 

518 def _get_distance(self, cluster_IDi: int, cluster_IDj: int) -> float: 

519 """ 

520 Get the distance between two clusters. 

521 

522 Args: 

523 cluster_IDi (int): ID of the first cluster. 

524 cluster_IDj (int): ID of the second cluster. 

525 

526 Returns: 

527 float : Distance between the two clusters. 

528 """ 

529 

530 # Sort IDs of cluster. 

531 min_cluster_ID: int = min(cluster_IDi, cluster_IDj) 

532 max_cluster_ID: int = max(cluster_IDi, cluster_IDj) 

533 

534 # Return the distance. 

535 return self.clusters_distance[min_cluster_ID][max_cluster_ID] 

536 

537 # ============================================================================== 

538 # DISTANCE : SETTER 

539 # ============================================================================== 

540 def _set_distance( 

541 self, 

542 distance: float, 

543 cluster_IDi: int, 

544 cluster_IDj: int, 

545 ) -> None: 

546 """ 

547 Set the distance between two clusters. 

548 

549 Args: 

550 distance (float): The distance between the two clusters. 

551 cluster_IDi (int): ID of the first cluster. 

552 cluster_IDj (int): ID of the second cluster. 

553 """ 

554 

555 # Sort IDs of cluster. 

556 min_cluster_ID: int = min(cluster_IDi, cluster_IDj) 

557 max_cluster_ID: int = max(cluster_IDi, cluster_IDj) 

558 

559 # Add distance to the dictionary of distance. 

560 if min_cluster_ID not in self.clusters_distance: 

561 self.clusters_distance[min_cluster_ID] = {} 

562 self.clusters_distance[min_cluster_ID][max_cluster_ID] = distance 

563 

564 # ============================================================================== 

565 # GET THE TWO CLOSEST CLUSTERS 

566 # ============================================================================== 

567 def _get_the_two_clostest_clusters(self) -> Optional[Tuple[int, int]]: 

568 """ 

569 Get the two clusters which are the two closest clusters. 

570 

571 Returns: 

572 Optional(Tuple[int, int]) : The IDs of the two closest clusters to merge. Return None if no cluster is suitable. 

573 """ 

574 

575 # Compute the two clostest clusters to merge. take the closest distance, then the closest cluster size. 

576 clostest_clusters = min( 

577 [ 

578 { 

579 "cluster_ID1": cluster_ID1, 

580 "cluster_ID2": cluster_ID2, 

581 "distance": self._get_distance(cluster_IDi=cluster_ID1, cluster_IDj=cluster_ID2), 

582 "merged_size": len(self.clusters_storage[cluster_ID1].members) 

583 + len(self.clusters_storage[cluster_ID2].members) 

584 # TODO : Choose between "distance then size(count)" and "size_type(boolean) then distance" 

585 } 

586 for cluster_ID1 in self.current_clusters 

587 for cluster_ID2 in self.current_clusters 

588 if cluster_ID1 < cluster_ID2 

589 ], 

590 key=lambda dst: (dst["distance"], dst["merged_size"]), 

591 ) 

592 

593 # Get clusters and distance. 

594 cluster_ID1: int = int(clostest_clusters["cluster_ID1"]) 

595 cluster_ID2: int = int(clostest_clusters["cluster_ID2"]) 

596 distance: float = clostest_clusters["distance"] 

597 

598 # Check distance. 

599 if distance == float("Inf"): 

600 return None 

601 

602 # Return the tow closest clusters. 

603 return cluster_ID1, cluster_ID2 

604 

605 # ============================================================================== 

606 # COMPUTE PREDICTED CLUSTERS 

607 # ============================================================================== 

608 def compute_predicted_clusters(self, nb_clusters: int, by: str = "size") -> Dict[str, int]: 

609 """ 

610 Compute the predicted clusters based on clustering tree and estimation of number of clusters. 

611 

612 Args: 

613 nb_clusters (int): The number of clusters to compute. 

614 by (str, optional): A string to identifies the criteria used to explore `HierarchicalConstrainedClustering` tree. Can be `"size"` or `"iteration"`. Defaults to `"size"`. 

615 

616 Raises: 

617 ValueError: if `clustering_root` was not set. 

618 

619 Returns: 

620 Dict[str,int] : A dictionary that contains the predicted cluster for each data ID. 

621 """ 

622 

623 # Check that the clustering has been made. 

624 if self.clustering_root is None: 

625 raise ValueError("The `clustering_root` is not set, probably because clustering was not run.") 

626 

627 ### 

628 ### EXPLORE CLUSTER TREE 

629 ### 

630 

631 # Define the resulted list of children as the children of `HierarchicalConstrainedClustering` root. 

632 list_of_clusters: List[Cluster] = [self.clustering_root] 

633 

634 # Explore `HierarchicalConstrainedClustering` children until dict_of_predicted_clusters has the right number of children. 

635 while len(list_of_clusters) < nb_clusters: 

636 if by == "size": 

637 # Get the biggest cluster in current children from `HierarchicalConstrainedClustering` exploration. 

638 # i.e. it's the cluster that has the more data to split. 

639 cluster_to_split = max(list_of_clusters, key=lambda c: len(c.members)) 

640 

641 else: # if by == "iteration": 

642 # Get the most recent cluster in current children from `HierarchicalConstrainedClustering` exploration. 

643 # i.e. it's the cluster that was last merged. 

644 cluster_to_split = max(list_of_clusters, key=lambda c: c.clustering_iteration) 

645 

646 # If the chosen cluster is a leaf : break the `HierarchicalConstrainedClustering` exploration. 

647 if cluster_to_split.children == []: # noqa: WPS520 

648 break 

649 

650 # Otherwise: The chosen cluster is a node, so split it and get its children. 

651 else: 

652 # ... remove the cluster obtained ... 

653 list_of_clusters.remove(cluster_to_split) 

654 

655 # ... and add all its children. 

656 for child in cluster_to_split.children: 

657 list_of_clusters.append(child) 

658 

659 ### 

660 ### GET PREDICTED CLUSTERS 

661 ### 

662 

663 # Initialize the dictionary of predicted clusters. 

664 predicted_clusters: Dict[str, int] = {data_ID: -1 for data_ID in self.list_of_data_IDs} 

665 

666 # For all cluster... 

667 for cluster in list_of_clusters: 

668 # ... and for all member in each cluster... 

669 for data_ID in cluster.members: 

670 # ... affect the predicted cluster (cluster ID) to the data. 

671 predicted_clusters[data_ID] = cluster.cluster_ID 

672 

673 # Rename cluster IDs by order. 

674 predicted_clusters = rename_clusters_by_order(clusters=predicted_clusters) 

675 

676 # Return predicted clusters 

677 return predicted_clusters 

678 

679 

680# ============================================================================== 

681# CLUSTER 

682# ============================================================================== 

683class Cluster: 

684 """ 

685 This class represents a cluster as a node of the hierarchical clustering tree. 

686 """ 

687 

688 # ============================================================================== 

689 # INITIALIZATION 

690 # ============================================================================== 

691 def __init__( 

692 self, 

693 vectors: Dict[str, csr_matrix], 

694 cluster_ID: int, 

695 clustering_iteration: int, 

696 children: Optional[List["Cluster"]] = None, 

697 members: Optional[List[str]] = None, 

698 ) -> None: 

699 """ 

700 The constructor for Cluster class. 

701 

702 Args: 

703 vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager` (if `constraints_manager` is set). The value of the dictionary represent the vector of each data. 

704 cluster_ID (int): The cluster ID that is defined during `HierarchicalConstrainedClustering.cluster` running. 

705 clustering_iteration (int): The cluster iteration that is defined during `HierarchicalConstrainedClustering.cluster` running. 

706 children (Optional[List["Cluster"]], optional): A list of clusters children for cluster initialization. Incompatible with `members` parameter. Defaults to `None`. 

707 members (Optional[List[str]], optional): A list of data IDs for cluster initialization. Incompatible with `children` parameter. Defaults to `None`. 

708 

709 Raises: 

710 ValueError: if `children` and `members` are both set or both unset. 

711 """ 

712 

713 # Store links to `vectors`. 

714 self.vectors: Dict[str, csr_matrix] = vectors 

715 

716 # Cluster ID and Clustering iteration. 

717 self.cluster_ID: int = cluster_ID 

718 self.clustering_iteration: int = clustering_iteration 

719 

720 # Check children and members. 

721 if ((children is not None) and (members is not None)) or ((children is None) and (members is None)): 

722 raise ValueError( 

723 "Cluster initialization must be by `children` setting or by `members` setting, but not by both or none of them." 

724 ) 

725 

726 # Add children (empty or not). 

727 self.children: List["Cluster"] = children if (children is not None) else [] 

728 

729 # Cluster inverse depth. 

730 self.cluster_inverse_depth: int = ( 

731 max([child.cluster_inverse_depth for child in self.children]) + 1 if (self.children) else 0 

732 ) 

733 

734 # Add members (empty or not). 

735 self.members: List[str] = ( 

736 members if members is not None else [data_ID for child in self.children for data_ID in child.members] 

737 ) 

738 

739 # Update centroids 

740 self.update_centroid() 

741 

742 # ============================================================================== 

743 # ADD NEW CHILDREN : 

744 # ============================================================================== 

745 def add_new_children( 

746 self, 

747 new_children: List["Cluster"], 

748 new_clustering_iteration: int, 

749 ) -> None: 

750 """ 

751 Add new children to the cluster. 

752 

753 Args: 

754 new_children (List["Cluster"]): The list of new clusters children to add. 

755 new_clustering_iteration (int): The new cluster iteration that is defined during HierarchicalConstrainedClustering.clusterize running. 

756 """ 

757 

758 # Update clustering iteration. 

759 self.clustering_iteration = new_clustering_iteration 

760 

761 # Update children. 

762 self.children += [new_child for new_child in new_children if new_child not in self.children] 

763 

764 # Update cluster inverse depth. 

765 self.cluster_inverse_depth = max([child.cluster_inverse_depth for child in self.children]) + 1 

766 

767 # Update members. 

768 self.members = [data_ID for child in self.children for data_ID in child.members] 

769 

770 # Update centroids. 

771 self.update_centroid() 

772 

773 # ============================================================================== 

774 # UPDATE CENTROIDS : 

775 # ============================================================================== 

776 def update_centroid(self) -> None: 

777 """ 

778 Update centroid of the cluster. 

779 """ 

780 

781 # Update centroids. 

782 self.centroid: csr_matrix = sum([self.vectors[data_ID] for data_ID in self.members]) / self.get_cluster_size() 

783 

784 # ============================================================================== 

785 # GET CLUSTER SIZE : 

786 # ============================================================================== 

787 def get_cluster_size(self) -> int: 

788 """ 

789 Get cluster size. 

790 

791 Returns: 

792 int: The cluster size, i.e. the number of members in the cluster. 

793 """ 

794 

795 # Update centroids. 

796 return len(self.members) 

797 

798 # ============================================================================== 

799 # TO DICTIONARY : 

800 # ============================================================================== 

801 def to_dict(self) -> Dict[str, Any]: 

802 """ 

803 Transform the Cluster object into a dictionary. It can be used before serialize this object in JSON. 

804 

805 Returns: 

806 Dict[str, Any]: A dictionary that represents the Cluster object. 

807 """ 

808 

809 # Define the result dictionary. 

810 results: Dict[str, Any] = {} 

811 

812 # Add clustering information. 

813 results["cluster_ID"] = self.cluster_ID 

814 results["clustering_iteration"] = self.clustering_iteration 

815 

816 # Add children information. 

817 results["children"] = [child.to_dict() for child in self.children] 

818 results["cluster_inverse_depth"] = self.cluster_inverse_depth 

819 

820 # Add members information. 

821 results["members"] = self.members 

822 

823 return results