Skip to content

Commit 0177dc5

Browse files
authored
Import contained RequireImpls when importing an Interface or NamedConstraint (#6344)
When importing an Interface or NamedConstraint, walk the block of `RequireImplsId`s, and for each one: - Import the RequireImplsDecl from it, which also imports the `RequireImpls` structure and its id. - Collect those decls and build a block of `RequireImplsId`s for the local SemIR to reference from the Interface or NamedConstraint. The import of RequireImplsDecl is done in a single phase instead of three, unlike other decls. This is possible since require declarations have no name, so they can't be referenced by instructions inside them, thus there's no cycles to concern ourselves with.
1 parent b300f36 commit 0177dc5

File tree

16 files changed

+429
-152
lines changed

16 files changed

+429
-152
lines changed

toolchain/check/handle_require.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,6 @@ auto HandleParseNode(Context& context, Parse::RequireDeclId node_id) -> bool {
253253
.facet_type_id = constraint_facet_type.facet_type_id,
254254
.decl_id = decl_id,
255255
.parent_scope_id = context.scope_stack().PeekNameScopeId(),
256-
.body_block_id = decl_block_id,
257256
.generic_id = BuildGenericDecl(context, decl_id)});
258257

259258
require_impls_decl.require_impls_id = require_impls_id;

toolchain/check/import_ref.cpp

Lines changed: 166 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
#include "toolchain/sem_ir/inst_kind.h"
3333
#include "toolchain/sem_ir/name_scope.h"
3434
#include "toolchain/sem_ir/specific_interface.h"
35+
#include "toolchain/sem_ir/specific_named_constraint.h"
3536
#include "toolchain/sem_ir/type_info.h"
3637
#include "toolchain/sem_ir/typed_insts.h"
3738

@@ -212,6 +213,12 @@ class ImportContext {
212213
auto import_name_scopes() -> const SemIR::NameScopeStore& {
213214
return import_ir().name_scopes();
214215
}
216+
auto import_require_impls() -> const SemIR::RequireImplsStore& {
217+
return import_ir().require_impls();
218+
}
219+
auto import_require_impls_blocks() -> const SemIR::RequireImplsBlockStore& {
220+
return import_ir().require_impls_blocks();
221+
}
215222
auto import_specifics() -> const SemIR::SpecificStore& {
216223
return import_ir().specifics();
217224
}
@@ -288,6 +295,12 @@ class ImportContext {
288295
auto local_name_scopes() -> SemIR::NameScopeStore& {
289296
return local_ir().name_scopes();
290297
}
298+
auto local_require_impls() -> SemIR::RequireImplsStore& {
299+
return local_ir().require_impls();
300+
}
301+
auto local_require_impls_blocks() -> SemIR::RequireImplsBlockStore& {
302+
return local_ir().require_impls_blocks();
303+
}
291304
auto local_specifics() -> SemIR::SpecificStore& {
292305
return local_ir().specifics();
293306
}
@@ -663,9 +676,7 @@ static auto GetLocalInstBlockContents(ImportRefResolver& resolver,
663676
auto import_block = resolver.import_inst_blocks().Get(import_block_id);
664677
inst_ids.reserve(import_block.size());
665678
for (auto import_inst_id : import_block) {
666-
auto const_id = GetLocalConstantId(resolver, import_inst_id);
667-
inst_ids.push_back(
668-
resolver.local_constant_values().GetInstIdIfValid(const_id));
679+
inst_ids.push_back(GetLocalConstantInstId(resolver, import_inst_id));
669680
}
670681

671682
return inst_ids;
@@ -684,6 +695,56 @@ static auto GetLocalCanonicalInstBlockId(ImportContext& context,
684695
return context.local_inst_blocks().AddCanonical(contents);
685696
}
686697

698+
// Imports the RequireImplsDecl instructions for each RequireImplsId in the
699+
// block, and gets the local RequireImplsIds from them. The returned vector is
700+
// only complete if there is no more work to do in the resolver on return.
701+
static auto GetLocalRequireImplsBlockContents(
702+
ImportRefResolver& resolver, SemIR::RequireImplsBlockId import_block_id)
703+
-> llvm::SmallVector<SemIR::RequireImplsId> {
704+
llvm::SmallVector<SemIR::RequireImplsId> require_decl_ids;
705+
if (!import_block_id.has_value() ||
706+
import_block_id == SemIR::RequireImplsBlockId::Empty) {
707+
return require_decl_ids;
708+
}
709+
710+
// Import the RequireImplsDecl for each RequireImpls in the block.
711+
auto import_block =
712+
resolver.import_require_impls_blocks().Get(import_block_id);
713+
require_decl_ids.reserve(import_block.size());
714+
for (auto import_require_impls_id : import_block) {
715+
const auto& import_require =
716+
resolver.import_require_impls().Get(import_require_impls_id);
717+
auto local_decl_id =
718+
GetLocalConstantInstId(resolver, import_require.decl_id);
719+
// If `local_decl_id` is None, the resolver will have more work to do, and
720+
// we will call this function to try get all the decl instructions again.
721+
if (local_decl_id.has_value()) {
722+
// Importing the RequireImplsDecl instruction in `local_decl_id` also
723+
// imported the RequireImpls structure that it points to through the
724+
// RequireImplsId.
725+
require_decl_ids.push_back(
726+
resolver.local_insts()
727+
.GetAs<SemIR::RequireImplsDecl>(local_decl_id)
728+
.require_impls_id);
729+
}
730+
}
731+
732+
return require_decl_ids;
733+
}
734+
735+
// Gets the local block of RequireImplsIds from the imported block. Only valid
736+
// to call once there is no more work to do after the call to
737+
// GetLocalRequireImplsBlockContents().
738+
static auto GetLocalCanonicalRequireImplsBlockId(
739+
ImportContext& context, SemIR::RequireImplsBlockId import_block_id,
740+
llvm::ArrayRef<SemIR::RequireImplsId> contents)
741+
-> SemIR::RequireImplsBlockId {
742+
if (!import_block_id.has_value()) {
743+
return SemIR::RequireImplsBlockId::None;
744+
}
745+
return context.local_require_impls_blocks().Add(contents);
746+
}
747+
687748
// Gets a local instruction block containing ImportRefs referring to the
688749
// instructions in the specified imported instruction block.
689750
static auto GetLocalImportRefInstBlock(ImportContext& context,
@@ -892,7 +953,8 @@ static auto GetLocalSpecificInterface(
892953
if (auto facet_type = interface_const_inst.TryAs<SemIR::FacetType>()) {
893954
const SemIR::FacetTypeInfo& new_facet_type_info =
894955
context.local_facet_types().Get(facet_type->facet_type_id);
895-
return *new_facet_type_info.TryAsSingleInterface();
956+
return std::get<SemIR::SpecificInterface>(
957+
*new_facet_type_info.TryAsSingleExtend());
896958
} else {
897959
auto generic_interface_type =
898960
context.local_types().GetAs<SemIR::GenericInterfaceType>(
@@ -923,13 +985,23 @@ static auto GetLocalNameScopeIdImpl(ImportRefResolver& resolver,
923985
case CARBON_KIND(SemIR::FacetType inst): {
924986
const SemIR::FacetTypeInfo& facet_type_info =
925987
resolver.local_facet_types().Get(inst.facet_type_id);
926-
// This is specifically the facet type produced by an interface
927-
// declaration, and so should consist of a single interface.
928-
// TODO: Will also have to handle named constraints here, once those are
929-
// implemented.
930-
auto interface = facet_type_info.TryAsSingleInterface();
931-
CARBON_CHECK(interface);
932-
return resolver.local_interfaces().Get(interface->interface_id).scope_id;
988+
if (auto single = facet_type_info.TryAsSingleExtend()) {
989+
// This is the facet type produced by an interface or named constraint
990+
// declaration.
991+
CARBON_KIND_SWITCH(*single) {
992+
case CARBON_KIND(SemIR::SpecificInterface interface): {
993+
return resolver.local_interfaces()
994+
.Get(interface.interface_id)
995+
.scope_id;
996+
}
997+
case CARBON_KIND(SemIR::SpecificNamedConstraint constraint): {
998+
return resolver.local_named_constraints()
999+
.Get(constraint.named_constraint_id)
1000+
.scope_id;
1001+
}
1002+
}
1003+
}
1004+
break;
9331005
}
9341006
case CARBON_KIND(SemIR::ImplDecl inst): {
9351007
return resolver.local_impls().Get(inst.impl_id).scope_id;
@@ -2292,6 +2364,64 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
22922364
return ResolveResult::Done(impl_const_id, new_impl.first_decl_id());
22932365
}
22942366

2367+
static auto TryResolveTypedInst(ImportRefResolver& resolver,
2368+
SemIR::RequireImplsDecl inst) -> ResolveResult {
2369+
const auto& import_require =
2370+
resolver.import_require_impls().Get(inst.require_impls_id);
2371+
2372+
// Load dependent constants.
2373+
auto parent_scope_id =
2374+
GetLocalNameScopeId(resolver, import_require.parent_scope_id);
2375+
auto generic_data = GetLocalGenericData(resolver, import_require.generic_id);
2376+
auto self_const_id = GetLocalConstantId(resolver, import_require.self_id);
2377+
auto facet_type_const_id =
2378+
GetLocalConstantId(resolver, import_require.facet_type_inst_id);
2379+
2380+
if (resolver.HasNewWork()) {
2381+
return ResolveResult::Retry();
2382+
}
2383+
2384+
// Make the decl and structure with placeholder values to be filled in.
2385+
SemIR::RequireImplsDecl require_decl = {
2386+
.require_impls_id = SemIR::RequireImplsId::None,
2387+
.decl_block_id = SemIR::InstBlockId::Empty};
2388+
auto require_decl_id = AddPlaceholderImportedInst(
2389+
resolver, import_require.decl_id, require_decl);
2390+
auto require_impls_id = resolver.local_require_impls().Add(
2391+
{.self_id = SemIR::TypeInstId::None,
2392+
.facet_type_inst_id = SemIR::TypeInstId::None,
2393+
.facet_type_id = SemIR::FacetTypeId::None,
2394+
.decl_id = require_decl_id,
2395+
.parent_scope_id = SemIR::NameScopeId::None,
2396+
.generic_id = MakeIncompleteGeneric(resolver, require_decl_id,
2397+
import_require.generic_id)});
2398+
2399+
// Write the RequireImplsId into the RequireImplsDecl.
2400+
require_decl.require_impls_id = require_impls_id;
2401+
auto require_decl_const_id =
2402+
ReplacePlaceholderImportedInst(resolver, require_decl_id, require_decl);
2403+
2404+
// Fill in the RequireImpls structure.
2405+
auto& new_require = resolver.local_require_impls().Get(require_impls_id);
2406+
new_require.self_id = AddLoadedImportRefForType(
2407+
resolver, import_require.self_id, self_const_id);
2408+
new_require.facet_type_inst_id = AddLoadedImportRefForType(
2409+
resolver, import_require.facet_type_inst_id, facet_type_const_id);
2410+
auto new_canonical_facet_type_inst_id =
2411+
resolver.local_constant_values().GetConstantInstId(
2412+
new_require.facet_type_inst_id);
2413+
auto new_canonical_facet_type =
2414+
resolver.local_insts().GetAs<SemIR::FacetType>(
2415+
new_canonical_facet_type_inst_id);
2416+
new_require.facet_type_id = new_canonical_facet_type.facet_type_id;
2417+
new_require.parent_scope_id = parent_scope_id;
2418+
2419+
SetGenericData(resolver, import_require.generic_id, new_require.generic_id,
2420+
generic_data);
2421+
2422+
return ResolveResult::Done(require_decl_const_id, require_decl_id);
2423+
}
2424+
22952425
static auto TryResolveTypedInst(ImportRefResolver& resolver,
22962426
SemIR::ImportRefLoaded /*inst*/,
22972427
SemIR::InstId inst_id) -> ResolveResult {
@@ -2403,9 +2533,9 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
24032533
if (auto facet_type = interface_const_inst.TryAs<SemIR::FacetType>()) {
24042534
const SemIR::FacetTypeInfo& facet_type_info =
24052535
resolver.local_facet_types().Get(facet_type->facet_type_id);
2406-
auto interface_type = facet_type_info.TryAsSingleInterface();
2407-
CARBON_CHECK(interface_type);
2408-
interface_id = interface_type->interface_id;
2536+
auto single = facet_type_info.TryAsSingleExtend();
2537+
CARBON_CHECK(single);
2538+
interface_id = std::get<SemIR::SpecificInterface>(*single).interface_id;
24092539
} else {
24102540
auto generic_interface_type =
24112541
resolver.local_types().GetAs<SemIR::GenericInterfaceType>(
@@ -2422,6 +2552,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
24222552
GetLocalInstBlockContents(resolver, import_interface.param_patterns_id);
24232553
auto generic_data =
24242554
GetLocalGenericData(resolver, import_interface.generic_id);
2555+
auto require_impls = GetLocalRequireImplsBlockContents(
2556+
resolver, import_interface.require_impls_block_id);
24252557

24262558
std::optional<SemIR::InstId> self_param_id;
24272559
if (import_interface.is_complete()) {
@@ -2441,7 +2573,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
24412573
implicit_param_patterns);
24422574
new_interface.param_patterns_id = GetLocalCanonicalInstBlockId(
24432575
resolver, import_interface.param_patterns_id, param_patterns);
2444-
// TODO: Import require_impls_block_id.
2576+
new_interface.require_impls_block_id = GetLocalCanonicalRequireImplsBlockId(
2577+
resolver, import_interface.require_impls_block_id, require_impls);
24452578
SetGenericData(resolver, import_interface.generic_id,
24462579
new_interface.generic_id, generic_data);
24472580

@@ -2555,12 +2688,10 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
25552688
named_constraint_const_inst.TryAs<SemIR::FacetType>()) {
25562689
const SemIR::FacetTypeInfo& facet_type_info =
25572690
resolver.local_facet_types().Get(facet_type->facet_type_id);
2558-
CARBON_CHECK(facet_type_info.extend_named_constraints.size() == 1);
2559-
CARBON_CHECK(facet_type_info.extend_constraints.empty());
2560-
CARBON_CHECK(facet_type_info.self_impls_constraints.empty());
2561-
CARBON_CHECK(facet_type_info.self_impls_named_constraints.empty());
2691+
auto single = facet_type_info.TryAsSingleExtend();
2692+
CARBON_CHECK(single);
25622693
named_constraint_id =
2563-
facet_type_info.extend_named_constraints.front().named_constraint_id;
2694+
std::get<SemIR::SpecificNamedConstraint>(*single).named_constraint_id;
25642695
} else {
25652696
auto generic_named_constraint_type =
25662697
resolver.local_types().GetAs<SemIR::GenericNamedConstraintType>(
@@ -2577,6 +2708,8 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
25772708
resolver, import_named_constraint.param_patterns_id);
25782709
auto generic_data =
25792710
GetLocalGenericData(resolver, import_named_constraint.generic_id);
2711+
auto require_impls = GetLocalRequireImplsBlockContents(
2712+
resolver, import_named_constraint.require_impls_block_id);
25802713

25812714
std::optional<SemIR::InstId> self_param_id;
25822715
if (import_named_constraint.is_complete()) {
@@ -2598,7 +2731,10 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
25982731
implicit_param_patterns);
25992732
new_named_constraint.param_patterns_id = GetLocalCanonicalInstBlockId(
26002733
resolver, import_named_constraint.param_patterns_id, param_patterns);
2601-
// TODO: Import require_impls_block_id.
2734+
new_named_constraint.require_impls_block_id =
2735+
GetLocalCanonicalRequireImplsBlockId(
2736+
resolver, import_named_constraint.require_impls_block_id,
2737+
require_impls);
26022738
SetGenericData(resolver, import_named_constraint.generic_id,
26032739
import_named_constraint.generic_id, generic_data);
26042740

@@ -2644,6 +2780,7 @@ static auto TryResolveTypedInst(ImportRefResolver& resolver,
26442780
GetLocalConstantInstId(resolver, rewrite.lhs_id);
26452781
GetLocalConstantInstId(resolver, rewrite.rhs_id);
26462782
}
2783+
// TODO: Import named constraints in the facet type.
26472784
if (resolver.HasNewWork()) {
26482785
return ResolveResult::Retry();
26492786
}
@@ -3373,6 +3510,9 @@ static auto TryResolveInstCanonical(ImportRefResolver& resolver,
33733510
case CARBON_KIND(SemIR::RequireCompleteType inst): {
33743511
return TryResolveTypedInst(resolver, inst);
33753512
}
3513+
case CARBON_KIND(SemIR::RequireImplsDecl inst): {
3514+
return TryResolveTypedInst(resolver, inst);
3515+
}
33763516
case CARBON_KIND(SemIR::ReturnSlotPattern inst): {
33773517
return TryResolveTypedInst(resolver, inst, constant_inst_id);
33783518
}
@@ -3918,12 +4058,12 @@ auto ImportInterface(Context& context, SemIR::ImportIRId import_ir_id,
39184058
// A non-generic interface will import as a facet type for that single
39194059
// interface.
39204060
if (auto facet_type = local_inst.TryAs<SemIR::FacetType>()) {
3921-
auto interface = context.facet_types()
3922-
.Get(facet_type->facet_type_id)
3923-
.TryAsSingleInterface();
3924-
CARBON_CHECK(interface,
4061+
auto single = context.facet_types()
4062+
.Get(facet_type->facet_type_id)
4063+
.TryAsSingleExtend();
4064+
CARBON_CHECK(single,
39254065
"Importing an interface didn't produce a single interface");
3926-
return interface->interface_id;
4066+
return std::get<SemIR::SpecificInterface>(*single).interface_id;
39274067
}
39284068

39294069
// A generic interface will import as a constant of generic interface type.

0 commit comments

Comments
 (0)