Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: make sure netcdf exporters can handle list of timesteps #369

Merged
merged 3 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 24 additions & 8 deletions pysteps/io/exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -392,9 +392,10 @@ def initialize_forecast_exporter_netcdf(
Start date of the forecast.
timestep: int
Time step of the forecast (minutes).
n_timesteps: int
Number of time steps in the forecast this argument is ignored if
incremental is set to 'timestep'.
n_timesteps: int or list of integers
Number of time steps to forecast or a list of time steps for which the
forecasts are computed (relative to the input time step). The elements of
the list are required to be in ascending order.
shape: tuple of int
Two-element tuple defining the shape (height,width) of the forecast
grids.
Expand Down Expand Up @@ -460,8 +461,14 @@ def initialize_forecast_exporter_netcdf(
+ "'timestep' or 'member'"
)

n_timesteps_is_list = isinstance(n_timesteps, list)
if n_timesteps_is_list:
num_timesteps = len(n_timesteps)
else:
num_timesteps = n_timesteps

if incremental == "timestep":
n_timesteps = None
num_timesteps = None
elif incremental == "member":
n_ens_members = None
elif incremental is not None:
Expand Down Expand Up @@ -498,7 +505,7 @@ def initialize_forecast_exporter_netcdf(
h, w = shape

ncf.createDimension("ens_number", size=n_ens_members)
ncf.createDimension("time", size=n_timesteps)
ncf.createDimension("time", size=num_timesteps)
ncf.createDimension("y", size=h)
ncf.createDimension("x", size=w)

Expand Down Expand Up @@ -585,7 +592,10 @@ def initialize_forecast_exporter_netcdf(

var_time = ncf.createVariable("time", int, dimensions=("time",))
if incremental != "timestep":
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
if n_timesteps_is_list:
var_time[:] = np.array(n_timesteps) * timestep * 60
else:
var_time[:] = [i * timestep * 60 for i in range(1, n_timesteps + 1)]
var_time.long_name = "forecast time"
startdate_str = datetime.strftime(startdate, "%Y-%m-%d %H:%M:%S")
var_time.units = "seconds since %s" % startdate_str
Expand Down Expand Up @@ -635,7 +645,8 @@ def initialize_forecast_exporter_netcdf(
exporter["timestep"] = timestep
exporter["metadata"] = metadata
exporter["incremental"] = incremental
exporter["num_timesteps"] = n_timesteps
exporter["num_timesteps"] = num_timesteps
exporter["timesteps"] = n_timesteps
dnerini marked this conversation as resolved.
Show resolved Hide resolved
exporter["num_ens_members"] = n_ens_members
exporter["shape"] = shape

Expand Down Expand Up @@ -853,7 +864,12 @@ def _export_netcdf(field, exporter):
else:
var_f[var_f.shape[0], :, :] = field
var_time = exporter["var_time"]
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
if isinstance(exporter["timesteps"], list):
var_time[len(var_time) - 1] = (
exporter["timesteps"][len(var_time) - 1] * exporter["timestep"] * 60
)
else:
var_time[len(var_time) - 1] = len(var_time) * exporter["timestep"] * 60
else:
var_f[var_f.shape[0], :, :, :] = field
var_ens_num = exporter["var_ens_num"]
Expand Down
22 changes: 14 additions & 8 deletions pysteps/tests/test_exporters.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,17 @@
"fill_value",
"scale_factor",
"offset",
"n_timesteps",
)

exporter_arg_values = [
(1, None, np.float32, None, None, None),
(1, "timestep", np.float32, 65535, None, None),
(2, None, np.float32, 65535, None, None),
(2, "timestep", np.float32, None, None, None),
(2, "member", np.float64, None, 0.01, 1.0),
(1, None, np.float32, None, None, None, 3),
(1, "timestep", np.float32, 65535, None, None, 3),
(2, None, np.float32, 65535, None, None, 3),
(2, None, np.float32, 65535, None, None, [1, 2, 4]),
(2, "timestep", np.float32, None, None, None, 3),
(2, "timestep", np.float32, None, None, None, [1, 2, 4]),
(2, "member", np.float64, None, 0.01, 1.0, 3),
]


Expand All @@ -54,7 +57,7 @@ def test_get_geotiff_filename():

@pytest.mark.parametrize(exporter_arg_names, exporter_arg_values)
def test_io_export_netcdf_one_member_one_time_step(
n_ens_members, incremental, datatype, fill_value, scale_factor, offset
n_ens_members, incremental, datatype, fill_value, scale_factor, offset, n_timesteps
):
"""
Test the export netcdf.
Expand All @@ -75,7 +78,6 @@ def test_io_export_netcdf_one_member_one_time_step(
file_path = os.path.join(outpath, outfnprefix + ".nc")
startdate = metadata["timestamps"][0]
timestep = metadata["accutime"]
n_timesteps = 3
shape = precip.shape[1:]

exporter = initialize_forecast_exporter_netcdf(
Expand All @@ -100,7 +102,11 @@ def test_io_export_netcdf_one_member_one_time_step(
if incremental == None:
export_forecast_dataset(precip, exporter)
if incremental == "timestep":
for t in range(n_timesteps):
if isinstance(n_timesteps, list):
timesteps = len(n_timesteps)
else:
timesteps = n_timesteps
for t in range(timesteps):
if n_ens_members > 1:
export_forecast_dataset(precip[:, t, :, :], exporter)
else:
Expand Down