diff --git a/pysteps/xarray_helpers.py b/pysteps/xarray_helpers.py index a1049f32..33ec2f40 100644 --- a/pysteps/xarray_helpers.py +++ b/pysteps/xarray_helpers.py @@ -83,6 +83,7 @@ def convert_input_to_xarray_dataset( precip: np.ndarray, quality: np.ndarray | None, metadata: dict[str, str | float | None], + startdate: datetime | None = None, ) -> xr.Dataset: """ Read a precip, quality, metadata tuple as returned by the importers @@ -99,6 +100,8 @@ def convert_input_to_xarray_dataset( metadata: dict Metadata dictionary containing the attributes described in the documentation of :py:mod:`pysteps.io.importers`. + startdate: datetime, None + Datetime object containing the start date and time for the nowcast Returns ------- @@ -107,7 +110,31 @@ def convert_input_to_xarray_dataset( """ var_name, attrs = cf_parameters_from_unit(metadata["unit"]) - h, w = precip.shape + + dims = None + timesteps = None + ens_number = None + + if precip.ndim == 4: + ens_number, timesteps, h, w = precip.shape + dims = ["ens_number", "time", "y", "x"] + + if startdate is None: + raise Exception("startdate missing") + + elif precip.ndim == 3: + timesteps, h, w = precip.shape + dims = ["time", "y", "x"] + + if startdate is None: + raise Exception("startdate missing") + + elif precip.ndim == 2: + h, w = precip.shape + dims = ["y", "x"] + else: + raise Exception(f"Precip field shape: {precip.shape} not supported") + x_r = np.linspace(metadata["x1"], metadata["x2"], w + 1)[:-1] x_r += 0.5 * (x_r[1] - x_r[0]) y_r = np.linspace(metadata["y1"], metadata["y2"], h + 1)[:-1] @@ -142,25 +169,33 @@ def convert_input_to_xarray_dataset( data_vars = { var_name: ( - ["y", "x"], + dims, precip, { "units": attrs["units"], "standard_name": attrs["standard_name"], "long_name": attrs["long_name"], "grid_mapping": "projection", - "transform": metadata["transform"], - "accutime": metadata["accutime"], - "threshold": metadata["threshold"], - "zerovalue": metadata["zerovalue"], - "zr_a": metadata["zr_a"], - "zr_b": metadata["zr_b"], }, ) } + + metadata_keys = [ + "transform", + "accutime", + "threshold", + "zerovalue", + "zr_a", + "zr_b", + ] + + for metadata_field in metadata_keys: + if metadata_field in metadata: + data_vars[var_name][2][metadata_field] = metadata[metadata_field] + if quality is not None: data_vars["quality"] = ( - ["y", "x"], + dims, quality, { "units": "1", @@ -210,6 +245,26 @@ def convert_input_to_xarray_dataset( }, ), } + + if ens_number is not None: + coords["ens_number"] = ( + ["ens_number"], + list(range(1, ens_number + 1, 1)), + { + "long_name": "ensemble member", + "standard_name": "realization", + "units": "", + }, + ) + + if timesteps is not None: + startdate_str = datetime.strftime(startdate, "%Y-%m-%d %H:%M:%S") + + coords["time"] = ( + ["time"], + list(range(1, timesteps + 1, 1)), + {"long_name": "forecast time", "units": "seconds since %s" % startdate_str}, + ) if grid_mapping_var_name is not None: coords[grid_mapping_name] = ( [], @@ -223,7 +278,7 @@ def convert_input_to_xarray_dataset( "precip_var": var_name, } dataset = xr.Dataset(data_vars=data_vars, coords=coords, attrs=attrs) - return dataset.sortby(["y", "x"]) + return dataset.sortby(dims) def convert_output_to_xarray_dataset(