Skip to content

Commit

Permalink
implement syncscopes everywhere (#20)
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy authored Dec 10, 2024
1 parent b099010 commit 98bb4ab
Show file tree
Hide file tree
Showing 10 changed files with 297 additions and 352 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
UnsafeAtomicsLLVM = ["LLVM"]

[compat]
LLVM = "8.1, 9"
LLVM = "9"
julia = "1.10"

[extras]
Expand Down
233 changes: 29 additions & 204 deletions ext/UnsafeAtomicsLLVM/atomics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,68 +89,12 @@ for (julia, llvm) in pairs(_llvm_from_julia_ordering)
@eval llvm_from_julia_ordering(::Val{$(QuoteNode(julia))}) = Val{$llvm}()
end

"""
@dynamic_order(order) do order
... use order ...
end
It is expanded to an expression similar to:
if order === :not_atomic
let order = Val(:not_atomic)
... use order ...
end
elseif order === :unordered
let order = Val(:unordered)
... use order ...
end
elseif ...
...
else
throw(ConcurrencyViolationError(...))
end
This is used for helping the compiler to optimize expressions such as
`atomic_pointerref(ptr, :monotonic)` and also to avoid abstract run-time dispatch.
"""
macro dynamic_order(thunk, order)
@assert Meta.isexpr(thunk, :->, 2) && Meta.isexpr(thunk.args[1], :tuple, 1)
ordervar = esc(thunk.args[1].args[1])
body = esc(thunk.args[2])
expr = foldr(
keys(_llvm_from_julia_ordering),
init = :(throw(ConcurrencyViolationError("invalid atomic ordering: ", order))),
) do key, r
quote
if order === $(QuoteNode(key))
let $ordervar = Val{$(QuoteNode(key))}()
$body
end
else
$r
end
end
end
quote
order = $(esc(order))
$expr
end
end

_valueof(::Val{x}) where {x} = x

@inline function atomic_pointerref(pointer, order::Symbol)
@dynamic_order(order) do order
atomic_pointerref(pointer, order)
end
end

@inline function atomic_pointerset(pointer, x, order::Symbol)
@dynamic_order(order) do order
atomic_pointerset(pointer, x, order)
end
end

@generated function atomic_pointerref(ptr::LLVMPtr{T,A}, order::AllOrdering) where {T,A}
@generated function atomic_pointerref(ptr::LLVMPtr{T,A}, order::AllOrdering, sync) where {T,A}
sizeof(T) == 0 && return T.instance
llvm_order = _valueof(llvm_from_julia_ordering(order()))
llvm_order = _valueof(llvm_from_julia_ordering(order()))
llvm_syncscope = _valueof(sync())
@dispose ctx = Context() begin
eltyp = convert(LLVMType, T)

Expand All @@ -170,6 +114,7 @@ end
typed_ptr = bitcast!(builder, parameters(llvm_f)[1], T_typed_ptr)
ld = load!(builder, eltyp, typed_ptr)
ordering!(ld, llvm_order)
syncscope!(ld, SyncScope(string(llvm_syncscope)))

if A != 0
metadata(ld)[LLVM.MD_tbaa] = tbaa_addrspace(A)
Expand All @@ -187,6 +132,7 @@ end
ptr::LLVMPtr{T,A},
x::T,
order::AllOrdering,
sync,
) where {T,A}
if sizeof(T) == 0
# Mimicking what `Core.Intrinsics.atomic_pointerset` generates.
Expand All @@ -197,7 +143,8 @@ end
return ptr
end
end
llvm_order = _valueof(llvm_from_julia_ordering(order()))
llvm_order = _valueof(llvm_from_julia_ordering(order()))
llvm_syncscope = _valueof(sync())
@dispose ctx = Context() begin
eltyp = convert(LLVMType, T)
T_ptr = convert(LLVMType, ptr)
Expand All @@ -216,6 +163,7 @@ end
val = parameters(llvm_f)[2]
st = store!(builder, val, typed_ptr)
ordering!(st, llvm_order)
syncscope!(st, SyncScope(string(llvm_syncscope)))

if A != 0
metadata(st)[LLVM.MD_tbaa] = tbaa_addrspace(A)
Expand Down Expand Up @@ -254,14 +202,12 @@ const binoptable = [

const AtomicRMWBinOpVal = Union{(Val{binop} for (_, _, binop) in binoptable)...}

# LLVM API accepts string literal as a syncscope argument.
@inline syncscope_to_string(::Type{Val{S}}) where {S} = string(S)

@generated function llvm_atomic_op(
binop::AtomicRMWBinOpVal,
ptr::LLVMPtr{T,A},
val::T,
order::LLVMOrderingVal,
sync,
) where {T,A}
@dispose ctx = Context() begin
T_val = convert(LLVMType, T)
Expand All @@ -270,21 +216,21 @@ const AtomicRMWBinOpVal = Union{(Val{binop} for (_, _, binop) in binoptable)...}
T_typed_ptr = LLVM.PointerType(T_val, A)

llvm_f, _ = create_function(T_val, [T_ptr, T_val])
llvm_syncscope = _valueof(sync())

@dispose builder = IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)

typed_ptr = bitcast!(builder, parameters(llvm_f)[1], T_typed_ptr)

single_threaded = false
rv = atomic_rmw!(
builder,
_valueof(binop()),
typed_ptr,
parameters(llvm_f)[2],
_valueof(order()),
single_threaded,
SyncScope(string(llvm_syncscope))
)

ret!(builder, rv)
Expand All @@ -294,139 +240,40 @@ const AtomicRMWBinOpVal = Union{(Val{binop} for (_, _, binop) in binoptable)...}
end
end

@generated function llvm_atomic_op(
binop::AtomicRMWBinOpVal,
ptr::LLVMPtr{T,A},
val::T,
order::LLVMOrderingVal,
syncscope::Val{S},
) where {T,A,S}
@dispose ctx = Context() begin
T_val = convert(LLVMType, T)
T_ptr = convert(LLVMType, ptr)

T_typed_ptr = LLVM.PointerType(T_val, A)
llvm_f, _ = create_function(T_val, [T_ptr, T_val])

@dispose builder = IRBuilder() begin
entry = BasicBlock(llvm_f, "entry")
position!(builder, entry)

typed_ptr = bitcast!(builder, parameters(llvm_f)[1], T_typed_ptr)
rv = atomic_rmw!(
builder,
_valueof(binop()),
typed_ptr,
parameters(llvm_f)[2],
_valueof(order()),
syncscope_to_string(syncscope),
)

ret!(builder, rv)
end
call_function(llvm_f, T, Tuple{LLVMPtr{T,A},T}, :ptr, :val)
end
end

@inline function atomic_pointermodify(pointer, op::OP, x, order::Symbol) where {OP}
@dynamic_order(order) do order
atomic_pointermodify(pointer, op, x, order)
end
end

@inline function atomic_pointermodify(
ptr::LLVMPtr{T},
op,
x::T,
::Val{:not_atomic},
) where {T}
old = atomic_pointerref(ptr, Val(:not_atomic))
new = op(old, x)
atomic_pointerset(ptr, new, Val(:not_atomic))
return old => new
end

@inline function atomic_pointermodify(
ptr::LLVMPtr{T},
::typeof(right),
x::T,
order::AtomicOrdering,
) where {T}
sync::Val{S}
) where {T, S}
old = llvm_atomic_op(
Val(LLVM.API.LLVMAtomicRMWBinOpXchg),
ptr,
x,
llvm_from_julia_ordering(order),
sync
)
return old => x
end

const atomictypes = Any[
Int8,
Int16,
Int32,
Int64,
Int128,
UInt8,
UInt16,
UInt32,
UInt64,
UInt128,
Float16,
Float32,
Float64,
]

for (opname, op, llvmop) in binoptable
opname === :xchg && continue
types = if opname in (:min, :max)
filter(t -> t <: Signed, atomictypes)
elseif opname in (:umin, :umax)
filter(t -> t <: Unsigned, atomictypes)
elseif opname in (:fadd, :fsub, :fmin, :fmax)
filter(t -> t <: AbstractFloat, atomictypes)
else
filter(t -> t <: Integer, atomictypes)
end
for T in types
@eval @inline function atomic_pointermodify(
ptr::LLVMPtr{$T},
::$(typeof(op)),
x::$T,
order::AtomicOrdering,
syncscope::Val{S} = Val{:system}(),
) where {S}
old =
syncscope isa Val{:system} ?
llvm_atomic_op($(Val(llvmop)), ptr, x, llvm_from_julia_ordering(order)) :
llvm_atomic_op(
$(Val(llvmop)),
ptr,
x,
llvm_from_julia_ordering(order),
syncscope,
)
return old => $op(old, x)
end
end
end

@inline atomic_pointerswap(pointer, new) = first(atomic_pointermodify(pointer, right, new))
@inline atomic_pointerswap(pointer, new, order) =
first(atomic_pointermodify(pointer, right, new, order))
# @inline atomic_pointerswap(pointer, new) = first(atomic_pointermodify(pointer, right, new))
@inline atomic_pointerswap(pointer, new, order, sync) =
first(atomic_pointermodify(pointer, right, new, order, sync))

@inline function atomic_pointermodify(
ptr::LLVMPtr{T},
op,
x::T,
order::AllOrdering,
) where {T}
sync::S,
) where {T, S}
# Should `fail_order` be stronger? Ref: https://github.com/JuliaLang/julia/issues/45256
fail_order = Val(:monotonic)
old = atomic_pointerref(ptr, fail_order)
old = atomic_pointerref(ptr, fail_order, sync)
while true
new = op(old, x)
(old, success) = atomic_pointerreplace(ptr, old, new, order, fail_order)
(old, success) = atomic_pointerreplace(ptr, old, new, order, fail_order, sync)
success && return old => new
end
end
Expand All @@ -437,9 +284,11 @@ end
val::T,
success_order::LLVMOrderingVal,
fail_order::LLVMOrderingVal,
sync,
) where {T,A}
llvm_success = _valueof(success_order())
llvm_fail = _valueof(fail_order())
llvm_syncscope = _valueof(sync())
@dispose ctx = Context() begin
T_val = convert(LLVMType, T)
T_pointee = T_val
Expand Down Expand Up @@ -471,15 +320,14 @@ end
val_int = bitcast!(builder, val_int, T_pointee)
end

single_threaded = false
res = atomic_cmpxchg!(
builder,
typed_ptr,
cmp_int,
val_int,
llvm_success,
llvm_fail,
single_threaded,
SyncScope(string(llvm_syncscope)),
)

rv = extract_value!(builder, res, 0)
Expand Down Expand Up @@ -514,42 +362,17 @@ end
end
end

@inline function atomic_pointerreplace(
pointer,
expected,
desired,
success_order::Symbol,
fail_order::Symbol,
)
# This avoids abstract dispatch at run-time but probably too much codegen?
#=
@dynamic_order(success_order) do success_order
@dynamic_order(fail_order) do fail_order
atomic_pointerreplace(pointer, expected, desired, success_order, fail_order)
end
end
=#

# This avoids excessive codegen while hopefully imposes no cost when const-prop works:
so = @dynamic_order(success_order) do success_order
success_order
end
fo = @dynamic_order(fail_order) do fail_order
fail_order
end
return atomic_pointerreplace(pointer, expected, desired, so, fo)
end

@inline function atomic_pointerreplace(
ptr::LLVMPtr{T},
expected::T,
desired::T,
::Val{:not_atomic},
::Val{:not_atomic},
sync,
) where {T}
old = atomic_pointerref(ptr, Val(:not_atomic))
old = atomic_pointerref(ptr, Val(:not_atomic), sync)
if old === expected
atomic_pointerset(ptr, desired, Val(:not_atomic))
atomic_pointerset(ptr, desired, Val(:not_atomic), sync)
success = true
else
success = false
Expand All @@ -563,10 +386,12 @@ end
desired::T,
success_order::_julia_ordering(((:not_atomic, :unordered))),
fail_order::_julia_ordering(((:not_atomic, :unordered, :release, :acquire_release))),
sync
) where {T} = llvm_atomic_cas(
ptr,
expected,
desired,
llvm_from_julia_ordering(success_order),
llvm_from_julia_ordering(fail_order),
sync
)
Loading

0 comments on commit 98bb4ab

Please sign in to comment.