Coverage for src\cognitivefactory\interactive_clustering\clustering\hierarchical.py: 100.00%
154 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
1# -*- coding: utf-8 -*-
3"""
4* Name: cognitivefactory.interactive_clustering.clustering.hierarchical
5* Description: Implementation of constrained hierarchical clustering algorithms.
6* Author: Erwan SCHILD
7* Created: 17/03/2021
8* Licence: CeCILL-C License v1.0 (https://cecill.info/licences.fr.html)
9"""
11# ==============================================================================
12# IMPORT PYTHON DEPENDENCIES
13# ==============================================================================
15from datetime import datetime
16from typing import Any, Dict, List, Optional, Tuple
18from scipy.sparse import csr_matrix, vstack
19from sklearn.metrics import pairwise_distances
21from cognitivefactory.interactive_clustering.clustering.abstract import (
22 AbstractConstrainedClustering,
23 rename_clusters_by_order,
24)
25from cognitivefactory.interactive_clustering.constraints.abstract import AbstractConstraintsManager
28# ==============================================================================
29# HIERARCHICAL CONSTRAINED CLUSTERING
30# ==============================================================================
31class HierarchicalConstrainedClustering(AbstractConstrainedClustering):
32 """
33 This class implements the hierarchical constrained clustering.
34 It inherits from `AbstractConstrainedClustering`.
36 References:
37 - Hierarchical Clustering: `Murtagh, F. et P. Contreras (2012). Algorithms for hierarchical clustering : An overview. Wiley Interdisc. Rew.: Data Mining and Knowledge Discovery 2, 86–97.`
38 - Constrained Hierarchical Clustering: `Davidson, I. et S. S. Ravi (2005). Agglomerative Hierarchical Clustering with Constraints : Theoretical and Empirical Results. Springer, Berlin, Heidelberg 3721, 12.`
40 Example:
41 ```python
42 # Import.
43 from scipy.sparse import csr_matrix
44 from cognitivefactory.interactive_clustering.clustering.hierarchical import HierarchicalConstrainedClustering
46 # Create an instance of hierarchical clustering.
47 clustering_model = HierarchicalConstrainedClustering(
48 linkage="ward",
49 random_seed=2,
50 )
52 # Define vectors.
53 # NB : use cognitivefactory.interactive_clustering.utils to preprocess and vectorize texts.
54 vectors = {
55 "0": csr_matrix([1.00, 0.00, 0.00]),
56 "1": csr_matrix([0.95, 0.02, 0.01]),
57 "2": csr_matrix([0.98, 0.00, 0.00]),
58 "3": csr_matrix([0.99, 0.00, 0.00]),
59 "4": csr_matrix([0.01, 0.99, 0.07]),
60 "5": csr_matrix([0.02, 0.99, 0.07]),
61 "6": csr_matrix([0.01, 0.99, 0.02]),
62 "7": csr_matrix([0.01, 0.01, 0.97]),
63 "8": csr_matrix([0.00, 0.01, 0.99]),
64 "9": csr_matrix([0.00, 0.00, 1.00]),
65 }
67 # Define constraints manager.
68 constraints_manager = BinaryConstraintsManager(list_of_data_IDs=list(vectors.keys()))
70 # Run clustering.
71 dict_of_predicted_clusters = clustering_model.cluster(
72 constraints_manager=constraints_manager,
73 vectors=vectors,
74 nb_clusters=3,
75 )
77 # Print results.
78 print("Expected results", ";", {"0": 0, "1": 0, "2": 0, "3": 0, "4": 1, "5": 1, "6": 1, "7": 2, "8": 2, "9": 2,})
79 print("Computed results", ":", dict_of_predicted_clusters)
80 ```
81 """
83 # ==============================================================================
84 # INITIALIZATION
85 # ==============================================================================
86 def __init__(self, linkage: str = "ward", random_seed: Optional[int] = None, **kargs) -> None:
87 """
88 The constructor for Hierarchical Constrainted Clustering class.
90 Args:
91 linkage (str, optional): The metric used to merge clusters. Several type are implemented :
92 - `"ward"`: Merge the two clusters for which the merged cluster from these clusters have the lowest intra-class distance.
93 - `"average"`: Merge the two clusters that have the closest barycenters.
94 - `"complete"`: Merge the two clusters for which the maximum distance between two data of these clusters is the lowest.
95 - `"single"`: Merge the two clusters for which the minimum distance between two data of these clusters is the lowest.
96 Defaults to `"ward"`.
97 random_seed (Optional[int], optional): The random seed to use to redo the same clustering. Defaults to `None`.
98 **kargs (dict): Other parameters that can be used in the instantiation.
100 Raises:
101 ValueError: if some parameters are incorrectly set.
102 """
104 # Store `self.linkage`.
105 if linkage not in {"ward", "average", "complete", "single"}:
106 raise ValueError("The `linkage` '" + str(linkage) + "' is not implemented.")
107 self.linkage: str = linkage
109 # Store `self.random_seed`
110 self.random_seed: Optional[int] = random_seed
112 # Store `self.kargs` for hierarchical clustering.
113 self.kargs = kargs
115 # Initialize `self.clustering_root` and `self.dict_of_predicted_clusters`.
116 self.clustering_root: Optional[Cluster] = None
117 self.dict_of_predicted_clusters: Optional[Dict[str, int]] = None
119 # ==============================================================================
120 # MAIN - CLUSTER DATA
121 # ==============================================================================
122 def cluster(
123 self,
124 constraints_manager: AbstractConstraintsManager,
125 vectors: Dict[str, csr_matrix],
126 nb_clusters: Optional[int],
127 verbose: bool = False,
128 **kargs,
129 ) -> Dict[str, int]:
130 """
131 The main method used to cluster data with the Hierarchical model.
133 Args:
134 constraints_manager (AbstractConstraintsManager): A constraints manager over data IDs that will force clustering to respect some conditions during computation.
135 vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager`. The value of the dictionary represent the vector of each data.
136 nb_clusters (Optional[int]): The number of clusters to compute.
137 verbose (bool, optional): Enable verbose output. Defaults to `False`.
138 **kargs (dict): Other parameters that can be used in the clustering.
140 Raises:
141 ValueError: If some parameters are incorrectly set.
143 Returns:
144 Dict[str,int]: A dictionary that contains the predicted cluster for each data ID.
145 """
147 ###
148 ### GET PARAMETERS
149 ###
151 # Store `self.constraints_manager` and `self.list_of_data_IDs`.
152 if not isinstance(constraints_manager, AbstractConstraintsManager):
153 raise ValueError("The `constraints_manager` parameter has to be a `AbstractConstraintsManager` type.")
154 self.constraints_manager: AbstractConstraintsManager = constraints_manager
155 self.list_of_data_IDs: List[str] = self.constraints_manager.get_list_of_managed_data_IDs()
157 # Store `self.vectors`.
158 if not isinstance(vectors, dict):
159 raise ValueError("The `vectors` parameter has to be a `dict` type.")
160 self.vectors: Dict[str, csr_matrix] = vectors
162 # Store `self.nb_clusters`.
163 if (nb_clusters is None) or (nb_clusters < 2):
164 raise ValueError("The `nb_clusters` '" + str(nb_clusters) + "' must be greater than or equal to 2.")
165 self.nb_clusters: int = min(nb_clusters, len(self.list_of_data_IDs))
167 # Compute pairwise distances.
168 matrix_of_pairwise_distances: csr_matrix = pairwise_distances(
169 X=vstack(self.vectors[data_ID] for data_ID in self.constraints_manager.get_list_of_managed_data_IDs()),
170 metric="euclidean", # TODO get different pairwise_distances config in **kargs
171 )
173 # Format pairwise distances in a dictionary and store `self.dict_of_pairwise_distances`.
174 self.dict_of_pairwise_distances: Dict[str, Dict[str, float]] = {
175 vector_ID1: {
176 vector_ID2: float(matrix_of_pairwise_distances[i1, i2])
177 for i2, vector_ID2 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs())
178 }
179 for i1, vector_ID1 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs())
180 }
182 ###
183 ### INITIALIZE HIERARCHICAL CONSTRAINED CLUSTERING
184 ###
186 # Verbose
187 if verbose: # pragma: no cover
188 # Verbose - Print progression status.
189 TIME_start: datetime = datetime.now()
190 print(
191 " ",
192 "CLUSTERING_ITERATION=" + "INITIALIZATION",
193 "(current_time = " + str(TIME_start - TIME_start).split(".")[0] + ")",
194 )
196 # Initialize `self.clustering_root` and `self.dict_of_predicted_clusters`.
197 self.clustering_root = None
198 self.dict_of_predicted_clusters = None
200 # Initialize iteration counter.
201 self.clustering_iteration: int = 0
203 # Initialize `current_clusters` and `self.clusters_storage`.
204 self.current_clusters: List[int] = []
205 self.clusters_storage: Dict[int, Cluster] = {}
207 # Get the list of possibles lists of MUST_LINK data for initialization.
208 list_of_possible_lists_of_MUST_LINK_data: List[List[str]] = self.constraints_manager.get_connected_components()
210 # Estimation of max number of iteration.
211 max_clustering_iteration: int = len(list_of_possible_lists_of_MUST_LINK_data) - 1
213 # For each list of same data (MUST_LINK constraints).
214 for MUST_LINK_data in list_of_possible_lists_of_MUST_LINK_data:
215 # Create a initial cluster with data that MUST be LINKed.
216 self._add_new_cluster_by_setting_members(
217 members=MUST_LINK_data,
218 )
220 # Initialize distance between clusters.
221 self.clusters_distance: Dict[int, Dict[int, float]] = {}
222 for cluster_IDi in self.current_clusters:
223 for cluster_IDj in self.current_clusters:
224 if cluster_IDi < cluster_IDj:
225 # Compute distance between cluster i and cluster j.
226 distance: float = self._compute_distance(cluster_IDi=cluster_IDi, cluster_IDj=cluster_IDj)
227 # Store distance between cluster i and cluster j.
228 self._set_distance(cluster_IDi=cluster_IDi, cluster_IDj=cluster_IDj, distance=distance)
230 # Initialize iterations at first iteration.
231 self.clustering_iteration = 1
233 ###
234 ### RUN ITERATIONS OF HIERARCHICAL CONSTRAINED CLUSTERING UNTIL CONVERGENCE
235 ###
237 # Iter until convergence of clustering.
238 while len(self.current_clusters) > 1:
239 # Verbose
240 if verbose: # pragma: no cover
241 # Verbose - Print progression status.
242 TIME_current: datetime = datetime.now()
243 print(
244 " ",
245 "CLUSTERING_ITERATION="
246 + str(self.clustering_iteration).zfill(6)
247 + "/"
248 + str(max_clustering_iteration).zfill(6),
249 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")",
250 end="\r",
251 )
253 # Get clostest clusters to merge
254 clostest_clusters: Optional[Tuple[int, int]] = self._get_the_two_clostest_clusters()
256 # If no clusters to merge, then stop iterations.
257 if clostest_clusters is None:
258 break
260 # Merge clusters the two closest clusters and add the merged cluster to the storage.
261 # If merge one cluster "node" with a cluster "leaf" : add the cluster "leaf" to the children of the cluster "node".
262 # If merge two clusters "nodes" or two clusters "leaves" : create a new cluster "node".
263 merged_cluster_ID: int = self._add_new_cluster_by_merging_clusters(
264 children=[
265 clostest_clusters[0],
266 clostest_clusters[1],
267 ]
268 )
270 # Update distances
271 for cluster_ID in self.current_clusters:
272 if cluster_ID != merged_cluster_ID:
273 # Compute distance between cluster and merged cluster.
274 distance = self._compute_distance(cluster_IDi=cluster_ID, cluster_IDj=merged_cluster_ID)
275 # Store distance between cluster and merged cluster.
276 self._set_distance(cluster_IDi=cluster_ID, cluster_IDj=merged_cluster_ID, distance=distance)
278 # Update self.clustering_iteration.
279 self.clustering_iteration += 1
281 ###
282 ### END HIERARCHICAL CONSTRAINED CLUSTERING
283 ###
285 # Verbose
286 if verbose: # pragma: no cover
287 # Verbose - Print progression status.
288 TIME_current = datetime.now()
290 # Case of clustering not completed.
291 if len(self.current_clusters) > 1:
292 print(
293 " ",
294 "CLUSTERING_ITERATION=" + str(self.clustering_iteration).zfill(5),
295 "-",
296 "End : No more cluster to merge",
297 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")",
298 )
299 else:
300 print(
301 " ",
302 "CLUSTERING_ITERATION=" + str(self.clustering_iteration).zfill(5),
303 "-",
304 "End : Full clustering done",
305 "(current_time = " + str(TIME_current - TIME_start).split(".")[0] + ")",
306 )
308 # If several clusters remains, then merge them in a cluster root.
309 if len(self.current_clusters) > 1:
310 # Merge all remaining clusters.
311 # If merge one cluster "node" with many cluster "leaves" : add clusters "leaves" to the children of the cluster "node".
312 # If merge many clusters "nodes" and/or many clusters "leaves" : create a new cluster "node".
313 self._add_new_cluster_by_merging_clusters(children=self.current_clusters.copy())
315 # Get clustering root.
316 root_ID: int = self.current_clusters[0]
317 self.clustering_root = self.clusters_storage[root_ID]
319 ###
320 ### GET PREDICTED CLUSTERS
321 ###
323 # Compute predicted clusters.
324 self.dict_of_predicted_clusters = self.compute_predicted_clusters(
325 nb_clusters=self.nb_clusters,
326 )
328 return self.dict_of_predicted_clusters
330 # ==============================================================================
331 # ADD CLUSTER BY SETTING MEMBERS :
332 # ==============================================================================
333 def _add_new_cluster_by_setting_members(
334 self,
335 members: List[str],
336 ) -> int:
337 """
338 Create or Update a cluster by setting its members, and add it to the storage and current clusters.
340 Args:
341 members (List[str]): A list of data IDs to define the new cluster by the data it contains.
343 Returns:
344 int : ID of the merged cluster.
345 """
346 # Get the ID of the new cluster.
347 new_cluster_ID: int = max(self.clusters_storage.keys()) + 1 if (self.clusters_storage) else 0
349 # Create the cluster.
350 new_cluster = Cluster(
351 vectors=self.vectors,
352 cluster_ID=new_cluster_ID,
353 clustering_iteration=self.clustering_iteration,
354 members=members,
355 )
357 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`.
358 self.current_clusters.append(new_cluster_ID)
359 self.clusters_storage[new_cluster_ID] = new_cluster
361 return new_cluster_ID
363 # ==============================================================================
364 # ADD CLUSTER BY MERGING CLUSTERS :
365 # ==============================================================================
366 def _add_new_cluster_by_merging_clusters(
367 self,
368 children: List[int],
369 ) -> int:
370 """
371 Create or Update a cluster by setting its children, and add it to the storage and current clusters.
373 Args:
374 children (List[int]): A list of cluster IDs to define the new cluster by its children.
376 Returns:
377 int : ID of the merged cluster.
378 """
380 # Remove all leaves children clusters from `self.current_clusters`.
381 for child_ID_to_remove in children:
382 self.current_clusters.remove(child_ID_to_remove)
384 """
385 ###
386 ### Tree optimization : if only one node, then update this node as parent of all leaves.
387 ### TODO : test of check if relevant to use. pros = smarter tree visualisation ; cons = cluster number more difficult to choose.
388 ###
390 # List of children nodes.
391 list_of_children_nodes: List[int] = [
392 child_ID
393 for child_ID in children
394 if len(self.clusters_storage[child_ID].children) > 0
395 ]
397 if len(list_of_children_nodes) == 1:
399 # Get the ID of the cluster to update
400 parent_cluster_ID: int = list_of_children_nodes[0]
401 parent_cluster: Cluster = self.clusters_storage[parent_cluster_ID]
403 # Add all leaves
404 parent_cluster.add_new_children(
405 new_children=[
406 self.clusters_storage[child_ID]
407 for child_ID in children
408 if child_ID != parent_cluster_ID
409 ],
410 new_clustering_iteration=self.clustering_iteration
411 )
413 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`.
414 self.current_clusters.append(parent_cluster_ID)
415 self.clusters_storage[parent_cluster_ID] = parent_cluster
417 # Return the cluster_ID of the created cluster.
418 return parent_cluster_ID
421 """
423 ###
424 ### Default case : Create a new node as parent of all children to merge.
425 ###
427 # Get the ID of the new cluster.
428 parent_cluster_ID: int = max(self.clusters_storage) + 1
430 # Create the cluster
431 parent_cluster = Cluster(
432 vectors=self.vectors,
433 cluster_ID=parent_cluster_ID,
434 clustering_iteration=self.clustering_iteration,
435 children=[self.clusters_storage[child_ID] for child_ID in children],
436 )
438 # Add new_cluster to `self.current_clusters` and `self.clusters_storage`.
439 self.current_clusters.append(parent_cluster_ID)
440 self.clusters_storage[parent_cluster_ID] = parent_cluster
442 # Return the cluster_ID of the created cluster.
443 return parent_cluster_ID
445 # ==============================================================================
446 # COMPUTE DISTANCE BETWEEN CLUSTERING NEW ITERATION OF CLUSTERING :
447 # ==============================================================================
448 def _compute_distance(self, cluster_IDi: int, cluster_IDj: int) -> float:
449 """
450 Compute distance between two clusters.
452 Args:
453 cluster_IDi (int): ID of the first cluster.
454 cluster_IDj (int): ID of the second cluster.
456 Returns:
457 float : Distance between the two clusters.
458 """
460 # Check `"CANNOT_LINK"` constraints.
461 for data_ID1 in self.clusters_storage[cluster_IDi].members:
462 for data_ID2 in self.clusters_storage[cluster_IDj].members:
463 if (
464 self.constraints_manager.get_inferred_constraint(
465 data_ID1=data_ID1,
466 data_ID2=data_ID2,
467 )
468 == "CANNOT_LINK"
469 ):
470 return float("Inf")
472 # Case 1 : `self.linkage` is "complete".
473 if self.linkage == "complete":
474 return max(
475 [
476 self.dict_of_pairwise_distances[data_ID_in_cluster_IDi][data_ID_in_cluster_IDj]
477 for data_ID_in_cluster_IDi in self.clusters_storage[cluster_IDi].members
478 for data_ID_in_cluster_IDj in self.clusters_storage[cluster_IDj].members
479 ]
480 )
482 # Case 2 : `self.linkage` is "average".
483 if self.linkage == "average":
484 return pairwise_distances(
485 X=self.clusters_storage[cluster_IDi].centroid,
486 Y=self.clusters_storage[cluster_IDj].centroid,
487 metric="euclidean", # TODO: Load different parameters for distance computing ?
488 )[0][0]
490 # Case 3 : `self.linkage` is "single".
491 if self.linkage == "single":
492 return min(
493 [
494 self.dict_of_pairwise_distances[data_ID_in_cluster_IDi][data_ID_in_cluster_IDj]
495 for data_ID_in_cluster_IDi in self.clusters_storage[cluster_IDi].members
496 for data_ID_in_cluster_IDj in self.clusters_storage[cluster_IDj].members
497 ]
498 )
500 # Case 4 : `self.linkage` is "ward".
501 ##if self.linkage == "ward": ## DEFAULTS
502 # Compute distance
503 merged_members: List[str] = (
504 self.clusters_storage[cluster_IDi].members + self.clusters_storage[cluster_IDj].members
505 )
506 return sum(
507 [
508 self.dict_of_pairwise_distances[data_IDi][data_IDj]
509 for i, data_IDi in enumerate(merged_members)
510 for j, data_IDj in enumerate(merged_members)
511 if i < j
512 ]
513 ) / (len(self.clusters_storage[cluster_IDi].members) * len(self.clusters_storage[cluster_IDj].members))
515 # ==============================================================================
516 # DISTANCE : GETTER
517 # ==============================================================================
518 def _get_distance(self, cluster_IDi: int, cluster_IDj: int) -> float:
519 """
520 Get the distance between two clusters.
522 Args:
523 cluster_IDi (int): ID of the first cluster.
524 cluster_IDj (int): ID of the second cluster.
526 Returns:
527 float : Distance between the two clusters.
528 """
530 # Sort IDs of cluster.
531 min_cluster_ID: int = min(cluster_IDi, cluster_IDj)
532 max_cluster_ID: int = max(cluster_IDi, cluster_IDj)
534 # Return the distance.
535 return self.clusters_distance[min_cluster_ID][max_cluster_ID]
537 # ==============================================================================
538 # DISTANCE : SETTER
539 # ==============================================================================
540 def _set_distance(
541 self,
542 distance: float,
543 cluster_IDi: int,
544 cluster_IDj: int,
545 ) -> None:
546 """
547 Set the distance between two clusters.
549 Args:
550 distance (float): The distance between the two clusters.
551 cluster_IDi (int): ID of the first cluster.
552 cluster_IDj (int): ID of the second cluster.
553 """
555 # Sort IDs of cluster.
556 min_cluster_ID: int = min(cluster_IDi, cluster_IDj)
557 max_cluster_ID: int = max(cluster_IDi, cluster_IDj)
559 # Add distance to the dictionary of distance.
560 if min_cluster_ID not in self.clusters_distance:
561 self.clusters_distance[min_cluster_ID] = {}
562 self.clusters_distance[min_cluster_ID][max_cluster_ID] = distance
564 # ==============================================================================
565 # GET THE TWO CLOSEST CLUSTERS
566 # ==============================================================================
567 def _get_the_two_clostest_clusters(self) -> Optional[Tuple[int, int]]:
568 """
569 Get the two clusters which are the two closest clusters.
571 Returns:
572 Optional(Tuple[int, int]) : The IDs of the two closest clusters to merge. Return None if no cluster is suitable.
573 """
575 # Compute the two clostest clusters to merge. take the closest distance, then the closest cluster size.
576 clostest_clusters = min(
577 [
578 {
579 "cluster_ID1": cluster_ID1,
580 "cluster_ID2": cluster_ID2,
581 "distance": self._get_distance(cluster_IDi=cluster_ID1, cluster_IDj=cluster_ID2),
582 "merged_size": len(self.clusters_storage[cluster_ID1].members)
583 + len(self.clusters_storage[cluster_ID2].members)
584 # TODO : Choose between "distance then size(count)" and "size_type(boolean) then distance"
585 }
586 for cluster_ID1 in self.current_clusters
587 for cluster_ID2 in self.current_clusters
588 if cluster_ID1 < cluster_ID2
589 ],
590 key=lambda dst: (dst["distance"], dst["merged_size"]),
591 )
593 # Get clusters and distance.
594 cluster_ID1: int = int(clostest_clusters["cluster_ID1"])
595 cluster_ID2: int = int(clostest_clusters["cluster_ID2"])
596 distance: float = clostest_clusters["distance"]
598 # Check distance.
599 if distance == float("Inf"):
600 return None
602 # Return the tow closest clusters.
603 return cluster_ID1, cluster_ID2
605 # ==============================================================================
606 # COMPUTE PREDICTED CLUSTERS
607 # ==============================================================================
608 def compute_predicted_clusters(self, nb_clusters: int, by: str = "size") -> Dict[str, int]:
609 """
610 Compute the predicted clusters based on clustering tree and estimation of number of clusters.
612 Args:
613 nb_clusters (int): The number of clusters to compute.
614 by (str, optional): A string to identifies the criteria used to explore `HierarchicalConstrainedClustering` tree. Can be `"size"` or `"iteration"`. Defaults to `"size"`.
616 Raises:
617 ValueError: if `clustering_root` was not set.
619 Returns:
620 Dict[str,int] : A dictionary that contains the predicted cluster for each data ID.
621 """
623 # Check that the clustering has been made.
624 if self.clustering_root is None:
625 raise ValueError("The `clustering_root` is not set, probably because clustering was not run.")
627 ###
628 ### EXPLORE CLUSTER TREE
629 ###
631 # Define the resulted list of children as the children of `HierarchicalConstrainedClustering` root.
632 list_of_clusters: List[Cluster] = [self.clustering_root]
634 # Explore `HierarchicalConstrainedClustering` children until dict_of_predicted_clusters has the right number of children.
635 while len(list_of_clusters) < nb_clusters:
636 if by == "size":
637 # Get the biggest cluster in current children from `HierarchicalConstrainedClustering` exploration.
638 # i.e. it's the cluster that has the more data to split.
639 cluster_to_split = max(list_of_clusters, key=lambda c: len(c.members))
641 else: # if by == "iteration":
642 # Get the most recent cluster in current children from `HierarchicalConstrainedClustering` exploration.
643 # i.e. it's the cluster that was last merged.
644 cluster_to_split = max(list_of_clusters, key=lambda c: c.clustering_iteration)
646 # If the chosen cluster is a leaf : break the `HierarchicalConstrainedClustering` exploration.
647 if cluster_to_split.children == []: # noqa: WPS520
648 break
650 # Otherwise: The chosen cluster is a node, so split it and get its children.
651 else:
652 # ... remove the cluster obtained ...
653 list_of_clusters.remove(cluster_to_split)
655 # ... and add all its children.
656 for child in cluster_to_split.children:
657 list_of_clusters.append(child)
659 ###
660 ### GET PREDICTED CLUSTERS
661 ###
663 # Initialize the dictionary of predicted clusters.
664 predicted_clusters: Dict[str, int] = {data_ID: -1 for data_ID in self.list_of_data_IDs}
666 # For all cluster...
667 for cluster in list_of_clusters:
668 # ... and for all member in each cluster...
669 for data_ID in cluster.members:
670 # ... affect the predicted cluster (cluster ID) to the data.
671 predicted_clusters[data_ID] = cluster.cluster_ID
673 # Rename cluster IDs by order.
674 predicted_clusters = rename_clusters_by_order(clusters=predicted_clusters)
676 # Return predicted clusters
677 return predicted_clusters
680# ==============================================================================
681# CLUSTER
682# ==============================================================================
683class Cluster:
684 """
685 This class represents a cluster as a node of the hierarchical clustering tree.
686 """
688 # ==============================================================================
689 # INITIALIZATION
690 # ==============================================================================
691 def __init__(
692 self,
693 vectors: Dict[str, csr_matrix],
694 cluster_ID: int,
695 clustering_iteration: int,
696 children: Optional[List["Cluster"]] = None,
697 members: Optional[List[str]] = None,
698 ) -> None:
699 """
700 The constructor for Cluster class.
702 Args:
703 vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager` (if `constraints_manager` is set). The value of the dictionary represent the vector of each data.
704 cluster_ID (int): The cluster ID that is defined during `HierarchicalConstrainedClustering.cluster` running.
705 clustering_iteration (int): The cluster iteration that is defined during `HierarchicalConstrainedClustering.cluster` running.
706 children (Optional[List["Cluster"]], optional): A list of clusters children for cluster initialization. Incompatible with `members` parameter. Defaults to `None`.
707 members (Optional[List[str]], optional): A list of data IDs for cluster initialization. Incompatible with `children` parameter. Defaults to `None`.
709 Raises:
710 ValueError: if `children` and `members` are both set or both unset.
711 """
713 # Store links to `vectors`.
714 self.vectors: Dict[str, csr_matrix] = vectors
716 # Cluster ID and Clustering iteration.
717 self.cluster_ID: int = cluster_ID
718 self.clustering_iteration: int = clustering_iteration
720 # Check children and members.
721 if ((children is not None) and (members is not None)) or ((children is None) and (members is None)):
722 raise ValueError(
723 "Cluster initialization must be by `children` setting or by `members` setting, but not by both or none of them."
724 )
726 # Add children (empty or not).
727 self.children: List["Cluster"] = children if (children is not None) else []
729 # Cluster inverse depth.
730 self.cluster_inverse_depth: int = (
731 max([child.cluster_inverse_depth for child in self.children]) + 1 if (self.children) else 0
732 )
734 # Add members (empty or not).
735 self.members: List[str] = (
736 members if members is not None else [data_ID for child in self.children for data_ID in child.members]
737 )
739 # Update centroids
740 self.update_centroid()
742 # ==============================================================================
743 # ADD NEW CHILDREN :
744 # ==============================================================================
745 def add_new_children(
746 self,
747 new_children: List["Cluster"],
748 new_clustering_iteration: int,
749 ) -> None:
750 """
751 Add new children to the cluster.
753 Args:
754 new_children (List["Cluster"]): The list of new clusters children to add.
755 new_clustering_iteration (int): The new cluster iteration that is defined during HierarchicalConstrainedClustering.clusterize running.
756 """
758 # Update clustering iteration.
759 self.clustering_iteration = new_clustering_iteration
761 # Update children.
762 self.children += [new_child for new_child in new_children if new_child not in self.children]
764 # Update cluster inverse depth.
765 self.cluster_inverse_depth = max([child.cluster_inverse_depth for child in self.children]) + 1
767 # Update members.
768 self.members = [data_ID for child in self.children for data_ID in child.members]
770 # Update centroids.
771 self.update_centroid()
773 # ==============================================================================
774 # UPDATE CENTROIDS :
775 # ==============================================================================
776 def update_centroid(self) -> None:
777 """
778 Update centroid of the cluster.
779 """
781 # Update centroids.
782 self.centroid: csr_matrix = sum([self.vectors[data_ID] for data_ID in self.members]) / self.get_cluster_size()
784 # ==============================================================================
785 # GET CLUSTER SIZE :
786 # ==============================================================================
787 def get_cluster_size(self) -> int:
788 """
789 Get cluster size.
791 Returns:
792 int: The cluster size, i.e. the number of members in the cluster.
793 """
795 # Update centroids.
796 return len(self.members)
798 # ==============================================================================
799 # TO DICTIONARY :
800 # ==============================================================================
801 def to_dict(self) -> Dict[str, Any]:
802 """
803 Transform the Cluster object into a dictionary. It can be used before serialize this object in JSON.
805 Returns:
806 Dict[str, Any]: A dictionary that represents the Cluster object.
807 """
809 # Define the result dictionary.
810 results: Dict[str, Any] = {}
812 # Add clustering information.
813 results["cluster_ID"] = self.cluster_ID
814 results["clustering_iteration"] = self.clustering_iteration
816 # Add children information.
817 results["children"] = [child.to_dict() for child in self.children]
818 results["cluster_inverse_depth"] = self.cluster_inverse_depth
820 # Add members information.
821 results["members"] = self.members
823 return results