Skip to content

Commit

Permalink
Merge pull request #1358 from FluxML/ox/cr_ambib
Browse files Browse the repository at this point in the history
MethodError if configured rrule is ambiguous
  • Loading branch information
oxinabox authored Jan 17, 2023
2 parents 084ea4b + b3dea49 commit e2f0b5f
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 11 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Zygote"
uuid = "e88e6eb3-aa80-5325-afca-941959d7151f"
version = "0.6.54"
version = "0.6.55"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
15 changes: 10 additions & 5 deletions src/compiler/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ such that if a suitable rule is defined later, the generated function will recom
function has_chain_rrule(T)
config_T, arg_Ts = Iterators.peel(T.parameters)
configured_rrule_m = meta(Tuple{typeof(rrule), config_T, arg_Ts...})
if _is_rrule_redispatcher(configured_rrule_m.method)
is_ambig = configured_rrule_m === nothing # this means there was an ambiguity error, on configured_rrule


if !is_ambig && _is_rrule_redispatcher(configured_rrule_m.method)
# The config is not being used:
# it is being redispatched without config, so we need the method it redispatches to
rrule_m = meta(Tuple{typeof(rrule), arg_Ts...})
Expand All @@ -33,6 +36,8 @@ function has_chain_rrule(T)
no_rrule_m = meta(Tuple{typeof(ChainRulesCore.no_rrule), config_T, arg_Ts...})
end

is_ambig |= rrule_m === nothing # this means there was an ambiguity error on unconfigured rrule

# To understand why we only need to check if the sigs match between no_rrule_m and rrule_m
# in order to decide if to use, one must consider the following facts:
# - for every method in `no_rrule` there is a identical one in `rrule` that returns nothing
Expand All @@ -51,16 +56,16 @@ function has_chain_rrule(T)
# It can be seen that checking if it matches is the correct way to decide if we should use the rrule or not.


do_not_use_rrule = matching_cr_sig(no_rrule_m, rrule_m)
if do_not_use_rrule
if !is_ambig && matching_cr_sig(no_rrule_m, rrule_m) # Not ambigious, and opted-out.
# Return instance for configured_rrule_m as that will be invalidated
# directly if configured rule added, or indirectly if unconfigured rule added
# Do not need an edge for `no_rrule` as no addition of methods to that can cause this
# decision to need to be revisited (only changes to `rrule`), since we are already not
# using the rrule, so not using more rules wouldn't change anything.
return false, configured_rrule_m.instance
else
# Otherwise found a rrule, no need to add any edges for `rrule`, as it will generate
# Either is ambigious, and we should try to use it, and then error
# or we are uses a rrule, no need to add any edges for `rrule`, as it will generate
# code with natural edges if a new method is defined there.
# We also do not need an edge to `no_rrule`, as any time a method is added to `no_rrule`
# a corresponding method is added to `rrule` (to return `nothing`), thus we will already
Expand All @@ -73,7 +78,7 @@ matching_cr_sig(t, s) = matching_cr_sig(t.method.sig, s.method.sig)
matching_cr_sig(::DataType, ::UnionAll) = false
matching_cr_sig(::UnionAll, ::DataType) = false
matching_cr_sig(t::Type, s::Type) = type_tuple_tail(t) == type_tuple_tail(s)
matching_cr_sig(::Any, ::Nothing) = false # https://github.com/FluxML/Zygote.jl/issues/1234
matching_cr_sig(::Any, ::Nothing) = false # ambigious https://github.com/FluxML/Zygote.jl/issues/1234

type_tuple_tail(d::DataType) = Tuple{d.parameters[2:end]...}
function type_tuple_tail(d::UnionAll)
Expand Down
19 changes: 14 additions & 5 deletions test/chainrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -278,11 +278,20 @@ using Zygote: ZygoteRuleConfig

# https://github.com/FluxML/Zygote.jl/issues/1234
@testset "rrule lookup ambiguities" begin
f_ambig(x, y) = x + y
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(f_ambig, 1, 2)
@testset "unconfigured" begin
f_ambig(x, y) = x + y
ChainRulesCore.rrule(::typeof(f_ambig), x::Int, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::typeof(f_ambig), x, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(f_ambig, 1, 2)
end
@testset "configured" begin
h_ambig(x, y) = x + y
ChainRulesCore.rrule(::ZygoteRuleConfig, ::typeof(h_ambig), x, y) = x + y, _ -> (0, 0)
ChainRulesCore.rrule(::RuleConfig, ::typeof(h_ambig), x::Int, y::Int) = x + y, _ -> (0, 0)

@test_throws MethodError pullback(h_ambig, 1, 2)
end
end
end

Expand Down

2 comments on commit e2f0b5f

@oxinabox
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/75840

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.55 -m "<description of version>" e2f0b5ff4dffc4d7737fe289c79b7edb07784d3e
git push origin v0.6.55

Please sign in to comment.