From 031e90f83fda90ad96fe0fea444a349d20af6b6e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Sat, 24 Aug 2024 12:37:12 -0600 Subject: [PATCH 1/4] Issue #22: Store associations between components --- docs/explanation/system.md | 38 ++++++++++ src/infrasys/component_associations.py | 99 ++++++++++++++++++++++++++ src/infrasys/component_manager.py | 59 +++++++++++++-- src/infrasys/exceptions.py | 6 +- src/infrasys/system.py | 24 +++++++ tests/test_system.py | 50 +++++++++++++ 6 files changed, 268 insertions(+), 8 deletions(-) create mode 100644 src/infrasys/component_associations.py diff --git a/docs/explanation/system.md b/docs/explanation/system.md index ceddba3..283480d 100644 --- a/docs/explanation/system.md +++ b/docs/explanation/system.md @@ -58,3 +58,41 @@ Component.model_json_schema() - `infrasys` includes some basic quantities in [infrasys.quantities](#quantity-api). - Pint will automatically convert a list or list of lists of values into a `numpy.ndarray`. infrasys will handle serialization/de-serialization of these types. + + +### Component Associations +The system tracks associations between components in order to optimize lookups. + +For example, suppose a Generator class has a field for a Bus. It is trivial to find a generator's +bus. However, if you need to find all generators connected to specific bus, you would have to +traverse all generators in the system and check their bus values. + +Every time you add a component to a system, `infrasys` inspects the component type for composed +components. It checks for directly connected components, such as `Generator.bus`, and lists of +components. (It does not inspect other composite data structures like dictionaries.) + +`infrasys` stores these component associations in a SQLite table and so lookups are fast. + +Here is how to complete this example: + +```python +generators = system.list_parent_components(bus) +``` + +If you only want to find specific types, you can pass that type as well. +```python +generators = system.list_parent_components(bus, component_type=Generator) +``` + +**Warning**: There is one potentially problematic case. + +Suppose that you have a system with generators and buses and then reassign the buses, as in +``` +gen1.bus = other_bus +``` + +`infrasys` cannot detect such reassignments and so the component associations will be incorrect. +You must inform `infrasys` to rebuild its internal table. +``` +system.rebuild_component_associations() +``` diff --git a/src/infrasys/component_associations.py b/src/infrasys/component_associations.py new file mode 100644 index 0000000..c635ce6 --- /dev/null +++ b/src/infrasys/component_associations.py @@ -0,0 +1,99 @@ +import sqlite3 +from typing import Optional, Type +from uuid import UUID + +from loguru import logger + +from infrasys.component import Component +from infrasys.utils.sqlite import execute + + +class ComponentAssociations: + """Stores associations between components. Allows callers to quickly find components composed + by other components, such as the generator to which a bus is connected.""" + + TABLE_NAME = "component_associations" + + def __init__(self) -> None: + # This uses a different database because it is not persisted when the system + # is saved to files. It will be rebuilt during de-serialization. + self._con = sqlite3.connect(":memory:") + self._create_metadata_table() + + def _create_metadata_table(self): + schema = [ + "id INTEGER PRIMARY KEY", + "component_uuid TEXT", + "component_type TEXT", + "attached_component_uuid TEXT", + "attached_component_type TEXT", + ] + schema_text = ",".join(schema) + cur = self._con.cursor() + execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})") + execute(cur, f"CREATE INDEX by_ac_uuid ON {self.TABLE_NAME}(attached_component_uuid)") + self._con.commit() + logger.debug("Created in-memory component associations table") + + def add(self, *components: Component): + """Store an association between each component and directly attached subcomponents. + + - Inspects the type of each field of each component's type. Looks for subtypes of + Component and lists of subtypes of Component. + - Does not consider component fields that are dictionaries or other data structures. + """ + rows = [] + for component in components: + for field in type(component).model_fields: + val = getattr(component, field) + if isinstance(val, Component): + rows.append(self._make_row(component, val)) + elif isinstance(val, list) and val and isinstance(val[0], Component): + for item in val: + rows.append(self._make_row(component, item)) + # FUTURE: consider supporting dictionaries like these examples: + # dict[str, Component] + # dict[str, [Component]] + + if rows: + self._insert_rows(rows) + + def clear(self) -> None: + """Clear all component associations.""" + execute(self._con.cursor(), f"DELETE FROM {self.TABLE_NAME}") + logger.info("Cleared all component associations.") + + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[UUID]: + """Return a list of all component UUIDS that compose this component. + For example, return all components connected to a bus. + """ + where_clause = "WHERE attached_component_uuid = ?" + if component_type is None: + params = [str(component.uuid)] + else: + params = [str(component.uuid), component_type.__name__] + where_clause += " AND component_type = ?" + query = f"SELECT component_uuid FROM {self.TABLE_NAME} {where_clause}" + cur = self._con.cursor() + return [UUID(x[0]) for x in execute(cur, query, params)] + + def _insert_rows(self, rows: list[tuple]) -> None: + cur = self._con.cursor() + placeholder = ",".join(["?"] * len(rows[0])) + query = f"INSERT INTO {self.TABLE_NAME} VALUES({placeholder})" + try: + cur.executemany(query, rows) + finally: + self._con.commit() + + @staticmethod + def _make_row(component: Component, attached_component: Component): + return ( + None, + str(component.uuid), + type(component).__name__, + str(attached_component.uuid), + type(attached_component).__name__, + ) diff --git a/src/infrasys/component_manager.py b/src/infrasys/component_manager.py index 88e7c2a..8e5edb1 100644 --- a/src/infrasys/component_manager.py +++ b/src/infrasys/component_manager.py @@ -1,13 +1,19 @@ """Manages components""" -from collections import defaultdict import itertools -from typing import Any, Callable, Iterable, Type +from collections import defaultdict +from typing import Any, Callable, Iterable, Optional, Type from uuid import UUID from loguru import logger from infrasys.component import Component -from infrasys.exceptions import ISAlreadyAttached, ISNotStored, ISOperationNotAllowed +from infrasys.component_associations import ComponentAssociations +from infrasys.exceptions import ( + ISAlreadyAttached, + ISNotStored, + ISOperationNotAllowed, + ISInvalidParameter, +) from infrasys.models import make_label, get_class_and_name_from_label @@ -23,6 +29,7 @@ def __init__( self._components_by_uuid: dict[UUID, Component] = {} self._uuid = uuid self._auto_add_composed_components = auto_add_composed_components + self._associations = ComponentAssociations() @property def auto_add_composed_components(self) -> bool: @@ -34,7 +41,7 @@ def auto_add_composed_components(self, val: bool) -> None: """Set auto_add_composed_components.""" self._auto_add_composed_components = val - def add(self, *args: Component, deserialization_in_progress=False) -> None: + def add(self, *components: Component, deserialization_in_progress=False) -> None: """Add one or more components to the system. Raises @@ -42,9 +49,15 @@ def add(self, *args: Component, deserialization_in_progress=False) -> None: ISAlreadyAttached Raised if a component is already attached to a system. """ - for component in args: + if not components: + msg = "add_associations requires at least one component" + raise ISInvalidParameter(msg) + + for component in components: self._add(component, deserialization_in_progress) + self._associations.add(*components) + def get(self, component_type: Type[Component], name: str) -> Any: """Return the component with the passed type and name. @@ -167,8 +180,22 @@ def iter_all(self) -> Iterable[Any]: """Return an iterator over all components.""" return self._components_by_uuid.values() + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that compose this component.""" + return [ + self.get_by_uuid(x) + for x in self._associations.list_parent_components( + component, component_type=component_type + ) + ] + def to_records( - self, component_type: Type[Component], filter_func: Callable | None = None, **kwargs + self, + component_type: Type[Component], + filter_func: Callable | None = None, + **kwargs, ) -> Iterable[dict]: """Return a dictionary representation of the requested components. @@ -207,6 +234,15 @@ def remove(self, component: Component) -> Any: msg = f"{component.label} is not stored" raise ISNotStored(msg) + attached_components = self.list_parent_components(component) + if attached_components: + label = ", ".join((x.label for x in attached_components)) + msg = ( + f"Cannot remove {component.label} because it is attached to these components: " + f"{label}" + ) + raise ISOperationNotAllowed(msg) + container = self._components[component_type][component.name] for i, comp in enumerate(container): if comp.uuid == component.uuid: @@ -259,6 +295,14 @@ def change_uuid(self, component: Component) -> None: msg = "change_component_uuid" raise NotImplementedError(msg) + def rebuild_component_associations(self) -> None: + """Clear the component associations and rebuild the table. This may be necessary + if a user reassigns connected components that are part of a system. + """ + self._associations.clear() + self._associations.add(*self.iter_all()) + logger.info("Rebuilt all component associations.") + def update( self, component_type: Type[Component], @@ -292,6 +336,7 @@ def _add(self, component: Component, deserialization_in_progress: bool) -> None: self._components[cls][name].append(component) self._components_by_uuid[component.uuid] = component + logger.debug("Added {} to the system", component.label) def _check_component_addition(self, component: Component) -> None: @@ -303,7 +348,7 @@ def _check_component_addition(self, component: Component) -> None: self._handle_composed_component(val) # Recurse. self._check_component_addition(val) - if isinstance(val, list) and val and isinstance(val[0], Component): + elif isinstance(val, list) and val and isinstance(val[0], Component): for item in val: self._handle_composed_component(item) # Recurse. diff --git a/src/infrasys/exceptions.py b/src/infrasys/exceptions.py index 8254ac2..0e0c120 100644 --- a/src/infrasys/exceptions.py +++ b/src/infrasys/exceptions.py @@ -18,13 +18,17 @@ class ISFileExists(ISBaseException): class ISConflictingArguments(ISBaseException): - """Raised if the arguments are conflict.""" + """Raised if the arguments conflict.""" class ISConflictingSystem(ISBaseException): """Raised if the system has conflicting values.""" +class ISInvalidParameter(ISBaseException): + """Raised if a parameter is invalid.""" + + class ISNotStored(ISBaseException): """Raised if the requested object is not stored.""" diff --git a/src/infrasys/system.py b/src/infrasys/system.py index f40437e..b595a46 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -597,6 +597,22 @@ def get_component_types(self) -> Iterable[Type[Component]]: """ return self._component_mgr.get_types() + def list_parent_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that compose this component. + + An example usage is where you need to find all components connected to a bus and the Bus + class does not contain that information. The system tracks these connections internally + and can find those components quickly. + + Examples + -------- + >>> components = system.list_parent_components(bus) + >>> print(f"These components are connected to {bus.label}: ", " ".join(components)) + """ + return self._component_mgr.list_parent_components(component, component_type=component_type) + def list_components_by_name(self, component_type: Type[Component], name: str) -> list[Any]: """Return all components that match component_type and name. @@ -625,6 +641,12 @@ def iter_all_components(self) -> Iterable[Any]: """ return self._component_mgr.iter_all() + def rebuild_component_associations(self) -> None: + """Clear the component associations and rebuild the table. This may be necessary + if a user reassigns connected components that are part of a system. + """ + self._component_mgr.rebuild_component_associations() + def remove_component(self, component: Component) -> Any: """Remove the component from the system and return it. @@ -636,6 +658,8 @@ def remove_component(self, component: Component) -> Any: ------ ISNotStored Raised if the component is not stored in the system. + ISOperationNotAllowed + Raised if the other components hold references to this component. Examples -------- diff --git a/tests/test_system.py b/tests/test_system.py index 8dbb96a..4a7cbfc 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -8,6 +8,7 @@ from infrasys.exceptions import ( ISAlreadyAttached, + ISInvalidParameter, ISNotStored, ISOperationNotAllowed, ISConflictingArguments, @@ -31,6 +32,8 @@ def test_system(): gen = SimpleGenerator(name="test-gen", active_power=1.0, rating=1.0, bus=bus, available=True) subsystem = SimpleSubsystem(name="test-subsystem", generators=[gen]) system.add_components(geo, bus, gen, subsystem) + with pytest.raises(ISInvalidParameter): + system.add_components() gen2 = system.get_component(SimpleGenerator, "test-gen") assert gen2 is gen @@ -141,6 +144,53 @@ def test_get_components_multiple_types(): assert len(selected_components) == 2 # 1 SimpleGenerator + 1 RenewableGenerator +def test_component_associations(tmp_path): + system = SimpleSystem() + for i in range(3): + geo = Location(x=i, y=i + 1) + bus = SimpleBus(name=f"bus{i}", voltage=1.1, coordinates=geo) + gen1 = SimpleGenerator( + name=f"gen{i}a", active_power=1.0, rating=1.0, bus=bus, available=True + ) + gen2 = SimpleGenerator( + name=f"gen{i}b", active_power=1.0, rating=1.0, bus=bus, available=True + ) + subsystem = SimpleSubsystem(name=f"test-subsystem{i}", generators=[gen1, gen2]) + system.add_components(geo, bus, gen1, gen2, subsystem) + + def check_attached_components(my_sys): + for i in range(3): + bus = my_sys.get_component(SimpleBus, f"bus{i}") + gen1 = my_sys.get_component(SimpleGenerator, f"gen{i}a") + gen2 = my_sys.get_component(SimpleGenerator, f"gen{i}b") + attached = my_sys.list_parent_components(bus, component_type=SimpleGenerator) + assert len(attached) == 2 + labels = {gen1.label, gen2.label} + for component in attached: + assert component.label in labels + attached_subsystems = my_sys.list_parent_components(component) + assert len(attached_subsystems) == 1 + assert attached_subsystems[0].name == f"test-subsystem{i}" + assert not my_sys.list_parent_components(attached_subsystems[0]) + + for component in (bus, gen1, gen2): + with pytest.raises(ISOperationNotAllowed): + my_sys.remove_component(component) + + check_attached_components(system) + system._component_mgr._associations.clear() + for component in system.iter_all_components(): + assert not system.list_parent_components(component) + + system.rebuild_component_associations() + check_attached_components(system) + + save_dir = tmp_path / "test_system" + system.save(save_dir) + system2 = SimpleSystem.from_json(save_dir / "system.json") + check_attached_components(system2) + + def test_time_series_attach_from_array(): system = SimpleSystem() bus = SimpleBus(name="test-bus", voltage=1.1) From 0a2a3a54fa28be1007b82f4289ee7c1375bd8395 Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Mon, 26 Aug 2024 18:19:52 -0600 Subject: [PATCH 2/4] Allow force-removal of components --- src/infrasys/component_associations.py | 3 --- src/infrasys/component_manager.py | 15 +++++++++------ src/infrasys/system.py | 15 ++++++++++++--- tests/test_system.py | 8 ++++++++ 4 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/infrasys/component_associations.py b/src/infrasys/component_associations.py index c635ce6..146bcaf 100644 --- a/src/infrasys/component_associations.py +++ b/src/infrasys/component_associations.py @@ -51,9 +51,6 @@ def add(self, *components: Component): elif isinstance(val, list) and val and isinstance(val[0], Component): for item in val: rows.append(self._make_row(component, item)) - # FUTURE: consider supporting dictionaries like these examples: - # dict[str, Component] - # dict[str, [Component]] if rows: self._insert_rows(rows) diff --git a/src/infrasys/component_manager.py b/src/infrasys/component_manager.py index 8e5edb1..9740985 100644 --- a/src/infrasys/component_manager.py +++ b/src/infrasys/component_manager.py @@ -216,7 +216,7 @@ def to_records( subcomponent[i] = sub_component_.label yield data - def remove(self, component: Component) -> Any: + def remove(self, component: Component, force: bool = False) -> Any: """Remove the component from the system and return it. Notes @@ -237,11 +237,14 @@ def remove(self, component: Component) -> Any: attached_components = self.list_parent_components(component) if attached_components: label = ", ".join((x.label for x in attached_components)) - msg = ( - f"Cannot remove {component.label} because it is attached to these components: " - f"{label}" - ) - raise ISOperationNotAllowed(msg) + if force: + logger.warning("Remove {} even though it is attached to these components: {label}") + else: + msg = ( + f"Cannot remove {component.label} because it is attached to these components: " + f"{label}" + ) + raise ISOperationNotAllowed(msg) container = self._components[component_type][component.name] for i, comp in enumerate(container): diff --git a/src/infrasys/system.py b/src/infrasys/system.py index b595a46..70bf600 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -606,6 +606,12 @@ def list_parent_components( class does not contain that information. The system tracks these connections internally and can find those components quickly. + Parameters + ---------- + component: Component + component_type: Optional[Type[Component]] + Filter the returned list to components of this type. + Examples -------- >>> components = system.list_parent_components(bus) @@ -647,19 +653,22 @@ def rebuild_component_associations(self) -> None: """ self._component_mgr.rebuild_component_associations() - def remove_component(self, component: Component) -> Any: + def remove_component(self, component: Component, force: bool = False) -> Any: """Remove the component from the system and return it. Parameters ---------- component : Component + force : bool + If True, remove the component even if other components hold references to this + component. Defaults to False. Raises ------ ISNotStored Raised if the component is not stored in the system. ISOperationNotAllowed - Raised if the other components hold references to this component. + Raised if the other components hold references to this component and force=False. Examples -------- @@ -675,7 +684,7 @@ def remove_component(self, component: Component) -> Any: variable_name=metadata.variable_name, **metadata.user_attributes, ) - component = self._component_mgr.remove(component) + component = self._component_mgr.remove(component, force=force) def remove_component_by_name(self, component_type: Type[Component], name: str) -> Any: """Remove the component with component_type and name from the system and return it. diff --git a/tests/test_system.py b/tests/test_system.py index 4a7cbfc..4d43f29 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -190,6 +190,14 @@ def check_attached_components(my_sys): system2 = SimpleSystem.from_json(save_dir / "system.json") check_attached_components(system2) + bus = system2.get_component(SimpleBus, "bus1") + with pytest.raises(ISOperationNotAllowed): + system2.remove_component(bus) + system2.remove_component(bus, force=True) + gen = system2.get_component(SimpleGenerator, "gen1a") + with pytest.raises(ISNotStored): + system2.get_component(SimpleBus, gen.bus.name) + def test_time_series_attach_from_array(): system = SimpleSystem() From d8c9997556dc54753d2791b22a4e5bff2667c70e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Tue, 27 Aug 2024 07:56:50 -0600 Subject: [PATCH 3/4] Cascade down by default in remove_component --- src/infrasys/component_associations.py | 32 +++++++++++- src/infrasys/component_manager.py | 69 ++++++++++++++++++-------- src/infrasys/system.py | 59 +++++++++++++++++++--- tests/test_system.py | 14 ++++-- 4 files changed, 143 insertions(+), 31 deletions(-) diff --git a/src/infrasys/component_associations.py b/src/infrasys/component_associations.py index 146bcaf..433e7a8 100644 --- a/src/infrasys/component_associations.py +++ b/src/infrasys/component_associations.py @@ -31,7 +31,10 @@ def _create_metadata_table(self): schema_text = ",".join(schema) cur = self._con.cursor() execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})") - execute(cur, f"CREATE INDEX by_ac_uuid ON {self.TABLE_NAME}(attached_component_uuid)") + execute( + cur, + f"CREATE INDEX by_c_uuid ON {self.TABLE_NAME}(component_uuid, attached_component_uuid)", + ) self._con.commit() logger.debug("Created in-memory component associations table") @@ -60,6 +63,22 @@ def clear(self) -> None: execute(self._con.cursor(), f"DELETE FROM {self.TABLE_NAME}") logger.info("Cleared all component associations.") + def list_child_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[UUID]: + """Return a list of all component UUIDS that this component composes. + For example, return the bus attached to a generator. + """ + where_clause = "WHERE component_uuid = ?" + if component_type is None: + params = [str(component.uuid)] + else: + params = [str(component.uuid), component_type.__name__] + where_clause += " AND attached_component_type = ?" + query = f"SELECT attached_component_uuid FROM {self.TABLE_NAME} {where_clause}" + cur = self._con.cursor() + return [UUID(x[0]) for x in execute(cur, query, params)] + def list_parent_components( self, component: Component, component_type: Optional[Type[Component]] = None ) -> list[UUID]: @@ -76,6 +95,17 @@ def list_parent_components( cur = self._con.cursor() return [UUID(x[0]) for x in execute(cur, query, params)] + def remove(self, component: Component) -> None: + """Delete all rows with this component.""" + query = f""" + DELETE + FROM {self.TABLE_NAME} + WHERE component_uuid = ? OR attached_component_uuid = ? + """ + params = [str(component.uuid), str(component.uuid)] + execute(self._con.cursor(), query, params) + logger.debug("Removed all associations with component {}", component.label) + def _insert_rows(self, rows: list[tuple]) -> None: cur = self._con.cursor() placeholder = ",".join(["?"] * len(rows[0])) diff --git a/src/infrasys/component_manager.py b/src/infrasys/component_manager.py index 9740985..36776b4 100644 --- a/src/infrasys/component_manager.py +++ b/src/infrasys/component_manager.py @@ -128,6 +128,10 @@ def get_types(self) -> Iterable[Type[Component]]: """Return an iterable of all stored types.""" return self._components.keys() + def has_component(self, component) -> bool: + """Return True if the component is attached.""" + return component.uuid in self._components_by_uuid + def iter( self, *component_types: Type[Component], filter_func: Callable | None = None ) -> Iterable[Any]: @@ -180,6 +184,17 @@ def iter_all(self) -> Iterable[Any]: """Return an iterator over all components.""" return self._components_by_uuid.values() + def list_child_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that this component composes.""" + return [ + self.get_by_uuid(x) + for x in self._associations.list_child_components( + component, component_type=component_type + ) + ] + def list_parent_components( self, component: Component, component_type: Optional[Type[Component]] = None ) -> list[Component]: @@ -216,7 +231,7 @@ def to_records( subcomponent[i] = sub_component_.label yield data - def remove(self, component: Component, force: bool = False) -> Any: + def remove(self, component: Component, cascade_down: bool = True, force: bool = False) -> Any: """Remove the component from the system and return it. Notes @@ -227,40 +242,54 @@ def remove(self, component: Component, force: bool = False) -> Any: component_type = type(component) # The system method should have already performed the check, but for completeness in case # someone calls it directly, check here. - if ( - component_type not in self._components - or component.name not in self._components[component_type] - ): + key = component.name or component.label + if component_type not in self._components or key not in self._components[component_type]: msg = f"{component.label} is not stored" raise ISNotStored(msg) - attached_components = self.list_parent_components(component) - if attached_components: - label = ", ".join((x.label for x in attached_components)) - if force: - logger.warning("Remove {} even though it is attached to these components: {label}") - else: - msg = ( - f"Cannot remove {component.label} because it is attached to these components: " - f"{label}" - ) - raise ISOperationNotAllowed(msg) - - container = self._components[component_type][component.name] + self._check_parent_components_for_remove(component, force) + container = self._components[component_type][key] for i, comp in enumerate(container): if comp.uuid == component.uuid: container.pop(i) - if not self._components[component_type][component.name]: - self._components[component_type].pop(component.name) + if not self._components[component_type][key]: + self._components[component_type].pop(key) self._components_by_uuid.pop(component.uuid) if not self._components[component_type]: self._components.pop(component_type) logger.debug("Removed component {}", component.label) + if cascade_down: + child_components = self._associations.list_child_components(component) + else: + child_components = [] + self._associations.remove(component) + for child_uuid in child_components: + child = self.get_by_uuid(child_uuid) + parent_components = self.list_parent_components(child) + if not parent_components: + self.remove(child, cascade_down=cascade_down, force=force) return msg = f"Component {component.label} is not stored" raise ISNotStored(msg) + def _check_parent_components_for_remove(self, component: Component, force: bool) -> None: + parent_components = self.list_parent_components(component) + if parent_components: + parent_labels = ", ".join((x.label for x in parent_components)) + if force: + logger.warning( + "Remove {} even though it is attached to these components: {}", + component.label, + parent_labels, + ) + else: + msg = ( + f"Cannot remove {component.label} because it is attached to these components: " + f"{parent_labels}" + ) + raise ISOperationNotAllowed(msg) + def copy( self, component: Component, diff --git a/src/infrasys/system.py b/src/infrasys/system.py index 70bf600..affef99 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -597,6 +597,27 @@ def get_component_types(self) -> Iterable[Type[Component]]: """ return self._component_mgr.get_types() + def has_component(self, component) -> bool: + """Return True if the component is attached.""" + return self._component_mgr.has_component(component) + + def list_child_components( + self, component: Component, component_type: Optional[Type[Component]] = None + ) -> list[Component]: + """Return a list of all components that this component composes. + + Parameters + ---------- + component: Component + component_type: Optional[Type[Component]] + Filter the returned list to components of this type. + + See Also + -------- + list_parent_components + """ + return self._component_mgr.list_child_components(component, component_type=component_type) + def list_parent_components( self, component: Component, component_type: Optional[Type[Component]] = None ) -> list[Component]: @@ -616,6 +637,10 @@ class does not contain that information. The system tracks these connections int -------- >>> components = system.list_parent_components(bus) >>> print(f"These components are connected to {bus.label}: ", " ".join(components)) + + See Also + -------- + list_child_components """ return self._component_mgr.list_parent_components(component, component_type=component_type) @@ -653,12 +678,18 @@ def rebuild_component_associations(self) -> None: """ self._component_mgr.rebuild_component_associations() - def remove_component(self, component: Component, force: bool = False) -> Any: + def remove_component( + self, component: Component, cascade_down: bool = True, force: bool = False + ) -> Any: """Remove the component from the system and return it. Parameters ---------- component : Component + cascade_down : bool + If True, remove all child components if they have no other parents. Defaults to True. + For example, if a generator has a bus, no other component holds a reference to that + bus, and you call remove_component on that generator, the bus will get removed as well. force : bool If True, remove the component even if other components hold references to this component. Defaults to False. @@ -684,15 +715,25 @@ def remove_component(self, component: Component, force: bool = False) -> Any: variable_name=metadata.variable_name, **metadata.user_attributes, ) - component = self._component_mgr.remove(component, force=force) + component = self._component_mgr.remove(component, cascade_down=cascade_down, force=force) - def remove_component_by_name(self, component_type: Type[Component], name: str) -> Any: + def remove_component_by_name( + self, + component_type: Type[Component], + name: str, + cascade_down: bool = True, + force: bool = False, + ) -> Any: """Remove the component with component_type and name from the system and return it. Parameters ---------- component_type : Type name : str + cascade_down : bool + Refer :meth:`remove_component`. + force : bool + Refer :meth:`remove_component`. Raises ------ @@ -706,14 +747,20 @@ def remove_component_by_name(self, component_type: Type[Component], name: str) - >>> generators = system.remove_by_name(Generator, "gen1") """ component = self.get_component(component_type, name) - return self.remove_component(component) + return self.remove_component(component, cascade_down=cascade_down, force=force) - def remove_component_by_uuid(self, uuid: UUID) -> Any: + def remove_component_by_uuid( + self, uuid: UUID, cascade_down: bool = True, force: bool = False + ) -> Any: """Remove the component with uuid from the system and return it. Parameters ---------- uuid : UUID + cascade_down : bool + Refer :meth:`remove_component`. + force : bool + Refer :meth:`remove_component`. Raises ------ @@ -726,7 +773,7 @@ def remove_component_by_uuid(self, uuid: UUID) -> Any: >>> generator = system.remove_component_by_uuid(uuid) """ component = self.get_component_by_uuid(uuid) - return self.remove_component(component) + return self.remove_component(component, cascade_down=cascade_down, force=force) def update_components( self, diff --git a/tests/test_system.py b/tests/test_system.py index 4d43f29..8fdd1c7 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -172,6 +172,7 @@ def check_attached_components(my_sys): assert len(attached_subsystems) == 1 assert attached_subsystems[0].name == f"test-subsystem{i}" assert not my_sys.list_parent_components(attached_subsystems[0]) + assert my_sys.list_child_components(component) == [bus] for component in (bus, gen1, gen2): with pytest.raises(ISOperationNotAllowed): @@ -511,16 +512,19 @@ def test_deepcopy_component(simple_system_with_time_series: SimpleSystem): assert gen2.bus is not gen1.bus -@pytest.mark.parametrize("in_memory", [True, False]) -def test_remove_component(in_memory): +@pytest.mark.parametrize("inputs", [(True, False), (True, False)]) +def test_remove_component(inputs): + in_memory, cascade_down = inputs system = SimpleSystem( name="test-system", auto_add_composed_components=True, time_series_in_memory=in_memory, ) gen1 = SimpleGenerator.example() + bus = gen1.bus system.add_components(gen1) gen2 = system.copy_component(gen1, name="gen2", attach=True) + assert gen2.bus is bus variable_name = "active_power" length = 8784 data = range(length) @@ -529,11 +533,13 @@ def test_remove_component(in_memory): ts = SingleTimeSeries.from_array(data, variable_name, start, resolution) system.add_time_series(ts, gen1, gen2) - system.remove_component_by_name(type(gen1), gen1.name) + system.remove_component_by_name(type(gen1), gen1.name, cascade_down=cascade_down) + assert system.has_component(bus) assert not system.has_time_series(gen1) assert system.has_time_series(gen2) - system.remove_component_by_uuid(gen2.uuid) + system.remove_component_by_uuid(gen2.uuid, cascade_down=cascade_down) + assert system.has_component(bus) != cascade_down assert not system.has_time_series(gen2) with pytest.raises(ISNotStored): From 85d33085102b5292abc0669964104d1232f5e10e Mon Sep 17 00:00:00 2001 From: Daniel Thom Date: Tue, 27 Aug 2024 13:12:29 -0600 Subject: [PATCH 4/4] Add test --- tests/test_system.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_system.py b/tests/test_system.py index 8fdd1c7..8c76b47 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -173,6 +173,7 @@ def check_attached_components(my_sys): assert attached_subsystems[0].name == f"test-subsystem{i}" assert not my_sys.list_parent_components(attached_subsystems[0]) assert my_sys.list_child_components(component) == [bus] + assert my_sys.list_child_components(component, component_type=SimpleBus) == [bus] for component in (bus, gen1, gen2): with pytest.raises(ISOperationNotAllowed):