Skip to content

Commit

Permalink
Don't separate temporally into multiple files by default (#190)
Browse files Browse the repository at this point in the history
* Temporal grouping is optional with default false

* Adapt tests to non-default grouping

* Separate plot and save methods tests

* Don't order parameters alphabetically in yaml files

* Linting

* Trigger dask progress bar internally

* Improve notebook text
  • Loading branch information
NoraLoose authored Nov 5, 2024
1 parent 2e62bc6 commit b5cc382
Show file tree
Hide file tree
Showing 14 changed files with 2,848 additions and 3,533 deletions.
3,338 changes: 1,839 additions & 1,499 deletions docs/boundary_forcing.ipynb

Large diffs are not rendered by default.

72 changes: 36 additions & 36 deletions docs/grid.ipynb

Large diffs are not rendered by default.

511 changes: 252 additions & 259 deletions docs/initial_conditions.ipynb

Large diffs are not rendered by default.

1,663 changes: 257 additions & 1,406 deletions docs/surface_forcing.ipynb

Large diffs are not rendered by default.

490 changes: 249 additions & 241 deletions docs/tides.ipynb

Large diffs are not rendered by default.

54 changes: 38 additions & 16 deletions roms_tools/setup/boundary_forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -649,7 +649,14 @@ def plot(
if var_name not in self.ds:
raise ValueError(f"Variable '{var_name}' is not found in dataset.")

field = self.ds[var_name].isel(bry_time=time).load()
field = self.ds[var_name].isel(bry_time=time)

if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
field = field.load()

title = field.long_name

if "s_rho" in field.dims:
Expand Down Expand Up @@ -699,24 +706,26 @@ def plot(
_line_plot(field, title=title)

def save(
self, filepath: Union[str, Path], np_eta: int = None, np_xi: int = None
self,
filepath: Union[str, Path],
np_eta: int = None,
np_xi: int = None,
group: bool = False,
) -> None:
"""Save the boundary forcing fields to netCDF4 files.
This method saves the dataset by grouping it into subsets based on the data frequency. The subsets are then written
to one or more netCDF4 files. The filenames of the output files reflect the temporal coverage of the data.
"""Save the boundary forcing fields to one or more netCDF4 files.
There are two modes of saving the dataset:
This method saves the dataset either as a single file or as multiple files depending on the partitioning and grouping options.
The dataset can be saved in two modes:
1. **Single File Mode (default)**:
1. **Single File Mode (default)**:
- If both `np_eta` and `np_xi` are `None`, the entire dataset is saved as a single netCDF4 file.
- The file is named based on the `filepath`, with `.nc` automatically appended.
If both `np_eta` and `np_xi` are `None`, the entire dataset, divided by temporal subsets, is saved as a single netCDF4 file
with the base filename specified by `filepath.nc`.
2. **Partitioned Mode**:
- If either `np_eta` or `np_xi` is specified, the dataset is partitioned into spatial tiles along the `eta` and `xi` axes.
- Each tile is saved as a separate netCDF4 file, and filenames are modified with an index (e.g., `"filepath_YYYYMM.0.nc"`, `"filepath_YYYYMM.1.nc"`).
2. **Partitioned Mode**:
- If either `np_eta` or `np_xi` is specified, the dataset is divided into spatial tiles along the eta-axis and xi-axis.
- Each spatial tile is saved as a separate netCDF4 file.
Additionally, if `group` is set to `True`, the dataset is first grouped into temporal subsets, resulting in multiple grouped files before partitioning and saving.
Parameters
----------
Expand All @@ -728,6 +737,8 @@ def save(
The number of partitions along the `eta` direction. If `None`, no spatial partitioning is performed.
np_xi : int, optional
The number of partitions along the `xi` direction. If `None`, no spatial partitioning is performed.
group: bool, optional
If `True`, groups the dataset into multiple files based on temporal data frequency. Defaults to `False`.
Returns
-------
Expand All @@ -742,7 +753,18 @@ def save(
if filepath.suffix == ".nc":
filepath = filepath.with_suffix("")

dataset_list, output_filenames = group_dataset(self.ds.load(), str(filepath))
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
self.ds.load()

if group:
dataset_list, output_filenames = group_dataset(self.ds, str(filepath))
else:
dataset_list = [self.ds]
output_filenames = [str(filepath)]

saved_filenames = save_datasets(
dataset_list, output_filenames, np_eta=np_eta, np_xi=np_xi
)
Expand Down Expand Up @@ -796,7 +818,7 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
# Write header
file.write(header)
# Write YAML data
yaml.dump(yaml_data, file, default_flow_style=False)
yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)

@classmethod
def from_yaml(
Expand Down
2 changes: 1 addition & 1 deletion roms_tools/setup/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,7 +679,7 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
# Write header
file.write(header)
# Write YAML data
yaml.dump(yaml_data, file, default_flow_style=False)
yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)

@classmethod
def from_yaml(cls, filepath: Union[str, Path]) -> "Grid":
Expand Down
24 changes: 19 additions & 5 deletions roms_tools/setup/initial_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,12 @@ def plot(
):
raise ValueError("For 2D fields, specify either eta or xi, not both.")

self.ds[var_name].load()
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
self.ds[var_name].load()

field = self.ds[var_name].squeeze()

if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
Expand Down Expand Up @@ -681,7 +686,13 @@ def save(
if filepath.suffix == ".nc":
filepath = filepath.with_suffix("")

dataset_list = [self.ds.load()]
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
self.ds.load()

dataset_list = [self.ds]
output_filenames = [str(filepath)]

saved_filenames = save_datasets(
Expand Down Expand Up @@ -719,15 +730,18 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:

initial_conditions_data = {
"InitialConditions": {
"source": self.source,
"ini_time": self.ini_time.isoformat(),
"model_reference_date": self.model_reference_date.isoformat(),
"source": self.source,
}
}
# Include bgc_source if it's not None
if self.bgc_source is not None:
initial_conditions_data["InitialConditions"]["bgc_source"] = self.bgc_source

initial_conditions_data["InitialConditions"][
"model_reference_date"
] = self.model_reference_date.isoformat()

yaml_data = {
**grid_yaml_data,
**initial_conditions_data,
Expand All @@ -737,7 +751,7 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
# Write header
file.write(header)
# Write YAML data
yaml.dump(yaml_data, file, default_flow_style=False)
yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)

@classmethod
def from_yaml(
Expand Down
53 changes: 37 additions & 16 deletions roms_tools/setup/surface_forcing.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,13 @@ def plot(self, var_name, time=0) -> None:
if var_name not in self.ds:
raise ValueError(f"Variable '{var_name}' is not found in dataset.")

field = self.ds[var_name].isel(time=time).load()
field = self.ds[var_name].isel(time=time)
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
field = field.load()

title = field.long_name

# assign lat / lon
Expand Down Expand Up @@ -502,24 +508,26 @@ def plot(self, var_name, time=0) -> None:
)

def save(
self, filepath: Union[str, Path], np_eta: int = None, np_xi: int = None
self,
filepath: Union[str, Path],
np_eta: int = None,
np_xi: int = None,
group: bool = False,
) -> None:
"""Save the surface forcing fields to netCDF4 files.
This method saves the dataset by grouping it into subsets based on the data frequency. The subsets are then written
to one or more netCDF4 files. The filenames of the output files reflect the temporal coverage of the data.
There are two modes of saving the dataset:
"""Save the surface forcing fields to one or more netCDF4 files.
1. **Single File Mode (default)**:
This method saves the dataset either as a single file or as multiple files depending on the partitioning and grouping options.
The dataset can be saved in two modes:
If both `np_eta` and `np_xi` are `None`, the entire dataset, divided by temporal subsets, is saved as a single netCDF4 file
with the base filename specified by `filepath.nc`.
1. **Single File Mode (default)**:
- If both `np_eta` and `np_xi` are `None`, the entire dataset is saved as a single netCDF4 file.
- The file is named based on the `filepath`, with `.nc` automatically appended.
2. **Partitioned Mode**:
2. **Partitioned Mode**:
- If either `np_eta` or `np_xi` is specified, the dataset is partitioned into spatial tiles along the `eta` and `xi` axes.
- Each tile is saved as a separate netCDF4 file, and filenames are modified with an index (e.g., `"filepath_YYYYMM.0.nc"`, `"filepath_YYYYMM.1.nc"`).
- If either `np_eta` or `np_xi` is specified, the dataset is divided into spatial tiles along the eta-axis and xi-axis.
- Each spatial tile is saved as a separate netCDF4 file.
Additionally, if `group` is set to `True`, the dataset is first grouped into temporal subsets, resulting in multiple grouped files before partitioning and saving.
Parameters
----------
Expand All @@ -531,6 +539,8 @@ def save(
The number of partitions along the `eta` direction. If `None`, no spatial partitioning is performed.
np_xi : int, optional
The number of partitions along the `xi` direction. If `None`, no spatial partitioning is performed.
group: bool, optional
If `True`, groups the dataset into multiple files based on temporal data frequency. Defaults to `False`.
Returns
-------
Expand All @@ -545,7 +555,18 @@ def save(
if filepath.suffix == ".nc":
filepath = filepath.with_suffix("")

dataset_list, output_filenames = group_dataset(self.ds.load(), str(filepath))
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
self.ds.load()

if group:
dataset_list, output_filenames = group_dataset(self.ds, str(filepath))
else:
dataset_list = [self.ds]
output_filenames = [str(filepath)]

saved_filenames = save_datasets(
dataset_list, output_filenames, np_eta=np_eta, np_xi=np_xi
)
Expand Down Expand Up @@ -603,7 +624,7 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
# Write header
file.write(header)
# Write YAML data
yaml.dump(yaml_data, file, default_flow_style=False)
yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)

@classmethod
def from_yaml(
Expand Down
21 changes: 17 additions & 4 deletions roms_tools/setup/tides.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,14 @@ def plot(self, var_name, ntides=0) -> None:
>>> tidal_forcing.plot("ssh_Re", nc=0)
"""

field = self.ds[var_name].isel(ntides=ntides).compute()
field = self.ds[var_name].isel(ntides=ntides)

if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
field = field.load()

if all(dim in field.dims for dim in ["eta_rho", "xi_rho"]):
field = field.where(self.grid.ds.mask_rho)
field = field.assign_coords(
Expand Down Expand Up @@ -378,7 +385,13 @@ def save(
if filepath.suffix == ".nc":
filepath = filepath.with_suffix("")

dataset_list = [self.ds.load()]
if self.use_dask:
from dask.diagnostics import ProgressBar

with ProgressBar():
self.ds.load()

dataset_list = [self.ds]
output_filenames = [str(filepath)]

saved_filenames = save_datasets(
Expand Down Expand Up @@ -419,8 +432,8 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
"TidalForcing": {
"source": self.source,
"ntides": self.ntides,
"model_reference_date": self.model_reference_date.isoformat(),
"allan_factor": self.allan_factor,
"model_reference_date": self.model_reference_date.isoformat(),
}
}

Expand All @@ -431,7 +444,7 @@ def to_yaml(self, filepath: Union[str, Path]) -> None:
# Write header
file.write(header)
# Write YAML data
yaml.dump(yaml_data, file, default_flow_style=False)
yaml.dump(yaml_data, file, default_flow_style=False, sort_keys=False)

@classmethod
def from_yaml(
Expand Down
Loading

0 comments on commit b5cc382

Please sign in to comment.