Coverage for src\cognitivefactory\interactive_clustering\sampling\clusters_based.py: 100.00%
57 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
1# -*- coding: utf-8 -*-
3"""
4* Name: cognitivefactory.interactive_clustering.sampling.clusters_based
5* Description: Implementation of constraints sampling based on clusters information.
6* Author: Erwan SCHILD
7* Created: 04/10/2021
8* Licence: CeCILL (https://cecill.info/licences.fr.html)
9"""
11# ==============================================================================
12# IMPORT PYTHON DEPENDENCIES
13# ==============================================================================
15import random
16from typing import Dict, List, Optional, Tuple
18from scipy.sparse import csr_matrix, vstack
19from sklearn.metrics import pairwise_distances
21from cognitivefactory.interactive_clustering.constraints.abstract import AbstractConstraintsManager
22from cognitivefactory.interactive_clustering.sampling.abstract import AbstractConstraintsSampling
25# ==============================================================================
26# CLUSTERS BASED CONSTRAINTS SAMPLING
27# ==============================================================================
28class ClustersBasedConstraintsSampling(AbstractConstraintsSampling):
29 """
30 This class implements the sampling of data IDs based on clusters information in order to annotate constraints.
31 It inherits from `AbstractConstraintsSampling`.
33 Example:
34 ```python
35 # Import.
36 from cognitivefactory.interactive_clustering.constraints.binary import BinaryConstraintsManager
37 from cognitivefactory.interactive_clustering.sampling.clusters_based import ClustersBasedConstraintsSampling
39 # Create an instance of random sampling.
40 sampler = ClustersBasedConstraintsSampling(random_seed=1)
42 # Define list of data IDs.
43 list_of_data_IDs = ["bonjour", "salut", "coucou", "au revoir", "a bientôt",]
45 # Define constraints manager.
46 constraints_manager = BinaryConstraintsManager(
47 list_of_data_IDs=list_of_data_IDs,
48 )
49 constraints_manager.add_constraint(data_ID1="bonjour", data_ID2="salut", constraint_type="MUST_LINK")
50 constraints_manager.add_constraint(data_ID1="au revoir", data_ID2="a bientôt", constraint_type="MUST_LINK")
52 # Run sampling.
53 selection = sampler.sample(
54 constraints_manager=constraints_manager,
55 nb_to_select=3,
56 )
58 # Print results.
59 print("Expected results", ";", [("au revoir", "bonjour"), ("bonjour", "coucou"), ("a bientôt", "coucou"),])
60 print("Computed results", ":", selection)
61 ```
62 """
64 # ==============================================================================
65 # INITIALIZATION
66 # ==============================================================================
67 def __init__(
68 self,
69 random_seed: Optional[int] = None,
70 clusters_restriction: Optional[str] = None,
71 distance_restriction: Optional[str] = None,
72 without_added_constraints: bool = True,
73 without_inferred_constraints: bool = True,
74 **kargs,
75 ) -> None:
76 """
77 The constructor for Clusters Based Constraints Sampling class.
79 Args:
80 random_seed (Optional[int]): The random seed to use to redo the same sampling. Defaults to `None`.
81 clusters_restriction (Optional[str]): Restrict the sampling with a cluster constraints. Can impose data IDs to be in `"same_cluster"` or `"different_clusters"`. Defaults to `None`. # TODO: `"specific_clusters"`
82 distance_restriction (Optional[str]): Restrict the sampling with a distance constraints. Can impose data IDs to be `"closest_neighbors"` or `"farthest_neighbors"`. Defaults to `None`.
83 without_added_constraints (bool): Option to not sample the already added constraints. Defaults to `True`.
84 without_inferred_constraints (bool): Option to not sample the deduced constraints from already added one. Defaults to `True`.
85 **kargs (dict): Other parameters that can be used in the instantiation.
87 Raises:
88 ValueError: if some parameters are incorrectly set.
89 """
91 # Store `self.random_seed`.
92 self.random_seed: Optional[int] = random_seed
94 # Store clusters restriction.
95 if clusters_restriction not in {None, "same_cluster", "different_clusters"}:
96 raise ValueError("The `clusters_restriction` '" + str(clusters_restriction) + "' is not implemented.")
97 self.clusters_restriction: Optional[str] = clusters_restriction
99 # Store distance restriction.
100 if distance_restriction not in {None, "closest_neighbors", "farthest_neighbors"}:
101 raise ValueError("The `distance_restriction` '" + str(distance_restriction) + "' is not implemented.")
102 self.distance_restriction: Optional[str] = distance_restriction
104 # Store constraints restrictions.
105 if not isinstance(without_added_constraints, bool):
106 raise ValueError("The `without_added_constraints` must be boolean")
107 self.without_added_constraints: bool = without_added_constraints
108 if not isinstance(without_inferred_constraints, bool):
109 raise ValueError("The `without_inferred_constraints` must be boolean")
110 self.without_inferred_constraints: bool = without_inferred_constraints
112 # ==============================================================================
113 # MAIN - SAMPLE
114 # ==============================================================================
115 def sample(
116 self,
117 constraints_manager: AbstractConstraintsManager,
118 nb_to_select: int,
119 clustering_result: Optional[Dict[str, int]] = None,
120 vectors: Optional[Dict[str, csr_matrix]] = None,
121 **kargs,
122 ) -> List[Tuple[str, str]]:
123 """
124 The main method used to sample pairs of data IDs for constraints annotation.
126 Args:
127 constraints_manager (AbstractConstraintsManager): A constraints manager over data IDs.
128 nb_to_select (int): The number of pairs of data IDs to sample.
129 clustering_result (Optional[Dict[str,int]], optional): A dictionary that represents the predicted cluster for each data ID. The keys of the dictionary represents the data IDs. If `None`, no clustering result are used during the sampling. Defaults to `None`.
130 vectors (Optional[Dict[str, csr_matrix]], optional): vectors (Dict[str, csr_matrix]): The representation of data vectors. The keys of the dictionary represents the data IDs. This keys have to refer to the list of data IDs managed by the `constraints_manager`. The value of the dictionary represent the vector of each data. If `None`, no vectors are used during the sampling. Defaults to `None`
131 **kargs (dict): Other parameters that can be used in the sampling.
133 Raises:
134 ValueError: if some parameters are incorrectly set or incompatible.
136 Returns:
137 List[Tuple[str,str]]: A list of couple of data IDs.
138 """
140 ###
141 ### GET PARAMETERS
142 ###
144 # Check `constraints_manager`.
145 if not isinstance(constraints_manager, AbstractConstraintsManager):
146 raise ValueError("The `constraints_manager` parameter has to be a `AbstractConstraintsManager` type.")
147 self.constraints_manager: AbstractConstraintsManager = constraints_manager
149 # Check `nb_to_select`.
150 if not isinstance(nb_to_select, int) or (nb_to_select < 0):
151 raise ValueError("The `nb_to_select` '" + str(nb_to_select) + "' must be greater than or equal to 0.")
152 elif nb_to_select == 0:
153 return []
155 # If `self.cluster_restriction` is set, check `clustering_result` parameters.
156 if self.clusters_restriction is not None:
157 if not isinstance(clustering_result, dict):
158 raise ValueError("The `clustering_result` parameter has to be a `Dict[str, int]` type.")
159 self.clustering_result: Dict[str, int] = clustering_result
161 # If `self.distance_restriction` is set, check `vectors` parameters.
162 if self.distance_restriction is not None:
163 if not isinstance(vectors, dict):
164 raise ValueError("The `vectors` parameter has to be a `Dict[str, csr_matrix]` type.")
165 self.vectors: Dict[str, csr_matrix] = vectors
167 ###
168 ### DEFINE POSSIBLE PAIRS OF DATA IDS
169 ###
171 # Initialize possible pairs of data IDs
172 list_of_possible_pairs_of_data_IDs: List[Tuple[str, str]] = []
174 # Loop over pairs of data IDs.
175 for data_ID1 in self.constraints_manager.get_list_of_managed_data_IDs():
176 for data_ID2 in self.constraints_manager.get_list_of_managed_data_IDs():
177 # Select ordered pairs.
178 if data_ID1 >= data_ID2:
179 continue
181 # Check clusters restriction.
182 if (
183 self.clusters_restriction == "same_cluster"
184 and self.clustering_result[data_ID1] != self.clustering_result[data_ID2]
185 ) or (
186 self.clusters_restriction == "different_clusters"
187 and self.clustering_result[data_ID1] == self.clustering_result[data_ID2]
188 ):
189 continue
191 # Check known constraints.
192 if (
193 self.without_added_constraints is True
194 and self.constraints_manager.get_added_constraint(data_ID1=data_ID1, data_ID2=data_ID2) is not None
195 ) or (
196 self.without_inferred_constraints is True
197 and self.constraints_manager.get_inferred_constraint(data_ID1=data_ID1, data_ID2=data_ID2)
198 is not None
199 ):
200 continue
202 # Add the pair of data IDs.
203 list_of_possible_pairs_of_data_IDs.append((data_ID1, data_ID2))
205 ###
206 ### SAMPLING
207 ###
209 # Precompute pairwise distances.
210 if self.distance_restriction is not None:
211 # Compute pairwise distances.
212 matrix_of_pairwise_distances: csr_matrix = pairwise_distances(
213 X=vstack(self.vectors[data_ID] for data_ID in self.constraints_manager.get_list_of_managed_data_IDs()),
214 metric="euclidean", # TODO get different pairwise_distances config in **kargs
215 )
217 # Format pairwise distances in a dictionary.
218 self.dict_of_pairwise_distances: Dict[str, Dict[str, float]] = {
219 vector_ID1: {
220 vector_ID2: float(matrix_of_pairwise_distances[i1, i2])
221 for i2, vector_ID2 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs())
222 }
223 for i1, vector_ID1 in enumerate(self.constraints_manager.get_list_of_managed_data_IDs())
224 }
226 # Set random seed.
227 random.seed(self.random_seed)
229 # Case of closest neightbors selection.
230 if self.distance_restriction == "closest_neighbors":
231 return sorted(
232 list_of_possible_pairs_of_data_IDs,
233 key=lambda combination: self.dict_of_pairwise_distances[combination[0]][combination[1]],
234 )[:nb_to_select]
236 # Case of farthest neightbors selection.
237 if self.distance_restriction == "farthest_neighbors":
238 return sorted(
239 list_of_possible_pairs_of_data_IDs,
240 key=lambda combination: self.dict_of_pairwise_distances[combination[0]][combination[1]],
241 reverse=True,
242 )[:nb_to_select]
244 # (default) Case of random selection.
245 return random.sample(
246 list_of_possible_pairs_of_data_IDs, k=min(nb_to_select, len(list_of_possible_pairs_of_data_IDs))
247 )