Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions toolchain/check/call.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,6 @@ auto PerformCallToFunction(Context& context, SemIR::LocId loc_id,
}

case SemIR::Function::SpecialFunctionKind::HasCppThunk: {
// This recurses back into `PerformCall`. However, we never form a C++
// thunk to a C++ thunk, so we only recurse once.
return PerformCppThunkCall(context, loc_id, callee_function.function_id,
context.inst_blocks().Get(converted_args_id),
callee.cpp_thunk_decl_id());
Expand Down Expand Up @@ -352,13 +350,15 @@ auto PerformCall(Context& context, SemIR::LocId loc_id, SemIR::InstId callee_id,
return SemIR::ErrorInst::InstId;
}
case CARBON_KIND(SemIR::CalleeFunction fn): {
context.ref_tags().Insert(fn.self_id, Context::RefTag::NotRequired);
return PerformCallToFunction(context, loc_id, callee_id, fn, arg_ids);
}
case CARBON_KIND(SemIR::CalleeNonFunction _): {
return PerformCallToNonFunction(context, loc_id, callee_id, arg_ids);
}

case CARBON_KIND(SemIR::CalleeCppOverloadSet overload): {
context.ref_tags().Insert(overload.self_id, Context::RefTag::NotRequired);
return PerformCallToCppFunction(context, loc_id,
overload.cpp_overload_set_id,
overload.self_id, arg_ids);
Expand Down
14 changes: 14 additions & 0 deletions toolchain/check/context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <utility>

#include "common/check.h"
#include "toolchain/base/kind_switch.h"
#include "toolchain/check/deferred_definition_worklist.h"
#include "toolchain/sem_ir/ids.h"

Expand Down Expand Up @@ -77,6 +78,19 @@ auto Context::VerifyOnFinish() const -> void {
CARBON_FATAL("{0}Built invalid semantics IR: {1}\n", sem_ir_,
verify.error());
}

if (!sem_ir_->has_errors()) {
auto ref_tags_needed = sem_ir_->CollectRefTagsNeeded();

ref_tags_.ForEach([&ref_tags_needed](SemIR::InstId inst_id, RefTag kind) {
CARBON_CHECK(
ref_tags_needed.Erase(inst_id) || kind == RefTag::NotRequired,
"Inst has unnecessary `ref` tag: {0}", inst_id);
});
ref_tags_needed.ForEach([this](SemIR::InstId inst_id) {
CARBON_FATAL("Inst missing `ref` tag: {0}", insts().Get(inst_id));
});
}
#endif
}

Expand Down
14 changes: 14 additions & 0 deletions toolchain/check/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,13 @@ class Context {
return var_storage_map_;
}

enum class RefTag { Present, NotRequired };

auto ref_tags() -> Map<SemIR::InstId, RefTag>& { return ref_tags_; }
auto ref_tags() const -> const Map<SemIR::InstId, RefTag>& {
return ref_tags_;
}

// During Choice typechecking, each alternative turns into a name binding on
// the Choice type, but this can't be done until the full Choice type is
// known. This represents each binding to be done at the end of checking the
Expand Down Expand Up @@ -430,6 +437,13 @@ class Context {
// processing the enclosing full-pattern.
Map<SemIR::InstId, SemIR::InstId> var_storage_map_;

// Insts in this map are syntactically permitted to be bound to a reference
// parameter, either because they've been explicitly tagged with `ref` in the
// source code, or because they appear in a position where that tag is not
// required, such as an operator operand (the RefTag value indicates which
// of those is the case).
Map<SemIR::InstId, RefTag> ref_tags_;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(@zygoloid I'm replying to your top-level comment here, for threading purposes)

How likely do you think it is that we retain the ref_tags map in the longer term? If we do, I think we should be thinking about ways to remove elements from it so it doesn't grow without bound, and also maybe thinking about avoiding adding all operands of operators to it -- currently I expect it'll get pretty large due to the number of operator invocations in a typical source file.

We could probably track whether we're in an "operand of an operator" context from call handling through into pattern matching and conversion in order to suppress the diagnostics for missing ref.

We'd have to plumb it through a lot of layers, but yes, that's an option. The main problem I see is that an operator usage and an explicit call to the corresponding interface method will emit identical SemIR, but the former cannot use a ref tag, whereas the latter might sometimes be required to use a ref tag (unless we're willing to categorically forbid ref parameters in the explicit parameter lists of operator methods, or say that explicit calls to operator methods don't need to (can't?) use ref tags).

So if ref_tags only tracks the places where ref tags are actually written in the code, as you suggest, we'd have to stop using the SemIR to validate ref_tags, as @chandlerc suggested, because it won't be possible to use the SemIR to reconstruct where the ref tags must have been.

I think this just brings us back to the question of whether we're modeling a syntactic property ("the user wrote a ref tag here", as you've been advocating for) or a semantic one ("this inst is used as the argument to a ref parameter", as @chandlerc has been advocating for), because the operator problem seems to rule out trying to finesse the difference. As you point out, the syntactic property would probably be substantially cheaper to represent, since it's much sparser in the code, which seems like evidence in favor of the syntactic approach, but I'm not sure I really understand the arguments in the other direction.


// Each alternative in a Choice gets an entry here, they are stored in
// declaration order. The vector is consumed and emptied at the end of the
// Choice definition.
Expand Down
45 changes: 44 additions & 1 deletion toolchain/check/convert.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,10 @@ static auto IsValidExprCategoryForConversionTarget(
category == SemIR::ExprCategory::DurableRef ||
category == SemIR::ExprCategory::EphemeralRef ||
category == SemIR::ExprCategory::Initializing;
case ConversionTarget::RefParam:
return category == SemIR::ExprCategory::DurableRef ||
category == SemIR::ExprCategory::EphemeralRef ||
category == SemIR::ExprCategory::Initializing;
case ConversionTarget::DurableRef:
return category == SemIR::ExprCategory::DurableRef;
case ConversionTarget::CppThunkRef:
Expand Down Expand Up @@ -1365,6 +1369,28 @@ auto PerformAction(Context& context, SemIR::LocId loc_id,
action.target_type_inst_id)});
}

// Diagnoses a missing or unnecessary `ref` tag when converting `expr_id` to
// `target`, and returns whether a `ref` tag is present.
static auto CheckRefTag(Context& context, SemIR::InstId expr_id,
ConversionTarget target) -> bool {
if (auto lookup_result = context.ref_tags().Lookup(expr_id)) {
if (lookup_result.value() == Context::RefTag::Present &&
target.kind != ConversionTarget::RefParam) {
CARBON_DIAGNOSTIC(RefTagNoRefParam, Error,
"`ref` tag is not an argument to a `ref` parameter");
context.emitter().Emit(expr_id, RefTagNoRefParam);
}
return true;
} else {
if (target.kind == ConversionTarget::RefParam) {
CARBON_DIAGNOSTIC(RefParamNoRefTag, Error,
"argument to `ref` parameter not marked with `ref`");
context.emitter().Emit(expr_id, RefParamNoRefTag);
}
return false;
}
}

auto Convert(Context& context, SemIR::LocId loc_id, SemIR::InstId expr_id,
ConversionTarget target, SemIR::ClassType* vtable_class_type)
-> SemIR::InstId {
Expand All @@ -1390,6 +1416,8 @@ auto Convert(Context& context, SemIR::LocId loc_id, SemIR::InstId expr_id,
return SemIR::ErrorInst::InstId;
}

bool has_ref_tag = CheckRefTag(context, expr_id, target);

// We can only perform initialization for complete, non-abstract types. Note
// that `RequireConcreteType` returns true for facet types, since their
// representation is fixed. This allows us to support using the `Self` of an
Expand Down Expand Up @@ -1520,6 +1548,9 @@ auto Convert(Context& context, SemIR::LocId loc_id, SemIR::InstId expr_id,
{.type_id = target.type_id,
.original_id = orig_expr_id,
.result_id = expr_id});
if (has_ref_tag) {
context.ref_tags().Insert(expr_id, Context::RefTag::NotRequired);
}
}

// For `as`, don't perform any value category conversions. In particular, an
Expand Down Expand Up @@ -1574,7 +1605,8 @@ auto Convert(Context& context, SemIR::LocId loc_id, SemIR::InstId expr_id,
// If a reference expression is an acceptable result, we're done.
if (target.kind == ConversionTarget::ValueOrRef ||
target.kind == ConversionTarget::Discarded ||
target.kind == ConversionTarget::CppThunkRef) {
target.kind == ConversionTarget::CppThunkRef ||
target.kind == ConversionTarget::RefParam) {
break;
}

Expand All @@ -1599,6 +1631,17 @@ auto Convert(Context& context, SemIR::LocId loc_id, SemIR::InstId expr_id,
}
return SemIR::ErrorInst::InstId;
}
if (target.kind == ConversionTarget::RefParam) {
// Don't diagnose a non-reference scrutinee if it has a user-written
// `ref` tag, because that's diagnosed in `Convert`.
if (auto lookup_result = context.ref_tags().Lookup(expr_id);
!lookup_result ||
lookup_result.value() != Context::RefTag::Present) {
CARBON_DIAGNOSTIC(ValueForRefParam, Error,
"value expression passed to reference parameter");
context.emitter().Emit(loc_id, ValueForRefParam);
}
}

// When initializing from a value, perform a copy.
if (target.is_initializer()) {
Expand Down
2 changes: 2 additions & 0 deletions toolchain/check/convert.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ struct ConversionTarget {
ValueOrRef,
// Convert to a durable reference of type `type_id`.
DurableRef,
// Convert to a reference, suitable for binding to a reference parameter.
RefParam,
// Convert to a reference of type `type_id`, for use as the argument to a
// C++ thunk.
CppThunkRef,
Expand Down
22 changes: 16 additions & 6 deletions toolchain/check/handle_call_expr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include "toolchain/check/call.h"
#include "toolchain/check/context.h"
#include "toolchain/check/handle.h"
#include "toolchain/check/inst.h"
#include "toolchain/sem_ir/expr_info.h"
#include "toolchain/sem_ir/inst.h"

namespace Carbon::Check {
Expand All @@ -16,12 +18,6 @@ auto HandleParseNode(Context& context, Parse::CallExprStartId node_id) -> bool {
return true;
}

auto HandleParseNode(Context& context, Parse::CallExprCommaId /*node_id*/)
-> bool {
context.param_and_arg_refs_stack().ApplyComma();
return true;
}

auto HandleParseNode(Context& context, Parse::CallExprId node_id) -> bool {
// Process the final explicit call argument now, but leave the arguments
// block on the stack until the end of this function.
Expand All @@ -37,4 +33,18 @@ auto HandleParseNode(Context& context, Parse::CallExprId node_id) -> bool {
return true;
}

auto HandleParseNode(Context& context, Parse::RefTagId node_id) -> bool {
auto expr_id = context.node_stack().Peek<Parse::NodeCategory::Expr>();

if (SemIR::GetExprCategory(context.sem_ir(), expr_id) !=
SemIR::ExprCategory::DurableRef) {
CARBON_DIAGNOSTIC(
RefTagNotDurableRef, Error,
"expression tagged with `ref` is not a durable reference");
context.emitter().Emit(node_id, RefTagNotDurableRef);
}
context.ref_tags().Insert(expr_id, Context::RefTag::Present);
return true;
}

} // namespace Carbon::Check
2 changes: 1 addition & 1 deletion toolchain/check/handle_let_and_var.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ auto HandleParseNode(Context& context, Parse::VariablePatternId node_id)
switch (context.full_pattern_stack().CurrentKind()) {
case FullPatternStack::Kind::ExplicitParamList:
case FullPatternStack::Kind::ImplicitParamList:
subpattern_id = AddPatternInst<SemIR::RefParamPattern>(
subpattern_id = AddPatternInst<SemIR::VarParamPattern>(
context, node_id,
{.type_id = type_id,
.subpattern_id = subpattern_id,
Expand Down
3 changes: 3 additions & 0 deletions toolchain/check/import_ref.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3402,6 +3402,9 @@ static auto TryResolveInstCanonical(ImportRefResolver& resolver,
case CARBON_KIND(SemIR::ValueParamPattern inst): {
return TryResolveTypedInst(resolver, inst, constant_inst_id);
}
case CARBON_KIND(SemIR::VarParamPattern inst): {
return TryResolveTypedInst(resolver, inst, constant_inst_id);
}
case CARBON_KIND(SemIR::VarPattern inst): {
return TryResolveTypedInst(resolver, inst, constant_inst_id);
}
Expand Down
3 changes: 2 additions & 1 deletion toolchain/check/merge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ static auto CheckRedeclParam(Context& context, bool is_implicit_param,
}
case SemIR::OutParamPattern::Kind:
case SemIR::RefParamPattern::Kind:
case SemIR::ValueParamPattern::Kind: {
case SemIR::ValueParamPattern::Kind:
case SemIR::VarParamPattern::Kind: {
pattern_stack.push_back(
{.prev_id =
prev_param_pattern.As<SemIR::AnyParamPattern>().subpattern_id,
Expand Down
1 change: 0 additions & 1 deletion toolchain/check/node_stack.h
Original file line number Diff line number Diff line change
Expand Up @@ -482,7 +482,6 @@ class NodeStack {
case Parse::NodeKind::BaseColon:
case Parse::NodeKind::BaseIntroducer:
case Parse::NodeKind::BreakStatementStart:
case Parse::NodeKind::CallExprComma:
case Parse::NodeKind::ChoiceAlternativeListComma:
case Parse::NodeKind::CodeBlock:
case Parse::NodeKind::CompileTimeBindingPatternStart:
Expand Down
7 changes: 7 additions & 0 deletions toolchain/check/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ auto BuildUnaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
return SemIR::ErrorInst::InstId;
}

// Operator operands don't require `ref` tags.
context.ref_tags().Insert(operand_id, Context::RefTag::NotRequired);

// For unary operators with a C++ class as the operand, try to import and call
// the C++ operator.
// TODO: Change impl lookup instead. See
Expand Down Expand Up @@ -102,6 +105,10 @@ auto BuildBinaryOperator(Context& context, SemIR::LocId loc_id, Operator op,
return SemIR::ErrorInst::InstId;
}

// Operator operands don't require `ref` tags.
context.ref_tags().Insert(lhs_id, Context::RefTag::NotRequired);
context.ref_tags().Insert(rhs_id, Context::RefTag::NotRequired);

// For binary operators with a C++ class as at least one of the operands, try
// to import and call the C++ operator.
// TODO: Instead of hooking this here, change impl lookup, so that a generic
Expand Down
42 changes: 23 additions & 19 deletions toolchain/check/pattern_match.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,10 @@ class MatchContext {
SemIR::ValueParamPattern param_pattern,
SemIR::InstId pattern_inst_id, WorkItem entry)
-> void;
auto DoEmitPatternMatch(Context& context,
SemIR::RefParamPattern param_pattern,
template <typename RefParamPatternT>
requires std::is_same_v<RefParamPatternT, SemIR::RefParamPattern> ||
std::is_same_v<RefParamPatternT, SemIR::VarParamPattern>
auto DoEmitPatternMatch(Context& context, RefParamPatternT param_pattern,
SemIR::InstId pattern_inst_id, WorkItem entry)
-> void;
auto DoEmitPatternMatch(Context& context,
Expand Down Expand Up @@ -350,8 +352,11 @@ auto MatchContext::DoEmitPatternMatch(Context& context,
}
}

template <typename RefParamPatternT>
requires std::is_same_v<RefParamPatternT, SemIR::RefParamPattern> ||
std::is_same_v<RefParamPatternT, SemIR::VarParamPattern>
auto MatchContext::DoEmitPatternMatch(Context& context,
SemIR::RefParamPattern param_pattern,
RefParamPatternT param_pattern,
SemIR::InstId pattern_inst_id,
WorkItem entry) -> void {
switch (kind_) {
Expand All @@ -362,29 +367,24 @@ auto MatchContext::DoEmitPatternMatch(Context& context,
param_pattern.index.index);
CARBON_CHECK(entry.scrutinee_id.has_value());

// TODO: If this is a `ref` pattern and !entry.is_self, require the
// scrutinee to have a `ref` tag.

auto scrutinee_ref_id = ConvertToValueOrRefOfType(
if (std::is_same_v<RefParamPatternT, SemIR::VarParamPattern>) {
results_.push_back(entry.scrutinee_id);
break;
}
auto scrutinee_type_id = ExtractScrutineeType(
context.sem_ir(),
SemIR::GetTypeOfInstInSpecific(context.sem_ir(), callee_specific_id_,
pattern_inst_id));
auto scrutinee_ref_id = Convert(
context, SemIR::LocId(entry.scrutinee_id), entry.scrutinee_id,
ExtractScrutineeType(
context.sem_ir(),
SemIR::GetTypeOfInstInSpecific(
context.sem_ir(), callee_specific_id_, pattern_inst_id)));

{.kind = ConversionTarget::RefParam, .type_id = scrutinee_type_id});
switch (SemIR::GetExprCategory(context.sem_ir(), scrutinee_ref_id)) {
case SemIR::ExprCategory::Error:
case SemIR::ExprCategory::DurableRef:
case SemIR::ExprCategory::EphemeralRef:
break;
default:
CARBON_DIAGNOSTIC(ValueForRefParam, Error,
"value expression passed to reference parameter");
context.emitter().Emit(entry.scrutinee_id, ValueForRefParam);
// Add fake reference expression to preserve invariants.
auto scrutinee = context.insts().GetWithLocId(entry.scrutinee_id);
scrutinee_ref_id = AddInst<SemIR::TemporaryStorage>(
context, scrutinee.loc_id, {.type_id = scrutinee.inst.type_id()});
scrutinee_ref_id = SemIR::ErrorInst::InstId;
}
results_.push_back(scrutinee_ref_id);
// Do not traverse farther, because the caller side of the pattern
Expand Down Expand Up @@ -644,6 +644,10 @@ auto MatchContext::EmitPatternMatch(Context& context,
DoEmitPatternMatch(context, param_pattern, entry.pattern_id, entry);
break;
}
case CARBON_KIND(SemIR::VarParamPattern param_pattern): {
DoEmitPatternMatch(context, param_pattern, entry.pattern_id, entry);
break;
}
case CARBON_KIND(SemIR::OutParamPattern param_pattern): {
DoEmitPatternMatch(context, param_pattern, entry.pattern_id, entry);
break;
Expand Down
3 changes: 1 addition & 2 deletions toolchain/check/testdata/class/fail_ref_self.carbon
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ fn F(c: Class, p: Class*) {
// CHECK:STDOUT: %c.ref: %Class = name_ref c, %c
// CHECK:STDOUT: %F.ref.loc29: %Class.F.type = name_ref F, @Class.%Class.F.decl [concrete = constants.%Class.F]
// CHECK:STDOUT: %Class.F.bound.loc29: <bound method> = bound_method %c.ref, %F.ref.loc29
// CHECK:STDOUT: %.loc29: ref %Class = temporary_storage
// CHECK:STDOUT: %Class.F.call.loc29: init %empty_tuple.type = call %Class.F.bound.loc29(%.loc29)
// CHECK:STDOUT: %Class.F.call.loc29: init %empty_tuple.type = call %Class.F.bound.loc29(<error>)
// CHECK:STDOUT: %p.ref: %ptr.e71 = name_ref p, %p
// CHECK:STDOUT: %.loc32: ref %Class = deref %p.ref
// CHECK:STDOUT: %F.ref.loc32: %Class.F.type = name_ref F, @Class.%Class.F.decl [concrete = constants.%Class.F]
Expand Down
Loading