Skip to content

Commit

Permalink
feat: extended pystac client to support aggregations stac-api extensi…
Browse files Browse the repository at this point in the history
  • Loading branch information
jverrydt committed Oct 31, 2024
1 parent 7214394 commit 51af1fe
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 2 deletions.
17 changes: 17 additions & 0 deletions tests/fixtures/catalog.json
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,23 @@
"rel": "self",
"type": "application/json",
"href": "https://stac.endpoint.io/collections"
},
{
"rel": "data",
"type": "application/json",
"href": "https://stac.endpoint.io/collections"
},
{
"rel": "aggregate",
"type": "application/json",
"title": "Aggregate",
"href": "https://stac.endpoint.io/aggregate"
},
{
"rel": "aggregations",
"type": "application/json",
"title": "Aggregations",
"href": "https://stac.endpoint.io/aggregations"
}
]
}
119 changes: 119 additions & 0 deletions tests/test_advanced_pystac_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
"""Test Advanced PySTAC client."""
import json
import os
from unittest.mock import MagicMock, patch

import pytest

from titiler.pystac import AdvancedClient

catalog_json = os.path.join(os.path.dirname(__file__), "fixtures", "catalog.json")


@pytest.fixture
def mock_stac_io():
"""STAC IO mock"""
return MagicMock()


@pytest.fixture
def client(mock_stac_io):
"""STAC client mock"""
client = AdvancedClient(id="pystac-client", description="pystac-client")

with open(catalog_json, "r") as f:
catalog = json.loads(f.read())
client.open = MagicMock()
client.open.return_value = catalog
client._collections_href = MagicMock()
client._collections_href.return_value = "http://example.com/collections"

client._stac_io = mock_stac_io
return client


def test_get_supported_aggregations(client, mock_stac_io):
"""Test supported STAC aggregation methods"""
mock_stac_io.read_json.return_value = {
"aggregations": [{"name": "aggregation1"}, {"name": "aggregation2"}]
}
supported_aggregations = client.get_supported_aggregations()
assert supported_aggregations == ["aggregation1", "aggregation2"]


@patch(
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
return_value=["datetime_frequency"],
)
def test_get_aggregation_unsupported(supported_aggregations, client):
"""Test handling of unsupported aggregation types"""
collection_id = "sentinel-2-l2a"
aggregation = "unsupported-aggregation"

with pytest.warns(
UserWarning, match="Aggregation type unsupported-aggregation is not supported"
):
aggregation_data = client.get_aggregation(collection_id, aggregation)
assert aggregation_data == []


@patch(
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
return_value=["datetime_frequency"],
)
def test_get_aggregation(supported_aggregations, client, mock_stac_io):
"""Test handling aggregation response"""
collection_id = "sentinel-2-l2a"
aggregation = "datetime_frequency"
aggregation_params = {"datetime_frequency_interval": "day"}

mock_stac_io.read_json.return_value = {
"aggregations": [
{
"name": "datetime_frequency",
"buckets": [
{
"key": "2023-12-11T00:00:00.000Z",
"data_type": "frequency_distribution",
"frequency": 1,
"to": None,
"from": None,
}
],
},
{
"name": "unusable_aggregation",
"buckets": [
{
"key": "2023-12-11T00:00:00.000Z",
}
],
},
]
}

aggregation_data = client.get_aggregation(
collection_id, aggregation, aggregation_params
)
assert aggregation_data[0]["key"] == "2023-12-11T00:00:00.000Z"
assert aggregation_data[0]["data_type"] == "frequency_distribution"
assert aggregation_data[0]["frequency"] == 1
assert len(aggregation_data) == 1


@patch(
"titiler.pystac.advanced_client.AdvancedClient.get_supported_aggregations",
return_value=["datetime_frequency"],
)
def test_get_aggregation_no_response(supported_aggregations, client, mock_stac_io):
"""Test handling of no aggregation response"""
collection_id = "sentinel-2-l2a"
aggregation = "datetime_frequency"
aggregation_params = {"datetime_frequency_interval": "day"}

mock_stac_io.read_json.return_value = []

aggregation_data = client.get_aggregation(
collection_id, aggregation, aggregation_params
)
assert aggregation_data == []
7 changes: 7 additions & 0 deletions titiler/pystac/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
"""titiler.pystac"""

__all__ = [
"AdvancedClient",
]

from titiler.pystac.advanced_client import AdvancedClient
87 changes: 87 additions & 0 deletions titiler/pystac/advanced_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
"""
This module provides an advanced client for interacting with STAC (SpatioTemporal Asset Catalog) APIs.
The `AdvancedClient` class extends the basic functionality of the `pystac.Client` to include
methods for retrieving and aggregating data from STAC collections.
"""

import warnings
from typing import Optional
from urllib.parse import urlencode

import pystac
from pystac_client import Client


class AdvancedClient(Client):
"""AdvancedClient extends the basic functionality of the pystac.Client class."""

def get_aggregation(
self,
collection_id: str,
aggregation: str,
aggregation_params: Optional[dict] = None,
) -> list[dict]:
"""Perform an aggregation on a STAC collection.
Args:
collection_id (str): The ID of the collection to aggregate.
aggregation (str): The aggregation type to perform.
aggregation_params (Optional[dict], optional): Additional parameters for the aggregation. Defaults to None.
Returns:
List[str]: The aggregation response.
"""
assert self._stac_io is not None

if aggregation not in self.get_supported_aggregations():
warnings.warn(
f"Aggregation type {aggregation} is not supported", stacklevel=1
)
return []

# Construct the URL for aggregation
url = (
self._collections_href(collection_id)
+ f"/aggregate?aggregations={aggregation}"
)
if aggregation_params:
params = urlencode(aggregation_params)
url += f"&{params}"

aggregation_response = self._stac_io.read_json(url)

if not aggregation_response:
return []

aggregation_data = []
for agg in aggregation_response["aggregations"]:
if agg["name"] == aggregation:
aggregation_data = agg["buckets"]

return aggregation_data

def get_supported_aggregations(self) -> list[str]:
"""Get the supported aggregation types.
Returns:
List[str]: The supported aggregations.
"""
response = self._stac_io.read_json(self.get_aggregations_link())
aggregations = response.get("aggregations", [])
return [agg["name"] for agg in aggregations]

def get_aggregations_link(self) -> Optional[pystac.Link]:
"""Returns this client's aggregations link.
Returns:
Optional[pystac.Link]: The aggregations link, or None if there is not one found.
"""
return next(
(
link
for link in self.links
if link.rel == "aggregations"
and link.media_type == pystac.MediaType.JSON
),
None,
)
19 changes: 17 additions & 2 deletions titiler/stacapi/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from morecantile import tms as morecantile_tms
from morecantile.defaults import TileMatrixSets
from pydantic import conint
from pystac_client import Client
from pystac_client.stac_api_io import StacApiIO
from rasterio.transform import xy as rowcol_to_coords
from rasterio.warp import transform as transform_points
Expand All @@ -45,6 +44,7 @@
from titiler.core.resources.responses import GeoJSONResponse, XMLResponse
from titiler.core.utils import render_image
from titiler.mosaic.factory import PixelSelectionParams
from titiler.pystac import AdvancedClient
from titiler.stacapi.backend import STACAPIBackend
from titiler.stacapi.dependencies import APIParams, STACApiParams, STACSearchParams
from titiler.stacapi.models import FeatureInfo, LayerDict
Expand Down Expand Up @@ -568,7 +568,7 @@ def get_layer_from_collections( # noqa: C901
),
headers=headers,
)
catalog = Client.open(url, stac_io=stac_api_io)
catalog = AdvancedClient.open(url, stac_io=stac_api_io)

layers: Dict[str, LayerDict] = {}
for collection in catalog.get_collections():
Expand All @@ -580,6 +580,7 @@ def get_layer_from_collections( # noqa: C901

tilematrixsets = render.pop("tilematrixsets", None)
output_format = render.pop("format", None)
aggregation = render.pop("aggregation", None)

_ = render.pop("minmax_zoom", None) # Not Used
_ = render.pop("title", None) # Not Used
Expand Down Expand Up @@ -643,6 +644,20 @@ def get_layer_from_collections( # noqa: C901
"values"
]
]
elif aggregation and aggregation["name"] == "datetime_frequency":
datetime_aggregation = catalog.get_aggregation(
collection_id=collection.id,
aggregation="datetime_frequency",
aggregation_params=aggregation["params"],
)
layer["time"] = [
python_datetime.datetime.strptime(
t["key"],
"%Y-%m-%dT%H:%M:%S.000Z",
).strftime("%Y-%m-%d")
for t in datetime_aggregation
if t["frequency"] > 0
]
elif intervals := temporal_extent.intervals:
start_date = intervals[0][0]
end_date = (
Expand Down

0 comments on commit 51af1fe

Please sign in to comment.