Skip to content

Commit

Permalink
Add option to control anti-aliasing in the resize layer (#555)
Browse files Browse the repository at this point in the history
  • Loading branch information
jonatanklosko authored Feb 21, 2024
1 parent cfa77a3 commit 40aede6
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 10 deletions.
8 changes: 6 additions & 2 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ defmodule Axon do
You may specify the parameter shape as either a static shape or
as function of the inputs to the given layer. If you specify the
parameter shape as a function, it will be given the
parameter shape as a function, it will be given the
## Options
Expand Down Expand Up @@ -2122,18 +2122,22 @@ defmodule Axon do
* `:method` - resize method. Defaults to `:nearest`.
* `:antialias` - whether an anti-aliasing filter should be used
when downsampling. Defaults to `true`.
* `:channels` - channel configuration. One of `:first` or
`:last`. Defaults to `:last`.
"""
@doc type: :shape
def resize(%Axon{} = x, resize_shape, opts \\ []) do
opts = Keyword.validate!(opts, [:name, method: :nearest, channels: :last])
opts = Keyword.validate!(opts, [:name, method: :nearest, antialias: true, channels: :last])
channels = opts[:channels]

layer(:resize, [x],
name: opts[:name],
method: opts[:method],
antialias: opts[:antialias],
channels: channels,
size: resize_shape,
op_name: :resize
Expand Down
52 changes: 44 additions & 8 deletions lib/axon/layers.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1915,8 +1915,21 @@ defmodule Axon.Layers do
must be at least rank 3, with fixed `batch` and `channel` dimensions.
Resizing will upsample or downsample using the given resize method.
Supported resize methods are `:nearest, :linear, :bilinear, :trilinear,
:cubic, :bicubic, :tricubic`.
## Options
* `:size` - a tuple specifying the resized spatial dimensions.
Required.
* `:method` - the resizing method to use, either of `:nearest`,
`:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to
`:nearest`.
* `:antialias` - whether an anti-aliasing filter should be used
when downsampling. This has no effect with upsampling. Defaults
to `true`.
* `:channels` - channels location, either `:first` or `:last`.
Defaults to `:last`.
## Examples
Expand Down Expand Up @@ -1951,6 +1964,7 @@ defmodule Axon.Layers do
:size,
method: :nearest,
channels: :last,
antialias: true,
mode: :inference
])

Expand All @@ -1962,22 +1976,36 @@ defmodule Axon.Layers do
{axis, put_elem(out_shape, axis, out_size)}
end)

antialias = opts[:antialias]

resized_input =
case opts[:method] do
:nearest ->
resize_nearest(input, out_shape, spatial_axes)

:bilinear ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1)
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_linear_kernel/1)

:bicubic ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1)
resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_cubic_kernel/1)

:lanczos3 ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1))
resize_with_kernel(
input,
out_shape,
spatial_axes,
antialias,
&fill_lanczos_kernel(3, &1)
)

:lanczos5 ->
resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1))
resize_with_kernel(
input,
out_shape,
spatial_axes,
antialias,
&fill_lanczos_kernel(5, &1)
)

method ->
raise ArgumentError,
Expand Down Expand Up @@ -2038,12 +2066,13 @@ defmodule Axon.Layers do

@f32_eps :math.pow(2, -23)

deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
deftransformp resize_with_kernel(input, out_shape, spatial_axes, antialias, kernel_fun) do
for axis <- spatial_axes, reduce: input do
input ->
resize_axis_with_kernel(input,
axis: axis,
output_size: elem(out_shape, axis),
antialias: antialias,
kernel_fun: kernel_fun
)
end
Expand All @@ -2052,12 +2081,19 @@ defmodule Axon.Layers do
defnp resize_axis_with_kernel(input, opts) do
axis = opts[:axis]
output_size = opts[:output_size]
antialias = opts[:antialias]
kernel_fun = opts[:kernel_fun]

input_size = Nx.axis_size(input, axis)

inv_scale = input_size / output_size
kernel_scale = max(1, inv_scale)

kernel_scale =
if antialias do
max(1, inv_scale)
else
1
end

sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
Expand Down
40 changes: 40 additions & 0 deletions test/axon/layers_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1009,6 +1009,46 @@ defmodule Axon.LayersTest do
atol: 1.0e-4
)
end

test "without anti-aliasing" do
# Upscaling

image = Nx.iota({1, 4, 4, 3}, type: :f32)

assert_all_close(
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false),
Nx.tensor([
[
[
[[1.5427, 2.5427, 3.5427], [5.7341, 6.7341, 7.7341], [9.9256, 10.9256, 11.9256]],
[[18.3085, 19.3085, 20.3085], [22.5, 23.5, 24.5], [26.6915, 27.6915, 28.6915]],
[
[35.0744, 36.0744, 37.0744],
[39.2659, 40.2659, 41.2659],
[43.4573, 44.4573, 45.4573]
]
]
]
]),
atol: 1.0e-4
)

# Downscaling (no effect)

image = Nx.iota({1, 2, 2, 3}, type: :f32)

assert_all_close(
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false),
Nx.tensor([
[
[[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]],
[[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]],
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
]
]),
atol: 1.0e-4
)
end
end

describe "lstm_cell" do
Expand Down

0 comments on commit 40aede6

Please sign in to comment.