Skip to content

Commit

Permalink
BUG:merge: Fix merging masked & scaled data (#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
snowman2 authored Nov 1, 2024
1 parent fca24e3 commit 3ca29fb
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 48 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ jobs:
gdal-version: ['3.8.2']
include:
- python-version: '3.10'
rasterio-version: ''
rasterio-version: '==1.3'
xarray-version: '==2024.7.0'
numpy-version: '<2'
run-with-scipy: 'YES'
Expand Down
1 change: 1 addition & 0 deletions docs/history.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ History

Latest
------
- BUG:merge: Fix merging masked and scaled data (issue #814)

0.17.0
------
Expand Down
2 changes: 1 addition & 1 deletion rioxarray/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ def __init__(
)
if self._unsigned_dtype is not None and self._fill_value is not None:
self._fill_value = self._unsigned_dtype.type(self._fill_value)
if self._unsigned_dtype is None and dtype.kind not in ("i", "u"):
if self._unsigned_dtype is None:
warnings.warn(
f"variable {name!r} has _Unsigned attribute but is not "
"of integer type. Ignoring attribute.",
Expand Down
83 changes: 45 additions & 38 deletions rioxarray/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
This module allows you to merge xarray Datasets/DataArrays
geospatially with the `rasterio.merge` module.
"""

from collections.abc import Sequence
from typing import Callable, Optional, Union

Expand All @@ -12,6 +11,7 @@
from rasterio.merge import merge as _rio_merge
from xarray import DataArray, Dataset

from rioxarray._io import open_rasterio
from rioxarray.rioxarray import _get_nonspatial_coords, _make_coords


Expand All @@ -31,13 +31,25 @@ def __init__(self, xds: DataArray):
self.count = int(xds.rio.count)
self.dtypes = [xds.dtype]
self.name = xds.name
self.nodatavals = [xds.rio.nodata]
if xds.rio.encoded_nodata is not None:
self.nodatavals = [xds.rio.encoded_nodata]
else:
self.nodatavals = [xds.rio.nodata]
res = xds.rio.resolution(recalc=True)
self.res = (abs(res[0]), abs(res[1]))
self.transform = xds.rio.transform(recalc=True)
# profile is only used for writing to a file.
# This never happens with rioxarray merge.
self.profile: dict = {}
self.profile: dict = {
"crs": self.crs,
"nodata": self.nodatavals[0],
}
self._scale_factor = self._xds.encoding.get("scale_factor", 1.0)
self._add_offset = self._xds.encoding.get("add_offset", 0.0)
self._mask_and_scale = (
self._xds.rio.encoded_nodata is not None
or self._scale_factor != 1
or self._add_offset != 0
or self._xds.encoding.get("_Unsigned") is not None
)

def colormap(self, *args, **kwargs) -> None:
"""
Expand All @@ -54,7 +66,15 @@ def read(self, *args, **kwargs) -> numpy.ma.MaskedArray:
with MemoryFile() as memfile:
self._xds.rio.to_raster(memfile.name)
with memfile.open() as dataset:
return dataset.read(*args, **kwargs)
if self._mask_and_scale:
kwargs["masked"] = True
out = dataset.read(*args, **kwargs)
if self._mask_and_scale:
if self._scale_factor != 1:
out = out * self._scale_factor
if self._add_offset != 0:
out = out + self._add_offset
return out


def merge_arrays(
Expand Down Expand Up @@ -132,41 +152,28 @@ def merge_arrays(
rioduckarrays.append(RasterioDatasetDuck(dataarray))

# use rasterio to merge
merged_data, merged_transform = _rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
)
# generate merged data array
representative_array = rioduckarrays[0]._xds
if parse_coordinates:
coords = _make_coords(
src_data_array=representative_array,
dst_affine=merged_transform,
dst_width=merged_data.shape[-1],
dst_height=merged_data.shape[-2],
with MemoryFile() as memfile:
_rio_merge(
rioduckarrays,
**{key: val for key, val in input_kwargs.items() if val is not None},
dst_path=memfile.name,
)
else:
coords = _get_nonspatial_coords(representative_array)

# make sure the output merged data shape is 2D if the
# original data was 2D. this can happen if the
# xarray datasarray was squeezed.
if len(merged_data.shape) == 3 and len(representative_array.shape) == 2:
merged_data = merged_data.squeeze()

xda = DataArray(
name=representative_array.name,
data=merged_data,
coords=coords,
dims=tuple(representative_array.dims),
attrs=representative_array.attrs,
)
xda.rio.write_nodata(
nodata if nodata is not None else representative_array.rio.nodata, inplace=True
)
xda.rio.write_crs(representative_array.rio.crs, inplace=True)
xda.rio.write_transform(merged_transform, inplace=True)
return xda
with open_rasterio( # type: ignore
memfile.name,
parse_coordinates=parse_coordinates,
mask_and_scale=rioduckarrays[0]._mask_and_scale,
) as xda:
xda = xda.load()
xda.coords.update(
{
coord: value
for coord, value in _get_nonspatial_coords(representative_array).items()
if coord not in xda.coords
}
)
return xda # type: ignore


def merge_datasets(
Expand Down
37 changes: 29 additions & 8 deletions test/integration/test_integration_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,12 @@ def test_merge_arrays(squeeze):
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"add_offset": 0.0,
"scale_factor": 1.0,
**rds.attrs,
}
assert merged.encoding["grid_mapping"] == "spatial_ref"


Expand Down Expand Up @@ -106,6 +111,7 @@ def test_merge__different_crs(dataset):
assert merged.rio.crs == rds.rio.crs
if not dataset:
assert merged.attrs == {
"AREA_OR_POINT": "Area",
"_FillValue": -28672,
"add_offset": 0.0,
"scale_factor": 1.0,
Expand All @@ -116,9 +122,6 @@ def test_merge__different_crs(dataset):
def test_merge_arrays__res():
dem_test = os.path.join(TEST_INPUT_DATA_DIR, "MODIS_ARRAY.nc")
with open_rasterio(dem_test, masked=True) as rds:
rds.attrs = {
"_FillValue": rds.rio.nodata,
}
arrays = [
rds.isel(x=slice(100), y=slice(100)),
rds.isel(x=slice(100, 200), y=slice(100, 200)),
Expand Down Expand Up @@ -151,9 +154,8 @@ def test_merge_arrays__res():
assert merged.coords["band"].values == [1]
assert sorted(merged.coords) == ["band", "spatial_ref", "x", "y"]
assert merged.rio.crs == rds.rio.crs
assert_almost_equal(merged.attrs.pop("_FillValue"), rds.attrs.pop("_FillValue"))
compare_attrs = dict(rds.attrs)
assert merged.attrs == compare_attrs
assert_almost_equal(merged.rio.nodata, rds.rio.nodata)
assert_almost_equal(merged.rio.encoded_nodata, rds.rio.encoded_nodata)
assert merged.encoding["grid_mapping"] == "spatial_ref"
assert_almost_equal(nansum(merged), 13760565)

Expand Down Expand Up @@ -191,7 +193,6 @@ def test_merge_datasets():
(-4447802.078667, -10007554.677, -3335851.559, -8895604.157333),
)
assert merged.rio.shape == (2400, 2400)
assert_almost_equal(merged[data_var].sum(), 4539666606551516)
assert_almost_equal(
tuple(merged[data_var].rio.transform()),
(
Expand All @@ -211,6 +212,7 @@ def test_merge_datasets():
assert merged.rio.crs == rds.rio.crs
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
assert_almost_equal(merged[data_var].sum(), 4539666606551516)


@pytest.mark.xfail(os.name == "nt", reason="On windows the merged data is different.")
Expand Down Expand Up @@ -266,3 +268,22 @@ def test_merge_datasets__res():
assert merged.attrs == rds.attrs
assert merged.encoding["grid_mapping"] == "spatial_ref"
assert_almost_equal(merged[data_var].sum(), 974566547463955)


@pytest.mark.parametrize("mask_and_scale", [True, False])
def test_merge_datasets__mask_and_scale(mask_and_scale):
test_file = os.path.join(TEST_INPUT_DATA_DIR, "tmmx_20190121.nc")
with open_rasterio(test_file, mask_and_scale=mask_and_scale) as rds:
rds = rds.to_dataset()
datasets = [
rds.isel(x=slice(100), y=slice(100)),
rds.isel(x=slice(100, None), y=slice(100, None)),
rds.isel(x=slice(100), y=slice(100, None)),
rds.isel(x=slice(100, None), y=slice(100)),
]
merged = merge_datasets(datasets)
total = merged.air_temperature.sum()
if mask_and_scale:
assert_almost_equal(total, 133376696)
else:
assert_almost_equal(total, 10981781386)

0 comments on commit 3ca29fb

Please sign in to comment.