Skip to content

Commit

Permalink
Add option to control anti-aliasing (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Feb 19, 2024
1 parent 4dc88d5 commit c63f31e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 7 deletions.
40 changes: 33 additions & 7 deletions lib/nx_image.ex
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ defmodule NxImage do
`:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
`:bilinear`
* `:antialias` - whether an anti-aliasing filter should be used
when downsampling. This has no effect with upsampling. Defaults
to `true`
* `:channels` - channels location, either `:first` or `:last`.
Defaults to `:last`
Expand Down Expand Up @@ -148,7 +152,7 @@ defmodule NxImage do
"""
@doc type: :transformation
deftransform resize(input, size, opts \\ []) when is_tuple(size) do
opts = Keyword.validate!(opts, channels: :last, method: :bilinear)
opts = Keyword.validate!(opts, channels: :last, method: :bilinear, antialias: true)
validate_image!(input)

{spatial_axes, out_shape} =
Expand All @@ -159,22 +163,36 @@ defmodule NxImage do
{axis, put_elem(out_shape, axis, out_size)}
end)

antialias = opts[:antialias]

resized_input =
case opts[:method] do
:nearest ->
resize_nearest(input, out_shape, spatial_axes)

:bilinear ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1)
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_linear_kernel/1)

:bicubic ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1)
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_cubic_kernel/1)

:lanczos3 ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1))
resize_with_kernel(
input,
out_shape,
spatial_axes,
antialias,
&fill_lanczos_kernel(3, &1)
)

:lanczos5 ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1))
resize_with_kernel(
input,
out_shape,
spatial_axes,
antialias,
&fill_lanczos_kernel(5, &1)
)

method ->
raise ArgumentError,
Expand Down Expand Up @@ -236,12 +254,13 @@ defmodule NxImage do

@f32_eps :math.pow(2, -23)

deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
deftransformp resize_with_kernel(input, out_shape, spatial_axes, antialias, kernel_fun) do
for axis <- spatial_axes, reduce: input do
input ->
resize_axis_with_kernel(input,
axis: axis,
output_size: elem(out_shape, axis),
antialias: antialias,
kernel_fun: kernel_fun
)
end
Expand All @@ -250,12 +269,19 @@ defmodule NxImage do
defnp resize_axis_with_kernel(input, opts) do
axis = opts[:axis]
output_size = opts[:output_size]
antialias = opts[:antialias]
kernel_fun = opts[:kernel_fun]

input_size = Nx.axis_size(input, axis)

inv_scale = input_size / output_size
kernel_scale = max(1, inv_scale)

kernel_scale =
if antialias do
max(1, inv_scale)
else
1
end

sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
Expand Down
34 changes: 34 additions & 0 deletions test/nx_image_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,40 @@ defmodule NxImageTest do
)
end

test "without anti-aliasing" do
# Upscaling

image = Nx.iota({4, 4, 3}, type: :f32)

assert_all_close(
NxImage.resize(image, {3, 3}, method: :bicubic, antialias: false),
Nx.tensor([
[
[[1.5427, 2.5427, 3.5427], [5.7341, 6.7341, 7.7341], [9.9256, 10.9256, 11.9256]],
[[18.3085, 19.3085, 20.3085], [22.5, 23.5, 24.5], [26.6915, 27.6915, 28.6915]],
[
[35.0744, 36.0744, 37.0744],
[39.2659, 40.2659, 41.2659],
[43.4573, 44.4573, 45.4573]
]
]
])
)

# Downscaling (no effect)

image = Nx.iota({2, 2, 3}, type: :f32)

assert_all_close(
NxImage.resize(image, {3, 3}, method: :bicubic, antialias: false),
Nx.tensor([
[[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]],
[[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]],
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
])
)
end

test "accepts a batch" do
image = Nx.iota({2, 2, 3}, type: :f32)
resized_image = NxImage.resize(image, {3, 3})
Expand Down

0 comments on commit c63f31e

Please sign in to comment.