Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RFC: Use non-blocking device side pointer mode in CUBLAS, with fallbacks #2616

Open
wants to merge 15 commits into
base: master
Choose a base branch
from

Conversation

kshyatt
Copy link
Contributor

@kshyatt kshyatt commented Jan 10, 2025

Attempting to address #2571

I've set the pointer mode to "device side" during handle creation. Since gemmGroupedBatched doesn't support device side pointer mode, it won't be usable. One workaround for this would be to add a new function to create a handle with host side mode, or add the pointer mode as an optional kwarg to handle(). Very open to feedback on this.

I've set this up so that users can supply CuRefs of the appropriate result type to the level 1 functions for results. If that's not provided, the functions execute as they do today (synchronously). Similarly, for functions taking alpha or beta scalar arguments, if the user provides CuRef (actually a CuRefArray), the functions will execute asynchronously and return instantly. If the user provides a Number, the behaviour is unchanged from today. I'm not married to this design and it can certainly be changed.

cc @Jutho

@kshyatt kshyatt requested a review from maleadt January 10, 2025 21:03
@kshyatt kshyatt added the cuda libraries Stuff about CUDA library wrappers. label Jan 10, 2025
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 10, 2025

I can also add some more @eval blocks to try to cut down on the repetitive fallback logic

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 10, 2025

Sample speedup:

julia> using CUDA, CUDA.CUBLAS, LinearAlgebra;

julia> n = Int(2^26);

julia> X = CUDA.rand(Float64, n);

julia> res = CuRef{Float64}(0.0);

# do some precompilation runs first

julia> @time CUBLAS.nrm2(n, X, res);
  0.000104 seconds (18 allocations: 288 bytes)

julia> @time CUBLAS.nrm2(n, X);
  0.001564 seconds (73 allocations: 3.094 KiB)

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

CUDA.jl Benchmarks

Benchmark suite Current: ff453f2 Previous: 24c236a Ratio
latency/precompile 46524357433.5 ns 46329110276 ns 1.00
latency/ttfp 7011623206 ns 6970088608 ns 1.01
latency/import 3658724318 ns 3628117061 ns 1.01
integration/volumerhs 9624184 ns 9622596.5 ns 1.00
integration/byval/slices=1 146839 ns 147068 ns 1.00
integration/byval/slices=3 425274 ns 425521 ns 1.00
integration/byval/reference 145055 ns 144904 ns 1.00
integration/byval/slices=2 286073 ns 286164 ns 1.00
integration/cudadevrt 103512 ns 103422 ns 1.00
kernel/indexing 14218 ns 14042 ns 1.01
kernel/indexing_checked 14924 ns 14520 ns 1.03
kernel/occupancy 642.6130952380952 ns 633.7076023391813 ns 1.01
kernel/launch 2012.9 ns 2101.9 ns 0.96
kernel/rand 14695 ns 14379 ns 1.02
array/reverse/1d 19517 ns 19719 ns 0.99
array/reverse/2d 25068 ns 25121 ns 1.00
array/reverse/1d_inplace 10827 ns 11160 ns 0.97
array/reverse/2d_inplace 12470 ns 13070 ns 0.95
array/copy 21068 ns 21024 ns 1.00
array/iteration/findall/int 156220 ns 155786 ns 1.00
array/iteration/findall/bool 135407 ns 134662.5 ns 1.01
array/iteration/findfirst/int 147578.5 ns 147148 ns 1.00
array/iteration/findfirst/bool 153305 ns 154167.5 ns 0.99
array/iteration/scalar 62796 ns 61499 ns 1.02
array/iteration/logical 203505 ns 203811.5 ns 1.00
array/iteration/findmin/1d 39812.5 ns 39639 ns 1.00
array/iteration/findmin/2d 94104 ns 94387 ns 1.00
array/reductions/reduce/1d 30590 ns 30507 ns 1.00
array/reductions/reduce/2d 40842 ns 51213.5 ns 0.80
array/reductions/mapreduce/1d 30022 ns 30459 ns 0.99
array/reductions/mapreduce/2d 40787 ns 51487 ns 0.79
array/broadcast 21196.5 ns 20729 ns 1.02
array/copyto!/gpu_to_gpu 13753 ns 11560 ns 1.19
array/copyto!/cpu_to_gpu 208123 ns 208930 ns 1.00
array/copyto!/gpu_to_cpu 242297 ns 241934 ns 1.00
array/accumulate/1d 108470 ns 108332 ns 1.00
array/accumulate/2d 80095 ns 80060 ns 1.00
array/construct 1263.65 ns 1258.3 ns 1.00
array/random/randn/Float32 44284 ns 42993 ns 1.03
array/random/randn!/Float32 26690 ns 26226 ns 1.02
array/random/rand!/Int64 27179 ns 26878 ns 1.01
array/random/rand!/Float32 8710.333333333334 ns 8569 ns 1.02
array/random/rand/Int64 29989 ns 29957 ns 1.00
array/random/rand/Float32 13191 ns 12962 ns 1.02
array/permutedims/4d 60897 ns 61080 ns 1.00
array/permutedims/2d 54885 ns 55180 ns 0.99
array/permutedims/3d 55969 ns 55780 ns 1.00
array/sorting/1d 2776673.5 ns 2774685 ns 1.00
array/sorting/by 3368203.5 ns 3367411.5 ns 1.00
array/sorting/2d 1085210 ns 1085055 ns 1.00
cuda/synchronization/stream/auto 1084 ns 1048.8 ns 1.03
cuda/synchronization/stream/nonblocking 6476.8 ns 6455.4 ns 1.00
cuda/synchronization/stream/blocking 849.6619718309859 ns 846.986301369863 ns 1.00
cuda/synchronization/context/auto 1153.2 ns 1230.1 ns 0.94
cuda/synchronization/context/nonblocking 6683.5 ns 6670.8 ns 1.00
cuda/synchronization/context/blocking 903.3125 ns 928 ns 0.97

This comment was automatically generated by workflow using github-action-benchmark.

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 11, 2025 via email

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 11, 2025

Is the test failure something I've done? Seems GPUArrays related

@kshyatt kshyatt force-pushed the ksh/device_side branch 2 times, most recently from a0829fa to 5d52d10 Compare January 16, 2025 16:05
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025

OK, I think this is ready for review!

@Jutho
Copy link
Contributor

Jutho commented Jan 16, 2025

I am not qualified to review, but certainly interested in the outcome. Will the non-blocking methods only accept CuRef objects for the scalar input or output quantities, or also zero-dimensional arrays (i.e. CuArray{T,0})?

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025 via email

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 16, 2025

You can create a CuRefArray{T} where T is some element type from a single element CuVector. In fact, CuRef itself does this under the hood.

Copy link
Member

@maleadt maleadt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if we should also improve CuRef to initialize its memory by calling fill instead of memcpy: When calling memcpy, the copy likely won't be truly asynchronous (that would require pinned memory). But if we call fill, which should be possible for most scalars, the argument is passed by value and I think the call will complete asynchronously.
Something to investigate!

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@maleadt
Copy link
Member

maleadt commented Jan 17, 2025

Something to investigate!

#2625

github-actions[bot]

This comment was marked as off-topic.

Copy link
Contributor

github-actions bot commented Jan 20, 2025

Your PR requires formatting changes to meet the project's style guidelines.
Please consider running Runic (git runic master) to apply these changes.

Click here to view the suggested changes.
diff --git a/lib/cublas/wrappers.jl b/lib/cublas/wrappers.jl
index 43ddfeaea..a5eb81c92 100644
--- a/lib/cublas/wrappers.jl
+++ b/lib/cublas/wrappers.jl
@@ -115,7 +115,7 @@ for (fname, fname_64, elty) in ((:cublasDscal_v2, :cublasDscal_v2_64, :Float64),
                                 (:cublasCscal_v2, :cublasCscal_v2_64, :ComplexF32))
     @eval begin
         function scal!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        x::StridedCuVecOrDenseMat{$elty})
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -147,7 +147,7 @@ for (fname, fname_64, elty, celty) in ((:cublasCsscal_v2, :cublasCsscal_v2_64, :
                                        (:cublasZdscal_v2, :cublasZdscal_v2_64, :Float64, :ComplexF64))
     @eval begin
         function scal!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        x::StridedCuVecOrDenseMat{$celty})
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, alpha, x, stride(x, 1))
@@ -190,8 +190,8 @@ for (jname, fname, fname_64, elty) in ((:dot, :cublasDdot_v2, :cublasDdot_v2_64,
     @eval begin
         function $jname(n::Integer,
                         x::StridedCuVecOrDenseMat{$elty},
-                        y::StridedCuVecOrDenseMat{$elty},
-                        result::Ref{$elty},
+                y::StridedCuVecOrDenseMat{$elty},
+                result::Ref{$elty},
             )
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), result)
@@ -339,7 +339,7 @@ for (fname, fname_64, elty) in ((:cublasDaxpy_v2, :cublasDaxpy_v2_64, :Float64),
                                 (:cublasCaxpy_v2, :cublasCaxpy_v2_64, :ComplexF32))
     @eval begin
         function axpy!(n::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        dx::StridedCuVecOrDenseMat{$elty},
                        dy::StridedCuVecOrDenseMat{$elty})
             if CUBLAS.version() >= v"12.0"
@@ -400,9 +400,9 @@ for (fname, fname_64, elty, cty, sty) in (
         function rot!(n::Integer,
                       x::StridedCuVecOrDenseMat{$elty},
                       y::StridedCuVecOrDenseMat{$elty},
-                      c::Ref{$cty},
-                      s::Ref{$sty},
-                     )
+                c::Ref{$cty},
+                s::Ref{$sty},
+            )
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, x, stride(x, 1), y, stride(y, 1), c, s)
             else
@@ -473,9 +473,9 @@ for (fname, fname_64, elty) in ((:cublasIdamax_v2, :cublasIdamax_v2_64, :Float64
                                 (:cublasIcamax_v2, :cublasIcamax_v2_64, :ComplexF32))
     @eval begin
         function iamax(n::Integer,
-                       dx::StridedCuVecOrDenseMat{$elty},
-                       result::Ref{Ti},
-                      ) where {Ti <: Integer}
+                dx::StridedCuVecOrDenseMat{$elty},
+                result::Ref{Ti},
+            ) where {Ti <: Integer}
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, dx, stride(dx, 1), result)
             else
@@ -494,9 +494,9 @@ for (fname, fname_64, elty) in ((:cublasIdamin_v2, :cublasIdamin_v2_64, :Float64
                                 (:cublasIcamin_v2, :cublasIcamin_v2_64, :ComplexF32))
     @eval begin
         function iamin(n::Integer,
-                       dx::StridedCuVecOrDenseMat{$elty},
-                       result::Ref{Ti},
-                      ) where {Ti <: Integer}
+                dx::StridedCuVecOrDenseMat{$elty},
+                result::Ref{Ti},
+            ) where {Ti <: Integer}
             if CUBLAS.version() >= v"12.0"
                 $fname_64(handle(), n, dx, stride(dx, 1), result)
             else
@@ -530,10 +530,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
                                 (:cublasCgemv_v2, :cublasCgemv_v2_64, :ComplexF32))
     @eval begin
         function gemv!(trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             # handle trans
             m,n = size(A)
@@ -562,7 +562,7 @@ end
 function gemv(trans::Char, alpha::Ref{T}, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
     return gemv!(trans, alpha, A, x, CuRef{T}(zero(T)), similar(x, size(A, (trans == 'N' ? 1 : 2))))
 end
-function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
     gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2))))
 end
 # should this be async?
@@ -580,12 +580,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
     )
     @eval begin
         function gemv_batched!(trans::Char,
-                               alpha::Ref{$eltyconst},
-                               A::Vector{<:StridedCuMatrix{$eltyin}},
-                               x::Vector{<:StridedCuVector{$eltyin}},
-                               beta::Ref{$eltyconst},
-                               y::Vector{<:StridedCuVector{$eltyout}}
-                              )
+                alpha::Ref{$eltyconst},
+                A::Vector{<:StridedCuMatrix{$eltyin}},
+                x::Vector{<:StridedCuVector{$eltyin}},
+                beta::Ref{$eltyconst},
+                y::Vector{<:StridedCuVector{$eltyout}}
+            )
             if length(A) != length(x) || length(A) != length(y)
                 throw(DimensionMismatch("Lengths of inputs must be the same"))
             end
@@ -616,13 +616,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
             y
         end
         function gemv_batched!(
-                               trans::Char,
-                               alpha::Number,
-                               A::Vector{<:StridedCuMatrix{$eltyin}},
-                               x::Vector{<:StridedCuVector{$eltyin}},
-                               beta::Number,
-                               y::Vector{<:StridedCuVector{$eltyout}}
-                              )
+                trans::Char,
+                alpha::Number,
+                A::Vector{<:StridedCuMatrix{$eltyin}},
+                x::Vector{<:StridedCuVector{$eltyin}},
+                beta::Number,
+                y::Vector{<:StridedCuVector{$eltyout}}
+            )
             gpu_α = CuRef{$eltyconst}(alpha)
             gpu_β = CuRef{$eltyconst}(beta)
             y = gemv_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -642,12 +642,12 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
     )
     @eval begin
         function gemv_strided_batched!(trans::Char,
-                                       alpha::Ref{$eltyconst},
-                                       A::AbstractArray{$eltyin, 3},
-                                       x::AbstractArray{$eltyin, 2},
-                                       beta::Ref{$eltyconst},
-                                       y::AbstractArray{$eltyout, 2}
-                                      )
+                alpha::Ref{$eltyconst},
+                A::AbstractArray{$eltyin, 3},
+                x::AbstractArray{$eltyin, 2},
+                beta::Ref{$eltyconst},
+                y::AbstractArray{$eltyout, 2}
+            )
             if size(A, 3) != size(x, 2) || size(A, 3) != size(y, 2)
                 throw(DimensionMismatch("Batch sizes must be equal for all inputs"))
             end
@@ -672,13 +672,13 @@ for (fname, fname_64, eltyin, eltyout, eltyconst) in (
             y
         end
         function gemv_strided_batched!(
-                                       trans::Char,
-                                       alpha::Number,
-                                       A::AbstractArray{$eltyin, 3},
-                                       x::AbstractArray{$eltyin, 2},
-                                       beta::Number,
-                                       y::AbstractArray{$eltyout, 2}
-                                      )
+                trans::Char,
+                alpha::Number,
+                A::AbstractArray{$eltyin, 3},
+                x::AbstractArray{$eltyin, 2},
+                beta::Number,
+                y::AbstractArray{$eltyout, 2}
+            )
             gpu_α = CuRef{$eltyconst}(alpha)
             gpu_β = CuRef{$eltyconst}(beta)
             y = gemv_strided_batched!(trans, gpu_α, A, x, gpu_β, y)
@@ -698,10 +698,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
                        m::Integer,
                        kl::Integer,
                        ku::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             n = size(A,2)
             # check dimensions
@@ -717,16 +717,17 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
             end
             y
         end
-        function gbmv!(trans::Char,
-                       m::Integer,
-                       kl::Integer,
-                       ku::Integer,
-                       alpha::Number,
-                       A::StridedCuMatrix{$elty},
-                       x::StridedCuVector{$elty},
-                       beta::Number,
-                       y::StridedCuVector{$elty}
-                      )
+        function gbmv!(
+                trans::Char,
+                m::Integer,
+                kl::Integer,
+                ku::Integer,
+                alpha::Number,
+                A::StridedCuMatrix{$elty},
+                x::StridedCuVector{$elty},
+                beta::Number,
+                y::StridedCuVector{$elty}
+            )
 
             gpu_α = CuRef{$elty}(alpha)
             gpu_β = CuRef{$elty}(beta)
@@ -736,8 +737,10 @@ for (fname, fname_64, elty) in ((:cublasDgbmv_v2, :cublasDgbmv_v2_64, :Float64),
         end
     end
 end
-function gbmv(trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
-              A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
+function gbmv(
+        trans::Char, m::Integer, kl::Integer, ku::Integer, alpha::Ref{T},
+        A::StridedCuMatrix{T}, x::StridedCuVector{T}
+    ) where {T}
     # TODO: fix gbmv bug in julia
     n = size(A, 2)
     leny = trans == 'N' ? m : n
@@ -760,10 +763,10 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
                                 (:cublasSspmv_v2, :cublasSspmv_v2_64, :Float32))
     @eval begin
         function spmv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        AP::StridedCuVector{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             n = round(Int, (sqrt(8*length(AP))-1)/2)
             if n != length(x) || n != length(y) throw(DimensionMismatch("")) end
@@ -778,21 +781,24 @@ for (fname, fname_64, elty) in ((:cublasDspmv_v2, :cublasDspmv_v2_64, :Float64),
         end
     end
 end
-function spmv!(uplo::Char,
-               alpha::Number,
-               AP::StridedCuVector{T},
-               x::StridedCuVector{T},
-               beta::Number,
-               y::StridedCuVector{T}
-              ) where {T}
+function spmv!(
+        uplo::Char,
+        alpha::Number,
+        AP::StridedCuVector{T},
+        x::StridedCuVector{T},
+        beta::Number,
+        y::StridedCuVector{T}
+    ) where {T}
     gpu_α = CuRef{T}(alpha)
     gpu_β = CuRef{T}(beta)
     y = spmv!(uplo, gpu_α, AP, x, gpu_β, y)
     synchronize()
     return y
 end
-function spmv(uplo::Char, alpha::Ref{T},
-              AP::StridedCuVector{T}, x::StridedCuVector{T}) where {T}
+function spmv(
+        uplo::Char, alpha::Ref{T},
+        AP::StridedCuVector{T}, x::StridedCuVector{T}
+    ) where {T}
     return spmv!(uplo, alpha, AP, x, CuRef{T}(zero(T)), similar(x))
 end
 function spmv(uplo::Char, alpha::Number,
@@ -811,10 +817,10 @@ for (fname, fname_64, elty) in ((:cublasDsymv_v2, :cublasDsymv_v2_64, :Float64),
     # Note that the complex symv are not BLAS but auiliary functions in LAPACK
     @eval begin
         function symv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -865,10 +871,10 @@ for (fname, fname_64, elty) in ((:cublasZhemv_v2, :cublasZhemv_v2_64, :ComplexF6
                                 (:cublasChemv_v2, :cublasChemv_v2_64, :ComplexF32))
     @eval begin
         function hemv!(uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             # TODO: fix dimension check bug in julia
             m, n = size(A)
@@ -923,10 +929,10 @@ for (fname, fname_64, elty) in ((:cublasDsbmv_v2, :cublasDsbmv_v2_64, :Float64),
     @eval begin
         function sbmv!(uplo::Char,
                        k::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             #if m != n throw(DimensionMismatch("Matrix A is $m by $n but must be square")) end
@@ -982,10 +988,10 @@ for (fname, fname_64, elty) in ((:cublasZhbmv_v2, :cublasZhbmv_v2_64, :ComplexF6
     @eval begin
         function hbmv!(uplo::Char,
                        k::Integer,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        y::StridedCuVector{$elty})
             m, n = size(A)
             if !(1<=(1+k)<=n) throw(DimensionMismatch("Incorrect number of bands")) end
@@ -1169,7 +1175,7 @@ for (fname, fname_64, elty) in ((:cublasDger_v2, :cublasDger_v2_64, :Float64),
                                 (:cublasCgerc_v2, :cublasCgerc_v2_64, :ComplexF32))
     @eval begin
         function ger!(
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       y::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
@@ -1205,7 +1211,7 @@ for (fname, fname_64, elty) in ((:cublasDspr_v2, :cublasDspr_v2_64, :Float64),
                                 (:cublasSspr_v2, :cublasSspr_v2_64, :Float32))
     @eval begin
         function spr!(uplo::Char,
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       AP::StridedCuVector{$elty})
             n = round(Int, (sqrt(8*length(AP))-1)/2)
@@ -1239,7 +1245,7 @@ for (fname, fname_64, elty) in ((:cublasDsyr_v2, :cublasDsyr_v2_64, :Float64),
                                 (:cublasCsyr_v2, :cublasCsyr_v2_64, :ComplexF32))
     @eval begin
         function syr!(uplo::Char,
-                      alpha::Ref{$elty},
+                alpha::Ref{$elty},
                       x::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
             m, n = size(A)
@@ -1275,7 +1281,7 @@ for (fname, fname_64, elty, relty) in (
     )
     @eval begin
         function her!(uplo::Char,
-                      alpha::Ref{$relty},
+                alpha::Ref{$relty},
                       x::StridedCuVector{$elty},
                       A::StridedCuMatrix{$elty})
             m, n = size(A)
@@ -1309,11 +1315,11 @@ for (fname, fname_64, elty) in ((:cublasZher2_v2, :cublasZher2_v2_64, :ComplexF6
                                 (:cublasCher2_v2, :cublasCher2_v2_64, :ComplexF32))
     @eval begin
         function her2!(uplo::Char,
-                       alpha::Ref{$elty},
-                       x::StridedCuVector{$elty},
-                       y::StridedCuVector{$elty},
-                       A::StridedCuMatrix{$elty}
-                      )
+                alpha::Ref{$elty},
+                x::StridedCuVector{$elty},
+                y::StridedCuVector{$elty},
+                A::StridedCuMatrix{$elty}
+            )
             m, n = size(A)
             m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
             length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -1353,10 +1359,10 @@ for (fname, fname_64, elty) in ((:cublasDgemm_v2, :cublasDgemm_v2_64, :Float64),
     @eval begin
         function gemm!(transA::Char,
                        transB::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
                        B::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuVecOrMat{$elty})
             m = size(A, transA == 'N' ? 1 : 2)
             k = size(A, transA == 'N' ? 2 : 1)
@@ -1494,10 +1500,10 @@ function gemmExComputeType(TA, TB, TC, m, k, n)
 end
 
 function gemmEx!(transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::StridedCuVecOrMat),
                  @nospecialize(B::StridedCuVecOrMat),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::StridedCuVecOrMat);
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
     m = size(A, transA == 'N' ? 1 : 2)
@@ -1552,10 +1558,10 @@ end
 
 # TODO for device mode pointers
 function gemmBatchedEx!(transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::Vector{<:StridedCuVecOrMat}),
                  @nospecialize(B::Vector{<:StridedCuVecOrMat}),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::Vector{<:StridedCuVecOrMat});
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT)
     if length(A) != length(B) || length(A) != length(C)
@@ -1623,11 +1629,11 @@ function gemmBatchedEx!(
 end
 
 function gemmStridedBatchedEx!(
-                 transA::Char, transB::Char,
-                 @nospecialize(alpha::Ref),
+        transA::Char, transB::Char,
+        @nospecialize(alpha::Ref),
                  @nospecialize(A::AbstractArray{Ta, 3}),
                  @nospecialize(B::AbstractArray{Tb, 3}),
-                 @nospecialize(beta::Ref),
+        @nospecialize(beta::Ref),
                  @nospecialize(C::AbstractArray{Tc, 3});
                  algo::cublasGemmAlgo_t=CUBLAS_GEMM_DEFAULT) where {Ta, Tb, Tc}
     if size(A, 3) != size(B, 3) || size(A, 3) != size(C, 3)
@@ -1866,10 +1872,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmBatched, :cublasDgemmBatched_64, :F
     @eval begin
         function gemm_batched!(transA::Char,
                                transB::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::Vector{<:StridedCuMatrix{$elty}},
                                B::Vector{<:StridedCuMatrix{$elty}},
-                               beta::Ref{$elty},
+                beta::Ref{$elty},
                                C::Vector{<:StridedCuMatrix{$elty}})
             if length(A) != length(B) || length(A) != length(C)
                 throw(DimensionMismatch(""))
@@ -1949,10 +1955,10 @@ for (fname, fname_64, elty) in ((:cublasDgemmStridedBatched, :cublasDgemmStrided
     @eval begin
         function gemm_strided_batched!(transA::Char,
                                transB::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::AbstractArray{$elty, 3}, # allow PermutedDimsArray
                                B::AbstractArray{$elty, 3},
-                               beta::Ref{$elty},
+                beta::Ref{$elty},
                                C::AbstractArray{$elty, 3})
            m = size(A, transA == 'N' ? 1 : 2)
            k = size(A, transA == 'N' ? 2 : 1)
@@ -2032,10 +2038,10 @@ for (fname, fname_64, elty) in ((:cublasDsymm_v2, :cublasDsymm_v2_64, :Float64),
     @eval begin
         function symm!(side::Char,
                        uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             k, nA = size(A)
             if k != nA throw(DimensionMismatch("Matrix A must be square")) end
@@ -2094,9 +2100,9 @@ for (fname, fname_64, elty) in ((:cublasDsyrk_v2, :cublasDsyrk_v2_64, :Float64),
     @eval begin
         function syrk!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2147,10 +2153,10 @@ for (fname, fname_64, elty) in ((:cublasDsyrkx, :cublasDsyrkx_64, :Float64),
     @eval begin
         function syrkx!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuVecOrMat{$elty},
                        B::StridedCuVecOrMat{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2206,10 +2212,10 @@ for (fname, fname_64, elty) in ((:cublasZhemm_v2, :cublasZhemm_v2_64, :ComplexF6
     @eval begin
         function hemm!(side::Char,
                        uplo::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        C::StridedCuMatrix{$elty})
             mA, nA = size(A)
             m, n = size(B)
@@ -2269,9 +2275,9 @@ for (fname, fname_64, elty, relty) in (
     @eval begin
         function herk!(uplo::Char,
                        trans::Char,
-                       alpha::Ref{$relty},
+                alpha::Ref{$relty},
                        A::StridedCuVecOrMat{$elty},
-                       beta::Ref{$relty},
+                beta::Ref{$relty},
                        C::StridedCuMatrix{$elty})
             mC, n = size(C)
             if mC != n throw(DimensionMismatch("C must be square")) end
@@ -2328,10 +2334,10 @@ for (fname, fname_64, elty) in ((:cublasDsyr2k_v2, :cublasDsyr2k_v2_64, :Float64
     @eval begin
         function syr2k!(uplo::Char,
                         trans::Char,
-                        alpha::Ref{$elty},
+                alpha::Ref{$elty},
                         A::StridedCuVecOrMat{$elty},
                         B::StridedCuVecOrMat{$elty},
-                        beta::Ref{$elty},
+                beta::Ref{$elty},
                         C::StridedCuMatrix{$elty})
             # TODO: check size of B in julia (syr2k!)
             m, n = size(C)
@@ -2387,7 +2393,7 @@ function syr2k(uplo::Char,
                B::StridedCuVecOrMat)
     T = eltype(A)
     n = size(A, trans == 'N' ? 1 : 2)
-    syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
+    return syr2k!(uplo, trans, convert(T, alpha), A, B, zero(T), similar(A, T, (n, n)))
 end
 function syr2k(uplo::Char, trans::Char, A::StridedCuVecOrMat, B::StridedCuVecOrMat)
     syr2k(uplo, trans, one(eltype(A)), A, B)
@@ -2401,10 +2407,10 @@ for (fname, fname_64, elty, relty) in (
     @eval begin
         function her2k!(uplo::Char,
                         trans::Char,
-                        alpha::Ref{$elty},
+                alpha::Ref{$elty},
                         A::StridedCuVecOrMat{$elty},
                         B::StridedCuVecOrMat{$elty},
-                        beta::Ref{$relty},
+                beta::Ref{$relty},
                         C::StridedCuMatrix{$elty})
             # TODO: check size of B in julia (her2k!)
             m, n = size(C)
@@ -2478,7 +2484,7 @@ for (mmname, smname, elty) in
                        uplo::Char,
                        transa::Char,
                        diag::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty},
                        C::StridedCuMatrix{$elty})
@@ -2500,7 +2506,7 @@ for (mmname, smname, elty) in
                        uplo::Char,
                        transa::Char,
                        diag::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
                        B::StridedCuMatrix{$elty})
             m, n = size(B)
@@ -2565,7 +2571,7 @@ for (fname, fname_64, elty) in ((:cublasDtrsmBatched, :cublasDtrsmBatched_64, :F
                                uplo::Char,
                                transa::Char,
                                diag::Char,
-                               alpha::Ref{$elty},
+                alpha::Ref{$elty},
                                A::Vector{<:StridedCuMatrix{$elty}},
                                B::Vector{<:StridedCuMatrix{$elty}})
             if length(A) != length(B)
@@ -2621,9 +2627,9 @@ for (fname, fname_64, elty) in ((:cublasDgeam, :cublasDgeam_64, :Float64),
     @eval begin
         function geam!(transa::Char,
                        transb::Char,
-                       alpha::Ref{$elty},
+                alpha::Ref{$elty},
                        A::StridedCuMatrix{$elty},
-                       beta::Ref{$elty},
+                beta::Ref{$elty},
                        B::StridedCuMatrix{$elty},
                        C::StridedCuMatrix{$elty})
             mA, nA = size(A)
@@ -2861,8 +2867,9 @@ for (fname, elty) in ((:cublasDgetriBatched, :Float64),
         end
 
         function getri_batched!(n, Aptrs::CuVector{CuPtr{$elty}},
-                                lda, Cptrs::CuVector{CuPtr{$elty}},ldc,
-                                pivotArray::CuArray{Cint})
+                lda, Cptrs::CuVector{CuPtr{$elty}}, ldc,
+                pivotArray::CuArray{Cint}
+            )
             batchSize = length(Aptrs)
             info = CuArray{Cint}(undef, batchSize)
             $fname(handle(), n, Aptrs, lda, pivotArray, Cptrs, ldc, info, batchSize)
diff --git a/test/libraries/cublas/level3.jl b/test/libraries/cublas/level3.jl
index 65b71c6a4..caf245a5b 100644
--- a/test/libraries/cublas/level3.jl
+++ b/test/libraries/cublas/level3.jl
@@ -352,12 +352,12 @@ k = 13
             @testset "herk!" begin
                 alpha = rand(elty)
                 beta = rand(elty)
-                A = rand(elty,m,m)
+                A = rand(elty, m, m)
                 hA = A + A'
                 d_A = CuArray(A)
                 d_C = CuArray(hA)
-                CUBLAS.herk!('U','N',real(alpha),d_A,real(beta),d_C)
-                C = real(alpha)*(A*A') + real(beta)*hA
+                CUBLAS.herk!('U', 'N', real(alpha), d_A, real(beta), d_C)
+                C = real(alpha) * (A * A') + real(beta) * hA
                 C = triu(C)
                 # move to host and compare
                 h_C = Array(d_C)
@@ -365,10 +365,10 @@ k = 13
                 @test C ≈ h_C
             end
             @testset "herk" begin
-                A = rand(elty,m,m)
+                A = rand(elty, m, m)
                 d_A = CuArray(A)
-                d_C = CUBLAS.herk('U','N',d_A)
-                C = A*A'
+                d_C = CUBLAS.herk('U', 'N', d_A)
+                C = A * A'
                 C = triu(C)
                 # move to host and compare
                 h_C = Array(d_C)
diff --git a/test/libraries/cublas/level3_gemm.jl b/test/libraries/cublas/level3_gemm.jl
index 6e04e8c42..6cccc2967 100644
--- a/test/libraries/cublas/level3_gemm.jl
+++ b/test/libraries/cublas/level3_gemm.jl
@@ -220,12 +220,12 @@ k = 13
             sB = sB + transpose(sB)
 
             for (TRa, ta, TRb, tb, TRc, a_func, b_func) in (
-                (UpperTriangular, identity,  LowerTriangular, identity,  Matrix, triu, tril),
-                (LowerTriangular, identity,  UpperTriangular, identity,  Matrix, tril, triu),
-                (UpperTriangular, identity,  UpperTriangular, transpose, Matrix, triu, triu),
-                (UpperTriangular, transpose, UpperTriangular, identity,  Matrix, triu, triu),
-                (LowerTriangular, identity,  LowerTriangular, transpose, Matrix, tril, tril),
-                (LowerTriangular, transpose, LowerTriangular, identity,  Matrix, tril, tril),
+                    (UpperTriangular, identity, LowerTriangular, identity, Matrix, triu, tril),
+                    (LowerTriangular, identity, UpperTriangular, identity, Matrix, tril, triu),
+                    (UpperTriangular, identity, UpperTriangular, transpose, Matrix, triu, triu),
+                    (UpperTriangular, transpose, UpperTriangular, identity, Matrix, triu, triu),
+                    (LowerTriangular, identity, LowerTriangular, transpose, Matrix, tril, tril),
+                    (LowerTriangular, transpose, LowerTriangular, identity, Matrix, tril, tril),
                 )
 
                 A = copy(sA) |> TRa

@maleadt
Copy link
Member

maleadt commented Jan 20, 2025

CI failures seem relevant.

Feel free to ignore the formatter; I made it less spammy 😉

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 23, 2025

I really do not know what is up with the 1.11 failure, it looks alloc_cache related?

@maleadt
Copy link
Member

maleadt commented Jan 25, 2025

Rebase to get rid of CI failures?

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 25, 2025 via email

@kshyatt kshyatt force-pushed the ksh/device_side branch 2 times, most recently from 804a967 to bcd41c1 Compare January 25, 2025 21:47
@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 25, 2025

Gotta admit I'm a bit mystified here as I cannot reproduce these trmm faliures locally.

If I run only the libraries/cublas tests or even just libraries using the runtests.jl argument support, everything succeeds locally. If I run the full test suite, I start seeing intermittent illegal access errors/incorrect results in syr2k!. Weird!

lib/cublas/wrappers.jl Outdated Show resolved Hide resolved
@maleadt
Copy link
Member

maleadt commented Jan 28, 2025

We could simplify the implementation here by switching the CUBLAS signatures from RefOrCuRef to simply CuRef (which IIUC is the only legal option now that we switched to DEVICE_POINTER mode), so that passing a scalar is automatically converted to a CuRef:

--- a/lib/cublas/wrappers.jl
+++ b/lib/cublas/wrappers.jl
@@ -530,10 +530,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
                                 (:cublasCgemv_v2, :cublasCgemv_v2_64, :ComplexF32))
     @eval begin
         function gemv!(trans::Char,
-                       alpha::Ref{$elty},
+                       alpha,
                        A::StridedCuMatrix{$elty},
                        x::StridedCuVector{$elty},
-                       beta::Ref{$elty},
+                       beta,
                        y::StridedCuVector{$elty})
             # handle trans
             m,n = size(A)
@@ -552,23 +552,10 @@ for (fname, fname_64, elty) in ((:cublasDgemv_v2, :cublasDgemv_v2_64, :Float64),
         end
     end
 end
-function gemv!(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}, beta::Number, y::StridedCuVector{T}) where {T}
-    gpu_α = CuRef{T}(alpha)
-    gpu_β = CuRef{T}(beta)
-    y = gemv!(trans, gpu_α, A, x, gpu_β, y)
-    synchronize()
-    return y
-end
-function gemv(trans::Char, alpha::Ref{T}, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T}
-    return gemv!(trans, alpha, A, x, CuRef{T}(zero(T)), similar(x, size(A, (trans == 'N' ? 1 : 2))))
-end
-function gemv(trans::Char, alpha::Number, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+gemv(trans::Char, alpha, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T} =
     gemv!(trans, alpha, A, x, zero(T), similar(x, size(A, (trans == 'N' ? 1 : 2))))
-end
-# should this be async?
-function gemv(trans::Char, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where T
+gemv(trans::Char, A::StridedCuMatrix{T}, x::StridedCuVector{T}) where {T} =
     gemv!(trans, one(T), A, x, zero(T), similar(x, T, size(A, (trans == 'N' ? 1 : 2))))
-end

 for (fname, fname_64, eltyin, eltyout, eltyconst) in (
         (:cublasDgemvBatched, :cublasDgemvBatched_64, :Float64, :Float64, :Float64),

etc for the others. Thoughts?

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 28, 2025

We could simplify the implementation here by switching the CUBLAS signatures from RefOrCuRef to simply CuRef (which IIUC is the only legal option now that we switched to DEVICE_POINTER mode), so that passing a scalar is automatically converted to a CuRef

Yeah, that makes sense. I was biased towards keeping the changes more "conservative" and perhaps changing handle() to allow users to set the pointer mode, but if we want to go "device side by default" then this would make life easier.

@kshyatt
Copy link
Contributor Author

kshyatt commented Jan 29, 2025

In the latest commit, I applied @maleadt 's suggestions widely. This really shortens the diff and tests passed locally. @Jutho , this should also allow your use-case of passing CuArray{T, 0} for the scalars as we have an auto-conversion of those to CuRef. If this run passes, I'll update the wrapper-generator signatures too.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cuda libraries Stuff about CUDA library wrappers.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants