Skip to content

Commit

Permalink
Leverage dask to simplify adapter but still leverage parallelization
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinsantana11 committed Aug 24, 2024
1 parent 0e08777 commit 64bf984
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 157 deletions.
194 changes: 38 additions & 156 deletions clouddrift/adapters/gdp/gdpsource.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
from __future__ import annotations

import asyncio
import datetime
import logging
import os
import tempfile
import warnings
from collections import defaultdict
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from typing import Callable

import dask.dataframe as dd
import numpy as np
import pandas as pd
import xarray as xr
from tqdm.asyncio import tqdm
from tqdm import tqdm

from clouddrift.adapters.gdp import get_gdp_metadata
from clouddrift.adapters.utils import download_with_progress
Expand Down Expand Up @@ -108,15 +106,15 @@

_INPUT_COLS_DTYPES = {
"id": np.int64,
"posObsMonth": np.int8,
"posObsMonth": np.float32,
"posObsDay": np.float64,
"posObsYear": np.int16,
"posObsYear": np.float32,
"latitude": np.float32,
"longitude": np.float32,
"qualityIndex": np.float32,
"senObsMonth": np.int8,
"senObsMonth": np.float32,
"senObsDay": np.float64,
"senObsYear": np.int16,
"senObsYear": np.float32,
"drogue": np.float32,
"sst": np.float32,
"voltage": np.float32,
Expand All @@ -125,6 +123,14 @@
"sensor6": np.float32,
}

_INPUT_COLS_PREFILTER_DTYPES = {
"posObsMonth": np.str_,
"posObsYear": np.float64,
"senObsMonth": np.str_,
"senObsYear": np.float64,
"drogue": np.str_,
}


VARS_ATTRS: dict = {
"id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"},
Expand Down Expand Up @@ -374,18 +380,16 @@ def _parse_datetime_with_day_ratio(
return np.array(values).astype("datetime64[ns]")


def _process_chunk(
df_chunk: pd.DataFrame,
start_idx: int,
end_idx: int,
def _process(
df: pd.DataFrame,
gdp_metadata_df: pd.DataFrame,
use_fill_values: bool,
) -> dict[int, xr.Dataset]:
"""Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset."""

# Transform the initial dataframe filtering out rows with really anomolous values
# examples include: years in the future, years way in the past before GDP program, etc...
preremove_df_chunk = df_chunk.assign(obs_index=range(start_idx, end_idx))
preremove_df_chunk = df
df_chunk = _apply_remove(
preremove_df_chunk,
filters=[
Expand Down Expand Up @@ -455,7 +459,7 @@ def _process_chunk(
md_df=gdp_metadata_df,
data_df=df_chunk,
use_fill_values=use_fill_values,
tqdm=dict(disable=True),
tqdm={"disable": True}
)
ds = ra.to_xarray()

Expand All @@ -465,143 +469,11 @@ def _process_chunk(
return drifter_ds_map


def _combine_chunked_drifter_datasets(datasets: list[xr.Dataset]) -> xr.Dataset:
"""Combines several drifter observations found in separate chunks, ordering them
by the observations row index.
"""
traj_dataset = xr.concat(
datasets, dim="obs", coords="minimal", data_vars=_DATA_VARS, compat="override"
)

new_rowsize = sum([ds.rowsize.values[0] for ds in datasets])
traj_dataset["rowsize"] = xr.DataArray(
np.array([new_rowsize], dtype=np.int64), coords=traj_dataset["rowsize"].coords
)

sort_coord = traj_dataset.coords["obs_index"]
vals: np.ndarray = sort_coord.data
sort_coord_dim = sort_coord.dims[-1]
sort_key = vals.argsort()

for coord_name in _COORDS:
coord = traj_dataset.coords[coord_name]
dim = coord.dims[-1]

if dim == sort_coord_dim:
sorted_coord = coord.isel({dim: sort_key})
traj_dataset.coords[coord_name] = sorted_coord

for varname in _DATA_VARS:
var = traj_dataset[varname]
dim = var.dims[-1]
sorted_var = var.isel({dim: sort_key})
traj_dataset[varname] = sorted_var

return traj_dataset


async def _parallel_get(
sources: list[str],
gdp_metadata_df: pd.DataFrame,
chunk_size: int,
tmp_path: str,
use_fill_values: bool,
max_chunks: int | None,
) -> list[xr.Dataset]:
"""Parallel process dataset in chunks leveraging multiprocessing."""
max_workers = (os.cpu_count() or 0) // 2
with ProcessPoolExecutor(max_workers=max_workers) as ppe:
drifter_chunked_datasets: dict[int, list[xr.Dataset]] = defaultdict(list)
start_idx = 0
for fp in tqdm(
sources,
desc="Loading files",
unit="file",
ncols=80,
total=len(sources),
position=0,
):
file_chunks = pd.read_csv(
fp,
sep=r"\s+",
header=None,
names=_INPUT_COLS,
engine="c",
compression="gzip",
chunksize=chunk_size,
)

joblist = list[Future]()
jobmap = dict[Future, pd.DataFrame]()
for idx, chunk in enumerate(file_chunks):
if max_chunks is not None and idx >= max_chunks:
break
ajob = ppe.submit(
_process_chunk,
chunk,
start_idx,
start_idx + len(chunk),
gdp_metadata_df,
use_fill_values,
)
start_idx += len(chunk)
jobmap[ajob] = chunk
joblist.append(ajob)

bar = tqdm(
desc="Processing file chunks",
unit="chunk",
ncols=80,
total=len(joblist),
position=1,
)

for ajob in as_completed(jobmap.keys()):
if (exc := ajob.exception()) is not None:
chunk = jobmap[ajob]
_logger.warn(f"bad chunk detected, exception: {ajob.exception()}")
raise exc

job_drifter_ds_map: dict[int, xr.Dataset] = ajob.result()
for id_ in job_drifter_ds_map.keys():
drifter_ds = job_drifter_ds_map[id_]
drifter_chunked_datasets[id_].append(drifter_ds)
bar.update()

combine_jobmap = dict[Future, int]()
for id_ in drifter_chunked_datasets.keys():
datasets = drifter_chunked_datasets[id_]

combine_job = ppe.submit(_combine_chunked_drifter_datasets, datasets)
combine_jobmap[combine_job] = id_

bar.close()
bar = tqdm(
desc="merging drifter chunks",
unit="drifter",
ncols=80,
total=len(drifter_chunked_datasets.keys()),
position=2,
)

os.makedirs(os.path.join(tmp_path, "drifters"), exist_ok=True)

drifter_datasets = list[xr.Dataset]()
for combine_job in as_completed(combine_jobmap.keys()):
dataset: xr.Dataset = combine_job.result()
drifter_datasets.append(dataset)
bar.update()
bar.close()
return drifter_datasets


def to_raggedarray(
tmp_path: str = _TMP_PATH,
skip_download: bool = False,
max: int | None = None,
chunk_size: int = 100_000,
use_fill_values: bool = True,
max_chunks: int | None = None,
) -> xr.Dataset:
"""Get the GDP source dataset."""

Expand All @@ -619,17 +491,27 @@ def to_raggedarray(

gdp_metadata_df = get_gdp_metadata(tmp_path)

# Run async process to parallelize data processing.
drifter_datasets = asyncio.run(
_parallel_get(
[dst for (_, dst) in requests],
gdp_metadata_df,
chunk_size,
tmp_path,
use_fill_values,
max_chunks,
)
import gzip

data_files = list()
for compressed_data_file in tqdm([dst for (_, dst) in requests], desc="Decompressing files", unit="file"):
decompressed_fp = compressed_data_file[:-3]
data_files.append(decompressed_fp)
if not os.path.exists(decompressed_fp):
with gzip.open(compressed_data_file, "rb") as compr, open(decompressed_fp, "wb") as decompr:
decompr.write(compr.read())

df = dd.read_csv(
data_files,
sep=r"\s+",
header=None,
names=_INPUT_COLS,
engine="c",
dtype=_INPUT_COLS_PREFILTER_DTYPES,
blocksize="1GB",
assume_missing=True
)
drifter_datasets = _process(df, gdp_metadata_df, use_fill_values)

# Sort the drifters by their start date.
deploy_date_id_map = {
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ dependencies = [
"scipy>=1.11.2",
"xarray>=2023.5.0",
"zarr>=2.14.2",
"tenacity>=8.2.3"
"tenacity>=8.2.3",
"dask>=2024.5.0"
]

[project.optional-dependencies]
Expand Down

0 comments on commit 64bf984

Please sign in to comment.