Skip to content

Commit

Permalink
⚡ use dask to improve memory and compute scalability
Browse files Browse the repository at this point in the history
* also remove
  • Loading branch information
kevinsantana11 committed Sep 29, 2024
1 parent 33ac0d5 commit 99d166c
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 204 deletions.
256 changes: 60 additions & 196 deletions clouddrift/adapters/gdp/gdpsource.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,18 @@
from __future__ import annotations

import asyncio
import datetime
import logging
import os
import tempfile
import warnings
from collections import defaultdict
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from typing import Callable

import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr
from tqdm.asyncio import tqdm
from dask import delayed
from tqdm import tqdm

from clouddrift.adapters.gdp import get_gdp_metadata
from clouddrift.adapters.utils import download_with_progress
Expand All @@ -25,12 +24,11 @@
_FILENAME_TEMPLATE = "buoydata_{start}_{end}_{suffix}.dat.gz"
_SECONDS_IN_DAY = 86_400

_COORDS = ["id", "obs_index"]
_COORDS = ["id", "position_datetime"]

_DATA_VARS = [
"latitude",
"longitude",
"position_datetime",
"sensor_datetime",
"drogue",
"sst",
Expand Down Expand Up @@ -108,15 +106,15 @@

_INPUT_COLS_DTYPES = {
"id": np.int64,
"posObsMonth": np.int8,
"posObsMonth": np.float32,
"posObsDay": np.float64,
"posObsYear": np.int16,
"posObsYear": np.float32,
"latitude": np.float32,
"longitude": np.float32,
"qualityIndex": np.float32,
"senObsMonth": np.int8,
"senObsMonth": np.float32,
"senObsDay": np.float64,
"senObsYear": np.int16,
"senObsYear": np.float32,
"drogue": np.float32,
"sst": np.float32,
"voltage": np.float32,
Expand All @@ -125,6 +123,14 @@
"sensor6": np.float32,
}

_INPUT_COLS_PREFILTER_DTYPES: dict[str, type[object]] = {
"posObsMonth": np.str_,
"posObsYear": np.float64,
"senObsMonth": np.str_,
"senObsYear": np.float64,
"drogue": np.str_,
}


VARS_ATTRS: dict = {
"id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"},
Expand Down Expand Up @@ -323,9 +329,9 @@ def _preprocess(id_, **kwargs) -> xr.Dataset:

coords = {
"id": (["traj"], np.array([id_]).astype(np.int64)),
"obs_index": (
"position_datetime": (
["obs"],
traj_data_df[["obs_index"]].values.flatten().astype(np.int32),
traj_data_df[["position_datetime"]].values.flatten().astype(np.datetime64),
),
}

Expand Down Expand Up @@ -374,20 +380,18 @@ def _parse_datetime_with_day_ratio(
return np.array(values).astype("datetime64[ns]")


def _process_chunk(
df_chunk: pd.DataFrame,
start_idx: int,
end_idx: int,
def _process(
df: dd.DataFrame,
gdp_metadata_df: pd.DataFrame,
use_fill_values: bool,
) -> dict[int, xr.Dataset]:
) -> xr.Dataset:
"""Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset."""

# Transform the initial dataframe filtering out rows with really anomolous values
# examples include: years in the future, years way in the past before GDP program, etc...
preremove_df_chunk = df_chunk.assign(obs_index=range(start_idx, end_idx))
preremove_df = df.compute()
df_chunk = _apply_remove(
preremove_df_chunk,
preremove_df,
filters=[
# Filter out year values that are in the future or predating the GDP program
lambda df: (df["posObsYear"] > datetime.datetime.now().year)
Expand All @@ -405,7 +409,7 @@ def _process_chunk(

drifter_ds_map = dict[int, xr.Dataset]()

preremove_len = len(preremove_df_chunk)
preremove_len = len(preremove_df)
postremove_len = len(df_chunk)

if preremove_len != postremove_len:
Expand Down Expand Up @@ -455,153 +459,16 @@ def _process_chunk(
md_df=gdp_metadata_df,
data_df=df_chunk,
use_fill_values=use_fill_values,
tqdm=dict(disable=True),
)
ds = ra.to_xarray()

for id_ in ids_with_md:
id_f_ds = subset(ds, dict(id=id_), row_dim_name="traj")
drifter_ds_map[id_] = id_f_ds
return drifter_ds_map


def _combine_chunked_drifter_datasets(datasets: list[xr.Dataset]) -> xr.Dataset:
"""Combines several drifter observations found in separate chunks, ordering them
by the observations row index.
"""
traj_dataset = xr.concat(
datasets, dim="obs", coords="minimal", data_vars=_DATA_VARS, compat="override"
tqdm={"disable": True},
)
return ra.to_xarray()

new_rowsize = sum([ds.rowsize.values[0] for ds in datasets])
traj_dataset["rowsize"] = xr.DataArray(
np.array([new_rowsize], dtype=np.int64), coords=traj_dataset["rowsize"].coords
)

sort_coord = traj_dataset.coords["obs_index"]
vals: np.ndarray = sort_coord.data
sort_coord_dim = sort_coord.dims[-1]
sort_key = vals.argsort()

for coord_name in _COORDS:
coord = traj_dataset.coords[coord_name]
dim = coord.dims[-1]

if dim == sort_coord_dim:
sorted_coord = coord.isel({dim: sort_key})
traj_dataset.coords[coord_name] = sorted_coord

for varname in _DATA_VARS:
var = traj_dataset[varname]
dim = var.dims[-1]
sorted_var = var.isel({dim: sort_key})
traj_dataset[varname] = sorted_var

return traj_dataset


async def _parallel_get(
sources: list[str],
gdp_metadata_df: pd.DataFrame,
chunk_size: int,
tmp_path: str,
use_fill_values: bool,
max_chunks: int | None,
) -> list[xr.Dataset]:
"""Parallel process dataset in chunks leveraging multiprocessing."""
max_workers = (os.cpu_count() or 0) // 2
with ProcessPoolExecutor(max_workers=max_workers) as ppe:
drifter_chunked_datasets: dict[int, list[xr.Dataset]] = defaultdict(list)
start_idx = 0
for fp in tqdm(
sources,
desc="Loading files",
unit="file",
ncols=80,
total=len(sources),
position=0,
):
file_chunks = pd.read_csv(
fp,
sep=r"\s+",
header=None,
names=_INPUT_COLS,
engine="c",
compression="gzip",
chunksize=chunk_size,
)

joblist = list[Future]()
jobmap = dict[Future, pd.DataFrame]()
for idx, chunk in enumerate(file_chunks):
if max_chunks is not None and idx >= max_chunks:
break
ajob = ppe.submit(
_process_chunk,
chunk,
start_idx,
start_idx + len(chunk),
gdp_metadata_df,
use_fill_values,
)
start_idx += len(chunk)
jobmap[ajob] = chunk
joblist.append(ajob)

bar = tqdm(
desc="Processing file chunks",
unit="chunk",
ncols=80,
total=len(joblist),
position=1,
)

for ajob in as_completed(jobmap.keys()):
if (exc := ajob.exception()) is not None:
chunk = jobmap[ajob]
_logger.warn(f"bad chunk detected, exception: {ajob.exception()}")
raise exc

job_drifter_ds_map: dict[int, xr.Dataset] = ajob.result()
for id_ in job_drifter_ds_map.keys():
drifter_ds = job_drifter_ds_map[id_]
drifter_chunked_datasets[id_].append(drifter_ds)
bar.update()

combine_jobmap = dict[Future, int]()
for id_ in drifter_chunked_datasets.keys():
datasets = drifter_chunked_datasets[id_]

combine_job = ppe.submit(_combine_chunked_drifter_datasets, datasets)
combine_jobmap[combine_job] = id_

bar.close()
bar = tqdm(
desc="merging drifter chunks",
unit="drifter",
ncols=80,
total=len(drifter_chunked_datasets.keys()),
position=2,
)

os.makedirs(os.path.join(tmp_path, "drifters"), exist_ok=True)

drifter_datasets = list[xr.Dataset]()
for combine_job in as_completed(combine_jobmap.keys()):
dataset: xr.Dataset = combine_job.result()
drifter_datasets.append(dataset)
bar.update()
bar.close()
return drifter_datasets


def to_raggedarray(
tmp_path: str = _TMP_PATH,
skip_download: bool = False,
max: int | None = None,
chunk_size: int = 100_000,
use_fill_values: bool = True,
max_chunks: int | None = None,
) -> xr.Dataset:
"""Get the GDP source dataset."""

Expand All @@ -611,52 +478,49 @@ def to_raggedarray(

# Filter down for testing purposes.
if max:
requests = [requests[max]]
requests = requests[:max]

# Download necessary data and metadata files.
if not skip_download:
download_with_progress(requests)
download_with_progress(requests)

gdp_metadata_df = get_gdp_metadata(tmp_path)

# Run async process to parallelize data processing.
drifter_datasets = asyncio.run(
_parallel_get(
[dst for (_, dst) in requests],
gdp_metadata_df,
chunk_size,
tmp_path,
use_fill_values,
max_chunks,
)
import gzip

data_files = list()
for compressed_data_file in tqdm(
[dst for (_, dst) in requests], desc="Decompressing files", unit="file"
):
decompressed_fp = compressed_data_file[:-3]
data_files.append(decompressed_fp)
if not os.path.exists(decompressed_fp):
with (
gzip.open(compressed_data_file, "rb") as compr,
open(decompressed_fp, "wb") as decompr,
):
decompr.write(compr.read())

wanted_dtypes = dict()
wanted_dtypes.update(_INPUT_COLS_DTYPES)
wanted_dtypes.update(_INPUT_COLS_PREFILTER_DTYPES)

df: dd.DataFrame = dd.read_csv(
data_files,
sep=r"\s+",
header=None,
names=_INPUT_COLS,
dtype=wanted_dtypes,
engine="c",
blocksize="1GB",
assume_missing=True,
)
ds = _process(df, gdp_metadata_df, use_fill_values)

# Sort the drifters by their start date.
deploy_date_id_map = {
ds["id"].data[0]: ds["start_date"].data[0] for ds in drifter_datasets
}
deploy_date_sort_key = np.argsort(list(deploy_date_id_map.values()))
sorted_drifter_datasets = [drifter_datasets[idx] for idx in deploy_date_sort_key]

# Concatenate drifter data and metadata variables separately.
obs_ds = xr.concat(
[ds.drop_dims("traj") for ds in sorted_drifter_datasets],
dim="obs",
data_vars=_DATA_VARS,
)
traj_ds = xr.concat(
[ds.drop_dims("obs") for ds in sorted_drifter_datasets],
dim="traj",
data_vars=_METADATA_VARS,
)

# Merge the separate datasets.
agg_ds = xr.merge([obs_ds, traj_ds])

# Add variable metadata.
for var_name in _DATA_VARS + _METADATA_VARS:
if var_name in VARS_ATTRS.keys():
agg_ds[var_name].attrs = VARS_ATTRS[var_name]
agg_ds.attrs = ATTRS
ds[var_name].attrs = VARS_ATTRS[var_name]
ds.attrs = ATTRS

return agg_ds
return ds
6 changes: 1 addition & 5 deletions clouddrift/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,6 @@ def gdp6h(decode_times: bool = True) -> xr.Dataset:
def gdp_source(
tmp_path: str = adapters.gdp_source._TMP_PATH,
max: int | None = None,
skip_download: bool = False,
use_fill_values: bool = True,
decode_times: bool = True,
) -> xr.Dataset:
Expand All @@ -178,9 +177,6 @@ def gdp_source(
max: int, optional
Maximum number of files to retrieve and parse to generate the aggregate file. Mainly used
for testing purposes.
skip_download: bool, False (default)
If True, skips downloading the data files and the code assumes the files have already been downloaded.
This is mainly used to skip downloading files if the remote doesn't provide the HTTP Last-Modified header.
use_fill_values: bool, True (default)
When True, missing metadata fields are replaced with fill values. When False and no metadata
is found for a given drifter its observations are ignored.
Expand Down Expand Up @@ -225,7 +221,7 @@ def gdp_source(
f"gdpsource_agg_{file_selection_label}.zarr",
decode_times,
lambda: adapters.gdp_source.to_raggedarray(
tmp_path, skip_download, max, use_fill_values=use_fill_values
tmp_path, max, use_fill_values=use_fill_values
),
)

Expand Down
Loading

0 comments on commit 99d166c

Please sign in to comment.