-
-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add missing mul! implementation #47
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Overloading mul!
seems fine.
Re deleting *
, note that @less *(rand(2,2), rand(2,2))
shows it calls similar(B)
, and that will give the wrong type here... although that could be fixed.
julia> using JLArrays
julia> onehotbatch([1,3,2], 0:5) |> jl
6×3 OneHotMatrix(::JLArray{UInt32, 1}) with eltype Bool:
⋅ ⋅ ⋅
1 ⋅ ⋅
⋅ ⋅ 1
⋅ 1 ⋅
⋅ ⋅ ⋅
⋅ ⋅ ⋅
julia> similar(ans, Float32)
6×3 Matrix{Float32}:
3.0f-44 3.66694f29 3.66699f29
0.0 1.0f-45 1.0f-45
2.7f-44 3.66696f29 3.66701f29
0.0 1.0f-45 1.0f-45
2.8f-44 3.66697f29 3.0f-45
0.0 1.0f-45 0.0
test/gpu.jl
Outdated
end | ||
|
||
# some specialized implementations call only mul! and not *, so we must ensure this works | ||
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) isa CuArray |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This tests only one case. I think that it should test vector & matrix output, each with a OneHotArray, and some reshaped array... and maybe something for which _isonehot(B) === false
to test the invoke
path?
And it should check that the results are correct, not just the type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many of the tests in the GPU suite only check type, but I now have this comparing to *
for the time being.
The branching here opens up a huge can of worms which it is a non-goal of mine to fix, see #48.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure. But new code really has to test that it produces the correct value. (And enough cases that obvious method ambiguities would be found.)
test/gpu.jl
Outdated
@@ -26,8 +26,11 @@ end | |||
if VERSION >= v"1.9" && CUDA.functional() | |||
@test gradient(A -> sum(A * y), gA)[1] isa CuArray | |||
else | |||
@test_broken gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? | |||
@test gradient(A -> sum(A * y), gA)[1] isa CuArray # fails with JLArray, bug in Zygote? |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you explain what's going on here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No idea. I ran it, and CI/CD failed because that test passed. I assume whatever bug there was in zygote got fixed.
The overall situation with matrix operations in this package is likely still incredibly wonky, ideally somebody would carefully think through exactly how these should be done and replace all of |
Also, am I correct in assuming that 1.6 is no longer relevant as it is no longer LTS? If so I can remove the 1.6 tests in favor of 1.10. |
Alright, I've swapped out 1.6 for 1.10. That's as much as I was planning to fix here, it should no longer immediately fail for someone trying to use it with Lux. |
Bumping
|
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## main #47 +/- ##
==========================================
- Coverage 96.26% 95.07% -1.20%
==========================================
Files 3 3
Lines 134 142 +8
==========================================
+ Hits 129 135 +6
- Misses 5 7 +2 ☔ View full report in Codecov by Sentry. |
Prior to this PR, this package does not define a method for
LinearAlgebra.mul!
. This low-level method is used by some procedures instead of*
. Without the new method included in this PR, such low-level calls could break in some cases. In particular, some of the optimizedmatmul
calls used by Lux.jl would error out with a method ambiguity or scalar indexing of GPU array error when used with GPU arrays.Ultimately the
*
methods should probably be removed, but I have left them alone for now to get this in without worrying too much about it breaking anything.GPU tests pass for me on CUDA.
PR Checklist