gccrs: typecheck: Properly select methods when dealing with specialization

gcc/rust/ChangeLog:

	* typecheck/rust-hir-type-check-expr.cc (is_default_fn): New.
	(emit_ambiguous_resolution_error): New.
	(handle_multiple_candidates): Properly handle multiple candidates in
	the case of specialization.
	(TypeCheckExpr::visit): Call `handle_multiple_candidates`.

gcc/testsuite/ChangeLog:

	* rust/execute/torture/min_specialization2.rs: New test.
	* rust/execute/torture/min_specialization3.rs: New test.
This commit is contained in:
Arthur Cohen 2025-04-03 16:22:10 +02:00
parent 9e367223ce
commit eb5eee065b
3 changed files with 172 additions and 24 deletions

View file

@ -16,6 +16,8 @@
// along with GCC; see the file COPYING3. If not see
// <http://www.gnu.org/licenses/>.
#include "optional.h"
#include "rust-hir-expr.h"
#include "rust-system.h"
#include "rust-tyty-call.h"
#include "rust-hir-type-check-struct-field.h"
@ -1154,6 +1156,94 @@ TypeCheckExpr::visit (HIR::FieldAccessExpr &expr)
infered = lookup->get_field_type ();
}
bool
is_default_fn (const MethodCandidate &candidate)
{
if (candidate.candidate.is_impl_candidate ())
{
auto *item = candidate.candidate.item.impl.impl_item;
if (item->get_impl_item_type () == HIR::ImplItem::FUNCTION)
{
auto &fn = static_cast<HIR::Function &> (*item);
return fn.is_default ();
}
}
return false;
}
void
emit_ambiguous_resolution_error (HIR::MethodCallExpr &expr,
std::set<MethodCandidate> &candidates)
{
rich_location r (line_table, expr.get_method_name ().get_locus ());
std::string rich_msg = "multiple "
+ expr.get_method_name ().get_segment ().as_string ()
+ " found";
// We have to filter out default candidates
for (auto &c : candidates)
if (!is_default_fn (c))
r.add_range (c.candidate.locus);
r.add_fixit_replace (rich_msg.c_str ());
rust_error_at (r, ErrorCode::E0592, "duplicate definitions with name %qs",
expr.get_method_name ().get_segment ().as_string ().c_str ());
}
// We are allowed to have multiple candidates if they are all specializable
// functions or if all of them except one are specializable functions.
// In the later case, we just return a valid candidate without erroring out
// about ambiguity. If there are two or more specialized functions, then we
// error out.
//
// FIXME: The first case is not handled at the moment, so we error out
tl::optional<const MethodCandidate &>
handle_multiple_candidates (HIR::MethodCallExpr &expr,
std::set<MethodCandidate> &candidates)
{
auto all_default = true;
tl::optional<const MethodCandidate &> found = tl::nullopt;
for (auto &c : candidates)
{
if (!is_default_fn (c))
{
all_default = false;
// We haven't found a final candidate yet, so we can select
// this one. However, if we already have a candidate, then
// that means there are multiple non-default candidates - we
// must error out
if (!found)
{
found = c;
}
else
{
emit_ambiguous_resolution_error (expr, candidates);
return tl::nullopt;
}
}
}
// None of the candidates were a non-default (specialized) function, so we
// error out
if (all_default)
{
rust_sorry_at (expr.get_locus (),
"cannot resolve method calls to non-specialized methods "
"(all function candidates are %qs)",
"default");
return tl::nullopt;
}
return found;
}
void
TypeCheckExpr::visit (HIR::MethodCallExpr &expr)
{
@ -1181,34 +1271,25 @@ TypeCheckExpr::visit (HIR::MethodCallExpr &expr)
return;
}
tl::optional<const MethodCandidate &> candidate = *candidates.begin ();
if (candidates.size () > 1)
{
rich_location r (line_table, expr.get_method_name ().get_locus ());
std::string rich_msg
= "multiple " + expr.get_method_name ().get_segment ().as_string ()
+ " found";
candidate = handle_multiple_candidates (expr, candidates);
for (auto &c : candidates)
r.add_range (c.candidate.locus);
if (!candidate)
return;
r.add_fixit_replace (rich_msg.c_str ());
auto found_candidate = *candidate;
rust_error_at (
r, ErrorCode::E0592, "duplicate definitions with name %qs",
expr.get_method_name ().get_segment ().as_string ().c_str ());
return;
}
auto candidate = *candidates.begin ();
rust_debug_loc (expr.get_method_name ().get_locus (),
"resolved method to: {%u} {%s} with [%lu] adjustments",
candidate.candidate.ty->get_ref (),
candidate.candidate.ty->debug_str ().c_str (),
(unsigned long) candidate.adjustments.size ());
found_candidate.candidate.ty->get_ref (),
found_candidate.candidate.ty->debug_str ().c_str (),
(unsigned long) found_candidate.adjustments.size ());
// Get the adjusted self
Adjuster adj (receiver_tyty);
TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);
TyTy::BaseType *adjusted_self = adj.adjust_type (found_candidate.adjustments);
rust_debug ("receiver: %s adjusted self %s",
receiver_tyty->debug_str ().c_str (),
adjusted_self->debug_str ().c_str ());
@ -1219,10 +1300,10 @@ TypeCheckExpr::visit (HIR::MethodCallExpr &expr)
HirId autoderef_mappings_id
= expr.get_receiver ().get_mappings ().get_hirid ();
context->insert_autoderef_mappings (autoderef_mappings_id,
std::move (candidate.adjustments));
std::move (found_candidate.adjustments));
PathProbeCandidate &resolved_candidate = candidate.candidate;
TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
PathProbeCandidate &resolved_candidate = found_candidate.candidate;
TyTy::BaseType *lookup_tyty = found_candidate.candidate.ty;
NodeId resolved_node_id
= resolved_candidate.is_impl_candidate ()
? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
@ -1249,8 +1330,8 @@ TypeCheckExpr::visit (HIR::MethodCallExpr &expr)
fn->prepare_higher_ranked_bounds ();
rust_debug_loc (expr.get_locus (), "resolved method call to: {%u} {%s}",
candidate.candidate.ty->get_ref (),
candidate.candidate.ty->debug_str ().c_str ());
found_candidate.candidate.ty->get_ref (),
found_candidate.candidate.ty->debug_str ().c_str ());
if (resolved_candidate.is_impl_candidate ())
{

View file

@ -0,0 +1,31 @@
#![feature(min_specialization)]
#[lang = "sized"]
trait Sized {}
trait Foo {
fn foo(&self) -> i32;
}
impl<T> Foo for T {
default fn foo(&self) -> i32 { // { dg-warning "unused" }
15
}
}
impl Foo for bool {
fn foo(&self) -> i32 {
if *self {
1
} else {
0
}
}
}
fn main() -> i32 {
let a = 1.foo() - 15;
let b = true.foo() - 1;
a + b
}

View file

@ -0,0 +1,36 @@
#![feature(min_specialization)]
#[lang = "sized"]
trait Sized {}
trait Foo {
fn foo(&self) -> i32;
}
struct Wrap<T>(T);
impl<T> Foo for T {
default fn foo(&self) -> i32 {
15
}
}
impl<T> Foo for Wrap<T> {
default fn foo(&self) -> i32 {
16
}
}
impl Foo for Wrap<bool> {
fn foo(&self) -> i32 {
if self.0 {
1
} else {
0
}
}
}
fn main() -> i32 {
Wrap(true).foo() - 1
}