Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrong gradient return for a function that is a double getindex of jagged arrays #1557

Closed
frankwswang opened this issue Feb 11, 2025 · 4 comments

Comments

@frankwswang
Copy link

MWE (tested on v0.7.4)

julia> using Zygote

julia> f1 = x->first(x)
#1 (generic function with 1 method)

julia> Zygote.gradient(f1, [1.2])
([1.0],)

julia> f2 = x->getindex(f1(x), 1)
#3 (generic function with 1 method)

julia> Zygote.gradient(f2, [[1.2]])
(ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}[[1.0]],)
@ToucheSir
Copy link
Member

Similar spirit and results:

julia> Zygote.gradient(first, [1.2])
([1.0],)

julia> g = Zygote.gradient(first  first, [[1.2]])[1]
1-element Vector{ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}:
 [1.0]

julia> dump(g)
Array{ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}}((1,))
  1: ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}
    val: Float64 1.0
    ind: Tuple{Int64}
      1: Int64 1
    axes: Tuple{Base.OneTo{Int64}}
      1: Base.OneTo{Int64}
        stop: Int64 1

julia> g == [[1.0]]
true

I'm curious why you think the gradient is wrong? OneElement was designed for the purpose of cheaply representing the gradient of a scalar index into an array, and it seems to be doing its job here. In fact, OneElement used to live in Zygote before it was moved to ChainRules.

@frankwswang
Copy link
Author

I assumed that the returned gradient should ideally be of type Vector{Vector{Float64}} (consistent with the type of the input argument) since OneElement is not interchangeable with its content for an arbitrary function call. For instance:

julia> f3 = (x, y) -> x^first(Zygote.gradient(f2, y));

julia> f3(1.1, [[1.2]])
ERROR: MethodError: no method matching ^(::Float64, ::Vector{ChainRules.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}})

should have not had any error if the returned value from Zygote.gradient(f2, y) had been ([[1.0]],).

@ToucheSir
Copy link
Member

Generally speaking, Zygote does not promise equivalence between input types and gradient types. In this case, there may be an argument for converting OneElements in the top-level gradient output to full arrays. But there would need to be evidence that the performance hit is worth it.

I'm not quite sure what your example is meant to show? f3(1.1, [[1.2]]) still doesn't work if first(Zygote.gradient(f2, y)) returns [[1.0]]:

julia> (1.)^[[1.]]
ERROR: MethodError: no method matching ^(::Float64, ::Vector{Vector{Float64}})
The function `^` exists, but no method is defined for this combination of argument types.

Closest candidates are:
  ^(::Float64, ::Integer)
   @ Base math.jl:1197
  ^(::T, ::Rational) where T<:AbstractFloat
   @ Base rational.jl:538
  ^(::T, ::Complex{S}) where {T<:Real, S<:Real}
   @ Base complex.jl:889
  ...

@frankwswang
Copy link
Author

I apologize for the bad example.

I was trying to simplify a piece of code from my project where the downstream function (not precisely (x, grad_val)->x^grad_val) has a strict requirement of the returned gradient grad_val to be a Vector{Vector{Float64}}. If Zygote.jl does not guarantee maintaining the container type of the input argument, then I'll think about other workarounds.

Thanks for the response.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants