diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml index dc519e5b..dadac50d 100644 --- a/.github/workflows/pre-commit.yml +++ b/.github/workflows/pre-commit.yml @@ -1,4 +1,4 @@ -name: lint +name: Linting on: # trigger on pushes to any branch, but not main diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml new file mode 100644 index 00000000..71bff3d3 --- /dev/null +++ b/.github/workflows/run_tests.yml @@ -0,0 +1,45 @@ +name: Unit Tests + +on: + # trigger on pushes to any branch, but not main + push: + branches-ignore: + - main + # and also on PRs to main + pull_request: + branches: + - main + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + pip install torch-geometric>=2.5.2 + - name: Load cache data + uses: actions/cache/restore@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + restore-keys: | + ${{ runner.os }}-meps-reduced-example-data-v0.1.0 + - name: Test with pytest + run: | + pytest -v -s + - name: Save cache data + uses: actions/cache/save@v4 + with: + path: data + key: ${{ runner.os }}-meps-reduced-example-data-v0.1.0 diff --git a/CHANGELOG.md b/CHANGELOG.md index fd836c7a..3544b299 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [unreleased](https://github.com/joeloskarsson/neural-lam/compare/v0.1.0...HEAD) ### Added +- Added tests for loading dataset, creating graph, and training model based on reduced MEPS dataset stored on AWS S3, along with automatic running of tests on push/PR to GitHub. Added caching of test data tp speed up running tests. + [/#38](https://github.com/mllam/neural-lam/pull/38) + @SimonKamuk - Replaced `constants.py` with `data_config.yaml` for data configuration management [\#31](https://github.com/joeloskarsson/neural-lam/pull/31) diff --git a/README.md b/README.md index ba0bb3fe..1bdc6602 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,6 @@ +![Linting](https://github.com/mllam/neural-lam/actions/workflows/pre-commit.yml/badge.svg) +![Automatic tests](https://github.com/mllam/neural-lam/actions/workflows/run_tests.yml/badge.svg) +

@@ -279,6 +282,8 @@ pre-commit run --all-files ``` from the root directory of the repository. +Furthermore, all tests in the ```tests``` directory will be run upon pushing changes by a github action. Failure in any of the tests will also reject the push/PR. + # Contact If you are interested in machine learning models for LAM, have questions about our implementation or ideas for extending it, feel free to get in touch. You can open a github issue on this page, or (if more suitable) send an email to [joel.oskarsson@liu.se](mailto:joel.oskarsson@liu.se). diff --git a/create_mesh.py b/create_mesh.py index f04b4d4b..41557a97 100644 --- a/create_mesh.py +++ b/create_mesh.py @@ -153,7 +153,7 @@ def prepend_node_index(graph, new_index): return networkx.relabel_nodes(graph, to_mapping, copy=True) -def main(): +def main(input_args=None): parser = ArgumentParser(description="Graph generation arguments") parser.add_argument( "--data_config", @@ -186,7 +186,7 @@ def main(): default=0, help="Generate hierarchical mesh graph (default: 0, no)", ) - args = parser.parse_args() + args = parser.parse_args(input_args) # Load grid positions config_loader = config.Config.from_file(args.data_config) diff --git a/docs/notebooks/create_reduced_meps_dataset.ipynb b/docs/notebooks/create_reduced_meps_dataset.ipynb new file mode 100644 index 00000000..daba23c4 --- /dev/null +++ b/docs/notebooks/create_reduced_meps_dataset.ipynb @@ -0,0 +1,239 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Creating meps_example_reduced\n", + "This notebook outlines how the small-size test dataset ```meps_example_reduced``` was created based on the slightly larger dataset ```meps_example```. The zipped up datasets are 263 MB and 2.6 GB, respectively. See [README.md](../../README.md) for info on how to download ```meps_example```.\n", + "\n", + "The dataset was reduced in size by reducing the number of grid points and variables.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# Standard library\n", + "import os\n", + "\n", + "# Third-party\n", + "import numpy as np\n", + "import torch" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of grid points was reduced to 1/4 by halving the number of coordinates in both the x and y direction. This was done by removing a quarter of the grid points along each outer edge, so the center grid points would stay centered in the new set.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load existing grid\n", + "grid_xy = np.load('data/meps_example/static/nwp_xy.npy')\n", + "# Get slices in each dimension by cutting off a quarter along each edge\n", + "num_x, num_y = grid_xy.shape[1:]\n", + "x_slice = slice(num_x//4, 3*num_x//4)\n", + "y_slice = slice(num_y//4, 3*num_y//4)\n", + "# Index and save reduced grid\n", + "grid_xy_reduced = grid_xy[:, x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/nwp_xy.npy', grid_xy_reduced)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "This cut out the border, so a new perimeter of 10 grid points was established as border (10 was also the border size in the original \"meps_example\").\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "# Outer 10 grid points are border\n", + "old_border_mask = np.load('data/meps_example/static/border_mask.npy')\n", + "assert np.all(old_border_mask[10:-10, 10:-10] == False)\n", + "assert np.all(old_border_mask[:10, :] == True)\n", + "assert np.all(old_border_mask[:, :10] == True)\n", + "assert np.all(old_border_mask[-10:,:] == True)\n", + "assert np.all(old_border_mask[:,-10:] == True)\n", + "\n", + "# Create new array with False everywhere but the outer 10 grid points\n", + "border_mask = np.zeros_like(grid_xy_reduced[0,:,:], dtype=bool)\n", + "border_mask[:10] = True\n", + "border_mask[:,:10] = True\n", + "border_mask[-10:] = True\n", + "border_mask[:,-10:] = True\n", + "np.save('data/meps_example_reduced/static/border_mask.npy', border_mask)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "A few other files also needed to be copied using only the new reduced grid" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Load surface_geopotential.npy, index only values from the reduced grid, and save to new file\n", + "surface_geopotential = np.load('data/meps_example/static/surface_geopotential.npy')\n", + "surface_geopotential_reduced = surface_geopotential[x_slice, y_slice]\n", + "np.save('data/meps_example_reduced/static/surface_geopotential.npy', surface_geopotential_reduced)\n", + "\n", + "# Load pytorch file grid_features.pt\n", + "grid_features = torch.load('data/meps_example/static/grid_features.pt')\n", + "# Index only values from the reduced grid. \n", + "# First reshape from (num_grid_points_total, 4) to (num_grid_points_x, num_grid_points_y, 4), \n", + "# then index, then reshape back to new total number of grid points\n", + "print(grid_features.shape)\n", + "grid_features_new = grid_features.reshape(num_x, num_y, 4)[x_slice,y_slice,:].reshape((-1, 4))\n", + "# Save to new file\n", + "torch.save(grid_features_new, 'data/meps_example_reduced/static/grid_features.pt')\n", + "\n", + "# flux_stats.pt is just a vector of length 2, so the grid shape and variable changes does not change this file\n", + "torch.save(torch.load('data/meps_example/static/flux_stats.pt'), 'data/meps_example_reduced/static/flux_stats.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "\n", + "The number of variables was reduced by truncating the variable list to the first 8." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "num_vars = 8\n", + "\n", + "# Load parameter_weights.npy, truncate to first 8 variables, and save to new file\n", + "parameter_weights = np.load('data/meps_example/static/parameter_weights.npy')\n", + "parameter_weights_reduced = parameter_weights[:num_vars]\n", + "np.save('data/meps_example_reduced/static/parameter_weights.npy', parameter_weights_reduced)\n", + "\n", + "# Do the same for following 4 pytorch files\n", + "for file in ['diff_mean', 'diff_std', 'parameter_mean', 'parameter_std']:\n", + " old_file = torch.load(f'data/meps_example/static/{file}.pt')\n", + " new_file = old_file[:num_vars]\n", + " torch.save(new_file, f'data/meps_example_reduced/static/{file}.pt')" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly the files in each of the directories train, test, and val have to be reduced. The folders all have the same structure with files of the following types:\n", + "```\n", + "nwp_YYYYMMDDHH_mbrXXX.npy\n", + "wtr_YYYYMMDDHH.npy\n", + "nwp_toa_downwelling_shortwave_flux_YYYYMMDDHH.npy\n", + "```\n", + "with ```YYYYMMDDHH``` being some date with hours, and ```XXX``` being some 3-digit integer.\n", + "\n", + "The first type of file has x and y in dimensions 1 and 2, and variable index in dimension 3. Dimension 0 is unchanged.\n", + "The second type has has x and y in dimensions 1 and 2. Dimension 0 is unchanged.\n", + "The last type has just x and y as the only 2 dimensions.\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(65, 268, 238, 18)\n", + "(65, 268, 238)\n" + ] + } + ], + "source": [ + "print(np.load('data/meps_example/samples/train/nwp_2022040100_mbr000.npy').shape)\n", + "print(np.load('data/meps_example/samples/train/nwp_toa_downwelling_shortwave_flux_2022040112.npy').shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The following loop goes through each file in each sample folder and indexes them according to the dimensions given by the file name." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for sample in ['train', 'test', 'val']:\n", + " files = os.listdir(f'data/meps_example/samples/{sample}')\n", + "\n", + " for f in files:\n", + " data = np.load(f'data/meps_example/samples/{sample}/{f}')\n", + " if 'mbr' in f:\n", + " data = data[:,x_slice,y_slice,:num_vars]\n", + " elif 'wtr' in f:\n", + " data = data[x_slice, y_slice]\n", + " else:\n", + " data = data[:,x_slice,y_slice]\n", + " np.save(f'data/meps_example_reduced/samples/{sample}/{f}', data)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Lastly, the file ```data_config.yaml``` is modified manually by truncating the variable units, long and short names, and setting the new grid shape. Also the unit descriptions containing ```^``` was automatically parsed using latex, and to avoid having to install latex in the GitHub CI/CD pipeline, this was changed to ```**```. \n", + "\n", + "This new config file was placed in ```data/meps_example_reduced```, and that directory was then zipped and placed in a European Weather Cloud S3 bucket." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/neural_lam/utils.py b/neural_lam/utils.py index 836b04ed..59a529eb 100644 --- a/neural_lam/utils.py +++ b/neural_lam/utils.py @@ -1,5 +1,6 @@ # Standard library import os +import shutil # Third-party import numpy as np @@ -250,7 +251,11 @@ def fractional_plot_bundle(fraction): Get the tueplots bundle, but with figure width as a fraction of the page width. """ - bundle = bundles.neurips2023(usetex=True, family="serif") + # If latex is not available, some visualizations might not render correctly, + # but will at least not raise an error. + # Alternatively, use unicode raised numbers. + usetex = True if shutil.which("latex") else False + bundle = bundles.neurips2023(usetex=usetex, family="serif") bundle.update(figsizes.neurips2023()) original_figsize = bundle["figure.figsize"] bundle["figure.figsize"] = ( diff --git a/requirements.txt b/requirements.txt index f381d54f..9309eea4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ plotly>=5.15.0 # for dev pre-commit>=2.15.0 +pytest>=8.1.1 +pooch>=1.8.1 diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_mllam_dataset.py b/tests/test_mllam_dataset.py new file mode 100644 index 00000000..f91170c9 --- /dev/null +++ b/tests/test_mllam_dataset.py @@ -0,0 +1,138 @@ +# Standard library +import os + +# Third-party +import pooch + +# First-party +from create_mesh import main as create_mesh +from neural_lam.config import Config +from neural_lam.utils import load_static_data +from neural_lam.weather_dataset import WeatherDataset +from train_model import main as train_model + +# Disable weights and biases to avoid unnecessary logging +# and to avoid having to deal with authentication +os.environ["WANDB_DISABLED"] = "true" + +# Initializing variables for the s3 client +S3_BUCKET_NAME = "mllam-testdata" +S3_ENDPOINT_URL = "https://object-store.os-api.cci1.ecmwf.int" +S3_FILE_PATH = "neural-lam/npy/meps_example_reduced.v0.1.0.zip" +S3_FULL_PATH = "/".join([S3_ENDPOINT_URL, S3_BUCKET_NAME, S3_FILE_PATH]) +TEST_DATA_KNOWN_HASH = ( + "98c7a2f442922de40c6891fe3e5d190346889d6e0e97550170a82a7ce58a72b7" +) + + +def test_retrieve_data_ewc(): + # Download and unzip test data into data/meps_example_reduced + pooch.retrieve( + url=S3_FULL_PATH, + known_hash=TEST_DATA_KNOWN_HASH, + processor=pooch.Unzip(extract_dir=""), + path="data", + fname="meps_example_reduced.zip", + ) + + +def test_load_reduced_meps_dataset(): + # The data_config.yaml file is downloaded and extracted in + # test_retrieve_data_ewc together with the dataset itself + data_config_file = "data/meps_example_reduced/data_config.yaml" + dataset_name = "meps_example_reduced" + + dataset = WeatherDataset(dataset_name="meps_example_reduced") + config = Config.from_file(data_config_file) + + var_names = config.values["dataset"]["var_names"] + var_units = config.values["dataset"]["var_units"] + var_longnames = config.values["dataset"]["var_longnames"] + + assert len(var_names) == len(var_longnames) + assert len(var_names) == len(var_units) + + # in future the number of grid static features + # will be provided by the Dataset class itself + n_grid_static_features = 4 + # Hardcoded in model + n_input_steps = 2 + + n_forcing_features = config.values["dataset"]["num_forcing_features"] + n_state_features = len(var_names) + n_prediction_timesteps = dataset.sample_length - n_input_steps + + nx, ny = config.values["grid_shape_state"] + n_grid = nx * ny + + # check that the dataset is not empty + assert len(dataset) > 0 + + # get the first item + init_states, target_states, forcing = dataset[0] + + # check that the shapes of the tensors are correct + assert init_states.shape == (n_input_steps, n_grid, n_state_features) + assert target_states.shape == ( + n_prediction_timesteps, + n_grid, + n_state_features, + ) + assert forcing.shape == ( + n_prediction_timesteps, + n_grid, + n_forcing_features, + ) + + static_data = load_static_data(dataset_name=dataset_name) + + required_props = { + "border_mask", + "grid_static_features", + "step_diff_mean", + "step_diff_std", + "data_mean", + "data_std", + "param_weights", + } + + # check the sizes of the props + assert static_data["border_mask"].shape == (n_grid, 1) + assert static_data["grid_static_features"].shape == ( + n_grid, + n_grid_static_features, + ) + assert static_data["step_diff_mean"].shape == (n_state_features,) + assert static_data["step_diff_std"].shape == (n_state_features,) + assert static_data["data_mean"].shape == (n_state_features,) + assert static_data["data_std"].shape == (n_state_features,) + assert static_data["param_weights"].shape == (n_state_features,) + + assert set(static_data.keys()) == required_props + + +def test_create_graph_reduced_meps_dataset(): + args = [ + "--graph=hierarchical", + "--hierarchical=1", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--levels=2", + ] + create_mesh(args) + + +def test_train_model_reduced_meps_dataset(): + args = [ + "--model=hi_lam", + "--data_config=data/meps_example_reduced/data_config.yaml", + "--n_workers=4", + "--epochs=1", + "--graph=hierarchical", + "--hidden_dim=16", + "--hidden_layers=1", + "--processor_layers=1", + "--ar_steps=1", + "--eval=val", + "--n_example_pred=0", + ] + train_model(args) diff --git a/train_model.py b/train_model.py index cbd787f0..03863275 100644 --- a/train_model.py +++ b/train_model.py @@ -23,7 +23,7 @@ } -def main(): +def main(input_args=None): """ Main function for training and evaluating models """ @@ -208,11 +208,10 @@ def main(): help="""JSON string with variable-IDs and lead times to log watched metrics (e.g. '{"1": [1, 2], "3": [3, 4]}')""", ) - args = parser.parse_args() + args = parser.parse_args(input_args) args.var_leads_metrics_watch = { int(k): v for k, v in json.loads(args.var_leads_metrics_watch).items() } - config_loader = config.Config.from_file(args.data_config) # Asserts for arguments