Skip to content

Commit

Permalink
Multi-argument support: basic infrastructure (#461)
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle authored Sep 14, 2024
1 parent 41bade9 commit 267023a
Show file tree
Hide file tree
Showing 31 changed files with 1,825 additions and 731 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Base: Fix1, Fix2
using Compat
import DifferentiationInterface as DI
using DifferentiationInterface:
Context,
DerivativeExtras,
GradientExtras,
HessianExtras,
Expand All @@ -16,7 +17,8 @@ using DifferentiationInterface:
SecondOrder,
Tangents,
inner,
outer
outer,
unwrap
using ForwardDiff.DiffResults:
DiffResults, DiffResult, GradientResult, HessianResult, MutableDiffResult
using ForwardDiff:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,36 @@ struct ForwardDiffOneArgPushforwardExtras{T,X} <: PushforwardExtras
xdual_tmp::X
end

function DI.prepare_pushforward(f::F, backend::AutoForwardDiff, x, tx::Tangents) where {F}
function DI.prepare_pushforward(
f::F, backend::AutoForwardDiff, x, tx::Tangents, contexts::Vararg{Context,C}
) where {F,C}
T = tag_type(f, backend, x)
xdual_tmp = make_dual_similar(T, x, tx)
return ForwardDiffOneArgPushforwardExtras{T,typeof(xdual_tmp)}(xdual_tmp)
end

function compute_ydual_onearg(
f::F, extras::ForwardDiffOneArgPushforwardExtras{T}, x::Number, tx::Tangents
) where {F,T}
f::F,
extras::ForwardDiffOneArgPushforwardExtras{T},
x::Number,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,T,C}
xdual_tmp = make_dual(T, x, tx)
ydual = f(xdual_tmp)
ydual = f(xdual_tmp, map(unwrap, contexts)...)
return ydual
end

function compute_ydual_onearg(
f::F, extras::ForwardDiffOneArgPushforwardExtras{T}, x, tx::Tangents
) where {F,T}
f::F,
extras::ForwardDiffOneArgPushforwardExtras{T},
x,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,T,C}
@compat (; xdual_tmp) = extras
make_dual!(T, xdual_tmp, x, tx)
ydual = f(xdual_tmp)
ydual = f(xdual_tmp, map(unwrap, contexts)...)
return ydual
end

Expand All @@ -33,8 +43,9 @@ function DI.value_and_pushforward(
::AutoForwardDiff,
x,
tx::Tangents{B},
) where {F,T,B}
ydual = compute_ydual_onearg(f, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,B,C}
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
y = myvalue(T, ydual)
ty = mypartials(T, Val(B), ydual)
return y, ty
Expand All @@ -47,8 +58,9 @@ function DI.value_and_pushforward!(
::AutoForwardDiff,
x,
tx::Tangents,
) where {F,T}
ydual = compute_ydual_onearg(f, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,C}
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
y = myvalue(T, ydual)
mypartials!(T, ty, ydual)
return y, ty
Expand All @@ -60,8 +72,9 @@ function DI.pushforward(
::AutoForwardDiff,
x,
tx::Tangents{B},
) where {F,T,B}
ydual = compute_ydual_onearg(f, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,B,C}
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
ty = mypartials(T, Val(B), ydual)
return ty
end
Expand All @@ -73,8 +86,9 @@ function DI.pushforward!(
::AutoForwardDiff,
x,
tx::Tangents,
) where {F,T}
ydual = compute_ydual_onearg(f, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,C}
ydual = compute_ydual_onearg(f, extras, x, tx, contexts...)
mypartials!(T, ty, ydual)
return ty
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,45 +26,60 @@ struct ForwardDiffOverSomethingHVPExtras{B<:AutoForwardDiff,G,E<:PushforwardExtr
end

function DI.prepare_hvp(
f::F, backend::SecondOrder{<:AutoForwardDiff}, x, tx::Tangents
) where {F}
f::F,
backend::SecondOrder{<:AutoForwardDiff},
x,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,C}
tagged_outer_backend = tag_backend_hvp(f, outer(backend), x)
T = tag_type(f, tagged_outer_backend, x)
xdual = make_dual(T, x, tx)
gradient_extras = DI.prepare_gradient(f, inner(backend), xdual)
inner_gradient(x) = DI.gradient(f, gradient_extras, inner(backend), x)
gradient_extras = DI.prepare_gradient(f, inner(backend), xdual, contexts...)
function inner_gradient(x, unannotated_contexts...)
annotated_contexts = map.(typeof.(contexts), unannotated_contexts)
return DI.gradient(f, gradient_extras, inner(backend), x, unannotated_contexts...)
end
outer_pushforward_extras = DI.prepare_pushforward(
inner_gradient, tagged_outer_backend, x, tx
inner_gradient, tagged_outer_backend, x, tx, contexts...
)
return ForwardDiffOverSomethingHVPExtras(
tagged_outer_backend, inner_gradient, outer_pushforward_extras
)
end

function DI.hvp(
f,
f::F,
extras::ForwardDiffOverSomethingHVPExtras,
::SecondOrder{<:AutoForwardDiff},
x,
tx::Tangents,
)
contexts::Vararg{Context,C},
) where {F,C}
@compat (; tagged_outer_backend, inner_gradient, outer_pushforward_extras) = extras
return DI.pushforward(
inner_gradient, outer_pushforward_extras, tagged_outer_backend, x, tx
inner_gradient, outer_pushforward_extras, tagged_outer_backend, x, tx, contexts...
)
end

function DI.hvp!(
f,
f::F,
tg::Tangents,
extras::ForwardDiffOverSomethingHVPExtras,
::SecondOrder{<:AutoForwardDiff},
x,
tx::Tangents,
)
contexts::Vararg{Context,C},
) where {F,C}
@compat (; tagged_outer_backend, inner_gradient, outer_pushforward_extras) = extras
DI.pushforward!(
inner_gradient, tg, outer_pushforward_extras, tagged_outer_backend, x, tx
inner_gradient,
tg,
outer_pushforward_extras,
tagged_outer_backend,
x,
tx,
contexts...,
)
return tg
end
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ struct ForwardDiffTwoArgPushforwardExtras{T,X,Y} <: PushforwardExtras
end

function DI.prepare_pushforward(
f!::F, y, backend::AutoForwardDiff, x, tx::Tangents
) where {F}
f!::F, y, backend::AutoForwardDiff, x, tx::Tangents, contexts::Vararg{Context,C}
) where {F,C}
T = tag_type(f!, backend, x)
xdual_tmp = make_dual_similar(T, x, tx)
ydual_tmp = make_dual_similar(T, y, tx) # dx only for batch size
Expand All @@ -17,20 +17,30 @@ function DI.prepare_pushforward(
end

function compute_ydual_twoarg(
f!::F, y, extras::ForwardDiffTwoArgPushforwardExtras{T}, x::Number, tx::Tangents
) where {F,T}
f!::F,
y,
extras::ForwardDiffTwoArgPushforwardExtras{T},
x::Number,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,T,C}
@compat (; ydual_tmp) = extras
xdual_tmp = make_dual(T, x, tx)
f!(ydual_tmp, xdual_tmp)
f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...)
return ydual_tmp
end

function compute_ydual_twoarg(
f!::F, y, extras::ForwardDiffTwoArgPushforwardExtras{T}, x, tx::Tangents
) where {F,T}
f!::F,
y,
extras::ForwardDiffTwoArgPushforwardExtras{T},
x,
tx::Tangents,
contexts::Vararg{Context,C},
) where {F,T,C}
@compat (; xdual_tmp, ydual_tmp) = extras
make_dual!(T, xdual_tmp, x, tx)
f!(ydual_tmp, xdual_tmp)
f!(ydual_tmp, xdual_tmp, map(unwrap, contexts)...)
return ydual_tmp
end

Expand All @@ -41,8 +51,9 @@ function DI.value_and_pushforward(
::AutoForwardDiff,
x,
tx::Tangents{B},
) where {F,T,B}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,B,C}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
myvalue!(T, y, ydual_tmp)
ty = mypartials(T, Val(B), ydual_tmp)
return y, ty
Expand All @@ -56,8 +67,9 @@ function DI.value_and_pushforward!(
::AutoForwardDiff,
x,
tx::Tangents,
) where {F,T}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,C}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
myvalue!(T, y, ydual_tmp)
mypartials!(T, ty, ydual_tmp)
return y, ty
Expand All @@ -70,8 +82,9 @@ function DI.pushforward(
::AutoForwardDiff,
x,
tx::Tangents{B},
) where {F,T,B}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,B,C}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
ty = mypartials(T, Val(B), ydual_tmp)
return ty
end
Expand All @@ -84,8 +97,9 @@ function DI.pushforward!(
::AutoForwardDiff,
x,
tx::Tangents,
) where {F,T}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx)
contexts::Vararg{Context,C},
) where {F,T,C}
ydual_tmp = compute_ydual_twoarg(f!, y, extras, x, tx, contexts...)
mypartials!(T, ty, ydual_tmp)
return ty
end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ using DifferentiationInterface:
maybe_outer,
multibasis,
pick_batchsize,
pushforward_performance
pushforward_performance,
unwrap
import DifferentiationInterface as DI
using SparseMatrixColorings:
AbstractColoringResult,
Expand Down
Loading

0 comments on commit 267023a

Please sign in to comment.