From 40aede683cc512a01dd2334d9de2ff5b17429c57 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonatan=20K=C5=82osko?= Date: Wed, 21 Feb 2024 13:11:08 +0100 Subject: [PATCH] Add option to control anti-aliasing in the resize layer (#555) --- lib/axon.ex | 8 ++++-- lib/axon/layers.ex | 52 +++++++++++++++++++++++++++++++++------ test/axon/layers_test.exs | 40 ++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 10 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 601b58ea..7a761fcb 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -390,7 +390,7 @@ defmodule Axon do You may specify the parameter shape as either a static shape or as function of the inputs to the given layer. If you specify the - parameter shape as a function, it will be given the + parameter shape as a function, it will be given the ## Options @@ -2122,18 +2122,22 @@ defmodule Axon do * `:method` - resize method. Defaults to `:nearest`. + * `:antialias` - whether an anti-aliasing filter should be used + when downsampling. Defaults to `true`. + * `:channels` - channel configuration. One of `:first` or `:last`. Defaults to `:last`. """ @doc type: :shape def resize(%Axon{} = x, resize_shape, opts \\ []) do - opts = Keyword.validate!(opts, [:name, method: :nearest, channels: :last]) + opts = Keyword.validate!(opts, [:name, method: :nearest, antialias: true, channels: :last]) channels = opts[:channels] layer(:resize, [x], name: opts[:name], method: opts[:method], + antialias: opts[:antialias], channels: channels, size: resize_shape, op_name: :resize diff --git a/lib/axon/layers.ex b/lib/axon/layers.ex index facec674..66a0bce5 100644 --- a/lib/axon/layers.ex +++ b/lib/axon/layers.ex @@ -1915,8 +1915,21 @@ defmodule Axon.Layers do must be at least rank 3, with fixed `batch` and `channel` dimensions. Resizing will upsample or downsample using the given resize method. - Supported resize methods are `:nearest, :linear, :bilinear, :trilinear, - :cubic, :bicubic, :tricubic`. + ## Options + + * `:size` - a tuple specifying the resized spatial dimensions. + Required. + + * `:method` - the resizing method to use, either of `:nearest`, + `:bilinear`, `:bicubic`, `:lanczos3`, `:lanczos5`. Defaults to + `:nearest`. + + * `:antialias` - whether an anti-aliasing filter should be used + when downsampling. This has no effect with upsampling. Defaults + to `true`. + + * `:channels` - channels location, either `:first` or `:last`. + Defaults to `:last`. ## Examples @@ -1951,6 +1964,7 @@ defmodule Axon.Layers do :size, method: :nearest, channels: :last, + antialias: true, mode: :inference ]) @@ -1962,22 +1976,36 @@ defmodule Axon.Layers do {axis, put_elem(out_shape, axis, out_size)} end) + antialias = opts[:antialias] + resized_input = case opts[:method] do :nearest -> resize_nearest(input, out_shape, spatial_axes) :bilinear -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_linear_kernel/1) + resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_linear_kernel/1) :bicubic -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_cubic_kernel/1) + resize_with_kernel(input, out_shape, spatial_axes, antialias, &fill_cubic_kernel/1) :lanczos3 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(3, &1)) + resize_with_kernel( + input, + out_shape, + spatial_axes, + antialias, + &fill_lanczos_kernel(3, &1) + ) :lanczos5 -> - resize_with_kernel(input, out_shape, spatial_axes, &fill_lanczos_kernel(5, &1)) + resize_with_kernel( + input, + out_shape, + spatial_axes, + antialias, + &fill_lanczos_kernel(5, &1) + ) method -> raise ArgumentError, @@ -2038,12 +2066,13 @@ defmodule Axon.Layers do @f32_eps :math.pow(2, -23) - deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do + deftransformp resize_with_kernel(input, out_shape, spatial_axes, antialias, kernel_fun) do for axis <- spatial_axes, reduce: input do input -> resize_axis_with_kernel(input, axis: axis, output_size: elem(out_shape, axis), + antialias: antialias, kernel_fun: kernel_fun ) end @@ -2052,12 +2081,19 @@ defmodule Axon.Layers do defnp resize_axis_with_kernel(input, opts) do axis = opts[:axis] output_size = opts[:output_size] + antialias = opts[:antialias] kernel_fun = opts[:kernel_fun] input_size = Nx.axis_size(input, axis) inv_scale = input_size / output_size - kernel_scale = max(1, inv_scale) + + kernel_scale = + if antialias do + max(1, inv_scale) + else + 1 + end sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5 x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale diff --git a/test/axon/layers_test.exs b/test/axon/layers_test.exs index 3c9db59a..7a6264f9 100644 --- a/test/axon/layers_test.exs +++ b/test/axon/layers_test.exs @@ -1009,6 +1009,46 @@ defmodule Axon.LayersTest do atol: 1.0e-4 ) end + + test "without anti-aliasing" do + # Upscaling + + image = Nx.iota({1, 4, 4, 3}, type: :f32) + + assert_all_close( + Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false), + Nx.tensor([ + [ + [ + [[1.5427, 2.5427, 3.5427], [5.7341, 6.7341, 7.7341], [9.9256, 10.9256, 11.9256]], + [[18.3085, 19.3085, 20.3085], [22.5, 23.5, 24.5], [26.6915, 27.6915, 28.6915]], + [ + [35.0744, 36.0744, 37.0744], + [39.2659, 40.2659, 41.2659], + [43.4573, 44.4573, 45.4573] + ] + ] + ] + ]), + atol: 1.0e-4 + ) + + # Downscaling (no effect) + + image = Nx.iota({1, 2, 2, 3}, type: :f32) + + assert_all_close( + Axon.Layers.resize(image, size: {3, 3}, method: :bicubic, antialias: false), + Nx.tensor([ + [ + [[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]], + [[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]], + [[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]] + ] + ]), + atol: 1.0e-4 + ) + end end describe "lstm_cell" do