Skip to content

Commit

Permalink
replace ImageData.render with custom render function (#690)
Browse files Browse the repository at this point in the history
  • Loading branch information
vincentsarago authored Aug 30, 2023
1 parent f894c88 commit 7c9f899
Show file tree
Hide file tree
Showing 10 changed files with 309 additions and 103 deletions.
29 changes: 29 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,35 @@

* replace `buffer` and `color_formula` endpoint parameters by external dependencies (`BufferParams` and `ColorFormulaParams`)

* add `titiler.core.utils.render_image` which allow non-binary alpha band created with custom colormap. `render_image` replace `ImageData.render` method.

```python
# before
if cmap := colormap or dst_colormap:
image = image.apply_colormap(cmap)

if not format:
format = ImageType.jpeg if image.mask.all() else ImageType.png

content = image.render(
img_format=format.driver,
**format.profile,
**render_params,
)

# now
# render_image will:
# - apply the colormap
# - choose the right output format if `None`
# - create the binary data
content, media_type = render_image(
image,
output_format=format,
colormap=colormap or dst_colormap,
**render_params,
)
```

### titiler.extension

* rename `geom-densify-pts` to `geometry_densify` **breaking change**
Expand Down
5 changes: 5 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,8 @@ no_implicit_optional = true
strict_optional = true
namespace_packages = true
explicit_package_bases = true

[tool.pytest.ini_options]
filterwarnings = [
"ignore::rasterio.errors.NotGeoreferencedWarning",
]
14 changes: 3 additions & 11 deletions src/titiler/core/tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
"""``pytest`` configuration."""

import os
import warnings
from typing import Any, Dict

import pytest
import rasterio
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile

DATA_DIR = os.path.join(os.path.dirname(__file__), "fixtures")
Expand All @@ -25,15 +23,9 @@ def set_env(monkeypatch):

def parse_img(content: bytes) -> Dict[Any, Any]:
"""Read tile image and return metadata."""
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with MemoryFile(content) as mem:
with mem.open() as dst:
return dst.profile
with MemoryFile(content) as mem:
with mem.open() as dst:
return dst.profile


def mock_rasterio_open(asset):
Expand Down
28 changes: 6 additions & 22 deletions src/titiler/core/tests/test_algorithms.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
"""Test the Algorithms class."""

import json
import warnings

import numpy
from fastapi import Depends, FastAPI
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rio_tiler.models import ImageData
from starlette.responses import Response
Expand Down Expand Up @@ -89,16 +87,9 @@ def main(algorithm=Depends(default_algorithms.dependency)):
# MAPBOX Terrain RGB
response = client.get("/", params={"algorithm": "terrainrgb"})
assert response.status_code == 200

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with MemoryFile(response.content) as mem:
with mem.open() as dst:
data = dst.read().astype(numpy.float64)
with MemoryFile(response.content) as mem:
with mem.open() as dst:
data = dst.read().astype(numpy.float64)

# https://docs.mapbox.com/data/tilesets/guides/access-elevation-data/
elevation = -10000 + (((data[0] * 256 * 256) + (data[1] * 256) + data[2]) * 0.1)
Expand All @@ -107,16 +98,9 @@ def main(algorithm=Depends(default_algorithms.dependency)):
# TILEZEN Terrarium
response = client.get("/", params={"algorithm": "terrarium"})
assert response.status_code == 200

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with MemoryFile(response.content) as mem:
with mem.open() as dst:
data = dst.read().astype(numpy.float64)
with MemoryFile(response.content) as mem:
with mem.open() as dst:
data = dst.read().astype(numpy.float64)

# https://github.com/tilezen/joerd/blob/master/docs/formats.md#terrarium
elevation = (data[0] * 256 + data[1] + data[2] / 256) - 32768
Expand Down
28 changes: 10 additions & 18 deletions src/titiler/core/tests/test_factories.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from fastapi import Depends, FastAPI, HTTPException, Path, Query, security, status
from morecantile.defaults import TileMatrixSets
from rasterio.crs import CRS
from rasterio.errors import NotGeoreferencedWarning
from rasterio.io import MemoryFile
from rio_tiler.errors import NoOverviewWarning
from rio_tiler.io import BaseReader, MultiBandReader, Reader, STACReader
Expand Down Expand Up @@ -1540,23 +1539,16 @@ def test_AutoFormat_Colormap():
response = client.get(f"/preview?url={DATA_DIR}/cog.tif&bidx=1&{cmap}")
assert response.status_code == 200
assert response.headers["content-type"] == "image/png"

with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
category=NotGeoreferencedWarning,
module="rasterio",
)
with MemoryFile(response.content) as mem:
with mem.open() as dst:
img = dst.read()
assert img[:, 0, 0].tolist() == [
0,
0,
0,
0,
] # when creating a PNG, GDAL will set masked value to 0
assert img[:, 500, 500].tolist() == [255, 0, 0, 255]
with MemoryFile(response.content) as mem:
with mem.open() as dst:
img = dst.read()
assert img[:, 0, 0].tolist() == [
0,
0,
0,
0,
] # when creating a PNG, GDAL will set masked value to 0
assert img[:, 500, 500].tolist() == [255, 0, 0, 255]


def test_rescale_dependency():
Expand Down
105 changes: 105 additions & 0 deletions src/titiler/core/tests/test_rendering.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""test titiler rendering function."""

import warnings

import numpy
import pytest
from rasterio.io import MemoryFile
from rio_tiler.errors import InvalidDatatypeWarning
from rio_tiler.models import ImageData

from titiler.core.resources.enums import ImageType
from titiler.core.utils import render_image


def test_rendering():
"""test rendering."""
im = ImageData(numpy.zeros((1, 256, 256), dtype="uint8"))

# Should render as JPEG
content, media = render_image(im)
assert media == "image/jpeg"
with MemoryFile(content) as mem:
with mem.open() as dst:
assert dst.profile["driver"] == "JPEG"
assert dst.count == 1
assert dst.width == 256
assert dst.height == 256
arr = dst.read()
assert numpy.unique(arr).tolist() == [0]

# Should render as PNG
content, media = render_image(im, output_format=ImageType.png)
assert media == "image/png"
with MemoryFile(content) as mem:
with mem.open() as dst:
assert dst.profile["driver"] == "PNG"
assert dst.count == 2
arr = dst.read()
assert numpy.unique(arr[0]).tolist() == [0]

with pytest.warns(InvalidDatatypeWarning):
_, media = render_image(
ImageData(numpy.zeros((1, 256, 256), dtype="uint16")),
output_format=ImageType.jpeg,
)
assert media == "image/jpeg"

with pytest.warns(InvalidDatatypeWarning):
_, media = render_image(
ImageData(numpy.zeros((1, 256, 256), dtype="float32")),
output_format=ImageType.png,
)
assert media == "image/png"

with pytest.warns(InvalidDatatypeWarning):
_, media = render_image(
ImageData(numpy.zeros((1, 256, 256), dtype="float32")),
output_format=ImageType.jp2,
)
assert media == "image/jp2"

# Make sure that we do not rescale uint16 data when there is a colormap
# Because the colormap will result in data between 0 and 255 it should be of type uint8
with warnings.catch_warnings():
warnings.simplefilter("error")
cm = {1: (0, 0, 0, 255), 1000: (255, 255, 255, 255)}
d = numpy.zeros((1, 256, 256), dtype="float32") + 1
d[0, 0:10, 0:10] = 1000
content, media = render_image(
ImageData(d),
output_format=ImageType.jpeg,
colormap=cm,
)
assert media == "image/jpeg"

with MemoryFile(content) as mem:
with mem.open() as dst:
assert dst.count == 3
assert dst.dtypes == ("uint8", "uint8", "uint8")
assert dst.read()[:, 0, 0].tolist() == [255, 255, 255]
assert dst.read()[:, 11, 11].tolist() == [0, 0, 0]

# Partial alpha values
cm = {
1: (0, 0, 0, 0),
500: (100, 100, 100, 50),
1000: (255, 255, 255, 255),
}
d = numpy.ma.zeros((1, 256, 256), dtype="float32") + 1
d[0, 0:10, 0:10] = 500
d[0, 10:20, 10:20] = 1000
content, media = render_image(
ImageData(d),
output_format=ImageType.png,
colormap=cm,
)
assert media == "image/png"

with MemoryFile(content) as mem:
with mem.open() as dst:
assert dst.count == 4
assert dst.dtypes == ("uint8", "uint8", "uint8", "uint8")
assert dst.read()[:, 0, 0].tolist() == [100, 100, 100, 50]
assert dst.read()[:, 11, 11].tolist() == [255, 255, 255, 255]
assert dst.read()[:, 30, 30].tolist() == [0, 0, 0, 0]
58 changes: 21 additions & 37 deletions src/titiler/core/titiler/core/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
from titiler.core.resources.enums import ImageType, MediaType, OptionalHeader
from titiler.core.resources.responses import GeoJSONResponse, JSONResponse, XMLResponse
from titiler.core.routing import EndpointScope
from titiler.core.utils import render_image

DEFAULT_TEMPLATES = Jinja2Templates(
directory="",
Expand Down Expand Up @@ -556,19 +557,14 @@ def tile(
if color_formula:
image.apply_color_formula(color_formula)

if cmap := colormap or dst_colormap:
image = image.apply_colormap(cmap)

if not format:
format = ImageType.jpeg if image.mask.all() else ImageType.png

content = image.render(
img_format=format.driver,
**format.profile,
content, media_type = render_image(
image,
output_format=format,
colormap=colormap or dst_colormap,
**render_params,
)

return Response(content, media_type=format.mediatype)
return Response(content, media_type=media_type)

def tilejson(self): # noqa: C901
"""Register /tilejson.json endpoint."""
Expand Down Expand Up @@ -916,19 +912,14 @@ def preview(
if color_formula:
image.apply_color_formula(color_formula)

if cmap := colormap or dst_colormap:
image = image.apply_colormap(cmap)

if not format:
format = ImageType.jpeg if image.mask.all() else ImageType.png

content = image.render(
img_format=format.driver,
**format.profile,
content, media_type = render_image(
image,
output_format=format,
colormap=colormap or dst_colormap,
**render_params,
)

return Response(content, media_type=format.mediatype)
return Response(content, media_type=media_type)

############################################################################
# /crop (Optional)
Expand Down Expand Up @@ -990,16 +981,14 @@ def part(
if color_formula:
image.apply_color_formula(color_formula)

if cmap := colormap or dst_colormap:
image = image.apply_colormap(cmap)

content = image.render(
img_format=format.driver,
**format.profile,
content, media_type = render_image(
image,
output_format=format,
colormap=colormap or dst_colormap,
**render_params,
)

return Response(content, media_type=format.mediatype)
return Response(content, media_type=media_type)

# POST endpoints
@self.router.post(
Expand Down Expand Up @@ -1054,19 +1043,14 @@ def geojson_crop(
if color_formula:
image.apply_color_formula(color_formula)

if cmap := colormap or dst_colormap:
image = image.apply_colormap(cmap)

if not format:
format = ImageType.jpeg if image.mask.all() else ImageType.png

content = image.render(
img_format=format.driver,
**format.profile,
content, media_type = render_image(
image,
output_format=format,
colormap=colormap or dst_colormap,
**render_params,
)

return Response(content, media_type=format.mediatype)
return Response(content, media_type=media_type)


@dataclass
Expand Down
Loading

0 comments on commit 7c9f899

Please sign in to comment.