Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix unwrap 1.11 regression + performance improvements #576

Merged
merged 8 commits into from
Nov 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 75 additions & 87 deletions src/unwrap.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
module Unwrap
using Random: GLOBAL_RNG, AbstractRNG
using Random: AbstractRNG, default_rng
export unwrap, unwrap!

"""
Expand Down Expand Up @@ -62,7 +62,7 @@ of an image, as each pixel is wrapped to stay within (-pi, pi].
- `circular_dims=(false, ...)`: When an element of this tuple is `true`, the
unwrapping process will consider the edges along the corresponding axis
of the array to be connected.
- `rng=GLOBAL_RNG`: Unwrapping of arrays with dimension > 1 uses a random
- `rng=default_rng()`: Unwrapping of arrays with dimension > 1 uses a random
initialization. A user can pass their own RNG through this argument.
"""
unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...)
Expand All @@ -80,7 +80,7 @@ unwrap(m::AbstractArray; kwargs...) = unwrap!(similar(m), m; kwargs...)

mutable struct Pixel{T}
periods::Int
val::T
const val::T
reliability::Float64
groupsize::Int
head::Pixel{T}
Expand All @@ -97,38 +97,39 @@ end
Pixel(v, rng) = Pixel{typeof(v)}(0, v, rand(rng), 1)
@inline Base.length(p::Pixel) = p.head.groupsize

struct Edge{N}
struct Edge{T}
reliability::Float64
periods::Int
pixel_1::CartesianIndex{N}
pixel_2::CartesianIndex{N}
pixel_1::Pixel{T}
pixel_2::Pixel{T}
end
function Edge{N}(pixel_image::AbstractArray, ind1::CartesianIndex{N}, ind2::CartesianIndex{N}, range) where N
@inbounds rel = pixel_image[ind1].reliability + pixel_image[ind2].reliability
@inbounds periods = find_period(pixel_image[ind1].val, pixel_image[ind2].val, range)
return Edge{N}(rel, periods, ind1, ind2)
function Edge{T}(p1::Pixel{T}, p2::Pixel{T}, range) where {T}
rel = p1.reliability + p2.reliability
periods = find_period(p1.val, p2.val, range)
return Edge{T}(rel, periods, p1, p2)
end
@inline Base.isless(e1::Edge, e2::Edge) = isless(e1.reliability, e2.reliability)

function unwrap_nd!(dest::AbstractArray{T, N},
src::AbstractArray{T, N};
range::Number=2*convert(T, pi),
circular_dims::NTuple{N, Bool}=ntuple(_->false, Val(N)),
rng::AbstractRNG=GLOBAL_RNG) where {T, N}
rng::AbstractRNG=default_rng()) where {T, N}

range_T = convert(T, range)

pixel_image = init_pixels(src, rng)
calculate_reliability(pixel_image, circular_dims, range_T)
edges = Edge{N}[]
edges = Edge{T}[]
num_edges = _predict_num_edges(size(src), circular_dims)
sizehint!(edges, num_edges)
for idx_dim=1:N
for idx_dim = 1:N
populate_edges!(edges, pixel_image, idx_dim, circular_dims[idx_dim], range_T)
end

sort!(edges, alg=MergeSort)
gather_pixels!(pixel_image, edges)
perm = sortperm(map(x -> x.reliability, edges); alg=MergeSort)
edges = edges[perm]
Comment on lines -130 to +131
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this better? And why the map? Shouldn't the custom isless take care of that? And if not, wouldn't a by= be better by avoiding a temporary array?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I found it faster in empirical testing. I'm not sure why, some cache locality issues when sorting? Because Edges are 4 times the size of floats.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are also allocations elsewhere that could be reduced, although I feel that would obscure the code.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure why, some cache locality issues when sorting? Because Edges are 4 times the size of floats.

Ah yes, that makes sense. Less data movement by doing sortperm and then indexing.

There are also allocations elsewhere that could be reduced, although I feel that would obscure the code.

Right, but

perm = sortperm(edges; by=x -> x.reliability, alg=MergeSort)

isn't exactly more obscure? Or even perm = sortperm(edges; alg=MergeSort), but relying on the isless method might be considered a bit obscure. Anyway, my approval to this PR persists, so feel free to merge if you don't want change this; it's really not important.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right, but

perm = sortperm(edges; by=x -> x.reliability, alg=MergeSort)

isn't exactly more obscure?

Oh sorry, I was referring to something else. Missed the part about by=.
I tried it out just now, and using the by keyword with sortperm is even slower than the original sort!. I also tried changing the isless method to use < (which may not be correct in case of NaNs and Inf) instead of isless, with good results, so at this point I think it might also be an inlining issue, but investigating that probably involves looking at the sorting machinery.

julia> @btime unwrap(A; dims=1:2) setup=A=rand(500,500);
  153.364 ms (250102 allocations: 55.25 MiB)  # with by kw
  93.429 ms (250105 allocations: 59.06 MiB)   # PR
  94.014 ms (250096 allocations: 41.93 MiB)   # sort!, with <

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, whatever... Merge as is, then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure.

gather_pixels!(edges)
unwrap_image!(dest, pixel_image, range_T)

return dest
Expand All @@ -145,80 +146,80 @@ end
# function to broadcast
function init_pixels(wrapped_image::AbstractArray{T, N}, rng) where {T, N}
pixel_image = similar(wrapped_image, Pixel{T})
Threads.@threads for i in eachindex(wrapped_image)
@inbounds pixel_image[i] = Pixel(wrapped_image[i], rng)
Threads.@threads for i in eachindex(wrapped_image, pixel_image)
pixel_image[i] = Pixel(wrapped_image[i], rng)
end
return pixel_image
end

function gather_pixels!(pixel_image, edges)
function gather_pixels!(edges)
for edge in edges
@inbounds p1 = pixel_image[edge.pixel_1]
@inbounds p2 = pixel_image[edge.pixel_2]
merge_groups!(edge, p1, p2)
p1 = edge.pixel_1
p2 = edge.pixel_2
if is_differentgroup(p1, p2)
periods = edge.periods
merge_groups!(periods, p1, p2)
end
end
end

function unwrap_image!(dest, pixel_image, range)
Threads.@threads for i in eachindex(dest)
@inbounds dest[i] = muladd(range, pixel_image[i].periods, pixel_image[i].val)
Threads.@threads for i in eachindex(dest, pixel_image)
p = pixel_image[i]
dest[i] = muladd(range, p.periods, p.val)
end
end

function wrap_val(val, range)
wrapped_val = val
wrapped_val += ifelse(val > range/2, -range, zero(val))
wrapped_val += ifelse(val < -range/2, range, zero(val))
wrapped_val -= ifelse(val > range / 2, range, zero(val))
wrapped_val += ifelse(val < -range / 2, range, zero(val))
return wrapped_val
end

function find_period(val_left, val_right, range)
difference = val_left - val_right
period = 0
period += ifelse(difference > range/2, -1, 0)
period += ifelse(difference < -range/2, 1, 0)
period -= (difference > range / 2)
period += (difference < -range / 2)
return period
end

function merge_groups!(edge, pixel_1, pixel_2)
if is_differentgroup(pixel_1, pixel_2)
# pixel 2 is alone in group
if is_pixelalone(pixel_2)
merge_pixels!(pixel_1, pixel_2, -edge.periods)
elseif is_pixelalone(pixel_1)
merge_pixels!(pixel_2, pixel_1, edge.periods)
function merge_groups!(periods, base, target)
# target is alone in group
if is_pixelalone(target)
periods = -periods
elseif is_pixelalone(base)
base, target = target, base
else
if is_bigger(base, target)
periods = -periods
else
if is_bigger(pixel_1, pixel_2)
merge_into_group!(pixel_1, pixel_2, -edge.periods)
else
merge_into_group!(pixel_2, pixel_1, edge.periods)
end
base, target = target, base
end
merge_into_group!(base, target, periods)
return
end
merge_pixels!(base, target, periods)
end

@inline function is_differentgroup(p1::Pixel, p2::Pixel)
return p1.head !== p2.head
end
@inline function is_pixelalone(pixel::Pixel)
return pixel.head === pixel.last
end
@inline function is_bigger(p1::Pixel, p2::Pixel)
return length(p1) ≥ length(p2)
end
@inline is_differentgroup(p1::Pixel, p2::Pixel) = p1.head !== p2.head
@inline is_pixelalone(pixel::Pixel) = pixel.head === pixel.last
@inline is_bigger(p1::Pixel, p2::Pixel) = length(p1) ≥ length(p2)

function merge_pixels!(pixel_base::Pixel, pixel_target::Pixel, periods)
pixel_base.head.groupsize += pixel_target.head.groupsize
pixel_base.head.last.next = pixel_target.head
pixel_base.head.last = pixel_target.head.last
pixel_target.head = pixel_base.head
pixel_target.periods = pixel_base.periods + periods
return nothing
end

function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods)
add_periods = pixel_base.periods + periods - pixel_target.periods
pixel = pixel_target.head
while pixel ≠ nothing
while !isnothing(pixel)
# merge all pixels in pixel_target's group to pixel_base's group
if pixel !== pixel_target
pixel.periods += add_periods
Expand All @@ -230,37 +231,33 @@ function merge_into_group!(pixel_base::Pixel, pixel_target::Pixel, periods)
merge_pixels!(pixel_base, pixel_target, periods)
end

function populate_edges!(edges, pixel_image::Array{T, N}, dim, connected, range) where {T, N}
size_img = collect(size(pixel_image))
size_img[dim] -= 1
idx_step = fill(0, N)
idx_step[dim] += 1
idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step))
idx_size = CartesianIndex{N}(NTuple{N,Int}(size_img))
for i in CartesianIndices(idx_size)
push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range))
function populate_edges!(edges::Vector{Edge{T}}, pixel_image::AbstractArray{Pixel{T},N}, dim, connected, range) where {T,N}
idx_step = ntuple(i -> Int(i == dim), Val(N))
idx_step_cart = CartesianIndex{N}(idx_step)
image_inds = CartesianIndices(pixel_image)
fi, li = first(image_inds), last(image_inds)
for i in fi:li-idx_step_cart
push!(edges, Edge{T}(pixel_image[i], pixel_image[i + idx_step_cart], range))
end
if connected
idx_step = fill!(idx_step, 0)
idx_step[dim] = -size_img[dim]
idx_step_cart = CartesianIndex{N}(NTuple{N,Int}(idx_step))
edge_begin = ones(Int, N)
edge_begin[dim] = size(pixel_image)[dim]
edge_begin_cart = CartesianIndex{N}(NTuple{N,Int}(edge_begin))
for i in CartesianIndices(ntuple(dim_idx -> edge_begin_cart[dim_idx]:size(pixel_image, dim_idx), N))
push!(edges, Edge{N}(pixel_image, i, i+idx_step_cart, range))
idx_step_cart *= size(pixel_image, dim) - 1
for i in fi+idx_step_cart:li
push!(edges, Edge{T}(pixel_image[i], pixel_image[i - idx_step_cart], range))
end
end
end

function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims, range) where {T, N}
# get the shifted pixel indices in CartesinanIndex form
# This gets all the nearest neighbors (CartesionIndex{N}() = one(CartesianIndex{N}))
pixel_shifts = CartesianIndices(ntuple(i -> -1:1, N))
one_cart = oneunit(CartesianIndex{N})
pixel_shifts = -one_cart:one_cart
image_inds = CartesianIndices(pixel_image)
fi, li = first(image_inds) + one_cart, last(image_inds) - one_cart
size_img = size(pixel_image)
# inner loop
for i in CartesianIndices(ntuple(dim -> 2:(size(pixel_image, dim)-1), N))
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range)
for i in fi:li
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts, range)
end

if !(true in circular_dims)
Expand All @@ -276,54 +273,45 @@ function calculate_reliability(pixel_image::AbstractArray{T, N}, circular_dims,
for (idx_ps, ps) in enumerate(pixel_shifts_border)
# if the pixel shift goes out of bounds, we make the shift wrap
if ps[idx_dim] == 1
fill!(new_ps, 0)
new_ps[idx_dim] = -size_img[idx_dim]+1
pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps))
new_ps[idx_dim] = 0
end
end
border_range = get_border_range(size_img, idx_dim, size_img[idx_dim])
border_range = get_border_range(fi:li, idx_dim, li[idx_dim] + 1)
for i in CartesianIndices(border_range)
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
end
# second border
pixel_shifts_border = copyto!(pixel_shifts_border, pixel_shifts)
for (idx_ps, ps) in enumerate(pixel_shifts_border)
# if the pixel shift goes out of bounds, we make the shift wrap, this time to the other side
if ps[idx_dim] == -1
fill!(new_ps, 0)
new_ps[idx_dim] = size_img[idx_dim]-1
pixel_shifts_border[idx_ps] = CartesianIndex{N}(NTuple{N,Int}(new_ps))
new_ps[idx_dim] = 0
end
end
border_range = get_border_range(size_img, idx_dim, 1)
border_range = get_border_range(fi:li, idx_dim, fi[idx_dim] - 1)
for i in CartesianIndices(border_range)
@inbounds pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
pixel_image[i].reliability = calculate_pixel_reliability(pixel_image, i, pixel_shifts_border, range)
end
end
end
end

function get_border_range(size_img::NTuple{N, T}, border_dim, border_idx) where {N, T}
border_range = [2:(size_img[dim]-1) for dim=1:N]
function get_border_range(C::CartesianIndices{N}, border_dim, border_idx) where {N}
border_range = [C.indices[dim] for dim=1:N]
border_range[border_dim] = border_idx:border_idx
return NTuple{N,UnitRange{Int}}(border_range)
end

function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T},N}, pixel_index, pixel_shifts, range) where {T,N}
pix_val = pixel_image[pixel_index].val
rel_contrib(shift) = @inbounds wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2
rel_contrib(shift) = wrap_val(pixel_image[pixel_index+shift].val - pix_val, range)^2
# for N=3, pixel_shifts[14] is null shift, can avoid if manually unrolling loop
sum_val = sum(rel_contrib, pixel_shifts)
return sum_val
end

# specialized pixel reliability calculations for different N
@inbounds function calculate_pixel_reliability(pixel_image::AbstractArray{Pixel{T}, 2}, pixel_index, pixel_shifts, range) where T
D1 = wrap_val(pixel_image[pixel_index+pixel_shifts[2]].val - pixel_image[pixel_index].val, range)
D2 = wrap_val(pixel_image[pixel_index+pixel_shifts[4]].val - pixel_image[pixel_index].val, range)
H = wrap_val(pixel_image[pixel_index+pixel_shifts[6]].val - pixel_image[pixel_index].val, range)
V = wrap_val(pixel_image[pixel_index+pixel_shifts[8]].val - pixel_image[pixel_index].val, range)
return H*H + V*V + D1*D1 + D2*D2
end

end
1 change: 1 addition & 0 deletions test/unwrap.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using DSP, Test
using Random: MersenneTwister

@testset "Unwrap 1D" begin
@test unwrap([0.1, 0.2, 0.3, 0.4]) ≈ [0.1, 0.2, 0.3, 0.4]
Expand Down
Loading