Coverage for src\cognitivefactory\interactive_clustering\constraints\binary.py: 100.00%
145 statements
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
« prev ^ index » next coverage.py v7.3.2, created at 2023-11-17 13:31 +0100
1# -*- coding: utf-8 -*-
3"""
4* Name: cognitivefactory.interactive_clustering.constraints.binary
5* Description: Implementation of binary constraints manager.
6* Author: Erwan SCHILD
7* Created: 17/03/2021
8* Licence: CeCILL-C License v1.0 (https://cecill.info/licences.fr.html)
9"""
11# ==============================================================================
12# IMPORT PYTHON DEPENDENCIES
13# ==============================================================================
15import json
16from typing import Any, Dict, List, Optional, Set, Tuple
18import networkx as nx
20from cognitivefactory.interactive_clustering.constraints.abstract import AbstractConstraintsManager
23# ==============================================================================
24# BINARY CONSTRAINTS MANAGER
25# ==============================================================================
26class BinaryConstraintsManager(AbstractConstraintsManager):
27 """
28 This class implements the binary constraints mangement.
29 It inherits from `AbstractConstraintsManager`, and it takes into account the strong transitivity of constraints.
31 References:
32 - Binary constraints in clustering: `Wagstaff, K. et C. Cardie (2000). Clustering with Instance-level Constraints. Proceedings of the Seventeenth International Conference on Machine Learning, 1103–1110.`
34 Example:
35 ```python
36 # Import.
37 from cognitivefactory.interactive_clustering.constraints.binary import BinaryConstraintsManager
39 # Create an instance of binary constraints manager.
40 constraints_manager = BinaryConstraintsManager(list_of_data_IDs=["0", "1", "2", "3", "4"])
42 # Add new data ID.
43 constraints_manager.add_data_ID(data_ID="99")
45 # Get list of data IDs.
46 constraints_manager.get_list_of_managed_data_IDs()
48 # Delete an existing data ID.
49 constraints_manager.delete_data_ID(data_ID="99")
51 # Add constraints.
52 constraints_manager.add_constraint(data_ID1="0", data_ID2="1", constraint_type="MUST_LINK")
53 constraints_manager.add_constraint(data_ID1="1", data_ID2="2", constraint_type="MUST_LINK")
54 constraints_manager.add_constraint(data_ID1="2", data_ID2="3", constraint_type="CANNOT_LINK")
56 # Get added constraint.
57 constraints_manager.get_added_constraint(data_ID1="0", data_ID2="1") # expected ("MUST_LINK", 1.0)
58 constraints_manager.get_added_constraint(data_ID1="0", data_ID2="2") # expected None
60 # Get inferred constraint.
61 constraints_manager.get_inferred_constraint(data_ID1="0", data_ID2="2") # expected "MUST_LINK"
62 constraints_manager.get_inferred_constraint(data_ID1="0", data_ID2="3") # expected "CANNOT_LINK"
63 constraints_manager.get_inferred_constraint(data_ID1="0", data_ID2="4") # expected None
64 ```
65 """
67 # ==============================================================================
68 # INITIALIZATION
69 # ==============================================================================
70 def __init__(self, list_of_data_IDs: List[str], **kargs) -> None:
71 """
72 The constructor for Binary Constraints Manager class.
73 This class use the strong transitivity to infer on constraints, so constraints values are not taken into account.
75 Args:
76 list_of_data_IDs (List[str]): The list of data IDs to manage.
77 **kargs (dict): Other parameters that can be used in the instantiation.
78 """
80 # Define `self._allowed_constraint_types`.
81 self._allowed_constraint_types: Set[str] = {
82 "MUST_LINK",
83 "CANNOT_LINK",
84 }
85 # Define `self._allowed_constraint_value_range`.
86 self._allowed_constraint_value_range: Dict[str, float] = {
87 "min": 1.0,
88 "max": 1.0,
89 }
91 # Store `self.kargs` for binary constraints managing.
92 self.kargs = kargs
94 # Initialize `self._constraints_dictionary`.
95 self._constraints_dictionary: Dict[str, Dict[str, Optional[Tuple[str, float]]]] = {
96 data_ID1: {
97 data_ID2: (
98 ("MUST_LINK", 1.0)
99 if (data_ID1 == data_ID2)
100 else None # Unknwon constraints if `data_ID1` != `data_ID2`.
101 )
102 for data_ID2 in list_of_data_IDs
103 if (data_ID1 <= data_ID2)
104 }
105 for data_ID1 in list_of_data_IDs
106 }
108 # Define `self._constraints_transitivity`.
109 # `Equivalent to `self._generate_constraints_transitivity()`
110 self._constraints_transitivity: Dict[str, Dict[str, Dict[str, None]]] = {
111 data_ID: {
112 "MUST_LINK": {data_ID: None}, # Initialize MUST_LINK clusters constraints.
113 "CANNOT_LINK": {}, # Initialize CANNOT_LINK clusters constraints.
114 }
115 for data_ID in list_of_data_IDs
116 }
118 # ==============================================================================
119 # DATA_ID MANAGEMENT - ADDITION
120 # ==============================================================================
121 def add_data_ID(
122 self,
123 data_ID: str,
124 ) -> bool:
125 """
126 The main method used to add a new data ID to manage.
128 Args:
129 data_ID (str): The data ID to manage.
131 Raises:
132 ValueError: if `data_ID` is already managed.
134 Returns:
135 bool: `True` if the addition is done.
136 """
138 # If `data_ID` is in the data IDs that are currently managed, then raises a `ValueError`.
139 if data_ID in self._constraints_dictionary.keys():
140 raise ValueError("The `data_ID` `'" + str(data_ID) + "'` is already managed.")
142 # Add `data_ID` to `self._constraints_dictionary.keys()`.
143 self._constraints_dictionary[data_ID] = {}
145 # Define constraint for `data_ID` and all other data IDs.
146 for other_data_ID in self._constraints_dictionary.keys():
147 if data_ID == other_data_ID:
148 self._constraints_dictionary[data_ID][data_ID] = ("MUST_LINK", 1.0)
149 elif data_ID < other_data_ID:
150 self._constraints_dictionary[data_ID][other_data_ID] = None
151 else: # elif data_ID > other_data_ID:
152 self._constraints_dictionary[other_data_ID][data_ID] = None
154 # Regenerate `self._constraints_transitivity`.
155 # `Equivalent to `self._generate_constraints_transitivity()`
156 self._constraints_transitivity[data_ID] = {
157 "MUST_LINK": {data_ID: None},
158 "CANNOT_LINK": {},
159 }
161 # Return `True`.
162 return True
164 # ==============================================================================
165 # DATA_ID MANAGEMENT - DELETION
166 # ==============================================================================
167 def delete_data_ID(
168 self,
169 data_ID: str,
170 ) -> bool:
171 """
172 The main method used to delete a data ID to no longer manage.
174 Args:
175 data_ID (str): The data ID to no longer manage.
177 Raises:
178 ValueError: if `data_ID` is not managed.
180 Returns:
181 bool: `True` if the deletion is done.
182 """
184 # If `data_ID` is not in the data IDs that are currently managed, then raises a `ValueError`.
185 if data_ID not in self._constraints_dictionary.keys():
186 raise ValueError("The `data_ID` `'" + str(data_ID) + "'` is not managed.")
188 # Remove `data_ID` from `self._constraints_dictionary.keys()`.
189 self._constraints_dictionary.pop(data_ID, None)
191 # Remove `data_ID` from all `self._constraints_dictionary[other_data_ID].keys()`.
192 for other_data_ID in self._constraints_dictionary.keys():
193 self._constraints_dictionary[other_data_ID].pop(data_ID, None)
195 # Regenerate `self._constraints_transitivity`
196 self._generate_constraints_transitivity()
198 # Return `True`.
199 return True
201 # ==============================================================================
202 # DATA_ID MANAGEMENT - LISTING
203 # ==============================================================================
204 def get_list_of_managed_data_IDs(
205 self,
206 ) -> List[str]:
207 """
208 The main method used to get the list of data IDs that are managed.
210 Returns:
211 List[str]: The list of data IDs that are managed.
212 """
214 # Return the possible keys of `self._constraints_dictionary`.
215 return list(self._constraints_dictionary.keys())
217 # ==============================================================================
218 # CONSTRAINTS MANAGEMENT - ADDITION
219 # ==============================================================================
220 def add_constraint(
221 self,
222 data_ID1: str,
223 data_ID2: str,
224 constraint_type: str,
225 constraint_value: float = 1.0,
226 ) -> bool:
227 """
228 The main method used to add a constraint between two data IDs.
230 Args:
231 data_ID1 (str): The first data ID that is concerned for this constraint addition.
232 data_ID2 (str): The second data ID that is concerned for this constraint addition.
233 constraint_type (str): The type of the constraint to add. The type have to be `"MUST_LINK"` or `"CANNOT_LINK"`.
234 constraint_value (float, optional): The value of the constraint to add. The value have to be in range `[0.0, 1.0]`. Defaults to `1.0`.
236 Raises:
237 ValueError: if `data_ID1`, `data_ID2`, `constraint_type` are not managed, or if a conflict is detected with constraints inference.
239 Returns:
240 bool: `True` if the addition is done, `False` is the constraint can't be added.
241 """
243 # If `data_ID1` is not in the data IDs that are currently managed, then raises a `ValueError`.
244 if data_ID1 not in self._constraints_dictionary.keys():
245 raise ValueError("The `data_ID1` `'" + str(data_ID1) + "'` is not managed.")
247 # If `data_ID2` is not in the data IDs that are currently managed, then raises a `ValueError`.
248 if data_ID2 not in self._constraints_dictionary.keys():
249 raise ValueError("The `data_ID2` `'" + str(data_ID2) + "'` is not managed.")
251 # If the `constraint_type` is not in `self._allowed_constraint_types`, then raises a `ValueError`.
252 if constraint_type not in self._allowed_constraint_types:
253 raise ValueError(
254 "The `constraint_type` `'"
255 + str(constraint_type)
256 + "'` is not managed. Allowed constraints types are : `"
257 + str(self._allowed_constraint_types)
258 + "`."
259 )
261 # Get current added constraint between `data_ID1` and `data_ID2`.
262 inferred_constraint: Optional[str] = self.get_inferred_constraint(
263 data_ID1=data_ID1,
264 data_ID2=data_ID2,
265 )
267 # Case of conflict with constraints inference.
268 if (inferred_constraint is not None) and (inferred_constraint != constraint_type):
269 raise ValueError(
270 "The `constraint_type` `'"
271 + str(constraint_type)
272 + "'` is incompatible with the inferred constraint `'"
273 + str(inferred_constraint)
274 + "'` between data IDs `'"
275 + data_ID1
276 + "'` and `'"
277 + data_ID2
278 + "'`."
279 )
281 # Get current added constraint between `data_ID1` and `data_ID2`.
282 added_constraint: Optional[Tuple[str, float]] = self.get_added_constraint(
283 data_ID1=data_ID1,
284 data_ID2=data_ID2,
285 )
287 # If the constraint has already be added, ...
288 if added_constraint is not None:
289 # ... do nothing.
290 return True # `added_constraint[0] == constraint_type`.
291 # Otherwise, the constraint has to be added.
293 # Add the direct constraint between `data_ID1` and `data_ID2`.
294 if data_ID1 <= data_ID2:
295 self._constraints_dictionary[data_ID1][data_ID2] = (constraint_type, 1.0)
296 else:
297 self._constraints_dictionary[data_ID2][data_ID1] = (constraint_type, 1.0)
299 # Add the transitivity constraint between `data_ID1` and `data_ID2`.
300 self._add_constraint_transitivity(
301 data_ID1=data_ID1,
302 data_ID2=data_ID2,
303 constraint_type=constraint_type,
304 )
306 return True
308 # ==============================================================================
309 # CONSTRAINTS MANAGEMENT - DELETION
310 # ==============================================================================
311 def delete_constraint(
312 self,
313 data_ID1: str,
314 data_ID2: str,
315 ) -> bool:
316 """
317 The main method used to delete a constraint between two data IDs.
319 Args:
320 data_ID1 (str): The first data ID that is concerned for this constraint deletion.
321 data_ID2 (str): The second data ID that is concerned for this constraint deletion.
323 Raises:
324 ValueError: if `data_ID1` or `data_ID2` are not managed.
326 Returns:
327 bool: `True` if the deletion is done, `False` if the constraint can't be deleted.
328 """
330 # If `data_ID1` is not in the data IDs that are currently managed, then raises a `ValueError`.
331 if data_ID1 not in self._constraints_dictionary.keys():
332 raise ValueError("The `data_ID1` `'" + str(data_ID1) + "'` is not managed.")
334 # If `data_ID2` is not in the data IDs that are currently managed, then raises a `ValueError`.
335 if data_ID2 not in self._constraints_dictionary.keys():
336 raise ValueError("The `data_ID2` `'" + str(data_ID2) + "'` is not managed.")
338 # Delete the constraint between `data_ID1` and `data_ID2`.
339 if data_ID1 <= data_ID2:
340 self._constraints_dictionary[data_ID1][data_ID2] = None
341 else:
342 self._constraints_dictionary[data_ID2][data_ID1] = None
344 # Regenerate `self._constraints_transitivity`.
345 self._generate_constraints_transitivity()
347 # Return `True`
348 return True
350 # ==============================================================================
351 # CONSTRAINTS MANAGEMENT - GETTER
352 # ==============================================================================
353 def get_added_constraint(
354 self,
355 data_ID1: str,
356 data_ID2: str,
357 ) -> Optional[Tuple[str, float]]:
358 """
359 The main method used to get the constraint added between the two data IDs.
360 Do not take into account the constraints transitivity, just look at constraints that are explicitly added.
362 Args:
363 data_ID1 (str): The first data ID that is concerned for this constraint.
364 data_ID2 (str): The second data ID that is concerned for this constraint.
366 Raises:
367 ValueError: if `data_ID1` or `data_ID2` are not managed.
369 Returns:
370 Optional[Tuple[str, float]]: `None` if no constraint, `(constraint_type, constraint_value)` otherwise.
371 """
373 # If `data_ID1` is not in the data IDs that are currently managed, then raises a `ValueError`.
374 if data_ID1 not in self._constraints_dictionary.keys():
375 raise ValueError("The `data_ID1` `'" + str(data_ID1) + "'` is not managed.")
377 # If `data_ID2` is not in the data IDs that are currently managed, then raises a `ValueError`.
378 if data_ID2 not in self._constraints_dictionary.keys():
379 raise ValueError("The `data_ID2` `'" + str(data_ID2) + "'` is not managed.")
381 # Retrun the current added constraint type and value.
382 return (
383 self._constraints_dictionary[data_ID1][data_ID2]
384 if (data_ID1 <= data_ID2)
385 else self._constraints_dictionary[data_ID2][data_ID1]
386 )
388 # ==============================================================================
389 # CONSTRAINTS EXPLORATION - GETTER
390 # ==============================================================================
391 def get_inferred_constraint(
392 self,
393 data_ID1: str,
394 data_ID2: str,
395 threshold: float = 1.0,
396 ) -> Optional[str]:
397 """
398 The main method used to check if the constraint inferred by transitivity between the two data IDs.
399 The transitivity is taken into account, and the `threshold` parameter is used to evaluate the impact of constraints transitivity.
401 Args:
402 data_ID1 (str): The first data ID that is concerned for this constraint.
403 data_ID2 (str): The second data ID that is concerned for this constraint.
404 threshold (float, optional): The threshold used to evaluate the impact of constraints transitivity link. Defaults to `1.0`.
406 Raises:
407 ValueError: if `data_ID1`, `data_ID2` or `threshold` are not managed.
409 Returns:
410 Optional[str]: The type of the inferred constraint. The type can be `None`, `"MUST_LINK"` or `"CANNOT_LINK"`.
411 """
413 # If `data_ID1` is not in the data IDs that are currently managed, then raises a `ValueError`.
414 if data_ID1 not in self._constraints_transitivity.keys():
415 raise ValueError("The `data_ID1` `'" + str(data_ID1) + "'` is not managed.")
417 # If `data_ID2` is not in the data IDs that are currently managed, then raises a `ValueError`.
418 if data_ID2 not in self._constraints_transitivity.keys():
419 raise ValueError("The `data_ID2` `'" + str(data_ID2) + "'` is not managed.")
421 # Case of `"MUST_LINK"`.
422 if data_ID1 in self._constraints_transitivity[data_ID2]["MUST_LINK"].keys():
423 return "MUST_LINK"
425 # Case of `"CANNOT_LINK"`.
426 if data_ID1 in self._constraints_transitivity[data_ID2]["CANNOT_LINK"].keys():
427 return "CANNOT_LINK"
429 # Case of `None`.
430 return None
432 # ==============================================================================
433 # CONSTRAINTS EXPLORATION - LIST OF COMPONENTS GETTER
434 # ==============================================================================
435 def get_connected_components(
436 self,
437 threshold: float = 1.0,
438 ) -> List[List[str]]:
439 """
440 The main method used to get the possible lists of data IDs that are linked by a `"MUST_LINK"` constraints.
441 Each list forms a component of the constraints transitivity graph, and it forms a partition of the managed data IDs.
442 The transitivity is taken into account, and the `threshold` parameters is used if constraints values are used in the constraints transitivity.
444 Args:
445 threshold (float, optional): The threshold used to define the transitivity link. Defaults to `1.0`.
447 Returns:
448 List[List[int]]: The list of lists of data IDs that represent a component of the constraints transitivity graph.
449 """
451 # Initialize the list of connected components.
452 list_of_connected_components: List[List[str]] = []
454 # For each data ID...
455 for data_ID in self._constraints_transitivity.keys():
456 # ... get the list of `"MUST_LINK"` data IDs linked by transitivity with `data_ID` ...
457 connected_component_of_a_data_ID = list(self._constraints_transitivity[data_ID]["MUST_LINK"].keys())
459 # ... and if the connected component is not already get...
460 if connected_component_of_a_data_ID not in list_of_connected_components:
461 # ... then add it to the list of connected components.
462 list_of_connected_components.append(connected_component_of_a_data_ID)
464 # Return the list of connected components.
465 return list_of_connected_components
467 # ==============================================================================
468 # CONSTRAINTS EXPLORATION - CHECK COMPLETUDE OF CONSTRAINTS
469 # ==============================================================================
470 def check_completude_of_constraints(
471 self,
472 threshold: float = 1.0,
473 ) -> bool:
474 """
475 The main method used to check if all possible constraints are known (not necessarily annotated because of the transitivity).
476 The transitivity is taken into account, and the `threshold` parameters is used if constraints values are used in the constraints transitivity.
478 Args:
479 threshold (float, optional): The threshold used to define the transitivity link. Defaults to `1.0`.
481 Returns:
482 bool: Return `True` if all constraints are known, `False` otherwise.
483 """
485 # For each data ID...
486 for data_ID in self._constraints_transitivity.keys():
487 # ... if some data IDs are not linked by transitivity to this `data_ID` with a `"MUST_LINK"` or `"CANNOT_LINK"` constraints...
488 if (
489 len(self._constraints_transitivity[data_ID]["MUST_LINK"].keys())
490 + len(self._constraints_transitivity[data_ID]["CANNOT_LINK"].keys())
491 ) != len(self._constraints_transitivity.keys()):
492 # ... then return `False`.
493 return False
495 # Otherwise, return `True`.
496 return True
498 # ==============================================================================
499 # CONSTRAINTS EXPLORATION - GET MIN AND MAX NUMBER OF CLUSTERS
500 # ==============================================================================
501 def get_min_and_max_number_of_clusters(
502 self,
503 threshold: float = 1.0,
504 ) -> Tuple[int, int]:
505 """
506 The main method used to get determine, for a clustering model that would not violate any constraints, the range of the possible clusters number.
507 Minimum number of cluster is estimated by the coloration of the `"CANNOT_LINK"` constraints graph.
508 Maximum number of cluster is defined by the number of `"MUST_LINK"` connected components.
509 The transitivity is taken into account, and the `threshold` parameters is used if constraints values are used in the constraints transitivity.
511 Args:
512 threshold (float, optional): The threshold used to define the transitivity link. Defaults to `1.0`.
514 Returns:
515 Tuple[int,int]: The minimum and the maximum possible clusters numbers (for a clustering model that would not violate any constraints).
516 """
518 # Get the `"MUST_LINK"` connected components.
519 list_of_connected_components: List[List[str]] = self.get_connected_components()
521 ###
522 ### 1. Estimation of minimum clusters number.
523 ###
525 # Get connected component ids.
526 list_of_connected_component_ids: List[str] = [component[0] for component in list_of_connected_components]
528 # Keep only components that have more that one `"CANNOT_LINK"` constraints.
529 list_of_linked_connected_components_ids: List[str] = [
530 component_id
531 for component_id in list_of_connected_component_ids
532 if len(self._constraints_transitivity[component_id]["CANNOT_LINK"].keys()) > 1 # noqa: WPS507
533 ]
535 # Get the `"CANNOT_LINK"` constraints.
536 list_of_cannot_link_constraints: List[Tuple[int, int]] = [
537 (i1, i2)
538 for i1, data_ID1 in enumerate(list_of_linked_connected_components_ids)
539 for i2, data_ID2 in enumerate(list_of_linked_connected_components_ids)
540 if (i1 < i2)
541 and ( # To get the complement, get all possible link that are not a `"CANNOT_LINK"`.
542 data_ID2 in self._constraints_transitivity[data_ID1]["CANNOT_LINK"].keys()
543 )
544 ]
546 # Create a networkx graph.
547 cannot_link_graph: nx.Graph = nx.Graph()
548 cannot_link_graph.add_nodes_from(list_of_connected_component_ids) # Add components id as nodes in the graph.
549 cannot_link_graph.add_edges_from(
550 list_of_cannot_link_constraints
551 ) # Add cannot link constraints as edges in the graph.
553 # Estimate the minimum clusters number by trying to colorate the `"CANNOT_LINK"` constraints graph.
554 # The lower bound has to be greater than 2.
555 estimation_of_minimum_clusters_number: int = max(
556 2,
557 1
558 + min(
559 max(nx.coloring.greedy_color(cannot_link_graph, strategy="largest_first").values()),
560 max(nx.coloring.greedy_color(cannot_link_graph, strategy="smallest_last").values()),
561 max(nx.coloring.greedy_color(cannot_link_graph, strategy="random_sequential").values()),
562 max(nx.coloring.greedy_color(cannot_link_graph, strategy="random_sequential").values()),
563 max(nx.coloring.greedy_color(cannot_link_graph, strategy="random_sequential").values()),
564 ),
565 )
567 ###
568 ### 2. Computation of maximum clusters number.
569 ###
571 # Determine the maximum clusters number with the number of `"MUST_LINK"` connected components.
572 maximum_clusters_number: int = len(list_of_connected_components)
574 # Return minimum and maximum.
575 return (estimation_of_minimum_clusters_number, maximum_clusters_number)
577 # ==============================================================================
578 # CONSTRAINTS TRANSITIVITY MANAGEMENT - GENERATE CONSTRAINTS TRANSITIVITY GRAPH
579 # ==============================================================================
580 def _generate_constraints_transitivity(
581 self,
582 ) -> None:
583 """
584 Generate `self._constraints_transitivity`, a constraints dictionary that takes into account the transitivity of constraints.
585 Suppose there is no inconsistency in `self._constraints_dictionary`.
586 It uses `Dict[str, None]` to simulate ordered sets.
587 """
589 # Reset constraints transitivity.
590 self._constraints_transitivity = {
591 data_ID: {
592 "MUST_LINK": {data_ID: None}, # Initialize MUST_LINK clusters constraints.
593 "CANNOT_LINK": {}, # Initialize CANNOT_LINK clusters constraints.
594 }
595 for data_ID in self._constraints_dictionary.keys()
596 }
598 for data_ID1 in self._constraints_dictionary.keys():
599 for data_ID2 in self._constraints_dictionary[data_ID1].keys():
600 # Get the constraint between `data_ID1` and `data_ID2`.
601 constraint = self._constraints_dictionary[data_ID1][data_ID2]
603 # Add the constraint transitivity if the constraint is not `None`.
604 if constraint is not None:
605 self._add_constraint_transitivity(
606 data_ID1=data_ID1,
607 data_ID2=data_ID2,
608 constraint_type=constraint[0],
609 )
611 # ==============================================================================
612 # CONSTRAINTS TRANSITIVITY MANAGEMENT - ADD CONSTRAINT TRANSITIVITY
613 # ==============================================================================
614 def _add_constraint_transitivity(
615 self,
616 data_ID1: str,
617 data_ID2: str,
618 constraint_type: str,
619 ) -> bool:
620 """
621 Add constraint transitivity in `self._constraints_transitivity` between `data_ID1` and `data_ID2` for constraint type `constraint_type`.
622 Suppose there is no inconsistency in `self._constraints_dictionary`.
624 Args:
625 data_ID1 (str): The first data ID that is concerned for this constraint transitivity addition.
626 data_ID2 (str): The second data ID that is concerned for this constraint transitivity addition.
627 constraint_type (str): The type of the constraint to add. The type have to be `"MUST_LINK"` or `"CANNOT_LINK"`.
629 Returns:
630 bool: `True` when the transitivity addition is done.
631 """
633 ###
634 ### Case 1 : `constraint_type` is `"MUST_LINK"`.
635 ###
636 if constraint_type == "MUST_LINK":
637 # Define new common set of `"MUST_LINK"` data IDs,
638 # by merging the sets of `"MUST_LINK"` data IDs for `data_ID1` and `data_ID2`.
639 new_MUST_LINK_common_set: Dict[str, None] = {
640 **self._constraints_transitivity[data_ID1]["MUST_LINK"],
641 **self._constraints_transitivity[data_ID2]["MUST_LINK"],
642 }
644 # Define new common set of `"CANNOT_LINK"` data IDs,
645 # by merging the sets of `"CANNOT_LINK"` data IDs for `data_ID1` and `data_ID2`.
646 new_CANNOT_LINK_common_set: Dict[str, None] = {
647 **self._constraints_transitivity[data_ID1]["CANNOT_LINK"],
648 **self._constraints_transitivity[data_ID2]["CANNOT_LINK"],
649 }
651 # For each data that are now similar to `data_ID1` and `data_ID2`...
652 for data_ID_ML in new_MUST_LINK_common_set.keys():
653 # ... affect the new set of `"MUST_LINK"` constraints...
654 self._constraints_transitivity[data_ID_ML]["MUST_LINK"] = new_MUST_LINK_common_set
655 # ... and affect the new set of `"CANNOT_LINK"` constraints.
656 self._constraints_transitivity[data_ID_ML]["CANNOT_LINK"] = new_CANNOT_LINK_common_set
658 # For each data that are now different to `data_ID1` and `data_ID2`...
659 for data_ID_CL in new_CANNOT_LINK_common_set.keys():
660 # ... affect the new set of `"CANNOT_LINK"` constraints.
661 self._constraints_transitivity[data_ID_CL]["CANNOT_LINK"] = {
662 **self._constraints_transitivity[data_ID_CL]["CANNOT_LINK"],
663 **new_MUST_LINK_common_set,
664 }
666 ###
667 ### Case 2 : `constraint_type` is `"CANNOT_LINK"`.
668 ###
669 else: # if constraint_type == "CANNOT_LINK":
670 # Define new common set of `"CANNOT_LINK"` data IDs for data IDs that are similar to `data_ID1`.
671 new_CANNOT_LINK_set_for_data_ID1: Dict[str, None] = {
672 **self._constraints_transitivity[data_ID1]["CANNOT_LINK"],
673 **self._constraints_transitivity[data_ID2]["MUST_LINK"],
674 }
676 # Define new common set of `"CANNOT_LINK"` data IDs for data IDs that are similar to `data_ID2`.
677 new_CANNOT_LINK_set_for_data_ID2: Dict[str, None] = {
678 **self._constraints_transitivity[data_ID2]["CANNOT_LINK"],
679 **self._constraints_transitivity[data_ID1]["MUST_LINK"],
680 }
682 # For each data that are similar to `data_ID1`...
683 for data_ID_like_data_ID1 in self._constraints_transitivity[data_ID1]["MUST_LINK"].keys():
684 # ... affect the new list of `"CANNOT_LINK"` constraints.
685 self._constraints_transitivity[data_ID_like_data_ID1]["CANNOT_LINK"] = new_CANNOT_LINK_set_for_data_ID1
687 # For each data that are similar to `data_ID2`...
688 for data_ID_like_data_ID2 in self._constraints_transitivity[data_ID2]["MUST_LINK"].keys():
689 # ... affect the new list of `"CANNOT_LINK"` constraints.
690 self._constraints_transitivity[data_ID_like_data_ID2]["CANNOT_LINK"] = new_CANNOT_LINK_set_for_data_ID2
692 # Return `True`
693 return True
695 # ==============================================================================
696 # CONSTRAINTS CONFLICT - GET INVOLVED DATA IDS IN A CONFLICT
697 # ==============================================================================
698 def get_list_of_involved_data_IDs_in_a_constraint_conflict(
699 self,
700 data_ID1: str,
701 data_ID2: str,
702 constraint_type: str,
703 ) -> Optional[List[str]]:
704 """
705 Get all data IDs involved in a constraints conflict.
707 Args:
708 data_ID1 (str): The first data ID involved in the constraint_conflit.
709 data_ID2 (str): The second data ID involved in the constraint_conflit.
710 constraint_type (str): The constraint that create a conflict. The constraints can be `"MUST_LINK"` or `"CANNOT_LINK"`.
712 Raises:
713 ValueError: if `data_ID1`, `data_ID2`, `constraint_type` are not managed.
715 Returns:
716 Optional[List[str]]: The list of data IDs that are involved in the conflict. It matches data IDs from connected components of `data_ID1` and `data_ID2`.
717 """
719 # If `data_ID1` is not in the data IDs that are currently managed, then raises a `ValueError`.
720 if data_ID1 not in self._constraints_dictionary.keys():
721 raise ValueError("The `data_ID1` `'" + str(data_ID1) + "'` is not managed.")
723 # If `data_ID2` is not in the data IDs that are currently managed, then raises a `ValueError`.
724 if data_ID2 not in self._constraints_dictionary.keys():
725 raise ValueError("The `data_ID2` `'" + str(data_ID2) + "'` is not managed.")
727 # If the `constraint_conflict` is not in `self._allowed_constraint_types`, then raises a `ValueError`.
728 if constraint_type not in self._allowed_constraint_types:
729 raise ValueError(
730 "The `constraint_type` `'"
731 + str(constraint_type)
732 + "'` is not managed. Allowed constraints types are : `"
733 + str(self._allowed_constraint_types)
734 + "`."
735 )
737 # Case of conflict (after trying to add a constraint different from the inferred constraint).
738 if self.get_inferred_constraint(
739 data_ID1=data_ID1, data_ID2=data_ID2
740 ) is not None and constraint_type != self.get_inferred_constraint(data_ID1=data_ID1, data_ID2=data_ID2):
741 return [
742 data_ID
743 for connected_component in self.get_connected_components() # Get involved components.
744 for data_ID in connected_component # Get data IDs from these components.
745 if (data_ID1 in connected_component or data_ID2 in connected_component)
746 ]
748 # Case of no conflict.
749 return None
751 # ==============================================================================
752 # SERIALIZATION - TO JSON
753 # ==============================================================================
754 def to_json(
755 self,
756 filepath: str = "./constraint_manager.json",
757 ) -> bool:
758 """
759 The main method used to serialize the constraints manager object into a JSON file.
761 Args:
762 filepath (str): The path where to serialize the constraints manager object.
764 Returns:
765 bool: `True` if the serialization is done.
766 """
768 # Serialize constraints manager.
769 with open(filepath, "w") as fileobject:
770 json.dump(
771 {
772 "list_of_managed_data_IDs": self.get_list_of_managed_data_IDs(),
773 "list_of_added_constraints": [
774 {
775 "data_ID1": data_ID1,
776 "data_ID2": data_ID2,
777 "constraint_type": constraint[0],
778 "constraints_value": 1.0, # Binary constraints manager, so force 1.0.
779 }
780 for data_ID1 in self._constraints_dictionary.keys()
781 for data_ID2, constraint in self._constraints_dictionary[data_ID1].items()
782 if (constraint is not None)
783 ],
784 },
785 fileobject,
786 indent=1,
787 )
789 # Return.
790 return True
793# ==============================================================================
794# SERIALIZATION - FROM JSON
795# ==============================================================================
796def load_constraints_manager_from_json(
797 filepath: str,
798) -> BinaryConstraintsManager:
799 """
800 The main method used initialize a constraints manager from a deserialized one.
802 Args:
803 filepath (str): The path where is the deserialized constraints manager object.
805 Returns:
806 BinaryConstraintsManager: The deserialized constraints manager.
807 """
809 # Deserialize constraints manager attributes.
810 with open(filepath, "r") as fileobject:
811 attributes_from_json: Dict[str, Any] = json.load(fileobject)
812 # list_of_managed_data_IDs: List[str] = attributes_from_json["list_of_managed_data_IDs"]
813 # list_of_added_constraints: List[Dict[str, Any]] = attributes_from_json["list_of_added_constraints"]
815 # Initialize blank constraints manager.
816 constraints_manager: BinaryConstraintsManager = BinaryConstraintsManager(
817 list_of_data_IDs=attributes_from_json["list_of_managed_data_IDs"],
818 )
820 # Load from json.
821 for constraint in attributes_from_json["list_of_added_constraints"]:
822 constraints_manager.add_constraint(
823 data_ID1=constraint["data_ID1"],
824 data_ID2=constraint["data_ID2"],
825 constraint_type=constraint["constraint_type"],
826 # constraint_value=constraint["constraint_value"], # Binary constraints manager, so force 1.0.
827 )
829 # Return the constraints manager.
830 return constraints_manager