Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GPU array scalar indexing edge cases #48

Open
ExpandingMan opened this issue Feb 4, 2025 · 6 comments
Open

GPU array scalar indexing edge cases #48

ExpandingMan opened this issue Feb 4, 2025 · 6 comments

Comments

@ExpandingMan
Copy link

This issue came up during #47.

There are some rather awkward cases where it is not clear how to elide scalar indexing into GPU arrays. For example

  y = onehotbatch(ones(3), 1:2) |> cu
  y = reshape(y, 3, 2)
  gA = rand(2, 3) |> cu
  
  @test LinearAlgebra.mul!(similar(gA, 2, 2), gA, y) ≈ gA*y

Both sides of this test are currently broken. The failure currently in main is a method ambiguity, however it's not entirely clear how to fix this as there don't seem to be easy answers about what to do in this case. I doubt it is possible to cover every conceivable such edge case, at some point users should have to materialize the array.

I think probably what needs to be done here is to add documentation and possibly convenience methods describing circumstances in which arrays should be materialized. Even though there's no clear answer to this I'm opening this issue because as of now the handling of this is extremely wonky, and any users not intimately familiar with Julia array packages will be justifiably confused.

@mcabbott
Copy link
Member

mcabbott commented Feb 4, 2025

Please include the error messages for both failures.

Do they fail on CPU too?

@ExpandingMan
Copy link
Author

This is only a problem on GPU because of scalar indexing.

On current main it should give

ERROR: MethodError: getindex(::OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.DeviceMemory}}, ::Int64, ::Int64) is ambiguous.

Candidates:
  getindex(x::OneHotArray{var"#s3", N, var"N+1", I} where {var"#s3", var"N+1", I<:Union{AbstractArray{var"#s3", N}, var"#s3"}}, i::Int64, I::Vararg{Int64, N}) where N
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/array.jl:65
  getindex(x::OneHotArray{<:Any, N, <:Any, <:GPUArraysCore.AbstractGPUArray}, i::Int64, I::Vararg{Any, N}) where N
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/array.jl:71

Possible fix, define
  getindex(::OneHotArray{var"#s3", N, <:Any, <:Union{…}} where var"#s3", ::Int64, ::Vararg{Int64, N}) where N

Stacktrace:
  [1] _unsafe_getindex_rs
    @ ./reshapedarray.jl:276 [inlined]
  [2] _unsafe_getindex
    @ ./reshapedarray.jl:273 [inlined]
  [3] getindex
    @ ./reshapedarray.jl:261 [inlined]
  [4] _generic_matmatmul!(C::Matrix{…}, A::CuArray{…}, B::Base.ReshapedArray{…}, _add::LinearAlgebra.MulAddMul{…})
    @ LinearAlgebra ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:894
  [5] generic_matmatmul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:868 [inlined]
  [6] _mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:287 [inlined]
  [7] mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:285 [inlined]
  [8] mul!
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:253 [inlined]
  [9] mul!(Y::Matrix{Float32}, A::CuArray{Float32, 2, CUDA.DeviceMemory}, B::Base.ReshapedArray{Bool, 2, OneHotMatrix{…}, Tuple{…}})
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:38
 [10] *
    @ ~/.julia/juliaup/julia-1.11.3+0.x64.linux.gnu/share/julia/stdlib/v1.11/LinearAlgebra/src/matmul.jl:114 [inlined]
 [11] *(A::CuArray{Float32, 2, CUDA.DeviceMemory}, B::Base.ReshapedArray{Bool, 2, OneHotMatrix{…}, Tuple{…}})
    @ OneHotArrays ~/.julia/dev/OneHotArrays/src/linalg.jl:8

Resolving this method ambiguity is easy of course, but it's not clear what the alternative would be. In my opinion it wasn't really a good idea for this package to start down the road of supporting ReshapeArray.

@ToucheSir
Copy link
Member

FWIW, this is the historical context that led to using the wrapper: FluxML/Flux.jl#1459 (comment).

One continuing tension with this package is that the code paths needed for a memory-efficient type can be in conflict for the ones needed for GPU support. Generally the idea was to have people use this outside of AD/GPU (i.e. during preprocessing and data loading) and materialize before using AD/GPU. Which is not to say the status quo is ideal, but it sheds some light on why GPU support hasn't been the No. 1 priority historically.

@mcabbott
Copy link
Member

mcabbott commented Feb 5, 2025

I think the getindex error is a dup of #28 . Close if you agree?

@ExpandingMan
Copy link
Author

I think there is a bigger issue here of what exactly do you do in the edge cases that this package allows where you would seem to have to scalar index a GPU array. However there is redundancy between this issue and #28, feel free to close if you don't think it's doing anything.

@mcabbott
Copy link
Member

mcabbott commented Feb 6, 2025

Isn't this just how arrays work?

E.g. one role of Transpose is to dispatch to the BLAS with a 'T', easy. The second role is to plug into generic code which doesn't know about it at all, so that reverse(transpose(M); dims=1) just works. This second role is the only reason it has supertype AbstractMatrix.

When it wraps a GPU array, the first is no harder, but almost none of the AbstractMatrix fallbacks will work. You get scalar indexing and then you either fix it (supply a routine) or avoid it. Having no supertype would perhaps give more obvious error message, like "MethodError: no method matching reverse(A::Transpose{Float64, JLArray{Float64, 2}})".

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants