Skip to content

Commit

Permalink
Dimension order and coordinate clean-up (#83)
Browse files Browse the repository at this point in the history
* Add tidal metadata and force uname and vname inputs to process_velocities

* Add uname and vname in process_velocities

* Refactor TidalForcing to inherit from ROMSMixinTools

* Drop coordinates in tidal forcing

* Test whether plotting on coarse grid works

* Update tidal notebook

* Define dimension order to (time, s_rho, eta_rho, xi_rho)

* Transpose regression data for initial conditions

* Drop coordinates, rename time -> ocean_time, and introduce abs_time

* Transpose regression test data

* Fix plotting by adding coordinates to field

* Adjust plotting routine to new coordinates handling

* Update grid notebook

* Format correctly

* Complete docstrings for initial conditions plotting function

* Get rid off month coordinate

* Remove coordinates and renaming from boundary forcing

* Transpose regression test data

* Add post-process method to CESMBGCSurfaceForcingDataset to keep it clean

* Assign lat/lon in object specific .plot method

* Update surface forcing notebook

* Add lat/lon coords before plotting

* Update notebooks

* Correct description of methods
  • Loading branch information
NoraLoose authored Aug 22, 2024
1 parent 552e49d commit caaba66
Show file tree
Hide file tree
Showing 19 changed files with 7,379 additions and 13,075 deletions.
11,796 changes: 4,129 additions & 7,667 deletions docs/boundary_forcing.ipynb

Large diffs are not rendered by default.

595 changes: 302 additions & 293 deletions docs/grid.ipynb

Large diffs are not rendered by default.

5,181 changes: 2,055 additions & 3,126 deletions docs/initial_conditions.ipynb

Large diffs are not rendered by default.

8 changes: 4 additions & 4 deletions docs/methods.rst
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,14 @@ During the grid generation process, ``ROMS-Tools`` also creates a topography fie
hmin=5, # Minimum ocean depth in meters (default: 5)
)
This functionality is executed through the :meth:`roms_tools.Grid.add_topography_and_mask` method, which is automatically called when an instance of the :class:`roms_tools.Grid` class is created.
This functionality is executed through the :meth:`roms_tools.Grid.update_topography_and_mask` method, which is automatically called when an instance of the :class:`roms_tools.Grid` class is created.

Users can also directly apply the :meth:`roms_tools.Grid.add_topography_and_mask` method if they wish to overwrite an existing topography or if a grid has been loaded from a file that lacks a topography field. For more detailed information and examples, please refer to `this example <grid.ipynb>`_.
Users can also directly apply the :meth:`roms_tools.Grid.update_topography_and_mask` method if they wish to overwrite an existing topography or if a grid has been loaded from a file that lacks a topography field. For more detailed information and examples, please refer to `this example <grid.ipynb>`_.

The :meth:`roms_tools.Grid.add_topography_and_mask` method completes five steps:
The :meth:`roms_tools.Grid.update_topography_and_mask` method completes five steps:

0. The topography from the specified ``topography_source`` is interpolated onto the ROMS grid.
1. The mask is defined using a dealiased version of the interpolated topography from step 0. In this step, the topography is evaluated at each grid point: values smaller than 0.11 meters are classified as land, and values larger than 0.11 meters are classified as ocean.
1. The mask is defined using a dealiased version of the interpolated topography from step 0. In this step, the topography is evaluated at each grid point: values smaller than 0.0 meters are classified as land, and values larger than 0.0 meters are classified as ocean.
2. The interpolated topography from step 0 is smoothed over the entire domain with a smoothing factor of 8. This step ensures that the topography is smooth at the grid scale, a prerequisite for avoiding grid-scale instabilities at runtime.
3. The mask is modified by filling enclosed basins with land.
4. Regions where the ocean depth is shallower than ``hmin`` are set to ``hmin``. The topography is then smoothed locally in such a way that the maximum slope parameter ``r`` is smaller than 0.2. The maximum slope parameter is given by
Expand Down
1,288 changes: 126 additions & 1,162 deletions docs/surface_forcing.ipynb

Large diffs are not rendered by default.

144 changes: 49 additions & 95 deletions docs/tides.ipynb

Large diffs are not rendered by default.

178 changes: 52 additions & 126 deletions roms_tools/setup/boundary_forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __post_init__(self):
vars_2d = ["zeta"]
vars_3d = ["temp", "salt", "u", "v"]
data_vars = super().regrid_data(data, vars_2d, vars_3d, lon, lat)
data_vars = super().process_velocities(data_vars, angle)
data_vars = super().process_velocities(data_vars, angle, "u", "v")
object.__setattr__(data, "data_vars", data_vars)

if self.bgc_source is not None:
Expand All @@ -121,9 +121,9 @@ def __post_init__(self):
bgc_data = None

d_meta = super().get_variable_metadata()
bdry_coords, rename = super().get_boundary_info()
bdry_coords = super().get_boundary_info()

ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords, rename)
ds = self._write_into_datatree(data, bgc_data, d_meta, bdry_coords)

for direction in ["south", "east", "north", "west"]:
if self.boundaries[direction]:
Expand Down Expand Up @@ -202,7 +202,7 @@ def _get_bgc_data(self):

return data

def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
def _write_into_dataset(self, data, d_meta, bdry_coords):

# save in new dataset
ds = xr.Dataset()
Expand All @@ -215,21 +215,18 @@ def _write_into_dataset(self, data, d_meta, bdry_coords, rename):
ds[f"{var}_{direction}"] = (
data.data_vars[var]
.isel(**bdry_coords["u"][direction])
.rename(**rename["u"][direction])
.astype(np.float32)
)
elif var in ["v", "vbar"]:
ds[f"{var}_{direction}"] = (
data.data_vars[var]
.isel(**bdry_coords["v"][direction])
.rename(**rename["v"][direction])
.astype(np.float32)
)
else:
ds[f"{var}_{direction}"] = (
data.data_vars[var]
.isel(**bdry_coords["rho"][direction])
.rename(**rename["rho"][direction])
.astype(np.float32)
)
ds[f"{var}_{direction}"].attrs[
Expand Down Expand Up @@ -288,127 +285,63 @@ def _write_into_dataset(self, data, d_meta, bdry_coords, rename):

return ds

def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords, rename):
def _write_into_datatree(self, data, bgc_data, d_meta, bdry_coords):

ds = self._add_global_metadata()
ds["sc_r"] = self.grid.ds["sc_r"]
ds["Cs_r"] = self.grid.ds["Cs_r"]

ds = DataTree(name="root", data=ds)

ds_physics = self._write_into_dataset(data, d_meta, bdry_coords, rename)
ds_physics = self._add_coordinates(bdry_coords, rename, ds_physics)
ds_physics = self._write_into_dataset(data, d_meta, bdry_coords)
ds_physics = self._add_global_metadata(ds_physics)
ds_physics.attrs["physics_source"] = self.physics_source["name"]

ds_physics = DataTree(name="physics", parent=ds, data=ds_physics)

if bgc_data:
ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords, rename)
ds_bgc = self._add_coordinates(bdry_coords, rename, ds_bgc)
ds_bgc = self._write_into_dataset(bgc_data, d_meta, bdry_coords)
ds_bgc = self._add_global_metadata(ds_bgc)
ds_bgc.attrs["bgc_source"] = self.bgc_source["name"]
ds_bgc = DataTree(name="bgc", parent=ds, data=ds_bgc)

return ds

def _add_coordinates(self, bdry_coords, rename, ds=None):

if ds is None:
ds = xr.Dataset()

for direction in ["south", "east", "north", "west"]:

if self.boundaries[direction]:
def _get_coordinates(self, direction, point):
"""
Retrieve layer and interface depth coordinates for a specified grid boundary.
lat_rho = self.grid.ds.lat_rho.isel(
**bdry_coords["rho"][direction]
).rename(**rename["rho"][direction])
lon_rho = self.grid.ds.lon_rho.isel(
**bdry_coords["rho"][direction]
).rename(**rename["rho"][direction])
layer_depth_rho = (
self.grid.ds["layer_depth_rho"]
.isel(**bdry_coords["rho"][direction])
.rename(**rename["rho"][direction])
)
interface_depth_rho = (
self.grid.ds["interface_depth_rho"]
.isel(**bdry_coords["rho"][direction])
.rename(**rename["rho"][direction])
)
This method extracts the layer depth and interface depth coordinates along
a specified boundary (north, south, east, or west) and for a specified point
type (rho, u, or v) from the grid dataset.
lat_u = self.grid.ds.lat_u.isel(**bdry_coords["u"][direction]).rename(
**rename["u"][direction]
)
lon_u = self.grid.ds.lon_u.isel(**bdry_coords["u"][direction]).rename(
**rename["u"][direction]
)
layer_depth_u = (
self.grid.ds["layer_depth_u"]
.isel(**bdry_coords["u"][direction])
.rename(**rename["u"][direction])
)
interface_depth_u = (
self.grid.ds["interface_depth_u"]
.isel(**bdry_coords["u"][direction])
.rename(**rename["u"][direction])
)
Parameters
----------
direction : str
The direction of the boundary to retrieve coordinates for. Valid options
are "north", "south", "east", and "west".
point : str
The type of grid point to retrieve coordinates for. Valid options are
"rho" for the grid's central points, "u" for the u-flux points, and "v"
for the v-flux points.
lat_v = self.grid.ds.lat_v.isel(**bdry_coords["v"][direction]).rename(
**rename["v"][direction]
)
lon_v = self.grid.ds.lon_v.isel(**bdry_coords["v"][direction]).rename(
**rename["v"][direction]
)
layer_depth_v = (
self.grid.ds["layer_depth_v"]
.isel(**bdry_coords["v"][direction])
.rename(**rename["v"][direction])
)
interface_depth_v = (
self.grid.ds["interface_depth_v"]
.isel(**bdry_coords["v"][direction])
.rename(**rename["v"][direction])
)
Returns
-------
xarray.DataArray, xarray.DataArray
The layer depth and interface depth coordinates for the specified grid
boundary and point type.
"""

ds = ds.assign_coords(
{
f"layer_depth_rho_{direction}": layer_depth_rho,
f"layer_depth_u_{direction}": layer_depth_u,
f"layer_depth_v_{direction}": layer_depth_v,
f"interface_depth_rho_{direction}": interface_depth_rho,
f"interface_depth_u_{direction}": interface_depth_u,
f"interface_depth_v_{direction}": interface_depth_v,
f"lat_rho_{direction}": lat_rho,
f"lat_u_{direction}": lat_u,
f"lat_v_{direction}": lat_v,
f"lon_rho_{direction}": lon_rho,
f"lon_u_{direction}": lon_u,
f"lon_v_{direction}": lon_v,
}
)
bdry_coords = super().get_boundary_info()

# Gracefully handle dropping variables that might not be present
variables_to_drop = [
"s_rho",
"layer_depth_rho",
"layer_depth_u",
"layer_depth_v",
"interface_depth_rho",
"interface_depth_u",
"interface_depth_v",
"lat_rho",
"lon_rho",
"lat_u",
"lon_u",
"lat_v",
"lon_v",
]
existing_vars = [var for var in variables_to_drop if var in ds]
ds = ds.drop_vars(existing_vars)
layer_depth = self.grid.ds[f"layer_depth_{point}"].isel(
**bdry_coords[point][direction]
)
interface_depth = self.grid.ds[f"interface_depth_{point}"].isel(
**bdry_coords[point][direction]
)

return ds
return layer_depth, interface_depth

def _add_global_metadata(self, ds=None):

Expand Down Expand Up @@ -486,8 +419,9 @@ def plot(
time : int, optional
The time index to plot. Default is 0.
layer_contours : bool, optional
Whether to include layer contours in the plot. This can help visualize the depth levels
of the field. Default is False.
If True, contour lines representing the boundaries between vertical layers will
be added to the plot. For clarity, the number of layer
contours displayed is limited to a maximum of 10. Default is False.
Returns
-------
Expand All @@ -513,6 +447,19 @@ def plot(
field = ds[varname].isel(bry_time=time).load()
title = field.long_name

if "s_rho" in field.dims:
if varname.startswith(("u_", "ubar_")):
point = "u"
elif varname.startswith(("v_", "vbar_")):
point = "v"
else:
point = "rho"
direction = varname.split("_")[-1]

layer_depth, interface_depth = self._get_coordinates(direction, point)

field = field.assign_coords({"layer_depth": layer_depth})

# chose colorbar
if varname.startswith(("u", "v", "ubar", "vbar", "zeta")):
vmax = max(field.max().values, -field.min().values)
Expand All @@ -530,27 +477,6 @@ def plot(

if len(field.dims) == 2:
if layer_contours:
depths_to_check = [
"interface_depth_rho",
"interface_depth_u",
"interface_depth_v",
]
try:
interface_depth = next(
ds[depth_label]
for depth_label in ds.coords
if any(
depth_label.startswith(prefix) for prefix in depths_to_check
)
and (
set(ds[depth_label].dims) - {"s_w"}
== set(field.dims) - {"s_rho"}
)
)
except StopIteration:
raise ValueError(
f"None of the expected depths ({', '.join(depths_to_check)}) have dimensions matching field.dims"
)
# restrict number of layer_contours to 10 for the sake of plot clearity
nr_layers = len(interface_depth["s_w"])
selected_layers = np.linspace(
Expand Down
24 changes: 22 additions & 2 deletions roms_tools/setup/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,9 +584,14 @@ def __post_init__(self):
"ntides": self.dim_names["ntides"],
},
)
self.check_dataset(ds)

# Select relevant fields
ds = super().select_relevant_fields(ds)

# Make sure that latitude is ascending
ds = super().ensure_latitude_ascending(ds)

# Check whether the data covers the entire globe
object.__setattr__(self, "is_global", super().check_if_global(ds))

Expand Down Expand Up @@ -769,6 +774,8 @@ def add_time_info(self, ds: xr.Dataset) -> xr.Dataset:
ds = assign_dates_to_climatology(ds, time_dim)
# rename dimension
ds = ds.swap_dims({time_dim: "time"})
if time_dim in ds.variables:
ds = ds.drop_vars(time_dim)
# Update dimension names
updated_dim_names = self.dim_names.copy()
updated_dim_names["time"] = "time"
Expand Down Expand Up @@ -872,9 +879,9 @@ def post_process(self):
ds["depth"].attrs["long_name"] = "Depth"
ds["depth"].attrs["units"] = "m"
ds = ds.swap_dims({"z_t": "depth"})
if "z_t" in ds:
if "z_t" in ds.variables:
ds = ds.drop_vars("z_t")
if "z_t_150m" in ds:
if "z_t_150m" in ds.variables:
ds = ds.drop_vars("z_t_150m")
# update dataset
object.__setattr__(self, "ds", ds)
Expand Down Expand Up @@ -932,6 +939,19 @@ class CESMBGCSurfaceForcingDataset(CESMDataset):

climatology: Optional[bool] = False

def post_process(self):
"""
Perform post-processing on the dataset to remove specific variables.
This method checks if the variable "z_t" exists in the dataset. If it does,
the variable is removed from the dataset. The modified dataset is then
reassigned to the `ds` attribute of the object.
"""

if "z_t" in self.ds.variables:
ds = self.ds.drop_vars("z_t")
object.__setattr__(self, "ds", ds)


@dataclass(frozen=True, kw_only=True)
class ERA5Dataset(Dataset):
Expand Down
Loading

0 comments on commit caaba66

Please sign in to comment.