From 81380f0493581ef914cc5b2419885b2014c148c3 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Wed, 21 Sep 2022 18:05:17 +0200 Subject: [PATCH] Unify and simplify adjoints for `pairwise(::Euclidean, ...)` (#1310) * Unify and simplify adjoints for `pairwise(::Euclidean, ...)` * Base threshold on result instead of input arguments * Apply suggestions --- Project.toml | 2 +- src/lib/distances.jl | 35 ++++++++++++++--------------------- test/gradcheck.jl | 11 +++++++++++ 3 files changed, 26 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 599325e12..7d277f688 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Zygote" uuid = "e88e6eb3-aa80-5325-afca-941959d7151f" -version = "0.6.48" +version = "0.6.49" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/lib/distances.jl b/src/lib/distances.jl index ee39e9de6..1adf30d01 100644 --- a/src/lib/distances.jl +++ b/src/lib/distances.jl @@ -64,31 +64,24 @@ end end end -@adjoint function pairwise(::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) +_sqrt_if_positive(d, δ) = d > δ ? sqrt(d) : zero(d) +@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix, Y::AbstractMatrix; dims=2) # Modify the forwards-pass slightly to ensure stability on the reverse. - function _pairwise_euclidean(X, Y) - δ = eps(promote_type(eltype(X), eltype(Y)))^2 - return sqrt.(max.(pairwise(SqEuclidean(), X, Y; dims=dims), δ)) - end - D, back = pullback(_pairwise_euclidean, X, Y) - - return D, function(Δ) - return (nothing, back(Δ)...) + function _pairwise_euclidean(sqdist::SqEuclidean, X, Y) + D2 = pairwise(sqdist, X, Y; dims=dims) + δ = eps(eltype(D2)) + return _sqrt_if_positive.(D2, δ) end + return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X, Y) end -@adjoint function pairwise(::Euclidean, X::AbstractMatrix; dims=2) - - _conditional(d, δ) = d > δ ? sqrt(d) : zero(d) - - function _pairwise_euclidean(X) - δ = eps(eltype(X))^2 - D2 = pairwise(SqEuclidean(), X; dims=dims) - return _conditional.(D2, δ) +@adjoint function pairwise(dist::Euclidean, X::AbstractMatrix; dims=2) + # Modify the forwards-pass slightly to ensure stability on the reverse. + function _pairwise_euclidean(sqdist::SqEuclidean, X) + D2 = pairwise(sqdist, X; dims=dims) + δ = eps(eltype(D2)) + return _sqrt_if_positive.(D2, δ) end - D, back = pullback(_pairwise_euclidean, X) - - _pairwise_pullback(Δ) = (nothing, back(Δ)...) - return D, _pairwise_pullback + return pullback(_pairwise_euclidean, SqEuclidean(dist.thresh), X) end diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 3330d5927..3cc10ce82 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -1197,6 +1197,7 @@ end Δ = randn(P, P) X = repeat(randn(rng, D), 1, P) + # Single input matrix Δ_fd = FiniteDifferences.j′vp( FiniteDifferences.central_fdm(5, 1), X -> pairwise(metric, X; dims=2), Δ, X ) @@ -1204,6 +1205,16 @@ end # This is impressively inaccurate, but at least it doesn't produce a NaN. @test first(Δ_fd) ≈ first(pb(Δ)) atol=1e-3 rtol=1e-3 + + # Two input matrices + Y = copy(X) + Δ_fd = FiniteDifferences.j′vp( + FiniteDifferences.central_fdm(5, 1), X -> pairwise(metric, X, Y; dims=2), Δ, X + ) + _, pb = Zygote.pullback(X -> pairwise(metric, X, Y; dims=2), X) + + # This is impressively inaccurate, but at least it doesn't produce a NaN. + @test first(Δ_fd) ≈ first(pb(Δ)) atol=1e-3 rtol=1e-3 end end