Skip to content

Commit

Permalink
fix: Removed units from TimeSeriesBase model and changed Pint
Browse files Browse the repository at this point in the history
behaviour.

List of changes
- Changed data handling when we create an instance of
  `SingleTimeSeries`.
- Updated test to check that data inside a `SingleTimeSeries` is consistent since it could return
  two different instances when serialize and deserialize.
- Updated arrow storage to convert to `pa.Array` when serializing only
  and better type hint
  • Loading branch information
pesap committed Oct 21, 2024
1 parent 2ba8324 commit 0e3c904
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 38 deletions.
12 changes: 7 additions & 5 deletions src/infrasys/arrow_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,13 +123,15 @@ def _get_single_time_series(
normalization=metadata.normalization,
)

def _convert_to_record_batch(self, array: SingleTimeSeries, variable_name: str):
def _convert_to_record_batch(
self, time_series: SingleTimeSeries, variable_name: str
) -> pa.RecordBatch:
"""Create record batch to save array to disk."""
pa_array = array.data.magnitude if isinstance(array.data, BaseQuantity) else array.data
if not isinstance(array.data, pa.Array) and isinstance(
array.data, BaseQuantity | pint.Quantity
pa_array = time_series.data
if not isinstance(pa_array, pa.Array) and isinstance(
pa_array, BaseQuantity | pint.Quantity
):
pa_array = pa.array(array.data.magnitude)
pa_array = pa.array(pa_array.magnitude)
assert isinstance(pa_array, pa.Array)
schema = pa.schema([pa.field(variable_name, pa_array.type)])
return pa.record_batch([pa_array], schema=schema)
Expand Down
10 changes: 4 additions & 6 deletions src/infrasys/time_series_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class TimeSeriesStorageType(str, Enum):
class TimeSeriesData(InfraSysBaseModelWithIdentifers, abc.ABC):
"""Base class for all time series models"""

units: Optional[str] = None
variable_name: str
normalization: NormalizationModel = None

Expand Down Expand Up @@ -74,16 +73,15 @@ def length(self) -> int:

@field_validator("data", mode="before")
@classmethod
def check_data(cls, data) -> pa.Array | BaseQuantity: # Standarize what object we receive.
def check_data(
cls, data
) -> pa.Array | pa.ChunkedArray | pint.Quantity: # Standarize what object we receive.
"""Check time series data."""
if len(data) < 2:
msg = f"SingleTimeSeries length must be at least 2: {len(data)}"
raise ValueError(msg)

if isinstance(data, BaseQuantity | pint.Quantity):
if not isinstance(data.magnitude, pa.Array):
cls = type(data)
return cls(data.magnitude, data.units)
if isinstance(data, pint.Quantity | BaseQuantity):
return data

if not isinstance(data, pa.Array):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_single_time_series.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,8 @@ def test_with_quantity():
assert ts.length == length
assert ts.resolution == resolution
assert ts.initial_time == initial_time
assert isinstance(ts.data.magnitude, pa.Array)
assert ts.data[-1].as_py() == length - 1
assert isinstance(ts.data, ActivePower)
assert ts.data[-1].magnitude == length - 1


def test_normalization():
Expand Down
49 changes: 24 additions & 25 deletions tests/test_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ def test_time_series():


@pytest.mark.parametrize(
"params", list(itertools.product([True, False], [True, False], [True, False]))
"in_memory,use_quantity,sql_json",
list(itertools.product([True, False], [True, False], [True, False])),
)
def test_time_series_retrieval(params):
in_memory, use_quantity, sql_json = params
def test_time_series_retrieval(in_memory, use_quantity, sql_json):
# in_memory, use_quantity, sql_json = params
try:
if not sql_json:
os.environ["__INFRASYS_NON_JSON_SQLITE__"] = "1"
Expand Down Expand Up @@ -290,29 +291,33 @@ def test_time_series_retrieval(params):
for metadata in system.list_time_series_metadata(gen, scenario="high"):
assert metadata.user_attributes["scenario"] == "high"

assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2030"
assert all(
np.equal(
system.get_time_series(
gen, variable_name, scenario="high", model_year="2030"
).data,
ts1.data,
)
== ts1
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2035"
assert all(
np.equal(
system.get_time_series(
gen, variable_name, scenario="high", model_year="2035"
).data,
ts2.data,
)
== ts2
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="low", model_year="2030"
assert all(
np.equal(
system.get_time_series(gen, variable_name, scenario="low", model_year="2030").data,
ts3.data,
)
== ts3
)
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="low", model_year="2035"
assert all(
np.equal(
system.get_time_series(gen, variable_name, scenario="low", model_year="2035").data,
ts4.data,
)
== ts4
)

with pytest.raises(ISAlreadyAttached):
Expand All @@ -324,12 +329,6 @@ def test_time_series_retrieval(params):
gen, variable_name=variable_name, scenario="high", model_year="2030"
)
assert not system.has_time_series(gen, variable_name=variable_name, model_year="2036")
assert (
system.get_time_series(
gen, variable_name=variable_name, scenario="high", model_year="2030"
)
== ts1
)
with pytest.raises(ISOperationNotAllowed):
system.get_time_series(gen, variable_name=variable_name, scenario="high")
with pytest.raises(ISNotStored):
Expand Down

0 comments on commit 0e3c904

Please sign in to comment.