Skip to content

Commit

Permalink
Add SemIR Vtable instruction and usage (#4732)
Browse files Browse the repository at this point in the history
Add a Vtable typed inst with a type_id (of the type this vtable applies
to) and list of virtual function decls (or import refs to function
object constants).

This doesn't add lowering/emission of the vtable, or usage when
initializing objects of the type.

Some questions in case they're interesting to discuss:
* is it right/worth having the type_id in the vtable? (probably makes it
easier to emit - using the type to get the class name to figure out the
mangled name for the vtable) perhaps it should be a ClassId?
* I'm thinking the logic in CheckCompleteClassType could be the place we
handle diagnostics for mismatched keywords (virtual/abstract for a
function that's already virtual/abstract, maybe checking for non-virtual
functions with the same name in a base class, or derived class functions
without `impl`, etc) - but we could move some of that to the moment we
walk the function decl, and record our findings in the function decl
(record the base function it overrides, or the index of the vtable to
slot to use when building the vtable at the end of the class)
* the Vtable typed inst has `constant_kind = InstConstantKind::Always`
and `is_lowered = false`, I think I added that in to workaround/address
some failures in lowering. And seems correct for this intermediate step
- I'll add lowering in a follow-up patch. But the constant_kind - what
should this be? We can just say all vtables are of VtableType (in which
case the `Always` constant kind sounds right to me) or we could have
them introduce a type with each virtual function as a named member,
even?

---------

Co-authored-by: Richard Smith <[email protected]>
Co-authored-by: Jon Ross-Perkins <[email protected]>
  • Loading branch information
3 people authored Jan 16, 2025
1 parent e348119 commit a8b46cf
Show file tree
Hide file tree
Showing 247 changed files with 391 additions and 3,458 deletions.
1 change: 1 addition & 0 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Context::Context(DiagnosticEmitter* emitter,
args_type_info_stack_("args_type_info_stack_", *sem_ir, vlog_stream),
decl_name_stack_(this),
scope_stack_(sem_ir_->identifiers()),
vtable_stack_("vtable_stack_", *sem_ir, vlog_stream),
global_init_(this) {
// Prepare fields which relate to the number of IRs available for import.
import_irs().Reserve(imported_ir_count);
Expand Down
6 changes: 6 additions & 0 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ class Context {
return generic_region_stack_;
}

auto vtable_stack() -> InstBlockStack& { return vtable_stack_; }

auto import_ir_constant_values()
-> llvm::SmallVector<SemIR::ConstantValueStore, 0>& {
return import_ir_constant_values_;
Expand Down Expand Up @@ -785,6 +787,10 @@ class Context {
// The stack of generic regions we are currently within.
GenericRegionStack generic_region_stack_;

// Contains a vtable block for each `class` scope which is currently being
// defined, regardless of whether the class can have virtual functions.
InstBlockStack vtable_stack_;

// Cache of reverse mapping from type constants to types.
//
// TODO: Instead of mapping to a dense `TypeId` space, we could make `TypeId`
Expand Down
3 changes: 3 additions & 0 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,9 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
case SemIR::TupleInit::Kind:
return RebuildInitAsValue(eval_context, inst, SemIR::TupleValue::Kind);

case SemIR::Vtable::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::Vtable::virtual_functions_id);
case SemIR::AutoType::Kind:
case SemIR::BoolType::Kind:
case SemIR::BoundMethodType::Kind:
Expand Down
55 changes: 53 additions & 2 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include "toolchain/check/eval.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/handle.h"
#include "toolchain/check/import_ref.h"
#include "toolchain/check/merge.h"
#include "toolchain/check/modifiers.h"
#include "toolchain/check/name_component.h"
#include "toolchain/sem_ir/function.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/typed_insts.h"
Expand Down Expand Up @@ -299,6 +301,7 @@ auto HandleParseNode(Context& context, Parse::ClassDefinitionStartId node_id)
context.inst_block_stack().Push();
context.node_stack().Push(node_id, class_id);
context.field_decls_stack().PushArray();
context.vtable_stack().Push();

// TODO: Handle the case where there's control flow in the class body. For
// example:
Expand Down Expand Up @@ -663,11 +666,12 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
bool defining_vptr = class_info.is_dynamic;
auto base_type_id =
class_info.GetBaseType(context.sem_ir(), SemIR::SpecificId::Invalid);
SemIR::Class* base_class_info = nullptr;
if (base_type_id.is_valid()) {
// TODO: If the base class is template dependent, we will need to decide
// whether to add a vptr as part of instantiation.
if (auto* base_class_info = TryGetAsClass(context, base_type_id);
base_class_info && base_class_info->is_dynamic) {
base_class_info = TryGetAsClass(context, base_type_id);
if (base_class_info && base_class_info->is_dynamic) {
defining_vptr = false;
}
}
Expand All @@ -691,6 +695,52 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
{.name_id = SemIR::NameId::Base, .type_id = base_type_id});
}

if (class_info.is_dynamic) {
llvm::SmallVector<SemIR::InstId> vtable;
if (!defining_vptr) {
LoadImportRef(context, base_class_info->vtable_id);
auto base_vtable_id = context.constant_values().GetConstantInstId(
base_class_info->vtable_id);
auto base_vtable_inst_block =
context.inst_blocks().Get(context.insts()
.GetAs<SemIR::Vtable>(base_vtable_id)
.virtual_functions_id);
// TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
// elements of the top of `vtable_stack`.
for (auto fn_decl_id : base_vtable_inst_block) {
auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
const auto& fn = context.functions().Get(fn_decl.function_id);
for (auto override_fn_decl_id :
context.vtable_stack().PeekCurrentBlockContents()) {
auto override_fn_decl =
context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
const auto& override_fn =
context.functions().Get(override_fn_decl.function_id);
// TODO: Validate that the overriding function's signature matches
// that of the overridden function.
if (override_fn.virtual_modifier ==
SemIR::FunctionFields::VirtualModifier::Impl &&
override_fn.name_id == fn.name_id) {
fn_decl_id = override_fn_decl_id;
}
}
vtable.push_back(fn_decl_id);
}
}

for (auto inst_id : context.vtable_stack().PeekCurrentBlockContents()) {
auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
const auto& fn = context.functions().Get(fn_decl.function_id);
if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
vtable.push_back(inst_id);
}
}
class_info.vtable_id = context.AddInst<SemIR::Vtable>(
node_id, {.type_id = context.GetSingletonType(
SemIR::VtableType::SingletonInstId),
.virtual_functions_id = context.inst_blocks().Add(vtable)});
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetSingletonType(SemIR::WitnessType::SingletonInstId),
Expand All @@ -711,6 +761,7 @@ auto HandleParseNode(Context& context, Parse::ClassDefinitionId node_id)

context.inst_block_stack().Pop();
context.field_decls_stack().PopArray();
context.vtable_stack().Pop();

FinishGenericDefinition(context, class_info.generic_id);

Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,19 @@ static auto BuildFunctionDecl(Context& context,
.Case(KeywordModifierSet::Impl,
SemIR::Function::VirtualModifier::Impl)
.Default(SemIR::Function::VirtualModifier::None);
SemIR::Class* virtual_class_info = nullptr;
if (virtual_modifier != SemIR::Function::VirtualModifier::None &&
parent_scope_inst) {
if (auto class_decl = parent_scope_inst->TryAs<SemIR::ClassDecl>()) {
auto& class_info = context.classes().Get(class_decl->class_id);
virtual_class_info = &context.classes().Get(class_decl->class_id);
if (virtual_modifier == SemIR::Function::VirtualModifier::Impl &&
!class_info.base_id.is_valid()) {
!virtual_class_info->base_id.is_valid()) {
CARBON_DIAGNOSTIC(ImplWithoutBase, Error, "impl without base class");
context.emitter().Build(node_id, ImplWithoutBase).Emit();
}
// TODO: If this is an `impl` function, check there's a matching base
// function that's impl or virtual.
class_info.is_dynamic = true;
virtual_class_info->is_dynamic = true;
}
}
if (introducer.modifier_set.HasAnyOf(KeywordModifierSet::Interface)) {
Expand All @@ -249,6 +250,9 @@ static auto BuildFunctionDecl(Context& context,
if (is_definition) {
function_info.definition_id = decl_id;
}
if (virtual_class_info) {
context.vtable_stack().AddInstId(decl_id);
}

if (name_context.state == DeclNameStack::NameContext::State::Poisoned) {
context.DiagnosePoisonedName(function_info.latest_decl_id());
Expand Down
34 changes: 31 additions & 3 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,24 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
inst_id);
}

static auto TryResolveTypedInst(ImportRefResolver& resolver, SemIR::Vtable inst,
SemIR::InstId /*import_inst_id*/)
-> ResolveResult {
auto type_const_id = GetLocalConstantId(resolver, inst.type_id);
auto virtual_functions =
GetLocalInstBlockContents(resolver, inst.virtual_functions_id);
if (resolver.HasNewWork()) {
return ResolveResult::Retry();
}

auto virtual_functions_id = GetLocalCanonicalInstBlockId(
resolver, inst.virtual_functions_id, virtual_functions);
return ResolveAs<SemIR::Vtable>(
resolver, {.type_id = resolver.local_context().GetTypeIdForTypeConstant(
type_const_id),
.virtual_functions_id = virtual_functions_id});
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
SemIR::BindAlias inst) -> ResolveResult {
auto value_id = GetLocalConstantId(resolver, inst.value_id);
Expand Down Expand Up @@ -1543,8 +1561,8 @@ static auto AddClassDefinition(ImportContext& context,
const SemIR::Class& import_class,
SemIR::Class& new_class,
SemIR::InstId complete_type_witness_id,
SemIR::InstId base_id, SemIR::InstId adapt_id)
-> void {
SemIR::InstId base_id, SemIR::InstId adapt_id,
SemIR::InstId vtable_id) -> void {
new_class.definition_id = new_class.first_owning_decl_id;

new_class.complete_type_witness_id = complete_type_witness_id;
Expand All @@ -1567,6 +1585,9 @@ static auto AddClassDefinition(ImportContext& context,
if (import_class.adapt_id.is_valid()) {
new_class.adapt_id = adapt_id;
}
if (import_class.vtable_id.is_valid()) {
new_class.vtable_id = vtable_id;
}
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1638,6 +1659,10 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
return ResolveResult::Retry(class_const_id, new_class.first_decl_id());
}

auto vtable_id = import_class.vtable_id.is_valid()
? AddImportRef(resolver, import_class.vtable_id)
: SemIR::InstId::Invalid;

new_class.parent_scope_id = parent_scope_id;
new_class.implicit_param_patterns_id = GetLocalParamPatternsId(
resolver, import_class.implicit_param_patterns_id);
Expand All @@ -1655,7 +1680,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
SemIR::WitnessType::SingletonInstId),
import_class.complete_type_witness_id, complete_type_witness_const_id);
AddClassDefinition(resolver, import_class, new_class,
complete_type_witness_id, base_id, adapt_id);
complete_type_witness_id, base_id, adapt_id, vtable_id);
}

return ResolveResult::Done(class_const_id, new_class.first_decl_id());
Expand Down Expand Up @@ -2717,6 +2742,9 @@ static auto TryResolveInstCanonical(ImportRefResolver& resolver,
case CARBON_KIND(SemIR::UnboundElementType inst): {
return TryResolveTypedInst(resolver, inst);
}
case CARBON_KIND(SemIR::Vtable inst): {
return TryResolveTypedInst(resolver, inst, inst_id);
}
default: {
// This instruction might have a constant value of a different kind.
auto constant_inst_id =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ let d: c = {};
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
19 changes: 0 additions & 19 deletions toolchain/check/testdata/alias/no_prelude/export_name.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ var d: D* = &c;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- export.carbon
Expand All @@ -106,7 +103,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.f42 = import_ref Main//base, C, unloaded
// CHECK:STDOUT: %import_ref.05a: type = import_ref Main//base, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8f2: <witness> = import_ref Main//base, loc4_10, loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.2c4 = import_ref Main//base, inst14 [no loc], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -120,9 +116,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "base.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8f2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.2c4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- export_orig.carbon
Expand All @@ -137,7 +130,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.3b0: type = import_ref Main//base, C, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.909 = import_ref Main//base, D, unloaded
// CHECK:STDOUT: %import_ref.8f2: <witness> = import_ref Main//base, loc4_10, loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.2c4 = import_ref Main//base, inst14 [no loc], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -151,9 +143,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "base.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8f2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.2c4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- use_export.carbon
Expand All @@ -168,7 +157,6 @@ var d: D* = &c;
// CHECK:STDOUT: imports {
// CHECK:STDOUT: %import_ref.c3f: type = import_ref Main//export, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8db: <witness> = import_ref Main//export, inst20 [indirect], loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.6a9 = import_ref Main//export, inst21 [indirect], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -183,9 +171,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "export.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8db
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.6a9
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down Expand Up @@ -238,7 +223,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.c3f: type = import_ref Main//export, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.06e: type = import_ref Main//export_orig, C, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8db: <witness> = import_ref Main//export_orig, inst20 [indirect], loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.6a9 = import_ref Main//export_orig, inst21 [indirect], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -257,9 +241,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "export_orig.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8db
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.6a9
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ let c_var: c = d;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @D {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%D
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,5 @@ extern alias C = Class;
// CHECK:STDOUT: class @Class {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%Class
// CHECK:STDOUT: }
// CHECK:STDOUT:
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ alias b = C;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Loading

0 comments on commit a8b46cf

Please sign in to comment.