Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add SemIR Vtable instruction and usage #4732

Merged
merged 19 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ Context::Context(DiagnosticEmitter* emitter,
args_type_info_stack_("args_type_info_stack_", *sem_ir, vlog_stream),
decl_name_stack_(this),
scope_stack_(sem_ir_->identifiers()),
vtable_stack_("vtable_stack_", *sem_ir, vlog_stream),
global_init_(this) {
// Prepare fields which relate to the number of IRs available for import.
import_irs().Reserve(imported_ir_count);
Expand Down
6 changes: 6 additions & 0 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ class Context {
return generic_region_stack_;
}

auto vtable_stack() -> InstBlockStack& { return vtable_stack_; }

auto import_ir_constant_values()
-> llvm::SmallVector<SemIR::ConstantValueStore, 0>& {
return import_ir_constant_values_;
Expand Down Expand Up @@ -785,6 +787,10 @@ class Context {
// The stack of generic regions we are currently within.
GenericRegionStack generic_region_stack_;

// Contains a vtable block for each `class` scope which is currently being
// defined, regardless of whether the class can have virtual functions.
InstBlockStack vtable_stack_;
dwblaikie marked this conversation as resolved.
Show resolved Hide resolved

// Cache of reverse mapping from type constants to types.
//
// TODO: Instead of mapping to a dense `TypeId` space, we could make `TypeId`
Expand Down
3 changes: 3 additions & 0 deletions toolchain/check/eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1664,6 +1664,9 @@ static auto TryEvalInstInContext(EvalContext& eval_context,
case SemIR::TupleInit::Kind:
return RebuildInitAsValue(eval_context, inst, SemIR::TupleValue::Kind);

case SemIR::Vtable::Kind:
return RebuildIfFieldsAreConstant(eval_context, inst,
&SemIR::Vtable::virtual_functions_id);
case SemIR::AutoType::Kind:
case SemIR::BoolType::Kind:
case SemIR::BoundMethodType::Kind:
Expand Down
55 changes: 53 additions & 2 deletions toolchain/check/handle_class.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
#include "toolchain/check/eval.h"
#include "toolchain/check/generic.h"
#include "toolchain/check/handle.h"
#include "toolchain/check/import_ref.h"
#include "toolchain/check/merge.h"
#include "toolchain/check/modifiers.h"
#include "toolchain/check/name_component.h"
#include "toolchain/sem_ir/function.h"
#include "toolchain/sem_ir/ids.h"
#include "toolchain/sem_ir/inst.h"
#include "toolchain/sem_ir/typed_insts.h"
Expand Down Expand Up @@ -299,6 +301,7 @@ auto HandleParseNode(Context& context, Parse::ClassDefinitionStartId node_id)
context.inst_block_stack().Push();
context.node_stack().Push(node_id, class_id);
context.field_decls_stack().PushArray();
context.vtable_stack().Push();

// TODO: Handle the case where there's control flow in the class body. For
// example:
Expand Down Expand Up @@ -663,11 +666,12 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
bool defining_vptr = class_info.is_dynamic;
auto base_type_id =
class_info.GetBaseType(context.sem_ir(), SemIR::SpecificId::Invalid);
SemIR::Class* base_class_info = nullptr;
if (base_type_id.is_valid()) {
// TODO: If the base class is template dependent, we will need to decide
// whether to add a vptr as part of instantiation.
if (auto* base_class_info = TryGetAsClass(context, base_type_id);
base_class_info && base_class_info->is_dynamic) {
base_class_info = TryGetAsClass(context, base_type_id);
if (base_class_info && base_class_info->is_dynamic) {
defining_vptr = false;
}
}
Expand All @@ -691,6 +695,52 @@ static auto CheckCompleteClassType(Context& context, Parse::NodeId node_id,
{.name_id = SemIR::NameId::Base, .type_id = base_type_id});
}

if (class_info.is_dynamic) {
llvm::SmallVector<SemIR::InstId> vtable;
if (!defining_vptr) {
LoadImportRef(context, base_class_info->vtable_id);
auto base_vtable_id = context.constant_values().GetConstantInstId(
base_class_info->vtable_id);
auto base_vtable_inst_block =
context.inst_blocks().Get(context.insts()
.GetAs<SemIR::Vtable>(base_vtable_id)
.virtual_functions_id);
// TODO: Avoid quadratic search. Perhaps build a map from `NameId` to the
// elements of the top of `vtable_stack`.
for (auto fn_decl_id : base_vtable_inst_block) {
auto fn_decl = GetCalleeFunction(context.sem_ir(), fn_decl_id);
const auto& fn = context.functions().Get(fn_decl.function_id);
for (auto override_fn_decl_id :
context.vtable_stack().PeekCurrentBlockContents()) {
auto override_fn_decl =
context.insts().GetAs<SemIR::FunctionDecl>(override_fn_decl_id);
const auto& override_fn =
context.functions().Get(override_fn_decl.function_id);
// TODO: Validate that the overriding function's signature matches
// that of the overridden function.
if (override_fn.virtual_modifier ==
SemIR::FunctionFields::VirtualModifier::Impl &&
override_fn.name_id == fn.name_id) {
fn_decl_id = override_fn_decl_id;
}
}
vtable.push_back(fn_decl_id);
dwblaikie marked this conversation as resolved.
Show resolved Hide resolved
}
}

for (auto inst_id : context.vtable_stack().PeekCurrentBlockContents()) {
auto fn_decl = context.insts().GetAs<SemIR::FunctionDecl>(inst_id);
const auto& fn = context.functions().Get(fn_decl.function_id);
if (fn.virtual_modifier != SemIR::FunctionFields::VirtualModifier::Impl) {
vtable.push_back(inst_id);
}
}
class_info.vtable_id = context.AddInst<SemIR::Vtable>(
node_id, {.type_id = context.GetSingletonType(
SemIR::VtableType::SingletonInstId),
.virtual_functions_id = context.inst_blocks().Add(vtable)});
}

return context.AddInst<SemIR::CompleteTypeWitness>(
node_id,
{.type_id = context.GetSingletonType(SemIR::WitnessType::SingletonInstId),
Expand All @@ -711,6 +761,7 @@ auto HandleParseNode(Context& context, Parse::ClassDefinitionId node_id)

context.inst_block_stack().Pop();
context.field_decls_stack().PopArray();
context.vtable_stack().Pop();

FinishGenericDefinition(context, class_info.generic_id);

Expand Down
10 changes: 7 additions & 3 deletions toolchain/check/handle_function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,18 +211,19 @@ static auto BuildFunctionDecl(Context& context,
.Case(KeywordModifierSet::Impl,
SemIR::Function::VirtualModifier::Impl)
.Default(SemIR::Function::VirtualModifier::None);
SemIR::Class* virtual_class_info = nullptr;
if (virtual_modifier != SemIR::Function::VirtualModifier::None &&
parent_scope_inst) {
if (auto class_decl = parent_scope_inst->TryAs<SemIR::ClassDecl>()) {
auto& class_info = context.classes().Get(class_decl->class_id);
virtual_class_info = &context.classes().Get(class_decl->class_id);
if (virtual_modifier == SemIR::Function::VirtualModifier::Impl &&
!class_info.base_id.is_valid()) {
!virtual_class_info->base_id.is_valid()) {
CARBON_DIAGNOSTIC(ImplWithoutBase, Error, "impl without base class");
context.emitter().Build(node_id, ImplWithoutBase).Emit();
}
// TODO: If this is an `impl` function, check there's a matching base
// function that's impl or virtual.
class_info.is_dynamic = true;
virtual_class_info->is_dynamic = true;
}
}
if (introducer.modifier_set.HasAnyOf(KeywordModifierSet::Interface)) {
Expand All @@ -249,6 +250,9 @@ static auto BuildFunctionDecl(Context& context,
if (is_definition) {
function_info.definition_id = decl_id;
}
if (virtual_class_info) {
context.vtable_stack().AddInstId(decl_id);
}

if (name_context.state == DeclNameStack::NameContext::State::Poisoned) {
context.DiagnosePoisonedName(function_info.latest_decl_id());
Expand Down
34 changes: 31 additions & 3 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1421,6 +1421,24 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
inst_id);
}

static auto TryResolveTypedInst(ImportRefResolver& resolver, SemIR::Vtable inst,
SemIR::InstId /*import_inst_id*/)
-> ResolveResult {
auto type_const_id = GetLocalConstantId(resolver, inst.type_id);
auto virtual_functions =
GetLocalInstBlockContents(resolver, inst.virtual_functions_id);
if (resolver.HasNewWork()) {
return ResolveResult::Retry();
}

auto virtual_functions_id = GetLocalCanonicalInstBlockId(
resolver, inst.virtual_functions_id, virtual_functions);
return ResolveAs<SemIR::Vtable>(
resolver, {.type_id = resolver.local_context().GetTypeIdForTypeConstant(
type_const_id),
.virtual_functions_id = virtual_functions_id});
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
SemIR::BindAlias inst) -> ResolveResult {
auto value_id = GetLocalConstantId(resolver, inst.value_id);
Expand Down Expand Up @@ -1543,8 +1561,8 @@ static auto AddClassDefinition(ImportContext& context,
const SemIR::Class& import_class,
SemIR::Class& new_class,
SemIR::InstId complete_type_witness_id,
SemIR::InstId base_id, SemIR::InstId adapt_id)
-> void {
SemIR::InstId base_id, SemIR::InstId adapt_id,
SemIR::InstId vtable_id) -> void {
new_class.definition_id = new_class.first_owning_decl_id;

new_class.complete_type_witness_id = complete_type_witness_id;
Expand All @@ -1567,6 +1585,9 @@ static auto AddClassDefinition(ImportContext& context,
if (import_class.adapt_id.is_valid()) {
new_class.adapt_id = adapt_id;
}
if (import_class.vtable_id.is_valid()) {
new_class.vtable_id = vtable_id;
}
}

static auto TryResolveTypedInst(ImportRefResolver& resolver,
Expand Down Expand Up @@ -1638,6 +1659,10 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
return ResolveResult::Retry(class_const_id, new_class.first_decl_id());
}

auto vtable_id = import_class.vtable_id.is_valid()
? AddImportRef(resolver, import_class.vtable_id)
: SemIR::InstId::Invalid;

new_class.parent_scope_id = parent_scope_id;
new_class.implicit_param_patterns_id = GetLocalParamPatternsId(
resolver, import_class.implicit_param_patterns_id);
Expand All @@ -1655,7 +1680,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
SemIR::WitnessType::SingletonInstId),
import_class.complete_type_witness_id, complete_type_witness_const_id);
AddClassDefinition(resolver, import_class, new_class,
complete_type_witness_id, base_id, adapt_id);
complete_type_witness_id, base_id, adapt_id, vtable_id);
}

return ResolveResult::Done(class_const_id, new_class.first_decl_id());
Expand Down Expand Up @@ -2717,6 +2742,9 @@ static auto TryResolveInstCanonical(ImportRefResolver& resolver,
case CARBON_KIND(SemIR::UnboundElementType inst): {
return TryResolveTypedInst(resolver, inst);
}
case CARBON_KIND(SemIR::Vtable inst): {
return TryResolveTypedInst(resolver, inst, inst_id);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why pass inst_id if it isn't going to be used?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thanks for the catch - removed the unused parameter in #4832

}
default: {
// This instruction might have a constant value of a different kind.
auto constant_inst_id =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,6 @@ let d: c = {};
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
19 changes: 0 additions & 19 deletions toolchain/check/testdata/alias/no_prelude/export_name.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,6 @@ var d: D* = &c;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- export.carbon
Expand All @@ -106,7 +103,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.f42 = import_ref Main//base, C, unloaded
// CHECK:STDOUT: %import_ref.05a: type = import_ref Main//base, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8f2: <witness> = import_ref Main//base, loc4_10, loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.2c4 = import_ref Main//base, inst14 [no loc], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -120,9 +116,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "base.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8f2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.2c4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- export_orig.carbon
Expand All @@ -137,7 +130,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.3b0: type = import_ref Main//base, C, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.909 = import_ref Main//base, D, unloaded
// CHECK:STDOUT: %import_ref.8f2: <witness> = import_ref Main//base, loc4_10, loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.2c4 = import_ref Main//base, inst14 [no loc], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -151,9 +143,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "base.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8f2
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.2c4
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: --- use_export.carbon
Expand All @@ -168,7 +157,6 @@ var d: D* = &c;
// CHECK:STDOUT: imports {
// CHECK:STDOUT: %import_ref.c3f: type = import_ref Main//export, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8db: <witness> = import_ref Main//export, inst20 [indirect], loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.6a9 = import_ref Main//export, inst21 [indirect], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -183,9 +171,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "export.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8db
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.6a9
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down Expand Up @@ -238,7 +223,6 @@ var d: D* = &c;
// CHECK:STDOUT: %import_ref.c3f: type = import_ref Main//export, D, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.06e: type = import_ref Main//export_orig, C, loaded [template = constants.%C]
// CHECK:STDOUT: %import_ref.8db: <witness> = import_ref Main//export_orig, inst20 [indirect], loaded [template = constants.%complete_type]
// CHECK:STDOUT: %import_ref.6a9 = import_ref Main//export_orig, inst21 [indirect], unloaded
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: file {
Expand All @@ -257,9 +241,6 @@ var d: D* = &c;
// CHECK:STDOUT:
// CHECK:STDOUT: class @C [from "export_orig.carbon"] {
// CHECK:STDOUT: complete_type_witness = imports.%import_ref.8db
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = imports.%import_ref.6a9
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,17 +48,11 @@ let c_var: c = d;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: class @D {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%D
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,5 @@ extern alias C = Class;
// CHECK:STDOUT: class @Class {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%Class
// CHECK:STDOUT: }
// CHECK:STDOUT:
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ alias b = C;
// CHECK:STDOUT: class @C {
// CHECK:STDOUT: %complete_type: <witness> = complete_type_witness %empty_struct_type [template = constants.%complete_type]
// CHECK:STDOUT: complete_type_witness = %complete_type
// CHECK:STDOUT:
// CHECK:STDOUT: !members:
// CHECK:STDOUT: .Self = constants.%C
// CHECK:STDOUT: }
// CHECK:STDOUT:
// CHECK:STDOUT: fn @__global_init() {
Expand Down
Loading
Loading