diff --git a/.github/workflows/pytest.yaml b/.github/workflows/pytest.yaml new file mode 100644 index 0000000..eb2eaba --- /dev/null +++ b/.github/workflows/pytest.yaml @@ -0,0 +1,27 @@ +name: Pytest + +on: [pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + + strategy: + matrix: + python-version: ["3.11"] + + steps: + - uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies and test + run: | + python -m pip install --upgrade pip + pip install uv + uv venv + source .venv/bin/activate + uv pip install -e . + pytest diff --git a/README.md b/README.md index 0f07853..9897c36 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,5 @@ # Hydro Data (Dev) +[![Tests](https://github.com/mhpi/hydro_data_dev/actions/workflows/pytest.yml/badge.svg)](https://github.com/hydro_data_dev/actions/actions/workflows/tests.yml) [![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff) [![Python](https://img.shields.io/badge/python-3.10%20%7C%203.11-blue)]() @@ -16,7 +17,7 @@ pip install . ``` ### Developer Mode Installation -The same clone as above, but use hatch's developer mode setting +The same clone as above, but uses hatch's developer mode setting ```shell pip install -e . ``` diff --git a/example/camels/example.ipynb b/example/camels/example.ipynb new file mode 100644 index 0000000..002f4d5 --- /dev/null +++ b/example/camels/example.ipynb @@ -0,0 +1,523 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CAMELS example:\n", + "### See below for an example of how to format data using hydro_data_dev for camels" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import hydro_data_dev as hdd" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Record(bucket='mhpi-spatial', dataset='camels', train_date_range=(datetime.datetime(1999, 10, 1, 0, 0), datetime.datetime(2008, 9, 30, 0, 0)), val_date_range=(datetime.datetime(2008, 10, 1, 0, 0), datetime.datetime(2014, 9, 30, 0, 0)), test_date_range=(datetime.datetime(1989, 10, 1, 0, 0), datetime.datetime(1999, 9, 30, 0, 0)), time_series_variables=['dayl_daymet', 'prcp_daymet', 'srad_daymet', 'tmean_daymet', 'vp_daymet'], target_variables=['runoff'], static_variables=['p_mean', 'pet_mean', 'p_seasonality', 'frac_snow', 'aridity', 'high_prec_freq', 'high_prec_dur', 'low_prec_freq', 'low_prec_dur', 'elev_mean', 'slope_mean', 'area_gages2', 'frac_forest', 'lai_max', 'lai_diff', 'gvf_max', 'gvf_diff', 'dom_land_cover_frac', 'dom_land_cover', 'root_depth_50', 'soil_depth_pelletier', 'soil_depth_statsgo', 'soil_porosity', 'soil_conductivity', 'max_water_content', 'sand_frac', 'silt_frac', 'clay_frac', 'geol_1st_class', 'glim_1st_class_frac', 'geol_2nd_class', 'glim_2nd_class_frac', 'carbonate_rocks_frac', 'geol_porostiy', 'geol_permeability'], station_ids=PosixPath('531_basin_list.txt'))" + ] + }, + "execution_count": 2, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "record = hdd.create_record(\"./example.yaml\")\n", + "record" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset> Size: 326MB\n",
+       "Dimensions:               (station_ids: 531, time: 12784)\n",
+       "Coordinates:\n",
+       "  * station_ids           (station_ids) int64 4kB 1022500 1031500 ... 11532500\n",
+       "  * time                  (time) datetime64[ns] 102kB 1980-01-01 ... 2014-12-31\n",
+       "    lon                   (station_ids) float64 4kB ...\n",
+       "    lat                   (station_ids) float64 4kB ...\n",
+       "Data variables: (12/41)\n",
+       "    dayl_daymet           (station_ids, time) float64 54MB ...\n",
+       "    prcp_daymet           (station_ids, time) float64 54MB ...\n",
+       "    srad_daymet           (station_ids, time) float64 54MB ...\n",
+       "    tmean_daymet          (station_ids, time) float64 54MB ...\n",
+       "    vp_daymet             (station_ids, time) float64 54MB ...\n",
+       "    p_mean                (station_ids) float64 4kB ...\n",
+       "    ...                    ...\n",
+       "    geol_2nd_class        (station_ids) float64 4kB ...\n",
+       "    glim_2nd_class_frac   (station_ids) float64 4kB ...\n",
+       "    carbonate_rocks_frac  (station_ids) float64 4kB ...\n",
+       "    geol_porostiy         (station_ids) float64 4kB ...\n",
+       "    geol_permeability     (station_ids) float64 4kB ...\n",
+       "    runoff                (station_ids, time) float64 54MB ...\n",
+       "Attributes:\n",
+       "    metadata:                     {"short_name": "CAMELS", "long_name": "Catc...\n",
+       "    static_variables:             ['p_mean', 'pet_mean', 'p_seasonality', 'fr...\n",
+       "    static_variables_units:       []\n",
+       "    time_series_variables:        ['dayl_nldas', 'prcp_nldas', 'srad_nldas', ...\n",
+       "    time_series_variables_units:  ['s', 'mm/day', 'W/m^2', 'mm', 'degC', 'deg...
" + ], + "text/plain": [ + " Size: 326MB\n", + "Dimensions: (station_ids: 531, time: 12784)\n", + "Coordinates:\n", + " * station_ids (station_ids) int64 4kB 1022500 1031500 ... 11532500\n", + " * time (time) datetime64[ns] 102kB 1980-01-01 ... 2014-12-31\n", + " lon (station_ids) float64 4kB ...\n", + " lat (station_ids) float64 4kB ...\n", + "Data variables: (12/41)\n", + " dayl_daymet (station_ids, time) float64 54MB ...\n", + " prcp_daymet (station_ids, time) float64 54MB ...\n", + " srad_daymet (station_ids, time) float64 54MB ...\n", + " tmean_daymet (station_ids, time) float64 54MB ...\n", + " vp_daymet (station_ids, time) float64 54MB ...\n", + " p_mean (station_ids) float64 4kB ...\n", + " ... ...\n", + " geol_2nd_class (station_ids) float64 4kB ...\n", + " glim_2nd_class_frac (station_ids) float64 4kB ...\n", + " carbonate_rocks_frac (station_ids) float64 4kB ...\n", + " geol_porostiy (station_ids) float64 4kB ...\n", + " geol_permeability (station_ids) float64 4kB ...\n", + " runoff (station_ids, time) float64 54MB ...\n", + "Attributes:\n", + " metadata: {\"short_name\": \"CAMELS\", \"long_name\": \"Catc...\n", + " static_variables: ['p_mean', 'pet_mean', 'p_seasonality', 'fr...\n", + " static_variables_units: []\n", + " time_series_variables: ['dayl_nldas', 'prcp_nldas', 'srad_nldas', ...\n", + " time_series_variables_units: ['s', 'mm/day', 'W/m^2', 'mm', 'degC', 'deg..." + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ds = hdd.fetch_data(record=record)\n", + "ds" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "hydro_data", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.10" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/example/camels/example.yaml b/example/camels/example.yaml index 22c96f6..66913f4 100644 --- a/example/camels/example.yaml +++ b/example/camels/example.yaml @@ -1,15 +1,15 @@ # config.yaml -bucket: "hydro-data-dev" +bucket: "mhpi-spatial" dataset: "camels" # Date ranges for train/val/test splits -train_date_list: +train_date_range: - "1999-10-01" - "2008-09-30" -val_date_list: +val_date_range: - "2008-10-01" - "2014-09-30" -test_date_list: +test_date_range: - "1989-10-01" - "1999-09-30" diff --git a/example/example.ipynb b/example/example.ipynb deleted file mode 100644 index 5d58488..0000000 --- a/example/example.ipynb +++ /dev/null @@ -1,503 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [], - "source": [ - "import hydro_data_dev as hdd" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset> Size: 1GB\n",
-       "Dimensions:               (station_ids: 671, time: 12784)\n",
-       "Coordinates:\n",
-       "  * station_ids           (station_ids) int64 5kB 1013500 1022500 ... 14400000\n",
-       "  * time                  (time) datetime64[ns] 102kB 1980-01-01 ... 2014-12-31\n",
-       "    lat                   (station_ids) float64 5kB ...\n",
-       "    lon                   (station_ids) float64 5kB ...\n",
-       "Data variables: (12/52)\n",
-       "    p_mean                (station_ids) float64 5kB ...\n",
-       "    soil_depth_statsgo    (station_ids) float64 5kB ...\n",
-       "    lai_max               (station_ids) float64 5kB ...\n",
-       "    tmax_daymet           (station_ids, time) float64 69MB ...\n",
-       "    lai_diff              (station_ids) float64 5kB ...\n",
-       "    srad_nldas            (station_ids, time) float64 69MB ...\n",
-       "    ...                    ...\n",
-       "    silt_frac             (station_ids) float64 5kB ...\n",
-       "    glim_1st_class_frac   (station_ids) float64 5kB ...\n",
-       "    soil_depth_pelletier  (station_ids) float64 5kB ...\n",
-       "    carbonate_rocks_frac  (station_ids) float64 5kB ...\n",
-       "    dom_land_cover_frac   (station_ids) float64 5kB ...\n",
-       "    geol_1st_class        (station_ids) float64 5kB ...\n",
-       "Attributes:\n",
-       "    metadata:                     {"short_name": "CAMELS", "long_name": "Catc...\n",
-       "    static_variables:             ['p_mean', 'pet_mean', 'p_seasonality', 'fr...\n",
-       "    static_variables_units:       []\n",
-       "    time_series_variables:        ['dayl_nldas', 'prcp_nldas', 'srad_nldas', ...\n",
-       "    time_series_variables_units:  ['s', 'mm/day', 'W/m^2', 'mm', 'degC', 'deg...
" - ], - "text/plain": [ - " Size: 1GB\n", - "Dimensions: (station_ids: 671, time: 12784)\n", - "Coordinates:\n", - " * station_ids (station_ids) int64 5kB 1013500 1022500 ... 14400000\n", - " * time (time) datetime64[ns] 102kB 1980-01-01 ... 2014-12-31\n", - " lat (station_ids) float64 5kB ...\n", - " lon (station_ids) float64 5kB ...\n", - "Data variables: (12/52)\n", - " p_mean (station_ids) float64 5kB ...\n", - " soil_depth_statsgo (station_ids) float64 5kB ...\n", - " lai_max (station_ids) float64 5kB ...\n", - " tmax_daymet (station_ids, time) float64 69MB ...\n", - " lai_diff (station_ids) float64 5kB ...\n", - " srad_nldas (station_ids, time) float64 69MB ...\n", - " ... ...\n", - " silt_frac (station_ids) float64 5kB ...\n", - " glim_1st_class_frac (station_ids) float64 5kB ...\n", - " soil_depth_pelletier (station_ids) float64 5kB ...\n", - " carbonate_rocks_frac (station_ids) float64 5kB ...\n", - " dom_land_cover_frac (station_ids) float64 5kB ...\n", - " geol_1st_class (station_ids) float64 5kB ...\n", - "Attributes:\n", - " metadata: {\"short_name\": \"CAMELS\", \"long_name\": \"Catc...\n", - " static_variables: ['p_mean', 'pet_mean', 'p_seasonality', 'fr...\n", - " static_variables_units: []\n", - " time_series_variables: ['dayl_nldas', 'prcp_nldas', 'srad_nldas', ...\n", - " time_series_variables_units: ['s', 'mm/day', 'W/m^2', 'mm', 'degC', 'deg..." - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "hdd.fetch_data(\n", - " bucket=\"mhpi-spatial\",\n", - " dataset=\"camels\",\n", - ")" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "kernelspec": { - "display_name": "hydro_data", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.11.10" - } - }, - "nbformat": 4, - "nbformat_minor": 2 -} diff --git a/pyproject.toml b/pyproject.toml index ff4b582..4357a25 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,6 @@ dependencies = [ "netcdf4", "numpy", "xarray", - "zarr", "pandas", "geopandas", "pydantic>=2.0", diff --git a/src/hydro_data_dev/__init__.py b/src/hydro_data_dev/__init__.py index 5b2248f..a641db5 100644 --- a/src/hydro_data_dev/__init__.py +++ b/src/hydro_data_dev/__init__.py @@ -1,10 +1,12 @@ from hydro_data_dev._version import version as __version__ -from hydro_data_dev.api.methods import fetch_data +from hydro_data_dev.api.methods import create_record, fetch_data, generate_scaler +from hydro_data_dev.core.Record import Record # in case setuptools scm screw up and find version to be 0.0.0 assert not __version__.startswith("0.0.0") __all__ = [ "fetch_data", - "calculate_statistics", + "create_record", + "generate_scaler", ] diff --git a/src/hydro_data_dev/api/methods.py b/src/hydro_data_dev/api/methods.py index ff3d580..4e00dfa 100644 --- a/src/hydro_data_dev/api/methods.py +++ b/src/hydro_data_dev/api/methods.py @@ -3,30 +3,36 @@ import icechunk import numpy as np import xarray as xr +import yaml -from hydro_data_dev.core.Config import Config +from hydro_data_dev.core.methods import _calc_stats +from hydro_data_dev.core.Record import Record -__all__ = ["fetch_data"] +__all__ = [ + "create_record", + "fetch_data", + "generate_scaler", +] -def fetch_data(bucket: str, dataset: str) -> xr.Dataset: - """A function to fetch data from an s3 bucket using icechunk +def fetch_data(record: Record) -> xr.Dataset: + """A function to fetch data from an s3 bucket using icechunk, then subsetting for the data you want + + Note: Jiangtao Liu's load_nc() function is used partially in this function Parameters ---------- - bucket : str - The name of the s3 bucket - dataset : str - The name of the dataset + record : Record + A Record object representing the configuration parameters Returns ------- xr.Dataset - The dataset fetched from the s3 bucket + The sliced dataset fetched from the s3 bucket """ storage_config = icechunk.StorageConfig.s3_anonymous( - bucket=bucket, - prefix=dataset, + bucket=record.bucket, + prefix=record.dataset, region=None, endpoint_url=None, ) @@ -37,6 +43,64 @@ def fetch_data(bucket: str, dataset: str) -> xr.Dataset: mode="r", ) except ValueError as e: - raise ValueError(f"Error opening the dataset: {dataset} from {bucket}") from e + raise ValueError(f"Error opening the dataset: {record.dataset} from {record.bucket}") from e ds = xr.open_zarr(store, zarr_format=3, consolidated=False) + + station_ids = record.station_ids.read_text().splitlines() + if station_ids is not None: + ds = ds.sel(station_ids=[int(station_ids) for station_ids in station_ids]) + + selected_vars = [] + if record.time_series_variables: + selected_vars.extend(record.time_series_variables) + if record.static_variables: + selected_vars.extend(record.static_variables) + if record.target_variables: + selected_vars.extend(record.target_variables) + if selected_vars: + ds = ds[selected_vars] return ds + + +def create_record( + record: str | Path, +) -> Record: + """A function to create a Record object from a yaml file + + Parameters + ---------- + record : str | Path + A string or Path object representing the path to the yaml file + + Returns + ------- + Record + A Record object representing the configuration parameters + """ + if isinstance(record, str): + record = Path(record) + record_dict = yaml.safe_load(record.read_text()) + record_obj = Record(**record_dict) + return record_obj + + +def generate_scaler( + forcing_data: xr.Dataset | None, + static_data: xr.Dataset, + observational_data: xr.Dataset, +) -> dict[str, np.ndarray]: + """A function to generate a scaler dictionary for the data + + Parameters + ---------- + forcing_data : xr.Dataset + The forcing data + static_data : xr.Dataset + The static data + """ + scaler = {} + if forcing_data is not None: + scaler["x_mean"], scaler["x_std"] = _calc_stats(forcing_data.values, axis=(0, 1)) + scaler["y_mean"], scaler["y_std"] = _calc_stats(observational_data.values, axis=(0, 1)) + scaler["c_mean"], scaler["c_std"] = _calc_stats(static_data.values) + return scaler diff --git a/src/hydro_data_dev/core/Config.py b/src/hydro_data_dev/core/Config.py deleted file mode 100644 index ede6286..0000000 --- a/src/hydro_data_dev/core/Config.py +++ /dev/null @@ -1,74 +0,0 @@ -from pathlib import Path - -import yaml -from pydantic import BaseModel, computed_field - - -class Config(BaseModel): - """ - Configuration model running deep learning simulations. - - Attributes - ---------- - bucket : str - The name of the bucket. - dataset : str - The name of the dataset. - train_date_list : List[str] - List of training dates. - val_date_list : List[str] - List of validation dates. - test_date_list : List[str] - List of test dates. - time_series_variables : List[str] - List of time series variables. - target_variables : List[str] - List of target variables. - static_variables : List[str] - List of static variables. - station_ids : Path - Path to the file containing station IDs. - add_coords : bool - Flag to add coordinates. - group_mask_dict : Optional[Dict] - Dictionary for group masks. - data_type : str - Type of the data. - """ - - bucket: str - dataset: str - train_date_list: list[str] - val_date_list: list[str] - test_date_list: list[str] - time_series_variables: list[str] - target_variables: list[str] - static_variables: list[str] - station_ids: Path - add_coords: bool - group_mask_dict: dict | None - data_type: str - - @computed_field - @property - def station_list(self) -> list[str]: - """Read station IDs from file""" - return self.station_ids.read_text().splitlines() - - @classmethod - def from_yaml(cls, path: Path) -> "Config": - """ - Create a Config instance from a YAML file. - - Parameters - ---------- - path : str | Path - The path to the YAML file. - - Returns - ------- - Config - An instance of the Config class. - """ - data = yaml.safe_load(path.read_text()) - return cls(**data) diff --git a/src/hydro_data_dev/core/Record.py b/src/hydro_data_dev/core/Record.py new file mode 100644 index 0000000..382104d --- /dev/null +++ b/src/hydro_data_dev/core/Record.py @@ -0,0 +1,96 @@ +from datetime import datetime +from pathlib import Path +from typing import Any + +from pydantic import BaseModel, field_validator, model_validator + +DATE_FORMAT = "%Y-%m-%d" + + +def _check_path(v: str) -> Path: + """A function to determine if a path exists or not + + Parameters + ---------- + v : str + A string representing a path + + Returns + ------- + Path + A Path object representing the path + """ + path = Path(v) + if not path.exists(): + cwd = Path.cwd() + raise ValueError(f"Path {v} does not exist. CWD: {cwd}") + return path + + +class Record(BaseModel): + """A dataclass to represent config inputs for your data fetching""" + + bucket: str + dataset: str + train_date_range: tuple[str, str] + val_date_range: tuple[str, str] + test_date_range: tuple[str, str] + time_series_variables: list[str] + target_variables: list[str] + static_variables: list[str] + station_ids: str + + @model_validator(mode="after") + @classmethod + def validate_dates(cls, config: Any) -> Any: + """A function to format str dates into datetime objects + + Parameters + ---------- + config : Any + A dictionary of configuration parameters + + Returns + ------- + Any + A dictionary of configuration parameters with datetime objects + """ + try: + config.train_date_range = tuple( + datetime.strptime(date_string, DATE_FORMAT) for date_string in config.train_date_range + ) + except ValueError as e: + raise ValueError("Error converting train_date_range to datetime") from e + try: + config.val_date_range = tuple( + datetime.strptime(date_string, DATE_FORMAT) for date_string in config.val_date_range + ) + except ValueError as e: + raise ValueError("Error converting val_date_range to datetime") from e + try: + config.test_date_range = tuple( + datetime.strptime(date_string, DATE_FORMAT) for date_string in config.test_date_range + ) + except ValueError as e: + raise ValueError("Error converting test_date_range to datetime") from e + return config + + @field_validator( + "station_ids", + ) + @classmethod + def validate_data_dir(cls, v: str) -> Path: + """A function to validate the data directory + + Parameters + ---------- + v : str + A string representing a path + + Returns + ------- + Path + A Path object representing the path + + """ + return _check_path(v) diff --git a/src/hydro_data_dev/core/methods.py b/src/hydro_data_dev/core/methods.py new file mode 100644 index 0000000..2202a96 --- /dev/null +++ b/src/hydro_data_dev/core/methods.py @@ -0,0 +1,51 @@ +import numpy as np + + +def _transform(data: np.ndarray, mean: np.ndarray, std: np.ndarray, inverse: bool = False) -> np.ndarray: + """A transformation function to normalize or denormalize data + + Author: Jiangtao Liu + + Parameters + ---------- + data : np.ndarray + The data to be transformed + mean : np.ndarray + The mean of the data + std : np.ndarray + The standard deviation of the data + inverse : bool, optional + A flag to indicate if the transformation is inverse, by default False + + Returns + ------- + np.ndarray + The transformed data + """ + if inverse: + return data * std + mean + else: + return (data - mean) / std + + +def _calc_stats(data: np.ndarray, axis: int = 0) -> tuple[np.ndarray, np.ndarray]: + """A function to calculate the mean and standard deviation of the data + + Author: Jiangtao Liu + + Parameters + ---------- + data : np.ndarray + The data to calculate stats from + axis : int, optional + The axis to calculate the stats, by default 0 + + Returns + ------- + tuple[np.ndarray, np.ndarray] + The mean and standard deviation of the data + """ + mean = np.nanmean(data, axis=axis).astype(float) + std = np.nanstd(data, axis=axis).astype(float) + std = np.maximum(std, 0.001) # Ensuring std is not too small + return mean, std diff --git a/tests/conftest.py b/tests/conftest.py index 15ed75e..edaafb0 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,23 +1,11 @@ from pathlib import Path -from typing import Any import pytest -import xarray as xr -current_dir = Path.cwd() +import hydro_data_dev as hdd @pytest.fixture -def sample_json() -> dict[str, Any]: - record = { - "meta": { - "name": "sample_camels", - "root": (current_dir / "sample_data/sample_camels.nc").__str__(), - "version": "v1.0.0", - }, - "data": xr.open_dataset(current_dir / "sample_data/sample_camels.nc", engine="netcdf4"), - "start_time": "1980-01-01T00:00:00.000000000", - "end_time": "2014-12-31T00:00:00.000000000", - "format": "netcdf", - } - return record +def sample_record() -> hdd.Record: + path = Path(__file__).parent / "data" / "example_record.yaml" + return hdd.create_record(record=path) diff --git a/tests/data/example_record.yaml b/tests/data/example_record.yaml new file mode 100644 index 0000000..709539b --- /dev/null +++ b/tests/data/example_record.yaml @@ -0,0 +1,33 @@ +# config.yaml +bucket: "mhpi-spatial" +dataset: "camels" + +# Date ranges for train/val/test splits +train_date_range: + - "1999-10-01" + - "2008-09-30" +val_date_range: + - "2008-10-01" + - "2014-09-30" +test_date_range: + - "1989-10-01" + - "1999-09-30" + +# Variables to use in the model +time_series_variables: + - dayl_daymet + - prcp_daymet + - srad_daymet + - tmean_daymet + - vp_daymet + +target_variables: + - runoff + +static_variables: + - p_mean + +station_ids: ./tests/data/sample_basins.txt +add_coords: false +group_mask_dict: null +data_type: basin diff --git a/tests/data/sample_basins.txt b/tests/data/sample_basins.txt new file mode 100644 index 0000000..fcffd52 --- /dev/null +++ b/tests/data/sample_basins.txt @@ -0,0 +1,10 @@ +01022500 +01031500 +01047000 +01052500 +01054200 +01055000 +01057000 +01073000 +01078000 +01123000 diff --git a/tests/test_fetch_data.py b/tests/test_fetch_data.py index b06f0c1..c36d378 100644 --- a/tests/test_fetch_data.py +++ b/tests/test_fetch_data.py @@ -1,8 +1,7 @@ -import pytest - import hydro_data_dev as hdd -def test_fetch_data() -> None: - with pytest.raises(ValueError): - hdd.fetch_data(bucket="test", dataset="test") +def test_fetch_data(sample_record: hdd.Record) -> None: + ds = hdd.fetch_data(record=sample_record) + + assert ds.station_ids.values.shape[0] == 10