diff --git a/pangeo_forge_recipes/aggregation.py b/pangeo_forge_recipes/aggregation.py index 55086f7c..51a2bdeb 100644 --- a/pangeo_forge_recipes/aggregation.py +++ b/pangeo_forge_recipes/aggregation.py @@ -282,5 +282,6 @@ def schema_to_zarr( """Initialize a zarr group based on a schema.""" ds = schema_to_template_ds(schema, specified_chunks=target_chunks, attrs=attrs) # using mode="w" makes this function idempotent - ds.to_zarr(target_store, mode="w", compute=False) + # NOTE: consolidated=False to move option to consolidate metadata later in StoreToZarr + ds.to_zarr(target_store, mode="w", compute=False, consolidated=False) return target_store diff --git a/pangeo_forge_recipes/rechunking.py b/pangeo_forge_recipes/rechunking.py index 57bbd542..dfd2a96c 100644 --- a/pangeo_forge_recipes/rechunking.py +++ b/pangeo_forge_recipes/rechunking.py @@ -6,6 +6,7 @@ import numpy as np import xarray as xr +import zarr from .aggregation import XarraySchema, determine_target_chunks from .chunk_grid import ChunkGrid @@ -238,3 +239,42 @@ def _sort_by_speed_of_varying(item): ds_combined = xr.combine_nested(dsets_to_concat, concat_dim=concat_dims_sorted) return first_index, ds_combined + + +def _gather_coordinate_dimensions(group: zarr.Group) -> List[str]: + return list( + set(itertools.chain(*(group[var].attrs.get("_ARRAY_DIMENSIONS", []) for var in group))) + ) + + +def consolidate_dimension_coordinates( + singleton_target_store: zarr.storage.FSStore, +) -> zarr.storage.FSStore: + """Consolidate dimension coordinates chunking""" + group = zarr.open_group(singleton_target_store) + + dims = (dim for dim in _gather_coordinate_dimensions(group) if dim in group) + for dim in dims: + arr = group[dim] + attrs = dict(arr.attrs) + data = arr[:] + + # This will generally use bulk-delete API calls + # config.storage_config.target.rm(dim, recursive=True) + + singleton_target_store.fs.rm(singleton_target_store.path + "/" + dim, recursive=True) + + new = group.array( + dim, + data, + chunks=arr.shape, + dtype=arr.dtype, + compressor=arr.compressor, + fill_value=arr.fill_value, + order=arr.order, + filters=arr.filters, + overwrite=True, + ) + + new.attrs.update(attrs) + return singleton_target_store diff --git a/pangeo_forge_recipes/transforms.py b/pangeo_forge_recipes/transforms.py index 8faf63c3..8359d151 100644 --- a/pangeo_forge_recipes/transforms.py +++ b/pangeo_forge_recipes/transforms.py @@ -21,7 +21,7 @@ from .combiners import CombineMultiZarrToZarr, CombineXarraySchemas from .openers import open_url, open_with_kerchunk, open_with_xarray from .patterns import CombineOp, Dimension, FileType, Index, augment_index_with_start_stop -from .rechunking import combine_fragments, split_fragment +from .rechunking import combine_fragments, consolidate_dimension_coordinates, split_fragment from .storage import CacheFSSpecTarget, FSSpecTarget from .writers import ZarrWriterMixin, store_dataset_fragment, write_combined_reference @@ -388,11 +388,6 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: ) -# TODO -# - consolidate coords -# - consolidate metadata - - @dataclass class Rechunk(beam.PTransform): target_chunks: Optional[Dict[str, int]] @@ -412,6 +407,25 @@ def expand(self, pcoll: beam.PCollection) -> beam.PCollection: return new_fragments +def _consolidate_zarr_metadata(store: zarr.storage.FSStore) -> zarr.storage.FSStore: + """Consolidate zarr metadata, passing the zarr store through once complete.""" + zarr.consolidate_metadata(store) + return store + + +class ConsolidateZarrMetadata(beam.PTransform): + def expand( + self, + pcoll: beam.PCollection[zarr.storage.FSStore], + ) -> beam.PCollection[zarr.storage.FSStore]: + return pcoll | beam.Map(_consolidate_zarr_metadata) + + +class ConsolidateDimensionCoordinates(beam.PTransform): + def expand(self, pcoll: beam.PCollection[zarr.storage.FSStore]) -> beam.PCollection: + return pcoll | beam.Map(consolidate_dimension_coordinates) + + @dataclass class CombineReferences(beam.PTransform): """Combines Kerchunk references into a single reference dataset. @@ -561,10 +575,22 @@ def expand(self, references: beam.PCollection) -> beam.PCollection[zarr.storage. ) +class SampleSingleton(beam.PTransform): + """Receive an input PCollection of any size, sample a single value from it, + and emit a singleton PCollection containing the single sampled value. + """ + + def expand(self, pcoll: beam.PCollection) -> beam.PCollection: + return ( + pcoll + | beam.combiners.Sample.FixedSizeGlobally(1) + | beam.FlatMap(lambda x: x) # https://stackoverflow.com/a/47146582 + ) + + @dataclass class StoreToZarr(beam.PTransform, ZarrWriterMixin): """Store a PCollection of Xarray datasets to Zarr. - :param combine_dims: The dimensions to combine :param store_name: Name for the Zarr store. It will be created with this name under `target_root`. @@ -572,6 +598,11 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin): `store_name` will be appended to this prefix to create a full path. :param target_chunks: Dictionary mapping dimension names to chunks sizes. If a dimension is a not named, the chunks will be inferred from the data. + :param consolidate_dimension_coordinates: Whether to rewrite coordinate variables as a + single chunk. We recommend consolidating coordinate variables to avoid + many small read requests to get the coordinates in xarray. Defaults to ``True``. + :param consolidate_metadata: Whether to consolidate metadata in the resulting + Zarr dataset. Defaults to ``True``. :param dynamic_chunking_fn: Optionally provide a function that takes an ``xarray.Dataset`` template dataset as its first argument and returns a dynamically generated chunking dict. If provided, ``target_chunks`` cannot also be passed. You can use this to determine chunking @@ -590,6 +621,8 @@ class StoreToZarr(beam.PTransform, ZarrWriterMixin): default_factory=RequiredAtRuntimeDefault ) target_chunks: Dict[str, int] = field(default_factory=dict) + consolidate_coords: bool = True + consolidate_metadata: bool = True dynamic_chunking_fn: Optional[Callable[[xr.Dataset], dict]] = None dynamic_chunking_fn_kwargs: Optional[dict] = field(default_factory=dict) attrs: Dict[str, str] = field(default_factory=dict) @@ -620,11 +653,18 @@ def expand( attrs=self.attrs, ) n_target_stores = rechunked_datasets | StoreDatasetFragments(target_store=target_store) + singleton_target_store = ( - n_target_stores - | beam.combiners.Sample.FixedSizeGlobally(1) - | beam.FlatMap(lambda x: x) # https://stackoverflow.com/a/47146582 + n_target_stores | SampleSingleton() + if not self.consolidate_coords + else n_target_stores | SampleSingleton() | ConsolidateDimensionCoordinates() ) + return ( + singleton_target_store + if not self.consolidate_metadata + else singleton_target_store | ConsolidateZarrMetadata() + ) + # TODO: optionally use `singleton_target_store` to # consolidate metadata and/or coordinate dims here diff --git a/pangeo_forge_recipes/writers.py b/pangeo_forge_recipes/writers.py index 3ca21ec5..0f43bbf5 100644 --- a/pangeo_forge_recipes/writers.py +++ b/pangeo_forge_recipes/writers.py @@ -90,7 +90,6 @@ def store_dataset_fragment( _store_data(vname, da.variable, index, zgroup) for vname, da in ds.data_vars.items(): _store_data(vname, da.variable, index, zgroup) - return target_store diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index 38c4a0de..25201424 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -8,6 +8,7 @@ import numpy as np import pytest import xarray as xr +import zarr from apache_beam.options.pipeline_options import PipelineOptions from apache_beam.testing.test_pipeline import TestPipeline from fsspec.implementations.reference import ReferenceFileSystem @@ -58,11 +59,13 @@ def test_xarray_zarr( xr.testing.assert_equal(ds.load(), daily_xarray_dataset) -def test_xarray_zarr_subpath( +@pytest.mark.parametrize("consolidate_coords", [True, False]) +def test_xarray_zarr_consolidate_coords( daily_xarray_dataset, netcdf_local_file_pattern_sequential, pipeline, tmp_target_url, + consolidate_coords, ): pattern = netcdf_local_file_pattern_sequential with pipeline as p: @@ -74,11 +77,16 @@ def test_xarray_zarr_subpath( target_root=tmp_target_url, store_name="subpath", combine_dims=pattern.combine_dim_keys, + consolidate_coords=consolidate_coords, ) ) + # TODO: This test needs to check if the consolidate_coords transform + # within StoreToZarr is consolidating the chunks of the coordinates - ds = xr.open_dataset(os.path.join(tmp_target_url, "subpath"), engine="zarr") - xr.testing.assert_equal(ds.load(), daily_xarray_dataset) + store = zarr.open(os.path.join(tmp_target_url, "subpath")) + + # fails + assert netcdf_local_file_pattern_sequential.dims["time"] == store.time.chunks[0] @pytest.mark.parametrize("output_file_name", ["reference.json", "reference.parquet"])