Skip to content

fix: Properly handle local trait impls #14424

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/hir-def/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,8 @@ pub trait DefDatabase: InternDatabase + ExpandDatabase + Upcast<dyn ExpandDataba
///
/// The `block_def_map` for block 0 would return `None`, while `block_def_map` of block 1 would
/// return a `DefMap` containing `inner`.
// FIXME: This actually can't return None anymore as we no longer allocate block scopes for
// non item declaring blocks
#[salsa::invoke(DefMap::block_def_map_query)]
fn block_def_map(&self, block: BlockId) -> Option<Arc<DefMap>>;

Expand Down
83 changes: 35 additions & 48 deletions crates/hir-ty/src/chalk_db.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
//! The implementation of `RustIrDatabase` for Chalk, which provides information
//! about the code that Chalk needs.
use std::sync::Arc;
use std::{iter, sync::Arc};

use cov_mark::hit;
use tracing::debug;

use chalk_ir::{cast::Cast, fold::shift::Shift, CanonicalVarKinds};
Expand All @@ -12,17 +11,16 @@ use base_db::CrateId;
use hir_def::{
expr::Movability,
lang_item::{lang_attr, LangItem, LangItemTarget},
AssocItemId, GenericDefId, HasModule, ItemContainerId, Lookup, ModuleId, TypeAliasId,
AssocItemId, BlockId, GenericDefId, HasModule, ItemContainerId, Lookup, TypeAliasId,
};
use hir_expand::name::name;

use crate::{
db::HirDatabase,
display::HirDisplay,
from_assoc_type_id, from_chalk_trait_id, from_foreign_def_id, make_binders,
make_single_type_binders,
from_assoc_type_id, from_chalk_trait_id, make_binders, make_single_type_binders,
mapping::{from_chalk, ToChalk, TypeAliasAsValue},
method_resolution::{TraitImpls, TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS},
method_resolution::{TyFingerprint, ALL_FLOAT_FPS, ALL_INT_FPS},
to_assoc_type_id, to_chalk_trait_id,
traits::ChalkContext,
utils::generics,
Expand Down Expand Up @@ -108,53 +106,41 @@ impl<'a> chalk_solve::RustIrDatabase<Interner> for ChalkContext<'a> {
_ => self_ty_fp.as_ref().map(std::slice::from_ref).unwrap_or(&[]),
};

fn local_impls(db: &dyn HirDatabase, module: ModuleId) -> Option<Arc<TraitImpls>> {
let block = module.containing_block()?;
hit!(block_local_impls);
db.trait_impls_in_block(block)
}

// Note: Since we're using impls_for_trait, only impls where the trait
// can be resolved should ever reach Chalk. impl_datum relies on that
// and will panic if the trait can't be resolved.
let in_deps = self.db.trait_impls_in_deps(self.krate);
let in_self = self.db.trait_impls_in_crate(self.krate);
let trait_module = trait_.module(self.db.upcast());
let type_module = match self_ty_fp {
Some(TyFingerprint::Adt(adt_id)) => Some(adt_id.module(self.db.upcast())),
Some(TyFingerprint::ForeignType(type_id)) => {
Some(from_foreign_def_id(type_id).module(self.db.upcast()))
}
Some(TyFingerprint::Dyn(trait_id)) => Some(trait_id.module(self.db.upcast())),
_ => None,
};
let impl_maps = [
Some(in_deps),
Some(in_self),
local_impls(self.db, trait_module),
type_module.and_then(|m| local_impls(self.db, m)),
];

let id_to_chalk = |id: hir_def::ImplId| id.to_chalk(self.db);
let impl_maps = [in_deps, in_self];
let block_impls = iter::successors(self.block, |&block_id| {
cov_mark::hit!(block_local_impls);
self.db
.block_def_map(block_id)
.and_then(|map| map.parent())
.and_then(|module| module.containing_block())
})
.filter_map(|block_id| self.db.trait_impls_in_block(block_id));

let result: Vec<_> = if fps.is_empty() {
debug!("Unrestricted search for {:?} impls...", trait_);
impl_maps
.iter()
.filter_map(|o| o.as_ref())
.flat_map(|impls| impls.for_trait(trait_).map(id_to_chalk))
.collect()
} else {
impl_maps
.iter()
.filter_map(|o| o.as_ref())
.flat_map(|impls| {
fps.iter().flat_map(move |fp| {
impls.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk)
})
})
.collect()
};
let id_to_chalk = |id: hir_def::ImplId| id.to_chalk(self.db);
let mut result = vec![];
match fps {
[] => {
debug!("Unrestricted search for {:?} impls...", trait_);
impl_maps.into_iter().chain(block_impls).for_each(|impls| {
result.extend(impls.for_trait(trait_).map(id_to_chalk));
});
}
fps => {
impl_maps.into_iter().chain(block_impls).for_each(|impls| {
result.extend(
fps.iter().flat_map(|fp| {
impls.for_trait_and_self_ty(trait_, *fp).map(id_to_chalk)
}),
);
});
}
}

debug!("impls_for_trait returned {} impls", result.len());
result
Expand Down Expand Up @@ -193,7 +179,7 @@ impl<'a> chalk_solve::RustIrDatabase<Interner> for ChalkContext<'a> {
&self,
environment: &chalk_ir::Environment<Interner>,
) -> chalk_ir::ProgramClauses<Interner> {
self.db.program_clauses_for_chalk_env(self.krate, environment.clone())
self.db.program_clauses_for_chalk_env(self.krate, self.block, environment.clone())
}

fn opaque_ty_data(&self, id: chalk_ir::OpaqueTyId<Interner>) -> Arc<OpaqueTyDatum> {
Expand Down Expand Up @@ -451,9 +437,10 @@ impl<'a> chalk_ir::UnificationDatabase<Interner> for &'a dyn HirDatabase {
pub(crate) fn program_clauses_for_chalk_env_query(
db: &dyn HirDatabase,
krate: CrateId,
block: Option<BlockId>,
environment: chalk_ir::Environment<Interner>,
) -> chalk_ir::ProgramClauses<Interner> {
chalk_solve::program_clauses_for_env(&ChalkContext { db, krate }, &environment)
chalk_solve::program_clauses_for_env(&ChalkContext { db, krate, block }, &environment)
}

pub(crate) fn associated_ty_data_query(
Expand Down
8 changes: 6 additions & 2 deletions crates/hir-ty/src/db.rs
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
fn trait_impls_in_crate(&self, krate: CrateId) -> Arc<TraitImpls>;

#[salsa::invoke(TraitImpls::trait_impls_in_block_query)]
fn trait_impls_in_block(&self, krate: BlockId) -> Option<Arc<TraitImpls>>;
fn trait_impls_in_block(&self, block: BlockId) -> Option<Arc<TraitImpls>>;

#[salsa::invoke(TraitImpls::trait_impls_in_deps_query)]
fn trait_impls_in_deps(&self, krate: CrateId) -> Arc<TraitImpls>;
Expand Down Expand Up @@ -197,20 +197,23 @@ pub trait HirDatabase: DefDatabase + Upcast<dyn DefDatabase> {
fn trait_solve(
&self,
krate: CrateId,
block: Option<BlockId>,
goal: crate::Canonical<crate::InEnvironment<crate::Goal>>,
) -> Option<crate::Solution>;

#[salsa::invoke(crate::traits::trait_solve_query)]
fn trait_solve_query(
&self,
krate: CrateId,
block: Option<BlockId>,
goal: crate::Canonical<crate::InEnvironment<crate::Goal>>,
) -> Option<crate::Solution>;

#[salsa::invoke(chalk_db::program_clauses_for_chalk_env_query)]
fn program_clauses_for_chalk_env(
&self,
krate: CrateId,
block: Option<BlockId>,
env: chalk_ir::Environment<Interner>,
) -> chalk_ir::ProgramClauses<Interner>;
}
Expand All @@ -232,10 +235,11 @@ fn infer_wait(db: &dyn HirDatabase, def: DefWithBodyId) -> Arc<InferenceResult>
fn trait_solve_wait(
db: &dyn HirDatabase,
krate: CrateId,
block: Option<BlockId>,
goal: crate::Canonical<crate::InEnvironment<crate::Goal>>,
) -> Option<crate::Solution> {
let _p = profile::span("trait_solve::wait");
db.trait_solve_query(krate, goal)
db.trait_solve_query(krate, block, goal)
}

#[test]
Expand Down
6 changes: 2 additions & 4 deletions crates/hir-ty/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ use crate::{
db::HirDatabase, fold_tys, fold_tys_and_consts, infer::coerce::CoerceMany,
lower::ImplTraitLoweringMode, static_lifetime, to_assoc_type_id, AliasEq, AliasTy, Const,
DomainGoal, GenericArg, Goal, ImplTraitId, InEnvironment, Interner, ProjectionTy, RpitId,
Substitution, TraitEnvironment, TraitRef, Ty, TyBuilder, TyExt, TyKind,
Substitution, TraitRef, Ty, TyBuilder, TyExt, TyKind,
};

// This lint has a false positive here. See the link below for details.
Expand Down Expand Up @@ -442,7 +442,6 @@ pub(crate) struct InferenceContext<'a> {
pub(crate) body: &'a Body,
pub(crate) resolver: Resolver,
table: unify::InferenceTable<'a>,
trait_env: Arc<TraitEnvironment>,
/// The traits in scope, disregarding block modules. This is used for caching purposes.
traits_in_scope: FxHashSet<TraitId>,
pub(crate) result: InferenceResult,
Expand Down Expand Up @@ -516,8 +515,7 @@ impl<'a> InferenceContext<'a> {
let trait_env = db.trait_environment_for_body(owner);
InferenceContext {
result: InferenceResult::default(),
table: unify::InferenceTable::new(db, trait_env.clone()),
trait_env,
table: unify::InferenceTable::new(db, trait_env),
return_ty: TyKind::Error.intern(Interner), // set in collect_* calls
resume_yield_tys: None,
return_coercion: None,
Expand Down
2 changes: 1 addition & 1 deletion crates/hir-ty/src/infer/coerce.rs
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ impl<'a> InferenceTable<'a> {
// Need to find out in what cases this is necessary
let solution = self
.db
.trait_solve(krate, canonicalized.value.clone().cast(Interner))
.trait_solve(krate, self.trait_env.block, canonicalized.value.clone().cast(Interner))
.ok_or(TypeError)?;

match solution {
Expand Down
35 changes: 23 additions & 12 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
use std::{
iter::{repeat, repeat_with},
mem,
sync::Arc,
};

use chalk_ir::{
Expand All @@ -15,7 +16,7 @@ use hir_def::{
generics::TypeOrConstParamData,
lang_item::LangItem,
path::{GenericArg, GenericArgs},
ConstParamId, FieldId, ItemContainerId, Lookup,
BlockId, ConstParamId, FieldId, ItemContainerId, Lookup,
};
use hir_expand::name::{name, Name};
use stdx::always;
Expand Down Expand Up @@ -147,19 +148,19 @@ impl<'a> InferenceContext<'a> {
self.infer_top_pat(pat, &input_ty);
self.result.standard_types.bool_.clone()
}
Expr::Block { statements, tail, label, id: _ } => {
self.infer_block(tgt_expr, statements, *tail, *label, expected)
Expr::Block { statements, tail, label, id } => {
self.infer_block(tgt_expr, *id, statements, *tail, *label, expected)
}
Expr::Unsafe { id: _, statements, tail } => {
self.infer_block(tgt_expr, statements, *tail, None, expected)
Expr::Unsafe { id, statements, tail } => {
self.infer_block(tgt_expr, *id, statements, *tail, None, expected)
}
Expr::Const { id: _, statements, tail } => {
Expr::Const { id, statements, tail } => {
self.with_breakable_ctx(BreakableKind::Border, None, None, |this| {
this.infer_block(tgt_expr, statements, *tail, None, expected)
this.infer_block(tgt_expr, *id, statements, *tail, None, expected)
})
.1
}
Expr::Async { id: _, statements, tail } => {
Expr::Async { id, statements, tail } => {
let ret_ty = self.table.new_type_var();
let prev_diverges = mem::replace(&mut self.diverges, Diverges::Maybe);
let prev_ret_ty = mem::replace(&mut self.return_ty, ret_ty.clone());
Expand All @@ -170,6 +171,7 @@ impl<'a> InferenceContext<'a> {
self.with_breakable_ctx(BreakableKind::Border, None, None, |this| {
this.infer_block(
tgt_expr,
*id,
statements,
*tail,
None,
Expand Down Expand Up @@ -394,7 +396,7 @@ impl<'a> InferenceContext<'a> {
}
}
let trait_ = fn_x
.get_id(self.db, self.trait_env.krate)
.get_id(self.db, self.table.trait_env.krate)
.expect("We just used it");
let trait_data = self.db.trait_data(trait_);
if let Some(func) = trait_data.method_by_name(&fn_x.method_name()) {
Expand Down Expand Up @@ -787,7 +789,7 @@ impl<'a> InferenceContext<'a> {
let canonicalized = self.canonicalize(base_ty.clone());
let receiver_adjustments = method_resolution::resolve_indexing_op(
self.db,
self.trait_env.clone(),
self.table.trait_env.clone(),
canonicalized.value,
index_trait,
);
Expand Down Expand Up @@ -1205,13 +1207,19 @@ impl<'a> InferenceContext<'a> {
fn infer_block(
&mut self,
expr: ExprId,
block_id: Option<BlockId>,
statements: &[Statement],
tail: Option<ExprId>,
label: Option<LabelId>,
expected: &Expectation,
) -> Ty {
let coerce_ty = expected.coercion_target_type(&mut self.table);
let g = self.resolver.update_to_inner_scope(self.db.upcast(), self.owner, expr);
let prev_env = block_id.map(|block_id| {
let prev_env = self.table.trait_env.clone();
Arc::make_mut(&mut self.table.trait_env).block = Some(block_id);
prev_env
});

let (break_ty, ty) =
self.with_breakable_ctx(BreakableKind::Block, Some(coerce_ty.clone()), label, |this| {
Expand Down Expand Up @@ -1300,6 +1308,9 @@ impl<'a> InferenceContext<'a> {
}
});
self.resolver.reset_to_guard(g);
if let Some(prev_env) = prev_env {
self.table.trait_env = prev_env;
}

break_ty.unwrap_or(ty)
}
Expand Down Expand Up @@ -1398,7 +1409,7 @@ impl<'a> InferenceContext<'a> {
method_resolution::lookup_method(
self.db,
&canonicalized_receiver.value,
self.trait_env.clone(),
self.table.trait_env.clone(),
self.get_traits_in_scope().as_ref().left_or_else(|&it| it),
VisibleFromModule::Filter(self.resolver.module()),
name,
Expand Down Expand Up @@ -1431,7 +1442,7 @@ impl<'a> InferenceContext<'a> {
let resolved = method_resolution::lookup_method(
self.db,
&canonicalized_receiver.value,
self.trait_env.clone(),
self.table.trait_env.clone(),
self.get_traits_in_scope().as_ref().left_or_else(|&it| it),
VisibleFromModule::Filter(self.resolver.module()),
method_name,
Expand Down
21 changes: 17 additions & 4 deletions crates/hir-ty/src/infer/unify.rs
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,8 @@ impl<'a> InferenceTable<'a> {
pub(crate) fn try_obligation(&mut self, goal: Goal) -> Option<Solution> {
let in_env = InEnvironment::new(&self.trait_env.env, goal);
let canonicalized = self.canonicalize(in_env);
let solution = self.db.trait_solve(self.trait_env.krate, canonicalized.value);
let solution =
self.db.trait_solve(self.trait_env.krate, self.trait_env.block, canonicalized.value);
solution
}

Expand Down Expand Up @@ -597,7 +598,11 @@ impl<'a> InferenceTable<'a> {
&mut self,
canonicalized: &Canonicalized<InEnvironment<Goal>>,
) -> bool {
let solution = self.db.trait_solve(self.trait_env.krate, canonicalized.value.clone());
let solution = self.db.trait_solve(
self.trait_env.krate,
self.trait_env.block,
canonicalized.value.clone(),
);

match solution {
Some(Solution::Unique(canonical_subst)) => {
Expand Down Expand Up @@ -684,7 +689,11 @@ impl<'a> InferenceTable<'a> {
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, canonical.value.cast(Interner)).is_some() {
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.value.cast(Interner))
.is_some()
{
self.register_obligation(obligation.goal);
let return_ty = self.normalize_projection_ty(projection);
for fn_x in [FnTrait::Fn, FnTrait::FnMut, FnTrait::FnOnce] {
Expand All @@ -695,7 +704,11 @@ impl<'a> InferenceTable<'a> {
environment: trait_env.clone(),
};
let canonical = self.canonicalize(obligation.clone());
if self.db.trait_solve(krate, canonical.value.cast(Interner)).is_some() {
if self
.db
.trait_solve(krate, self.trait_env.block, canonical.value.cast(Interner))
.is_some()
{
return Some((fn_x, arg_tys, return_ty));
}
}
Expand Down
1 change: 0 additions & 1 deletion crates/hir-ty/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
//! The type system. We currently use this to infer types for completion, hover
//! information and various assists.

#![warn(rust_2018_idioms, unused_lifetimes, semicolon_in_expressions_from_macros)]

#[allow(unused)]
Expand Down
Loading