-
Notifications
You must be signed in to change notification settings - Fork 233
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
RFC: Use non-blocking device side pointer mode in CUBLAS, with fallbacks #2616
base: master
Are you sure you want to change the base?
Conversation
I can also add some more |
Sample speedup: julia> using CUDA, CUDA.CUBLAS, LinearAlgebra;
julia> n = Int(2^26);
julia> X = CUDA.rand(Float64, n);
julia> res = CuRef{Float64}(0.0);
# do some precompilation runs first
julia> @time CUBLAS.nrm2(n, X, res);
0.000104 seconds (18 allocations: 288 bytes)
julia> @time CUBLAS.nrm2(n, X);
0.001564 seconds (73 allocations: 3.094 KiB)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
CUDA.jl Benchmarks
Benchmark suite | Current: ff453f2 | Previous: 24c236a | Ratio |
---|---|---|---|
latency/precompile |
46524357433.5 ns |
46329110276 ns |
1.00 |
latency/ttfp |
7011623206 ns |
6970088608 ns |
1.01 |
latency/import |
3658724318 ns |
3628117061 ns |
1.01 |
integration/volumerhs |
9624184 ns |
9622596.5 ns |
1.00 |
integration/byval/slices=1 |
146839 ns |
147068 ns |
1.00 |
integration/byval/slices=3 |
425274 ns |
425521 ns |
1.00 |
integration/byval/reference |
145055 ns |
144904 ns |
1.00 |
integration/byval/slices=2 |
286073 ns |
286164 ns |
1.00 |
integration/cudadevrt |
103512 ns |
103422 ns |
1.00 |
kernel/indexing |
14218 ns |
14042 ns |
1.01 |
kernel/indexing_checked |
14924 ns |
14520 ns |
1.03 |
kernel/occupancy |
642.6130952380952 ns |
633.7076023391813 ns |
1.01 |
kernel/launch |
2012.9 ns |
2101.9 ns |
0.96 |
kernel/rand |
14695 ns |
14379 ns |
1.02 |
array/reverse/1d |
19517 ns |
19719 ns |
0.99 |
array/reverse/2d |
25068 ns |
25121 ns |
1.00 |
array/reverse/1d_inplace |
10827 ns |
11160 ns |
0.97 |
array/reverse/2d_inplace |
12470 ns |
13070 ns |
0.95 |
array/copy |
21068 ns |
21024 ns |
1.00 |
array/iteration/findall/int |
156220 ns |
155786 ns |
1.00 |
array/iteration/findall/bool |
135407 ns |
134662.5 ns |
1.01 |
array/iteration/findfirst/int |
147578.5 ns |
147148 ns |
1.00 |
array/iteration/findfirst/bool |
153305 ns |
154167.5 ns |
0.99 |
array/iteration/scalar |
62796 ns |
61499 ns |
1.02 |
array/iteration/logical |
203505 ns |
203811.5 ns |
1.00 |
array/iteration/findmin/1d |
39812.5 ns |
39639 ns |
1.00 |
array/iteration/findmin/2d |
94104 ns |
94387 ns |
1.00 |
array/reductions/reduce/1d |
30590 ns |
30507 ns |
1.00 |
array/reductions/reduce/2d |
40842 ns |
51213.5 ns |
0.80 |
array/reductions/mapreduce/1d |
30022 ns |
30459 ns |
0.99 |
array/reductions/mapreduce/2d |
40787 ns |
51487 ns |
0.79 |
array/broadcast |
21196.5 ns |
20729 ns |
1.02 |
array/copyto!/gpu_to_gpu |
13753 ns |
11560 ns |
1.19 |
array/copyto!/cpu_to_gpu |
208123 ns |
208930 ns |
1.00 |
array/copyto!/gpu_to_cpu |
242297 ns |
241934 ns |
1.00 |
array/accumulate/1d |
108470 ns |
108332 ns |
1.00 |
array/accumulate/2d |
80095 ns |
80060 ns |
1.00 |
array/construct |
1263.65 ns |
1258.3 ns |
1.00 |
array/random/randn/Float32 |
44284 ns |
42993 ns |
1.03 |
array/random/randn!/Float32 |
26690 ns |
26226 ns |
1.02 |
array/random/rand!/Int64 |
27179 ns |
26878 ns |
1.01 |
array/random/rand!/Float32 |
8710.333333333334 ns |
8569 ns |
1.02 |
array/random/rand/Int64 |
29989 ns |
29957 ns |
1.00 |
array/random/rand/Float32 |
13191 ns |
12962 ns |
1.02 |
array/permutedims/4d |
60897 ns |
61080 ns |
1.00 |
array/permutedims/2d |
54885 ns |
55180 ns |
0.99 |
array/permutedims/3d |
55969 ns |
55780 ns |
1.00 |
array/sorting/1d |
2776673.5 ns |
2774685 ns |
1.00 |
array/sorting/by |
3368203.5 ns |
3367411.5 ns |
1.00 |
array/sorting/2d |
1085210 ns |
1085055 ns |
1.00 |
cuda/synchronization/stream/auto |
1084 ns |
1048.8 ns |
1.03 |
cuda/synchronization/stream/nonblocking |
6476.8 ns |
6455.4 ns |
1.00 |
cuda/synchronization/stream/blocking |
849.6619718309859 ns |
846.986301369863 ns |
1.00 |
cuda/synchronization/context/auto |
1153.2 ns |
1230.1 ns |
0.94 |
cuda/synchronization/context/nonblocking |
6683.5 ns |
6670.8 ns |
1.00 |
cuda/synchronization/context/blocking |
903.3125 ns |
928 ns |
0.97 |
This comment was automatically generated by workflow using github-action-benchmark.
Yeah, should one of us open an issue?
…On Sat, Jan 11, 2025 at 2:48 AM Tim Besard ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In lib/cublas/wrappers.jl
<#2616 (comment)>:
> function scal!(n::Integer, alpha::Number, x::StridedCuVecOrDenseMat{Float16})
- α = convert(Float32, alpha)
- cublasScalEx(handle(), n, Ref{Float32}(α), Float32, x, Float16, stride(x, 1), Float32)
+ α = CuRef{Float32}( convert(Float32, alpha) )
We should improve CuRef so that it can be constructed identically to Ref.
Ref{T}(x) doing an implicit convert is pretty convenient.
—
Reply to this email directly, view it on GitHub
<#2616 (review)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGKJY7VKNAPMMZTTKAF2YT2KDEFVAVCNFSM6AAAAABU7EYIIGVHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDKNBUGU4DSOBRGU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
Is the test failure something I've done? Seems GPUArrays related |
a0829fa
to
5d52d10
Compare
OK, I think this is ready for review! |
I am not qualified to review, but certainly interested in the outcome. Will the non-blocking methods only accept |
For now only CuRef but these are easy to create (it’s exported by CUDA.jl).
I think one can also create them without a copy from a regular CuArray?
…On Thu, Jan 16, 2025 at 3:41 PM Jutho ***@***.***> wrote:
I am not qualified to review, but certainly interested in the outcome.
Will the non-blocking methods only accept CuRef objects for the scalar
input or output quantities, or also zero-dimensional arrays (i.e.
CuArray{T,0})?
—
Reply to this email directly, view it on GitHub
<#2616 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGKJYYFBVIOILWK4G4PORD2LAKPLAVCNFSM6AAAAABU7EYIIGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDKOJWHA2DSMBXG4>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
You can create a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wonder if we should also improve CuRef
to initialize its memory by calling fill
instead of memcpy
: When calling memcpy
, the copy likely won't be truly asynchronous (that would require pinned memory). But if we call fill
, which should be possible for most scalars, the argument is passed by value and I think the call will complete asynchronously.
Something to investigate!
|
829083e
to
fd59678
Compare
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl
index 43ddfeaea..a5eb81c92 100644
--- a/lib/cublas/wrappers.jl
+++ b/lib/cublas/wrappers.jl
@@ -115,7 +115,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
(:cublasCscal_v2, :cublasCscal_v2_64, :ComplexF32))
@eval begin
function scal!(n::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
x::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -147,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
(:cublasZdscal_v2, :cublasZdscal_v2_64, :Float64, :ComplexF64))
@eval begin
function scal!(n::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
x::StridedCuVecOrDenseMat{$celty})
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -190,8 +190,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
@eval begin
function $jname(n::Integer,
x::StridedCuVecOrDenseMat{$elty},
- y::StridedCuVecOrDenseMat{$elty},
- result::Ref{$elty},
+ y::StridedCuVecOrDenseMat{$elty},
+ result::Ref{$elty},
)
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result)
@@ -339,7 +339,7 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
(:cublasCaxpy_v2, :cublasCaxpy_v2_64, :ComplexF32))
@eval begin
function axpy!(n::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
dx::StridedCuVecOrDenseMat{$elty},
dy::StridedCuVecOrDenseMat{$elty})
if CUBLAS.version() >= v"12.0"
@@ -400,9 +400,9 @@ for (fname, fname_64, elty, cty, sty) in (
function rot!(n::Integer,
x::StridedCuVecOrDenseMat{$elty},
y::StridedCuVecOrDenseMat{$elty},
- c::Ref{$cty},
- s::Ref{$sty},
- )
+ c::Ref{$cty},
+ s::Ref{$sty},
+ )
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), c, s)
else
@@ -473,9 +473,9 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
(:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32))
@eval begin
function iamax(n::Integer,
- dx::StridedCuVecOrDenseMat{$elty},
- result::Ref{Ti},
- ) where {Ti <: Integer}
+ dx::StridedCuVecOrDenseMat{$elty},
+ result::Ref{Ti},
+ ) where {Ti <: Integer}
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, dx, stride(dx, 1), result)
else
@@ -494,9 +494,9 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
(:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32))
@eval begin
function iamin(n::Integer,
- dx::StridedCuVecOrDenseMat{$elty},
- result::Ref{Ti},
- ) where {Ti <: Integer}
+ dx::StridedCuVecOrDenseMat{$elty},
+ result::Ref{Ti},
+ ) where {Ti <: Integer}
if CUBLAS.version() >= v"12.0"
$fname_64(handle(), n, dx, stride(dx, 1), result)
else
@@ -530,10 +530,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
(:cublasCgemv_v2, :cublasCgemv_v2_64, :ComplexF32))
@eval begin
function gemv!(trans::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
# handle trans
m,n = size(A)
@@ -562,7 +562,7 @@ end
function gemv(trans::Char, alpha::Ref{T}, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
return gemv!(trans, alpha, A, x, CuRef{T}(zero(T)), similar(x, size(A, (trans == 'N' ? 1 : 2))))
end
-function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2))))
end
# should this be async?
@@ -580,12 +580,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
)
@eval begin
function gemv_batched!(trans::Char,
- alpha::Ref{$eltyconst},
- A::Vector{<:StridedCuMatrix{$eltyin}},
- x::Vector{<:StridedCuVector{$eltyin}},
- beta::Ref{$eltyconst},
- y::Vector{<:StridedCuVector{$eltyout}}
- )
+ alpha::Ref{$eltyconst},
+ A::Vector{<:StridedCuMatrix{$eltyin}},
+ x::Vector{<:StridedCuVector{$eltyin}},
+ beta::Ref{$eltyconst},
+ y::Vector{<:StridedCuVector{$eltyout}}
+ )
if length(A) != length(x) || length(A) != length(y)
throw(DimensionMismatch("Lengths of inputs must be the same"))
end
@@ -616,13 +616,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
y
end
function gemv_batched!(
- trans::Char,
- alpha::Number,
- A::Vector{<:StridedCuMatrix{$eltyin}},
- x::Vector{<:StridedCuVector{$eltyin}},
- beta::Number,
- y::Vector{<:StridedCuVector{$eltyout}}
- )
+ trans::Char,
+ alpha::Number,
+ A::Vector{<:StridedCuMatrix{$eltyin}},
+ x::Vector{<:StridedCuVector{$eltyin}},
+ beta::Number,
+ y::Vector{<:StridedCuVector{$eltyout}}
+ )
gpu_α = CuRef{$eltyconst}(alpha)
gpu_β = CuRef{$eltyconst}(beta)
y = gemv_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -642,12 +642,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
)
@eval begin
function gemv_strided_batched!(trans::Char,
- alpha::Ref{$eltyconst},
- A::AbstractArray{$eltyin, 3},
- x::AbstractArray{$eltyin, 2},
- beta::Ref{$eltyconst},
- y::AbstractArray{$eltyout, 2}
- )
+ alpha::Ref{$eltyconst},
+ A::AbstractArray{$eltyin, 3},
+ x::AbstractArray{$eltyin, 2},
+ beta::Ref{$eltyconst},
+ y::AbstractArray{$eltyout, 2}
+ )
if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
end
@@ -672,13 +672,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
y
end
function gemv_strided_batched!(
- trans::Char,
- alpha::Number,
- A::AbstractArray{$eltyin, 3},
- x::AbstractArray{$eltyin, 2},
- beta::Number,
- y::AbstractArray{$eltyout, 2}
- )
+ trans::Char,
+ alpha::Number,
+ A::AbstractArray{$eltyin, 3},
+ x::AbstractArray{$eltyin, 2},
+ beta::Number,
+ y::AbstractArray{$eltyout, 2}
+ )
gpu_α = CuRef{$eltyconst}(alpha)
gpu_β = CuRef{$eltyconst}(beta)
y = gemv_strided_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -698,10 +698,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
m::Integer,
kl::Integer,
ku::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
n = size(A,2)
# check dimensions
@@ -717,16 +717,17 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
end
y
end
- function gbmv!(trans::Char,
- m::Integer,
- kl::Integer,
- ku::Integer,
- alpha::Number,
- A::StridedCuMatrix{$elty},
- x::StridedCuVector{$elty},
- beta::Number,
- y::StridedCuVector{$elty}
- )
+ function gbmv!(
+ trans::Char,
+ m::Integer,
+ kl::Integer,
+ ku::Integer,
+ alpha::Number,
+ A::StridedCuMatrix{$elty},
+ x::StridedCuVector{$elty},
+ beta::Number,
+ y::StridedCuVector{$elty}
+ )
gpu_α = CuRef{$elty}(alpha)
gpu_β = CuRef{$elty}(beta)
@@ -736,8 +737,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
end
end
end
-function gbmv(trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
- A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
+function gbmv(
+ trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
+ A::StridedCuMatrix{T}, x::StridedCuVector{T}
+ ) where {T}
# TODO: fix gbmv bug in julia
n = size(A, 2)
leny = trans == 'N' ? m : n
@@ -760,10 +763,10 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
(:cublasSspmv_v2, :cublasSspmv_v2_64, :Float32))
@eval begin
function spmv!(uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
AP::StridedCuVector{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
n = round(Int, (sqrt(8*length(AP))-1)/2)
if n != length(x) || n != length(y) throw(DimensionMismatch("")) end
@@ -778,21 +781,24 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
end
end
end
-function spmv!(uplo::Char,
- alpha::Number,
- AP::StridedCuVector{T},
- x::StridedCuVector{T},
- beta::Number,
- y::StridedCuVector{T}
- ) where {T}
+function spmv!(
+ uplo::Char,
+ alpha::Number,
+ AP::StridedCuVector{T},
+ x::StridedCuVector{T},
+ beta::Number,
+ y::StridedCuVector{T}
+ ) where {T}
gpu_α = CuRef{T}(alpha)
gpu_β = CuRef{T}(beta)
y = spmv!(uplo, gpu_α, AP, x, gpu_β, y)
synchronize()
return y
end
-function spmv(uplo::Char, alpha::Ref{T},
- AP::StridedCuVector{T}, x::StridedCuVector{T}) where {T}
+function spmv(
+ uplo::Char, alpha::Ref{T},
+ AP::StridedCuVector{T}, x::StridedCuVector{T}
+ ) where {T}
return spmv!(uplo, alpha, AP, x, CuRef{T}(zero(T)), similar(x))
end
function spmv(uplo::Char, alpha::Number,
@@ -811,10 +817,10 @@ for (fname, fname_64, elty) in ((:cublasDsymv_v2, :cublasDsymv_v2_64, :Float64),
# Note that the complex symv are not BLAS but auiliary functions in LAPACK
@eval begin
function symv!(uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
m, n = size(A)
if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -865,10 +871,10 @@ for (fname, fname_64, elty) in ((:cublasZhemv_v2, :cublasZhemv_v2_64, :ComplexF6
(:cublasChemv_v2, :cublasChemv_v2_64, :ComplexF32))
@eval begin
function hemv!(uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
# TODO: fix dimension check bug in julia
m, n = size(A)
@@ -923,10 +929,10 @@ for (fname, fname_64, elty) in ((:cublasDsbmv_v2, :cublasDsbmv_v2_64, :Float64),
@eval begin
function sbmv!(uplo::Char,
k::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
m, n = size(A)
#if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -982,10 +988,10 @@ for (fname, fname_64, elty) in ((:cublasZhbmv_v2, :cublasZhbmv_v2_64, :ComplexF6
@eval begin
function hbmv!(uplo::Char,
k::Integer,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
y::StridedCuVector{$elty})
m, n = size(A)
if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
@@ -1169,7 +1175,7 @@ for (fname, fname_64, elty) in ((:cublasDger_v2, :cublasDger_v2_64, :Float64),
(:cublasCgerc_v2, :cublasCgerc_v2_64, :ComplexF32))
@eval begin
function ger!(
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
x::StridedCuVector{$elty},
y::StridedCuVector{$elty},
A::StridedCuMatrix{$elty})
@@ -1205,7 +1211,7 @@ for (fname, fname_64, elty) in ((:cublasDspr_v2, :cublasDspr_v2_64, :Float64),
(:cublasSspr_v2, :cublasSspr_v2_64, :Float32))
@eval begin
function spr!(uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
x::StridedCuVector{$elty},
AP::StridedCuVector{$elty})
n = round(Int, (sqrt(8*length(AP))-1)/2)
@@ -1239,7 +1245,7 @@ for (fname, fname_64, elty) in ((:cublasDsyr_v2, :cublasDsyr_v2_64, :Float64),
(:cublasCsyr_v2, :cublasCsyr_v2_64, :ComplexF32))
@eval begin
function syr!(uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
x::StridedCuVector{$elty},
A::StridedCuMatrix{$elty})
m, n = size(A)
@@ -1275,7 +1281,7 @@ for (fname, fname_64, elty, relty) in (
)
@eval begin
function her!(uplo::Char,
- alpha::Ref{$relty},
+ alpha::Ref{$relty},
x::StridedCuVector{$elty},
A::StridedCuMatrix{$elty})
m, n = size(A)
@@ -1309,11 +1315,11 @@ for (fname, fname_64, elty) in ((:cublasZher2_v2, :cublasZher2_v2_64, :ComplexF6
(:cublasCher2_v2, :cublasCher2_v2_64, :ComplexF32))
@eval begin
function her2!(uplo::Char,
- alpha::Ref{$elty},
- x::StridedCuVector{$elty},
- y::StridedCuVector{$elty},
- A::StridedCuMatrix{$elty}
- )
+ alpha::Ref{$elty},
+ x::StridedCuVector{$elty},
+ y::StridedCuVector{$elty},
+ A::StridedCuMatrix{$elty}
+ )
m, n = size(A)
m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -1353,10 +1359,10 @@ for (fname, fname_64, elty) in ((:cublasDgemm_v2, :cublasDgemm_v2_64, :Float64),
@eval begin
function gemm!(transA::Char,
transB::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuVecOrMat{$elty},
B::StridedCuVecOrMat{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuVecOrMat{$elty})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
@@ -1494,10 +1500,10 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
end
function gemmEx!(transA::Char, transB::Char,
- @nospecialize(alpha::Ref),
+ @nospecialize(alpha::Ref),
@nospecialize(A::StridedCuVecOrMat),
@nospecialize(B::StridedCuVecOrMat),
- @nospecialize(beta::Ref),
+ @nospecialize(beta::Ref),
@nospecialize(C::StridedCuVecOrMat);
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
m = size(A, transA == 'N' ? 1 : 2)
@@ -1552,10 +1558,10 @@ end
# TODO for device mode pointers
function gemmBatchedEx!(transA::Char, transB::Char,
- @nospecialize(alpha::Ref),
+ @nospecialize(alpha::Ref),
@nospecialize(A::Vector{<:StridedCuVecOrMat}),
@nospecialize(B::Vector{<:StridedCuVecOrMat}),
- @nospecialize(beta::Ref),
+ @nospecialize(beta::Ref),
@nospecialize(C::Vector{<:StridedCuVecOrMat});
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
if length(A) != length(B) || length(A) != length(C)
@@ -1623,11 +1629,11 @@ function gemmBatchedEx!(
end
function gemmStridedBatchedEx!(
- transA::Char, transB::Char,
- @nospecialize(alpha::Ref),
+ transA::Char, transB::Char,
+ @nospecialize(alpha::Ref),
@nospecialize(A::AbstractArray{Ta, 3}),
@nospecialize(B::AbstractArray{Tb, 3}),
- @nospecialize(beta::Ref),
+ @nospecialize(beta::Ref),
@nospecialize(C::AbstractArray{Tc, 3});
algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
@@ -1866,10 +1872,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :F
@eval begin
function gemm_batched!(transA::Char,
transB::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::Vector{<:StridedCuMatrix{$elty}},
B::Vector{<:StridedCuMatrix{$elty}},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::Vector{<:StridedCuMatrix{$elty}})
if length(A) != length(B) || length(A) != length(C)
throw(DimensionMismatch(""))
@@ -1949,10 +1955,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmStridedBatched, :cublasDgemmStrided
@eval begin
function gemm_strided_batched!(transA::Char,
transB::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::AbstractArray{$elty, 3}, # allow PermutedDimsArray
B::AbstractArray{$elty, 3},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::AbstractArray{$elty, 3})
m = size(A, transA == 'N' ? 1 : 2)
k = size(A, transA == 'N' ? 2 : 1)
@@ -2032,10 +2038,10 @@ for (fname, fname_64, elty) in ((:cublasDsymm_v2, :cublasDsymm_v2_64, :Float64),
@eval begin
function symm!(side::Char,
uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
B::StridedCuMatrix{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuMatrix{$elty})
k, nA = size(A)
if k != nA throw(DimensionMismatch("Matrix A must be square")) end
@@ -2094,9 +2100,9 @@ for (fname, fname_64, elty) in ((:cublasDsyrk_v2, :cublasDsyrk_v2_64, :Float64),
@eval begin
function syrk!(uplo::Char,
trans::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuVecOrMat{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuMatrix{$elty})
mC, n = size(C)
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2147,10 +2153,10 @@ for (fname, fname_64, elty) in ((:cublasDsyrkx, :cublasDsyrkx_64, :Float64),
@eval begin
function syrkx!(uplo::Char,
trans::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuVecOrMat{$elty},
B::StridedCuVecOrMat{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuMatrix{$elty})
mC, n = size(C)
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2206,10 +2212,10 @@ for (fname, fname_64, elty) in ((:cublasZhemm_v2, :cublasZhemm_v2_64, :ComplexF6
@eval begin
function hemm!(side::Char,
uplo::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
B::StridedCuMatrix{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuMatrix{$elty})
mA, nA = size(A)
m, n = size(B)
@@ -2269,9 +2275,9 @@ for (fname, fname_64, elty, relty) in (
@eval begin
function herk!(uplo::Char,
trans::Char,
- alpha::Ref{$relty},
+ alpha::Ref{$relty},
A::StridedCuVecOrMat{$elty},
- beta::Ref{$relty},
+ beta::Ref{$relty},
C::StridedCuMatrix{$elty})
mC, n = size(C)
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2328,10 +2334,10 @@ for (fname, fname_64, elty) in ((:cublasDsyr2k_v2, :cublasDsyr2k_v2_64, :Float64
@eval begin
function syr2k!(uplo::Char,
trans::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuVecOrMat{$elty},
B::StridedCuVecOrMat{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
C::StridedCuMatrix{$elty})
# TODO: check size of B in julia (syr2k!)
m, n = size(C)
@@ -2387,7 +2393,7 @@ function syr2k(uplo::Char,
B::StridedCuVecOrMat)
T = eltype(A)
n = size(A, trans == 'N' ? 1 : 2)
- syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
+ return syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
end
function syr2k(uplo::Char, trans::Char, A::StridedCuVecOrMat, B::StridedCuVecOrMat)
syr2k(uplo, trans, one(eltype(A)), A, B)
@@ -2401,10 +2407,10 @@ for (fname, fname_64, elty, relty) in (
@eval begin
function her2k!(uplo::Char,
trans::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuVecOrMat{$elty},
B::StridedCuVecOrMat{$elty},
- beta::Ref{$relty},
+ beta::Ref{$relty},
C::StridedCuMatrix{$elty})
# TODO: check size of B in julia (her2k!)
m, n = size(C)
@@ -2478,7 +2484,7 @@ for (mmname, smname, elty) in
uplo::Char,
transa::Char,
diag::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
B::StridedCuMatrix{$elty},
C::StridedCuMatrix{$elty})
@@ -2500,7 +2506,7 @@ for (mmname, smname, elty) in
uplo::Char,
transa::Char,
diag::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
B::StridedCuMatrix{$elty})
m, n = size(B)
@@ -2565,7 +2571,7 @@ for (fname, fname_64, elty) in ((:cublasDtrsmBatched, :cublasDtrsmBatched_64, :F
uplo::Char,
transa::Char,
diag::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::Vector{<:StridedCuMatrix{$elty}},
B::Vector{<:StridedCuMatrix{$elty}})
if length(A) != length(B)
@@ -2621,9 +2627,9 @@ for (fname, fname_64, elty) in ((:cublasDgeam, :cublasDgeam_64, :Float64),
@eval begin
function geam!(transa::Char,
transb::Char,
- alpha::Ref{$elty},
+ alpha::Ref{$elty},
A::StridedCuMatrix{$elty},
- beta::Ref{$elty},
+ beta::Ref{$elty},
B::StridedCuMatrix{$elty},
C::StridedCuMatrix{$elty})
mA, nA = size(A)
@@ -2861,8 +2867,9 @@ for (fname, elty) in ((:cublasDgetriBatched, :Float64),
end
function getri_batched!(n, Aptrs::CuVector{CuPtr{$elty}},
- lda, Cptrs::CuVector{CuPtr{$elty}},ldc,
- pivotArray::CuArray{Cint})
+ lda, Cptrs::CuVector{CuPtr{$elty}}, ldc,
+ pivotArray::CuArray{Cint}
+ )
batchSize = length(Aptrs)
info = CuArray{Cint}(undef, batchSize)
$fname(handle(), n, Aptrs, lda, pivotArray, Cptrs, ldc, info, batchSize)
diff --git a/test/libraries/cublas/level3.jl b/test/libraries/cublas/level3.jl
index 65b71c6a4..caf245a5b 100644
--- a/test/libraries/cublas/level3.jl
+++ b/test/libraries/cublas/level3.jl
@@ -352,12 +352,12 @@ k = 13
@testset "herk!" begin
alpha = rand(elty)
beta = rand(elty)
- A = rand(elty,m,m)
+ A = rand(elty, m, m)
hA = A + A'
d_A = CuArray(A)
d_C = CuArray(hA)
- CUBLAS.herk!('U','N',real(alpha),d_A,real(beta),d_C)
- C = real(alpha)*(A*A') + real(beta)*hA
+ CUBLAS.herk!('U', 'N', real(alpha), d_A, real(beta), d_C)
+ C = real(alpha) * (A * A') + real(beta) * hA
C = triu(C)
# move to host and compare
h_C = Array(d_C)
@@ -365,10 +365,10 @@ k = 13
@test C ≈ h_C
end
@testset "herk" begin
- A = rand(elty,m,m)
+ A = rand(elty, m, m)
d_A = CuArray(A)
- d_C = CUBLAS.herk('U','N',d_A)
- C = A*A'
+ d_C = CUBLAS.herk('U', 'N', d_A)
+ C = A * A'
C = triu(C)
# move to host and compare
h_C = Array(d_C)
diff --git a/test/libraries/cublas/level3_gemm.jl b/test/libraries/cublas/level3_gemm.jl
index 6e04e8c42..6cccc2967 100644
--- a/test/libraries/cublas/level3_gemm.jl
+++ b/test/libraries/cublas/level3_gemm.jl
@@ -220,12 +220,12 @@ k = 13
sB = sB + transpose(sB)
for (TRa, ta, TRb, tb, TRc, a_func, b_func) in (
- (UpperTriangular, identity, LowerTriangular, identity, Matrix, triu, tril),
- (LowerTriangular, identity, UpperTriangular, identity, Matrix, tril, triu),
- (UpperTriangular, identity, UpperTriangular, transpose, Matrix, triu, triu),
- (UpperTriangular, transpose, UpperTriangular, identity, Matrix, triu, triu),
- (LowerTriangular, identity, LowerTriangular, transpose, Matrix, tril, tril),
- (LowerTriangular, transpose, LowerTriangular, identity, Matrix, tril, tril),
+ (UpperTriangular, identity, LowerTriangular, identity, Matrix, triu, tril),
+ (LowerTriangular, identity, UpperTriangular, identity, Matrix, tril, triu),
+ (UpperTriangular, identity, UpperTriangular, transpose, Matrix, triu, triu),
+ (UpperTriangular, transpose, UpperTriangular, identity, Matrix, triu, triu),
+ (LowerTriangular, identity, LowerTriangular, transpose, Matrix, tril, tril),
+ (LowerTriangular, transpose, LowerTriangular, identity, Matrix, tril, tril),
)
A = copy(sA) |> TRa |
CI failures seem relevant. Feel free to ignore the formatter; I made it less spammy 😉 |
fd59678
to
a2dedad
Compare
I really do not know what is up with the 1.11 failure, it looks |
Rebase to get rid of CI failures? |
Yep, next on my to do list
…On Sat, Jan 25, 2025 at 2:43 AM Tim Besard ***@***.***> wrote:
Rebase to get rid of CI failures?
—
Reply to this email directly, view it on GitHub
<#2616 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAGKJY6QWHTIRCVHYN4CSE32MM6DTAVCNFSM6AAAAABU7EYIIGVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDMMJTHAZDQNBZGU>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
804a967
to
bcd41c1
Compare
Gotta admit I'm a bit mystified here as I cannot reproduce these If I run only the |
bcd41c1
to
abf0200
Compare
abf0200
to
1f7fb89
Compare
We could simplify the implementation here by switching the CUBLAS signatures from --- a/lib/cublas/wrappers.jl
+++ b/lib/cublas/wrappers.jl
@@ -530,10 +530,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
(:cublasCgemv_v2, :cublasCgemv_v2_64, :ComplexF32))
@eval begin
function gemv!(trans::Char,
- alpha::Ref{$elty},
+ alpha,
A::StridedCuMatrix{$elty},
x::StridedCuVector{$elty},
- beta::Ref{$elty},
+ beta,
y::StridedCuVector{$elty})
# handle trans
m,n = size(A)
@@ -552,23 +552,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
end
end
end
-function gemv!(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}, beta::Number, y::StridedCuVector{T}) where {T}
- gpu_α = CuRef{T}(alpha)
- gpu_β = CuRef{T}(beta)
- y = gemv!(trans, gpu_α, A, x, gpu_β, y)
- synchronize()
- return y
-end
-function gemv(trans::Char, alpha::Ref{T}, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
- return gemv!(trans, alpha, A, x, CuRef{T}(zero(T)), similar(x, size(A, (trans == 'N' ? 1 : 2))))
-end
-function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+gemv(trans::Char, alpha, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T} =
gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2))))
-end
-# should this be async?
-function gemv(trans::Char, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+gemv(trans::Char, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T} =
gemv!(trans, one(T), A, x, zero(T), similar(x, T, size(A, (trans == 'N' ? 1 : 2))))
-end
for (fname, fname_64, eltyin, eltyout, eltyconst) in (
(:cublasDgemvBatched, :cublasDgemvBatched_64, :Float64, :Float64, :Float64), etc for the others. Thoughts? |
Yeah, that makes sense. I was biased towards keeping the changes more "conservative" and perhaps changing |
In the latest commit, I applied @maleadt 's suggestions widely. This really shortens the diff and tests passed locally. @Jutho , this should also allow your use-case of passing |
Attempting to address #2571
I've set the pointer mode to "device side" during handle creation. Since
gemmGroupedBatched
doesn't support device side pointer mode, it won't be usable. One workaround for this would be to add a new function to create a handle with host side mode, or add the pointer mode as an optional kwarg tohandle()
. Very open to feedback on this.I've set this up so that users can supply
CuRef
s of the appropriate result type to the level 1 functions for results. If that's not provided, the functions execute as they do today (synchronously). Similarly, for functions takingalpha
orbeta
scalar arguments, if the user providesCuRef
(actually aCuRefArray
), the functions will execute asynchronously and return instantly. If the user provides aNumber
, the behaviour is unchanged from today. I'm not married to this design and it can certainly be changed.cc @Jutho