From 0e3c904d94142a50f1af31d54485feeb91509814 Mon Sep 17 00:00:00 2001 From: pesap Date: Mon, 21 Oct 2024 15:21:31 -0600 Subject: [PATCH] fix: Removed `units` from TimeSeriesBase model and changed Pint behaviour. List of changes - Changed data handling when we create an instance of `SingleTimeSeries`. - Updated test to check that data inside a `SingleTimeSeries` is consistent since it could return two different instances when serialize and deserialize. - Updated arrow storage to convert to `pa.Array` when serializing only and better type hint --- src/infrasys/arrow_storage.py | 12 +++++--- src/infrasys/time_series_models.py | 10 +++--- tests/test_single_time_series.py | 4 +-- tests/test_system.py | 49 +++++++++++++++--------------- 4 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/infrasys/arrow_storage.py b/src/infrasys/arrow_storage.py index 62de451..73a7d83 100644 --- a/src/infrasys/arrow_storage.py +++ b/src/infrasys/arrow_storage.py @@ -123,13 +123,15 @@ def _get_single_time_series( normalization=metadata.normalization, ) - def _convert_to_record_batch(self, array: SingleTimeSeries, variable_name: str): + def _convert_to_record_batch( + self, time_series: SingleTimeSeries, variable_name: str + ) -> pa.RecordBatch: """Create record batch to save array to disk.""" - pa_array = array.data.magnitude if isinstance(array.data, BaseQuantity) else array.data - if not isinstance(array.data, pa.Array) and isinstance( - array.data, BaseQuantity | pint.Quantity + pa_array = time_series.data + if not isinstance(pa_array, pa.Array) and isinstance( + pa_array, BaseQuantity | pint.Quantity ): - pa_array = pa.array(array.data.magnitude) + pa_array = pa.array(pa_array.magnitude) assert isinstance(pa_array, pa.Array) schema = pa.schema([pa.field(variable_name, pa_array.type)]) return pa.record_batch([pa_array], schema=schema) diff --git a/src/infrasys/time_series_models.py b/src/infrasys/time_series_models.py index 54d334e..0753b6c 100644 --- a/src/infrasys/time_series_models.py +++ b/src/infrasys/time_series_models.py @@ -45,7 +45,6 @@ class TimeSeriesStorageType(str, Enum): class TimeSeriesData(InfraSysBaseModelWithIdentifers, abc.ABC): """Base class for all time series models""" - units: Optional[str] = None variable_name: str normalization: NormalizationModel = None @@ -74,16 +73,15 @@ def length(self) -> int: @field_validator("data", mode="before") @classmethod - def check_data(cls, data) -> pa.Array | BaseQuantity: # Standarize what object we receive. + def check_data( + cls, data + ) -> pa.Array | pa.ChunkedArray | pint.Quantity: # Standarize what object we receive. """Check time series data.""" if len(data) < 2: msg = f"SingleTimeSeries length must be at least 2: {len(data)}" raise ValueError(msg) - if isinstance(data, BaseQuantity | pint.Quantity): - if not isinstance(data.magnitude, pa.Array): - cls = type(data) - return cls(data.magnitude, data.units) + if isinstance(data, pint.Quantity | BaseQuantity): return data if not isinstance(data, pa.Array): diff --git a/tests/test_single_time_series.py b/tests/test_single_time_series.py index b2a94f5..9620c5d 100644 --- a/tests/test_single_time_series.py +++ b/tests/test_single_time_series.py @@ -82,8 +82,8 @@ def test_with_quantity(): assert ts.length == length assert ts.resolution == resolution assert ts.initial_time == initial_time - assert isinstance(ts.data.magnitude, pa.Array) - assert ts.data[-1].as_py() == length - 1 + assert isinstance(ts.data, ActivePower) + assert ts.data[-1].magnitude == length - 1 def test_normalization(): diff --git a/tests/test_system.py b/tests/test_system.py index 8c76b47..d2e27ed 100644 --- a/tests/test_system.py +++ b/tests/test_system.py @@ -249,10 +249,11 @@ def test_time_series(): @pytest.mark.parametrize( - "params", list(itertools.product([True, False], [True, False], [True, False])) + "in_memory,use_quantity,sql_json", + list(itertools.product([True, False], [True, False], [True, False])), ) -def test_time_series_retrieval(params): - in_memory, use_quantity, sql_json = params +def test_time_series_retrieval(in_memory, use_quantity, sql_json): + # in_memory, use_quantity, sql_json = params try: if not sql_json: os.environ["__INFRASYS_NON_JSON_SQLITE__"] = "1" @@ -290,29 +291,33 @@ def test_time_series_retrieval(params): for metadata in system.list_time_series_metadata(gen, scenario="high"): assert metadata.user_attributes["scenario"] == "high" - assert ( - system.get_time_series( - gen, variable_name=variable_name, scenario="high", model_year="2030" + assert all( + np.equal( + system.get_time_series( + gen, variable_name, scenario="high", model_year="2030" + ).data, + ts1.data, ) - == ts1 ) - assert ( - system.get_time_series( - gen, variable_name=variable_name, scenario="high", model_year="2035" + assert all( + np.equal( + system.get_time_series( + gen, variable_name, scenario="high", model_year="2035" + ).data, + ts2.data, ) - == ts2 ) - assert ( - system.get_time_series( - gen, variable_name=variable_name, scenario="low", model_year="2030" + assert all( + np.equal( + system.get_time_series(gen, variable_name, scenario="low", model_year="2030").data, + ts3.data, ) - == ts3 ) - assert ( - system.get_time_series( - gen, variable_name=variable_name, scenario="low", model_year="2035" + assert all( + np.equal( + system.get_time_series(gen, variable_name, scenario="low", model_year="2035").data, + ts4.data, ) - == ts4 ) with pytest.raises(ISAlreadyAttached): @@ -324,12 +329,6 @@ def test_time_series_retrieval(params): gen, variable_name=variable_name, scenario="high", model_year="2030" ) assert not system.has_time_series(gen, variable_name=variable_name, model_year="2036") - assert ( - system.get_time_series( - gen, variable_name=variable_name, scenario="high", model_year="2030" - ) - == ts1 - ) with pytest.raises(ISOperationNotAllowed): system.get_time_series(gen, variable_name=variable_name, scenario="high") with pytest.raises(ISNotStored):