diff --git a/src/infrasys/system.py b/src/infrasys/system.py index deaa055..f40437e 100644 --- a/src/infrasys/system.py +++ b/src/infrasys/system.py @@ -2,6 +2,7 @@ import json import shutil +import sqlite3 from operator import itemgetter from collections import defaultdict from datetime import datetime @@ -33,16 +34,20 @@ ) from infrasys.time_series_manager import TimeSeriesManager, TIME_SERIES_KWARGS from infrasys.time_series_models import SingleTimeSeries, TimeSeriesData, TimeSeriesMetadata +from infrasys.utils.sqlite import backup, create_in_memory_db, restore class System: """Implements behavior for systems""" + DB_FILENAME = "time_series_metadata.db" + def __init__( self, name: Optional[str] = None, description: Optional[str] = None, auto_add_composed_components: bool = False, + con: Optional[sqlite3.Connection] = None, time_series_manager: Optional[TimeSeriesManager] = None, uuid: Optional[UUID] = None, **kwargs: Any, @@ -60,6 +65,8 @@ def __init__( The default behavior is to raise an ISOperationNotAllowed when this condition occurs. This handles values that are components, such as generator.bus, and lists of components, such as subsystem.generators, but not any other form of nested components. + con : None | sqlite3.Connection + Users should not pass this. De-serialization (from_json) will pass a Connection. time_series_manager : None | TimeSeriesManager Users should not pass this. De-serialization (from_json) will pass a constructed manager. @@ -79,8 +86,11 @@ def __init__( self._name = name self._description = description self._component_mgr = ComponentManager(self._uuid, auto_add_composed_components) + self._con = con or create_in_memory_db() time_series_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS} - self._time_series_mgr = time_series_manager or TimeSeriesManager(**time_series_kwargs) + self._time_series_mgr = time_series_manager or TimeSeriesManager( + self._con, **time_series_kwargs + ) self._data_format_version: Optional[str] = None # Note to devs: if you add new fields, add support in to_json/from_json as appropriate. @@ -127,10 +137,9 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None) msg = f"{filename=} already exists. Choose a different path or set overwrite=True." raise ISFileExists(msg) - if not filename.parent.exists(): - filename.parent.mkdir() - + filename.parent.mkdir(exist_ok=True) time_series_dir = filename.parent / (filename.stem + "_time_series") + time_series_dir.mkdir(exist_ok=True) system_data = { "name": self.name, "description": self.description, @@ -161,7 +170,8 @@ def to_json(self, filename: Path | str, overwrite=False, indent=None, data=None) json.dump(data, f_out, indent=indent) logger.info("Wrote system data to {}", filename) - self._time_series_mgr.serialize(self._make_time_series_directory(filename)) + backup(self._con, time_series_dir / self.DB_FILENAME) + self._time_series_mgr.serialize(time_series_dir) @classmethod def from_json( @@ -257,12 +267,20 @@ def from_dict( """ system_data = data if "system" not in data else data["system"] ts_kwargs = {k: v for k, v in kwargs.items() if k in TIME_SERIES_KWARGS} + ts_path = ( + time_series_parent_dir + if isinstance(time_series_parent_dir, Path) + else Path(time_series_parent_dir) + ) + con = create_in_memory_db() + restore(con, ts_path / data["time_series"]["directory"] / System.DB_FILENAME) time_series_manager = TimeSeriesManager.deserialize( - data["time_series"], time_series_parent_dir, **ts_kwargs + con, data["time_series"], ts_path, **ts_kwargs ) system = cls( name=system_data.get("name"), description=system_data.get("description"), + con=con, time_series_manager=time_series_manager, uuid=UUID(system_data["uuid"]), **kwargs, diff --git a/src/infrasys/time_series_manager.py b/src/infrasys/time_series_manager.py index 7013387..a705ccc 100644 --- a/src/infrasys/time_series_manager.py +++ b/src/infrasys/time_series_manager.py @@ -1,5 +1,6 @@ """Manages time series arrays""" +import sqlite3 from datetime import datetime from pathlib import Path from typing import Any, Optional, Type @@ -32,7 +33,13 @@ def _process_time_series_kwarg(key: str, **kwargs: Any) -> Any: class TimeSeriesManager: """Manages time series for a system.""" - def __init__(self, storage: Optional[TimeSeriesStorageBase] = None, **kwargs) -> None: + def __init__( + self, + con: sqlite3.Connection, + storage: Optional[TimeSeriesStorageBase] = None, + initialize: bool = True, + **kwargs, + ) -> None: base_directory: Path | None = _process_time_series_kwarg("time_series_directory", **kwargs) self._read_only = _process_time_series_kwarg("time_series_read_only", **kwargs) self._storage = storage or ( @@ -40,7 +47,7 @@ def __init__(self, storage: Optional[TimeSeriesStorageBase] = None, **kwargs) -> if _process_time_series_kwarg("time_series_in_memory", **kwargs) else ArrowTimeSeriesStorage.create_with_temp_directory(base_directory=base_directory) ) - self._metadata_store = TimeSeriesMetadataStore() + self._metadata_store = TimeSeriesMetadataStore(con, initialize=initialize) # TODO: create parsing mechanism? CSV, CSV + JSON @@ -245,11 +252,11 @@ def _get_by_metadata( def serialize(self, dst: Path | str, src: Optional[Path | str] = None) -> None: """Serialize the time series data to dst.""" self._storage.serialize(dst, src) - self._metadata_store.backup(dst) @classmethod def deserialize( cls, + con: sqlite3.Connection, data: dict[str, Any], parent_dir: Path | str, **kwargs: Any, @@ -269,9 +276,7 @@ def deserialize( storage = ArrowTimeSeriesStorage.create_with_temp_directory() storage.serialize(src=time_series_dir, dst=storage.get_time_series_directory()) - mgr = cls(storage=storage, **kwargs) - mgr.metadata_store.restore(time_series_dir) - return mgr + return cls(con, storage=storage, initialize=False, **kwargs) def _handle_read_only(self) -> None: if self._read_only: diff --git a/src/infrasys/time_series_metadata_store.py b/src/infrasys/time_series_metadata_store.py index f469260..44cffcb 100644 --- a/src/infrasys/time_series_metadata_store.py +++ b/src/infrasys/time_series_metadata_store.py @@ -6,7 +6,6 @@ import os import sqlite3 from dataclasses import dataclass -from pathlib import Path from typing import Any, Optional, Sequence from uuid import UUID @@ -28,11 +27,11 @@ class TimeSeriesMetadataStore: """Stores time series metadata in a SQLite database.""" TABLE_NAME = "time_series_metadata" - DB_FILENAME = "time_series_metadata.db" - def __init__(self): - self._con = sqlite3.connect(":memory:") - self._create_metadata_table() + def __init__(self, con: sqlite3.Connection, initialize: bool = True): + self._con = con + if initialize: + self._create_metadata_table() self._supports_sqlite_json = _does_sqlite_support_json() if not self._supports_sqlite_json: # This is true on Ubuntu 22.04, which is used by GitHub runners as of March 2024. @@ -126,24 +125,6 @@ def add( ] self._insert_rows(rows) - def backup(self, directory: Path | str) -> None: - """Backup the database to a file in directory.""" - path = directory if isinstance(directory, Path) else Path(directory) - filename = path / self.DB_FILENAME - with sqlite3.connect(filename) as con: - self._con.backup(con) - con.close() - logger.info("Backed up the time series metadata to {}", filename) - - def restore(self, directory: Path | str) -> None: - """Restore the database from a file to memory.""" - path = directory if isinstance(directory, Path) else Path(directory) - filename = path / self.DB_FILENAME - with sqlite3.connect(filename) as con: - con.backup(self._con) - con.close() - logger.info("Restored the time series metadata to memory") - def get_time_series_counts(self) -> "TimeSeriesCounts": """Return summary counts of components and time series.""" query = f""" diff --git a/src/infrasys/utils/sqlite.py b/src/infrasys/utils/sqlite.py index daad819..6aeec82 100644 --- a/src/infrasys/utils/sqlite.py +++ b/src/infrasys/utils/sqlite.py @@ -1,11 +1,33 @@ """Utility functions for SQLite""" import sqlite3 +from pathlib import Path from typing import Any, Sequence from loguru import logger +def backup(src_con: sqlite3.Connection, filename: Path | str) -> None: + """Backup a database to a file.""" + with sqlite3.connect(filename) as dst_con: + src_con.backup(dst_con) + dst_con.close() + logger.info("Backed up the database to {}.", filename) + + +def restore(dst_con: sqlite3.Connection, filename: Path | str) -> None: + """Restore a database from a file.""" + with sqlite3.connect(filename) as src_con: + src_con.backup(dst_con) + src_con.close() + logger.info("Restored the database from {}.", filename) + + +def create_in_memory_db(database: str = ":memory:") -> sqlite3.Connection: + """Create an in-memory database.""" + return sqlite3.connect(database) + + def execute(cursor: sqlite3.Cursor, query: str, params: Sequence[str] = ()) -> Any: """Execute a SQL query.""" logger.trace("SQL query: {query} {params=}", query) diff --git a/tests/test_arrow_storage.py b/tests/test_arrow_storage.py index e61caf6..1529dec 100644 --- a/tests/test_arrow_storage.py +++ b/tests/test_arrow_storage.py @@ -56,7 +56,7 @@ def test_copy_files(tmp_path): system.to_json(filename) logger.info("Starting deserialization") - system2 = SimpleSystem.from_json(filename, base_directory=tmp_path) + system2 = SimpleSystem.from_json(filename) gen1b = system2.get_component(SimpleGenerator, gen1.name) time_series = system2.time_series.get(gen1b) time_series_fpath = ( diff --git a/tests/test_serialization.py b/tests/test_serialization.py index f382459..d696645 100644 --- a/tests/test_serialization.py +++ b/tests/test_serialization.py @@ -59,7 +59,7 @@ def test_serialization(tmp_path): system.to_json(filename, overwrite=True) system2 = SimpleSystem.from_json(filename) for key, val in system.__dict__.items(): - if key not in ("_component_mgr", "_time_series_mgr"): + if key not in ("_component_mgr", "_time_series_mgr", "_con"): assert getattr(system2, key) == val components2 = list(system2.iter_all_components()) @@ -195,16 +195,14 @@ def test_system_save(tmp_path, simple_system_with_time_series): simple_system = simple_system_with_time_series custom_folder = "my_system" fpath = tmp_path / custom_folder - fname = "test_system" + fname = "test_system.json" simple_system.save(fpath, filename=fname) assert os.path.exists(fpath), f"Folder {fpath} was not created successfully" assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully" - fname = "test_system" with pytest.raises(FileExistsError): simple_system.save(fpath, filename=fname) - fname = "test_system" simple_system.save(fpath, filename=fname, overwrite=True) assert os.path.exists(fpath), f"Folder {fpath} was not created successfully" assert os.path.exists(fpath / fname), f"Serialized system {fname} was not created successfully"