-
-
Notifications
You must be signed in to change notification settings - Fork 213
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Wrong gradient return for a function that is a double getindex
of jagged arrays
#1557
Comments
Similar spirit and results: julia> Zygote.gradient(first, [1.2])
([1.0],)
julia> g = Zygote.gradient(first ∘ first, [[1.2]])[1]
1-element Vector{ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}:
[1.0]
julia> dump(g)
Array{ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}((1,))
1: ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}
val: Float64 1.0
ind: Tuple{Int64}
1: Int64 1
axes: Tuple{Base.OneTo{Int64}}
1: Base.OneTo{Int64}
stop: Int64 1
julia> g == [[1.0]]
true I'm curious why you think the gradient is wrong? |
I assumed that the returned gradient should ideally be of type
should have not had any error if the returned value from |
Generally speaking, Zygote does not promise equivalence between input types and gradient types. In this case, there may be an argument for converting I'm not quite sure what your example is meant to show? julia> (1.)^[[1.]]
ERROR: MethodError: no method matching ^(::Float64, ::Vector{Vector{Float64}})
The function `^` exists, but no method is defined for this combination of argument types.
Closest candidates are:
^(::Float64, ::Integer)
@ Base math.jl:1197
^(::T, ::Rational) where T<:AbstractFloat
@ Base rational.jl:538
^(::T, ::Complex{S}) where {T<:Real, S<:Real}
@ Base complex.jl:889
... |
I apologize for the bad example. I was trying to simplify a piece of code from my project where the downstream function (not precisely Thanks for the response. |
MWE (tested on v0.7.4)
The text was updated successfully, but these errors were encountered: