Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support adding CodeInstances to JIT for interpreters defining a codegen cache #57272

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Compiler/extras/CompilerDevTools/src/CompilerDevTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,13 @@ struct SplitCacheInterp <: Compiler.AbstractInterpreter
inf_params::Compiler.InferenceParams
opt_params::Compiler.OptimizationParams
inf_cache::Vector{Compiler.InferenceResult}
codegen_cache::IdDict{CodeInstance,CodeInfo}
function SplitCacheInterp(;
world::UInt = Base.get_world_counter(),
inf_params::Compiler.InferenceParams = Compiler.InferenceParams(),
opt_params::Compiler.OptimizationParams = Compiler.OptimizationParams(),
inf_cache::Vector{Compiler.InferenceResult} = Compiler.InferenceResult[])
new(world, inf_params, opt_params, inf_cache)
new(world, inf_params, opt_params, inf_cache, IdDict{CodeInstance,CodeInfo}())
end
end

Expand All @@ -23,10 +24,11 @@ Compiler.OptimizationParams(interp::SplitCacheInterp) = interp.opt_params
Compiler.get_inference_world(interp::SplitCacheInterp) = interp.world
Compiler.get_inference_cache(interp::SplitCacheInterp) = interp.inf_cache
Compiler.cache_owner(::SplitCacheInterp) = SplitCacheOwner()
Compiler.codegen_cache(interp::SplitCacheInterp) = interp.codegen_cache

import Core.OptimizedGenerics.CompilerPlugins: typeinf, typeinf_edge
@eval @noinline typeinf(::SplitCacheOwner, mi::MethodInstance, source_mode::UInt8) =
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)
Base.invoke_in_world(which(typeinf, Tuple{SplitCacheOwner, MethodInstance, UInt8}).primary_world, Compiler.typeinf_ext_toplevel, SplitCacheInterp(; world=Base.tls_world_age()), mi, source_mode)

@eval @noinline function typeinf_edge(::SplitCacheOwner, mi::MethodInstance, parent_frame::Compiler.InferenceState, world::UInt, source_mode::UInt8)
# TODO: This isn't quite right, we're just sketching things for now
Expand Down
85 changes: 50 additions & 35 deletions Compiler/src/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ function finish!(interp::AbstractInterpreter, caller::InferenceState, validation
ci, inferred_result, const_flag, first(result.valid_worlds), last(result.valid_worlds), encode_effects(result.ipo_effects),
result.analysis_results, di, edges)
engine_reject(interp, ci)
if !discard_src && isdefined(interp, :codegen) && uncompressed isa CodeInfo
codegen = codegen_cache(interp)
if !discard_src && codegen !== nothing && uncompressed isa CodeInfo
# record that the caller could use this result to generate code when required, if desired, to avoid repeating n^2 work
interp.codegen[ci] = uncompressed
codegen[ci] = uncompressed
if bootstrapping_compiler && inferred_result == nothing
# This is necessary to get decent bootstrapping performance
# when compiling the compiler to inject everything eagerly
Expand Down Expand Up @@ -184,8 +185,9 @@ function finish!(interp::AbstractInterpreter, mi::MethodInstance, ci::CodeInstan
ccall(:jl_update_codeinst, Cvoid, (Any, Any, Int32, UInt, UInt, UInt32, Any, Any, Any),
ci, nothing, const_flag, min_world, max_world, ipo_effects, nothing, di, edges)
code_cache(interp)[mi] = ci
if isdefined(interp, :codegen)
interp.codegen[ci] = src
codegen = codegen_cache(interp)
if codegen !== nothing
codegen[ci] = src
end
engine_reject(interp, ci)
return nothing
Expand Down Expand Up @@ -1167,7 +1169,10 @@ function typeinf_ext(interp::AbstractInterpreter, mi::MethodInstance, source_mod

ci = result.ci # reload from result in case it changed
@assert frame.cache_mode != CACHE_MODE_NULL
@assert is_result_constabi_eligible(result) || (!isdefined(interp, :codegen) || haskey(interp.codegen, ci))
@assert is_result_constabi_eligible(result) || begin
codegen = codegen_cache(interp)
codegen === nothing || haskey(codegen, ci)
end
@assert is_result_constabi_eligible(result) == use_const_api(ci)
@assert isdefined(ci, :inferred) "interpreter did not fulfill our expectations"
if !is_cached(frame) && source_mode == SOURCE_MODE_ABI
Expand Down Expand Up @@ -1233,44 +1238,54 @@ function collectinvokes!(wq::Vector{CodeInstance}, ci::CodeInfo)
end
end

# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
interp = NativeInterpreter(world)
ci = typeinf_ext(interp, mi, source_mode)
if source_mode == SOURCE_MODE_ABI && ci isa CodeInstance && !ci_has_invoke(ci)
inspected = IdSet{CodeInstance}()
tocompile = Vector{CodeInstance}()
push!(tocompile, ci)
while !isempty(tocompile)
# ci_has_real_invoke(ci) && return ci # optimization: cease looping if ci happens to get compiled (not just jl_fptr_wait_for_compiled, but fully jl_is_compiled_codeinst)
callee = pop!(tocompile)
ci_has_invoke(callee) && continue
callee in inspected && continue
src = get(interp.codegen, callee, nothing)
function add_codeinsts_to_jit!(interp::AbstractInterpreter, ci, source_mode::UInt8)
source_mode == SOURCE_MODE_ABI || return
ci isa CodeInstance && !ci_has_invoke(ci) || return
codegen = codegen_cache(interp)
codegen === nothing && return
inspected = IdSet{CodeInstance}()
tocompile = Vector{CodeInstance}()
push!(tocompile, ci)
while !isempty(tocompile)
# ci_has_real_invoke(ci) && return ci # optimization: cease looping if ci happens to get compiled (not just jl_fptr_wait_for_compiled, but fully jl_is_compiled_codeinst)
callee = pop!(tocompile)
ci_has_invoke(callee) && continue
callee in inspected && continue
src = get(codegen, callee, nothing)
if !isa(src, CodeInfo)
src = @atomic :monotonic callee.inferred
if isa(src, String)
src = _uncompressed_ir(callee, src)
end
if !isa(src, CodeInfo)
src = @atomic :monotonic callee.inferred
if isa(src, String)
src = _uncompressed_ir(callee, src)
end
if !isa(src, CodeInfo)
newcallee = typeinf_ext(interp, callee.def, source_mode)
if newcallee isa CodeInstance
callee === ci && (ci = newcallee) # ci stopped meeting the requirements after typeinf_ext last checked, try again with newcallee
push!(tocompile, newcallee)
#else
# println("warning: could not get source code for ", callee.def)
end
continue
newcallee = typeinf_ext(interp, callee.def, source_mode)
if newcallee isa CodeInstance
callee === ci && (ci = newcallee) # ci stopped meeting the requirements after typeinf_ext last checked, try again with newcallee
push!(tocompile, newcallee)
#else
# println("warning: could not get source code for ", callee.def)
end
continue
end
push!(inspected, callee)
collectinvokes!(tocompile, src)
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), callee, src)
end
push!(inspected, callee)
collectinvokes!(tocompile, src)
ccall(:jl_add_codeinst_to_jit, Cvoid, (Any, Any), callee, src)
end
end

function typeinf_ext_toplevel(interp::AbstractInterpreter, mi::MethodInstance, source_mode::UInt8)
ci = typeinf_ext(interp, mi, source_mode)
add_codeinsts_to_jit!(interp, ci, source_mode)
return ci
end

# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
interp = NativeInterpreter(world)
return typeinf_ext_toplevel(interp, mi, source_mode)
end

# This is a bridge for the C code calling `jl_typeinf_func()` on set of Method matches
function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim::Bool)
inspected = IdSet{CodeInstance}()
Expand Down
17 changes: 17 additions & 0 deletions Compiler/src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ the following methods to satisfy the `AbstractInterpreter` API requirement:
- `get_inference_world(interp::NewInterpreter)` - return the world age for this interpreter
- `get_inference_cache(interp::NewInterpreter)` - return the local inference cache
- `cache_owner(interp::NewInterpreter)` - return the owner of any new cache entries

If `CodeInstance`s compiled using `interp::NewInterpreter` are meant to be executed with `invoke`,
a method `codegen_cache(interp::NewInterpreter) -> IdDict{CodeInstance, CodeInfo}` must be defined,
and inference must be triggered via `typeinf_ext_toplevel` with source mode `SOURCE_MODE_ABI`.
"""
abstract type AbstractInterpreter end

Expand Down Expand Up @@ -430,6 +434,19 @@ to incorporate customized dispatches for the overridden methods.
method_table(interp::AbstractInterpreter) = InternalMethodTable(get_inference_world(interp))
method_table(interp::NativeInterpreter) = interp.method_table

"""
codegen_cache(interp::AbstractInterpreter) -> Union{Nothing, IdDict{CodeInstance, CodeInfo}}

Optionally return a cache associating a `CodeInfo` to a `CodeInstance` that should be added to the JIT
for future execution via `invoke(f, ::CodeInstance, args...)`. This cache is used during `typeinf_ext_toplevel`,
and may be safely discarded between calls to this function.

By default, a value of `nothing` is returned indicating that `CodeInstance`s should not be added to the JIT.
Attempting to execute them via `invoke` will result in an error.
"""
codegen_cache(interp::AbstractInterpreter) = nothing
codegen_cache(interp::NativeInterpreter) = interp.codegen

"""
By default `AbstractInterpreter` implements the following inference bail out logic:
- `bail_out_toplevel_call(::AbstractInterpreter, sig, ::InferenceState)`: bail out from
Expand Down
14 changes: 14 additions & 0 deletions Compiler/test/AbstractInterpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -534,3 +534,17 @@ let interp = DebugInterp()
end
@test found
end

@newinterp InvokeInterp
struct InvokeOwner end
codegen = IdDict{CodeInstance, CodeInfo}()
Compiler.cache_owner(::InvokeInterp) = InvokeOwner()
Compiler.codegen_cache(::InvokeInterp) = codegen
let interp = InvokeInterp()
source_mode = Compiler.SOURCE_MODE_ABI
f = (+)
args = (1, 1)
mi = @ccall jl_method_lookup(Any[f, args...]::Ptr{Any}, (1+length(args))::Csize_t, Base.tls_world_age()::Csize_t)::Ref{Core.MethodInstance}
ci = Compiler.typeinf_ext_toplevel(interp, mi, source_mode)
@test invoke(f, ci, args...) == 2
end