Skip to content

Commit

Permalink
WIP:Try out gpuc.deferred.with
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Sep 26, 2024
1 parent 60fa61f commit 4f14a5e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 26 deletions.
65 changes: 39 additions & 26 deletions examples/jit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,31 +116,39 @@ function get_trampoline(job)
return addr
end

# import GPUCompiler: deferred_codegen_jobs
# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world}
# # manual version of native_job because we have a function type
# source = methodinstance(F, Base.to_tuple_type(tt), world)
# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# # XXX: do we actually require the Julia runtime?
# # with jlruntime=false, we reach an unreachable.
# params = TestCompilerParams()
# config = CompilerConfig(target, params; kernel=false)
# job = CompilerJob(source, config, world)
# # XXX: invoking GPUCompiler from a generated function is not allowed!
# # for things to work, we need to forward the correct world, at least.

# addr = get_trampoline(job)
# trampoline = pointer(addr)
# id = Base.reinterpret(Int, trampoline)

# deferred_codegen_jobs[id] = job

# quote
# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline)
# assume(ptr != C_NULL)
# return ptr
# end
# end
const runtime_cache = Dict{Any, Ptr{Cvoid}}()

function compiler(job)
JuliaContext() do _
ir, meta = GPUCompiler.compile(:llvm, job; validate=false)
# So 1. serialize the module
buf = convert(MemoryBuffer, ir)
buf, LLVM.name(meta.entry)
end
end

function linker(_, (buf, entry_fn))
compiler = jit[]
lljit = compiler.jit
jd = JITDylib(lljit)

# 2. deserialize and wrap by a ThreadSafeModule
ThreadSafeContext() do ts_ctx
tsm = context!(context(ts_ctx)) do
mod = parse(LLVM.Module, buf)
ThreadSafeModule(mod)
end

LLVM.add!(lljit, jd, tsm)
end
addr = LLVM.lookup(lljit, entry_fn)
pointer(addr)
end

function GPUCompiler.var"gpuc.deferred.with"(config::GPUCompiler.CompilerConfig{<:NativeCompilerTarget}, f::F, args...) where F
source = methodinstance(F, Base.to_tuple_type(typeof(args)))
GPUCompiler.cached_compilation(runtime_cache, source, config, compiler, linker)::Ptr{Cvoid}
end

@generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N}
argtt = tt.parameters[1]
Expand Down Expand Up @@ -226,7 +234,12 @@ end
rt = Core.Compiler.return_type(f, tt)
# FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work
# But that will only be needed here, and in Enzyme...
ptr = GPUCompiler.var"gpuc.deferred"(f, args...)
target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true)
# XXX: do we actually require the Julia runtime?
# with jlruntime=false, we reach an unreachable.
params = TestCompilerParams()
config = CompilerConfig(target, params; kernel=false)
ptr = GPUCompiler.var"gpuc.deferred.with"(config, f, args...)
abi_call(ptr, rt, tt, f, args...)
end

Expand Down
12 changes: 12 additions & 0 deletions src/driver.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,20 @@ end

## deferred compilation

"""
var"gpuc.deferred"(f, args...)::Ptr{Cvoid}
As if we were to call `f(args...)` but instead we are
putting down a marker and return a function pointer to later
call.
"""
function var"gpuc.deferred" end

"""
var"gpuc.deferred,with"(config::CompilerConfig, f, args...)::Ptr{Cvoid}
"""
function var"gpuc.deferred.with" end

## compiler entrypoint

export compile
Expand Down

0 comments on commit 4f14a5e

Please sign in to comment.