-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: extended pystac client to support aggregations stac-api extensi…
- Loading branch information
Showing
5 changed files
with
247 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,119 @@ | ||
"""Test Advanced PySTAC client.""" | ||
import json | ||
import os | ||
from unittest.mock import MagicMock, patch | ||
|
||
import pytest | ||
|
||
from titiler.pystac import AdvancedClient | ||
|
||
catalog_json = os.path.join(os.path.dirname(__file__), "fixtures", "catalog.json") | ||
|
||
|
||
@pytest.fixture | ||
def mock_stac_io(): | ||
"""STAC IO mock""" | ||
return MagicMock() | ||
|
||
|
||
@pytest.fixture | ||
def client(mock_stac_io): | ||
"""STAC client mock""" | ||
client = AdvancedClient(id="pystac-client", description="pystac-client") | ||
|
||
with open(catalog_json, "r") as f: | ||
catalog = json.loads(f.read()) | ||
client.open = MagicMock() | ||
client.open.return_value = catalog | ||
client._collections_href = MagicMock() | ||
client._collections_href.return_value = "http://example.com/collections" | ||
|
||
client._stac_io = mock_stac_io | ||
return client | ||
|
||
|
||
def test_get_supported_aggregations(client, mock_stac_io): | ||
"""Test supported STAC aggregation methods""" | ||
mock_stac_io.read_json.return_value = { | ||
"aggregations": [{"name": "aggregation1"}, {"name": "aggregation2"}] | ||
} | ||
supported_aggregations = client.get_supported_aggregations() | ||
assert supported_aggregations == ["aggregation1", "aggregation2"] | ||
|
||
|
||
@patch( | ||
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", | ||
return_value=["datetime_frequency"], | ||
) | ||
def test_get_aggregation_unsupported(supported_aggregations, client): | ||
"""Test handling of unsupported aggregation types""" | ||
collection_id = "sentinel-2-l2a" | ||
aggregation = "unsupported-aggregation" | ||
|
||
with pytest.warns( | ||
UserWarning, match="Aggregation type unsupported-aggregation is not supported" | ||
): | ||
aggregation_data = client.get_aggregation(collection_id, aggregation) | ||
assert aggregation_data == [] | ||
|
||
|
||
@patch( | ||
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", | ||
return_value=["datetime_frequency"], | ||
) | ||
def test_get_aggregation(supported_aggregations, client, mock_stac_io): | ||
"""Test handling aggregation response""" | ||
collection_id = "sentinel-2-l2a" | ||
aggregation = "datetime_frequency" | ||
aggregation_params = {"datetime_frequency_interval": "day"} | ||
|
||
mock_stac_io.read_json.return_value = { | ||
"aggregations": [ | ||
{ | ||
"name": "datetime_frequency", | ||
"buckets": [ | ||
{ | ||
"key": "2023-12-11T00:00:00.000Z", | ||
"data_type": "frequency_distribution", | ||
"frequency": 1, | ||
"to": None, | ||
"from": None, | ||
} | ||
], | ||
}, | ||
{ | ||
"name": "unusable_aggregation", | ||
"buckets": [ | ||
{ | ||
"key": "2023-12-11T00:00:00.000Z", | ||
} | ||
], | ||
}, | ||
] | ||
} | ||
|
||
aggregation_data = client.get_aggregation( | ||
collection_id, aggregation, aggregation_params | ||
) | ||
assert aggregation_data[0]["key"] == "2023-12-11T00:00:00.000Z" | ||
assert aggregation_data[0]["data_type"] == "frequency_distribution" | ||
assert aggregation_data[0]["frequency"] == 1 | ||
assert len(aggregation_data) == 1 | ||
|
||
|
||
@patch( | ||
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations", | ||
return_value=["datetime_frequency"], | ||
) | ||
def test_get_aggregation_no_response(supported_aggregations, client, mock_stac_io): | ||
"""Test handling of no aggregation response""" | ||
collection_id = "sentinel-2-l2a" | ||
aggregation = "datetime_frequency" | ||
aggregation_params = {"datetime_frequency_interval": "day"} | ||
|
||
mock_stac_io.read_json.return_value = [] | ||
|
||
aggregation_data = client.get_aggregation( | ||
collection_id, aggregation, aggregation_params | ||
) | ||
assert aggregation_data == [] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""titiler.pystac""" | ||
|
||
__all__ = [ | ||
"AdvancedClient", | ||
] | ||
|
||
from titiler.pystac.advanced_client import AdvancedClient |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
""" | ||
This module provides an advanced client for interacting with STAC (SpatioTemporal Asset Catalog) APIs. | ||
The `AdvancedClient` class extends the basic functionality of the `pystac.Client` to include | ||
methods for retrieving and aggregating data from STAC collections. | ||
""" | ||
|
||
import warnings | ||
from typing import Optional | ||
from urllib.parse import urlencode | ||
|
||
import pystac | ||
from pystac_client import Client | ||
|
||
|
||
class AdvancedClient(Client): | ||
"""AdvancedClient extends the basic functionality of the pystac.Client class.""" | ||
|
||
def get_aggregation( | ||
self, | ||
collection_id: str, | ||
aggregation: str, | ||
aggregation_params: Optional[dict] = None, | ||
) -> list[dict]: | ||
"""Perform an aggregation on a STAC collection. | ||
Args: | ||
collection_id (str): The ID of the collection to aggregate. | ||
aggregation (str): The aggregation type to perform. | ||
aggregation_params (Optional[dict], optional): Additional parameters for the aggregation. Defaults to None. | ||
Returns: | ||
List[str]: The aggregation response. | ||
""" | ||
assert self._stac_io is not None | ||
|
||
if aggregation not in self.get_supported_aggregations(): | ||
warnings.warn( | ||
f"Aggregation type {aggregation} is not supported", stacklevel=1 | ||
) | ||
return [] | ||
|
||
# Construct the URL for aggregation | ||
url = ( | ||
self._collections_href(collection_id) | ||
+ f"/aggregate?aggregations={aggregation}" | ||
) | ||
if aggregation_params: | ||
params = urlencode(aggregation_params) | ||
url += f"&{params}" | ||
|
||
aggregation_response = self._stac_io.read_json(url) | ||
|
||
if not aggregation_response: | ||
return [] | ||
|
||
aggregation_data = [] | ||
for agg in aggregation_response["aggregations"]: | ||
if agg["name"] == aggregation: | ||
aggregation_data = agg["buckets"] | ||
|
||
return aggregation_data | ||
|
||
def get_supported_aggregations(self) -> list[str]: | ||
"""Get the supported aggregation types. | ||
Returns: | ||
List[str]: The supported aggregations. | ||
""" | ||
response = self._stac_io.read_json(self.get_aggregations_link()) | ||
aggregations = response.get("aggregations", []) | ||
return [agg["name"] for agg in aggregations] | ||
|
||
def get_aggregations_link(self) -> Optional[pystac.Link]: | ||
"""Returns this client's aggregations link. | ||
Returns: | ||
Optional[pystac.Link]: The aggregations link, or None if there is not one found. | ||
""" | ||
return next( | ||
( | ||
link | ||
for link in self.links | ||
if link.rel == "aggregations" | ||
and link.media_type == pystac.MediaType.JSON | ||
), | ||
None, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters