From 4f14a5e88d76784ebd2aa0b5beb9c6e33cd69c36 Mon Sep 17 00:00:00 2001 From: Valentin Churavy Date: Fri, 9 Aug 2024 16:43:13 +0200 Subject: [PATCH] WIP:Try out gpuc.deferred.with --- examples/jit.jl | 65 +++++++++++++++++++++++++++++-------------------- src/driver.jl | 12 +++++++++ 2 files changed, 51 insertions(+), 26 deletions(-) diff --git a/examples/jit.jl b/examples/jit.jl index 6f0a259c..86fd467c 100644 --- a/examples/jit.jl +++ b/examples/jit.jl @@ -116,31 +116,39 @@ function get_trampoline(job) return addr end -# import GPUCompiler: deferred_codegen_jobs -# @generated function deferred_codegen(f::F, ::Val{tt}, ::Val{world}) where {F,tt,world} -# # manual version of native_job because we have a function type -# source = methodinstance(F, Base.to_tuple_type(tt), world) -# target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) -# # XXX: do we actually require the Julia runtime? -# # with jlruntime=false, we reach an unreachable. -# params = TestCompilerParams() -# config = CompilerConfig(target, params; kernel=false) -# job = CompilerJob(source, config, world) -# # XXX: invoking GPUCompiler from a generated function is not allowed! -# # for things to work, we need to forward the correct world, at least. - -# addr = get_trampoline(job) -# trampoline = pointer(addr) -# id = Base.reinterpret(Int, trampoline) - -# deferred_codegen_jobs[id] = job - -# quote -# ptr = ccall("extern deferred_codegen", llvmcall, Ptr{Cvoid}, (Ptr{Cvoid},), $trampoline) -# assume(ptr != C_NULL) -# return ptr -# end -# end +const runtime_cache = Dict{Any, Ptr{Cvoid}}() + +function compiler(job) + JuliaContext() do _ + ir, meta = GPUCompiler.compile(:llvm, job; validate=false) + # So 1. serialize the module + buf = convert(MemoryBuffer, ir) + buf, LLVM.name(meta.entry) + end +end + +function linker(_, (buf, entry_fn)) + compiler = jit[] + lljit = compiler.jit + jd = JITDylib(lljit) + + # 2. deserialize and wrap by a ThreadSafeModule + ThreadSafeContext() do ts_ctx + tsm = context!(context(ts_ctx)) do + mod = parse(LLVM.Module, buf) + ThreadSafeModule(mod) + end + + LLVM.add!(lljit, jd, tsm) + end + addr = LLVM.lookup(lljit, entry_fn) + pointer(addr) +end + +function GPUCompiler.var"gpuc.deferred.with"(config::GPUCompiler.CompilerConfig{<:NativeCompilerTarget}, f::F, args...) where F + source = methodinstance(F, Base.to_tuple_type(typeof(args))) + GPUCompiler.cached_compilation(runtime_cache, source, config, compiler, linker)::Ptr{Cvoid} +end @generated function abi_call(f::Ptr{Cvoid}, rt::Type{RT}, tt::Type{T}, func::F, args::Vararg{Any, N}) where {T, RT, F, N} argtt = tt.parameters[1] @@ -226,7 +234,12 @@ end rt = Core.Compiler.return_type(f, tt) # FIXME: Horrible idea, have `var"gpuc.deferred"` actually do the work # But that will only be needed here, and in Enzyme... - ptr = GPUCompiler.var"gpuc.deferred"(f, args...) + target = NativeCompilerTarget(; jlruntime=true, llvm_always_inline=true) + # XXX: do we actually require the Julia runtime? + # with jlruntime=false, we reach an unreachable. + params = TestCompilerParams() + config = CompilerConfig(target, params; kernel=false) + ptr = GPUCompiler.var"gpuc.deferred.with"(config, f, args...) abi_call(ptr, rt, tt, f, args...) end diff --git a/src/driver.jl b/src/driver.jl index 21040745..de6a0810 100644 --- a/src/driver.jl +++ b/src/driver.jl @@ -41,8 +41,20 @@ end ## deferred compilation +""" + var"gpuc.deferred"(f, args...)::Ptr{Cvoid} + +As if we were to call `f(args...)` but instead we are +putting down a marker and return a function pointer to later +call. +""" function var"gpuc.deferred" end +""" + var"gpuc.deferred,with"(config::CompilerConfig, f, args...)::Ptr{Cvoid} +""" +function var"gpuc.deferred.with" end + ## compiler entrypoint export compile