Migrate variance to the new solver

This commit is contained in:
Chayim Refael Friedman 2025-10-17 05:14:02 +03:00
parent 7c871c2a4b
commit 6232ba8d08
9 changed files with 227 additions and 423 deletions

View file

@ -1,55 +1,9 @@
//! The implementation of `RustIrDatabase` for Chalk, which provides information
//! about the code that Chalk needs.
use hir_def::{CallableDefId, GenericDefId};
use crate::{Interner, db::HirDatabase, mapping::from_chalk};
use crate::Interner;
pub(crate) type AssocTypeId = chalk_ir::AssocTypeId<Interner>;
pub(crate) type TraitId = chalk_ir::TraitId<Interner>;
pub(crate) type AdtId = chalk_ir::AdtId<Interner>;
pub(crate) type ImplId = chalk_ir::ImplId<Interner>;
pub(crate) type Variances = chalk_ir::Variances<Interner>;
impl chalk_ir::UnificationDatabase<Interner> for &dyn HirDatabase {
fn fn_def_variance(
&self,
fn_def_id: chalk_ir::FnDefId<Interner>,
) -> chalk_ir::Variances<Interner> {
HirDatabase::fn_def_variance(*self, from_chalk(*self, fn_def_id))
}
fn adt_variance(&self, adt_id: chalk_ir::AdtId<Interner>) -> chalk_ir::Variances<Interner> {
HirDatabase::adt_variance(*self, adt_id.0)
}
}
pub(crate) fn fn_def_variance_query(
db: &dyn HirDatabase,
callable_def: CallableDefId,
) -> Variances {
Variances::from_iter(
Interner,
db.variances_of(GenericDefId::from_callable(db, callable_def))
.as_deref()
.unwrap_or_default()
.iter()
.map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}
pub(crate) fn adt_variance_query(db: &dyn HirDatabase, adt_id: hir_def::AdtId) -> Variances {
Variances::from_iter(
Interner,
db.variances_of(adt_id.into()).as_deref().unwrap_or_default().iter().map(|v| match v {
crate::variance::Variance::Covariant => chalk_ir::Variance::Covariant,
crate::variance::Variance::Invariant => chalk_ir::Variance::Invariant,
crate::variance::Variance::Contravariant => chalk_ir::Variance::Contravariant,
crate::variance::Variance::Bivariant => chalk_ir::Variance::Invariant,
}),
)
}

View file

@ -3,8 +3,8 @@
use hir_def::{ItemContainerId, Lookup, TraitId};
use crate::{
Binders, DynTy, Interner, ProjectionTy, Substitution, TraitRef, Ty, db::HirDatabase,
from_assoc_type_id, from_chalk_trait_id, generics::generics, to_chalk_trait_id,
Interner, ProjectionTy, Substitution, TraitRef, Ty, db::HirDatabase, from_assoc_type_id,
from_chalk_trait_id, generics::generics, to_chalk_trait_id,
};
pub(crate) trait ProjectionTyExt {
@ -35,23 +35,6 @@ impl ProjectionTyExt for ProjectionTy {
}
}
pub(crate) trait DynTyExt {
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>>;
}
impl DynTyExt for DynTy {
fn principal(&self) -> Option<Binders<Binders<&TraitRef>>> {
self.bounds.as_ref().filter_map(|bounds| {
bounds.interned().first().and_then(|b| {
b.as_ref().filter_map(|b| match b {
crate::WhereClause::Implemented(trait_ref) => Some(trait_ref),
_ => None,
})
})
})
}
}
pub(crate) trait TraitRefExt {
fn hir_trait_id(&self) -> TraitId;
}

View file

@ -17,7 +17,6 @@ use triomphe::Arc;
use crate::{
Binders, ImplTraitId, ImplTraits, InferenceResult, TraitEnvironment, Ty, TyDefId, ValueTyDefId,
chalk_db,
consteval::ConstEvalError,
dyn_compatibility::DynCompatibilityViolation,
layout::{Layout, LayoutError},
@ -308,19 +307,13 @@ pub trait HirDatabase: DefDatabase + std::fmt::Debug {
#[salsa::interned]
fn intern_coroutine(&self, id: InternedCoroutine) -> InternedCoroutineId;
#[salsa::invoke(chalk_db::fn_def_variance_query)]
fn fn_def_variance(&self, fn_def_id: CallableDefId) -> chalk_db::Variances;
#[salsa::invoke(chalk_db::adt_variance_query)]
fn adt_variance(&self, adt_id: AdtId) -> chalk_db::Variances;
#[salsa::invoke(crate::variance::variances_of)]
#[salsa::cycle(
// cycle_fn = crate::variance::variances_of_cycle_fn,
// cycle_initial = crate::variance::variances_of_cycle_initial,
cycle_result = crate::variance::variances_of_cycle_initial,
)]
fn variances_of(&self, def: GenericDefId) -> Option<Arc<[crate::variance::Variance]>>;
fn variances_of(&self, def: GenericDefId) -> crate::next_solver::VariancesOf<'_>;
// next trait solver

View file

@ -117,7 +117,6 @@ pub use utils::{
TargetFeatureIsSafeInTarget, Unsafety, all_super_traits, direct_super_traits,
is_fn_unsafe_to_call, target_feature_is_safe_in_target,
};
pub use variance::Variance;
use chalk_ir::{BoundVar, DebruijnIndex, Safety, Scalar};

View file

@ -155,7 +155,7 @@ impl From<DefWithBodyId> for SolverDefId {
}
impl TryFrom<SolverDefId> for GenericDefId {
type Error = SolverDefId;
type Error = ();
fn try_from(value: SolverDefId) -> Result<Self, Self::Error> {
Ok(match value {
@ -170,7 +170,7 @@ impl TryFrom<SolverDefId> for GenericDefId {
| SolverDefId::InternedCoroutineId(_)
| SolverDefId::InternedOpaqueTyId(_)
| SolverDefId::EnumVariantId(_)
| SolverDefId::Ctor(_) => return Err(value),
| SolverDefId::Ctor(_) => return Err(()),
})
}
}

View file

@ -83,7 +83,7 @@ macro_rules! _interned_vec_nolifetime_salsa {
($name:ident, $ty:ty) => {
interned_vec_nolifetime_salsa!($name, $ty, nofold);
impl<'db> rustc_type_ir::TypeFoldable<DbInterner<'db>> for $name {
impl<'db> rustc_type_ir::TypeFoldable<DbInterner<'db>> for $name<'db> {
fn try_fold_with<F: rustc_type_ir::FallibleTypeFolder<DbInterner<'db>>>(
self,
folder: &mut F,
@ -104,7 +104,7 @@ macro_rules! _interned_vec_nolifetime_salsa {
}
}
impl<'db> rustc_type_ir::TypeVisitable<DbInterner<'db>> for $name {
impl<'db> rustc_type_ir::TypeVisitable<DbInterner<'db>> for $name<'db> {
fn visit_with<V: rustc_type_ir::TypeVisitor<DbInterner<'db>>>(
&self,
visitor: &mut V,
@ -117,14 +117,14 @@ macro_rules! _interned_vec_nolifetime_salsa {
}
};
($name:ident, $ty:ty, nofold) => {
#[salsa::interned(no_lifetime, constructor = new_, debug)]
#[salsa::interned(constructor = new_, debug)]
pub struct $name {
#[returns(ref)]
inner_: smallvec::SmallVec<[$ty; 2]>,
}
impl $name {
pub fn new_from_iter<'db>(
impl<'db> $name<'db> {
pub fn new_from_iter(
interner: DbInterner<'db>,
data: impl IntoIterator<Item = $ty>,
) -> Self {
@ -140,7 +140,7 @@ macro_rules! _interned_vec_nolifetime_salsa {
}
}
impl rustc_type_ir::inherent::SliceLike for $name {
impl<'db> rustc_type_ir::inherent::SliceLike for $name<'db> {
type Item = $ty;
type IntoIter = <smallvec::SmallVec<[$ty; 2]> as IntoIterator>::IntoIter;
@ -154,7 +154,7 @@ macro_rules! _interned_vec_nolifetime_salsa {
}
}
impl IntoIterator for $name {
impl<'db> IntoIterator for $name<'db> {
type Item = $ty;
type IntoIter = <Self as rustc_type_ir::inherent::SliceLike>::IntoIter;
@ -163,7 +163,7 @@ macro_rules! _interned_vec_nolifetime_salsa {
}
}
impl Default for $name {
impl<'db> Default for $name<'db> {
fn default() -> Self {
$name::new_from_iter(DbInterner::conjure(), [])
}
@ -887,7 +887,7 @@ macro_rules! as_lang_item {
impl<'db> rustc_type_ir::Interner for DbInterner<'db> {
type DefId = SolverDefId;
type LocalDefId = SolverDefId;
type LocalDefIds = SolverDefIds;
type LocalDefIds = SolverDefIds<'db>;
type TraitId = TraitIdWrapper;
type ForeignId = TypeAliasIdWrapper;
type FunctionId = CallableIdWrapper;
@ -904,7 +904,7 @@ impl<'db> rustc_type_ir::Interner for DbInterner<'db> {
type Term = Term<'db>;
type BoundVarKinds = BoundVarKinds;
type BoundVarKinds = BoundVarKinds<'db>;
type BoundVarKind = BoundVarKind;
type PredefinedOpaques = PredefinedOpaques<'db>;
@ -977,7 +977,7 @@ impl<'db> rustc_type_ir::Interner for DbInterner<'db> {
type GenericsOf = Generics;
type VariancesOf = VariancesOf;
type VariancesOf = VariancesOf<'db>;
type AdtDef = AdtDef;
@ -1045,10 +1045,9 @@ impl<'db> rustc_type_ir::Interner for DbInterner<'db> {
fn variances_of(self, def_id: Self::DefId) -> Self::VariancesOf {
let generic_def = match def_id {
SolverDefId::FunctionId(def_id) => def_id.into(),
SolverDefId::AdtId(def_id) => def_id.into(),
SolverDefId::Ctor(Ctor::Struct(def_id)) => def_id.into(),
SolverDefId::Ctor(Ctor::Enum(def_id)) => def_id.loc(self.db).parent.into(),
SolverDefId::Ctor(Ctor::Enum(def_id)) | SolverDefId::EnumVariantId(def_id) => {
def_id.loc(self.db).parent.into()
}
SolverDefId::InternedOpaqueTyId(_def_id) => {
// FIXME(next-solver): track variances
//
@ -1059,17 +1058,20 @@ impl<'db> rustc_type_ir::Interner for DbInterner<'db> {
(0..self.generics_of(def_id).count()).map(|_| Variance::Invariant),
);
}
_ => return VariancesOf::new_from_iter(self, []),
SolverDefId::Ctor(Ctor::Struct(def_id)) => def_id.into(),
SolverDefId::AdtId(def_id) => def_id.into(),
SolverDefId::FunctionId(def_id) => def_id.into(),
SolverDefId::ConstId(_)
| SolverDefId::StaticId(_)
| SolverDefId::TraitId(_)
| SolverDefId::TypeAliasId(_)
| SolverDefId::ImplId(_)
| SolverDefId::InternedClosureId(_)
| SolverDefId::InternedCoroutineId(_) => {
return VariancesOf::new_from_iter(self, []);
}
};
VariancesOf::new_from_iter(
self,
self.db()
.variances_of(generic_def)
.as_deref()
.unwrap_or_default()
.iter()
.map(|v| v.to_nextsolver(self)),
)
self.db.variances_of(generic_def)
}
fn type_of(self, def_id: Self::DefId) -> EarlyBinder<Self, Self::Ty> {

View file

@ -605,8 +605,8 @@ impl<'db, T: NextSolverToChalk<'db, U>, U: HasInterner<Interner = Interner>>
}
}
impl<'db> ChalkToNextSolver<'db, BoundVarKinds> for chalk_ir::VariableKinds<Interner> {
fn to_nextsolver(&self, interner: DbInterner<'db>) -> BoundVarKinds {
impl<'db> ChalkToNextSolver<'db, BoundVarKinds<'db>> for chalk_ir::VariableKinds<Interner> {
fn to_nextsolver(&self, interner: DbInterner<'db>) -> BoundVarKinds<'db> {
BoundVarKinds::new_from_iter(
interner,
self.iter(Interner).map(|v| v.to_nextsolver(interner)),
@ -614,7 +614,7 @@ impl<'db> ChalkToNextSolver<'db, BoundVarKinds> for chalk_ir::VariableKinds<Inte
}
}
impl<'db> NextSolverToChalk<'db, chalk_ir::VariableKinds<Interner>> for BoundVarKinds {
impl<'db> NextSolverToChalk<'db, chalk_ir::VariableKinds<Interner>> for BoundVarKinds<'db> {
fn to_chalk(self, interner: DbInterner<'db>) -> chalk_ir::VariableKinds<Interner> {
chalk_ir::VariableKinds::from_iter(Interner, self.iter().map(|v| v.to_chalk(interner)))
}
@ -763,36 +763,6 @@ impl<'db> ChalkToNextSolver<'db, rustc_ast_ir::Mutability> for chalk_ir::Mutabil
}
}
impl<'db> ChalkToNextSolver<'db, rustc_type_ir::Variance> for crate::Variance {
fn to_nextsolver(&self, interner: DbInterner<'db>) -> rustc_type_ir::Variance {
match self {
crate::Variance::Covariant => rustc_type_ir::Variance::Covariant,
crate::Variance::Invariant => rustc_type_ir::Variance::Invariant,
crate::Variance::Contravariant => rustc_type_ir::Variance::Contravariant,
crate::Variance::Bivariant => rustc_type_ir::Variance::Bivariant,
}
}
}
impl<'db> ChalkToNextSolver<'db, rustc_type_ir::Variance> for chalk_ir::Variance {
fn to_nextsolver(&self, interner: DbInterner<'db>) -> rustc_type_ir::Variance {
match self {
chalk_ir::Variance::Covariant => rustc_type_ir::Variance::Covariant,
chalk_ir::Variance::Invariant => rustc_type_ir::Variance::Invariant,
chalk_ir::Variance::Contravariant => rustc_type_ir::Variance::Contravariant,
}
}
}
impl<'db> ChalkToNextSolver<'db, VariancesOf> for chalk_ir::Variances<Interner> {
fn to_nextsolver(&self, interner: DbInterner<'db>) -> VariancesOf {
VariancesOf::new_from_iter(
interner,
self.as_slice(Interner).iter().map(|v| v.to_nextsolver(interner)),
)
}
}
impl<'db> ChalkToNextSolver<'db, Goal<DbInterner<'db>, Predicate<'db>>>
for chalk_ir::InEnvironment<chalk_ir::Goal<Interner>>
{

View file

@ -13,43 +13,45 @@
//! by the next salsa version. If not, we will likely have to adapt and go with the rustc approach
//! while installing firewall per item queries to prevent invalidation issues.
use crate::db::HirDatabase;
use crate::generics::{Generics, generics};
use crate::next_solver::DbInterner;
use crate::next_solver::mapping::{ChalkToNextSolver, NextSolverToChalk};
use crate::{
AliasTy, Const, ConstScalar, DynTyExt, GenericArg, GenericArgData, Interner, Lifetime,
LifetimeData, Ty, TyKind,
use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId, signatures::StructFlags};
use rustc_ast_ir::Mutability;
use rustc_type_ir::{
Variance,
inherent::{AdtDef, IntoKind, SliceLike},
};
use chalk_ir::Mutability;
use hir_def::signatures::StructFlags;
use hir_def::{AdtId, GenericDefId, GenericParamId, VariantId};
use std::fmt;
use std::ops::Not;
use stdx::never;
use triomphe::Arc;
pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Arc<[Variance]>> {
use crate::{
db::HirDatabase,
generics::{Generics, generics},
next_solver::{
Const, ConstKind, DbInterner, ExistentialPredicate, GenericArg, GenericArgs, Region,
RegionKind, Term, Ty, TyKind, VariancesOf,
},
};
pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> VariancesOf<'_> {
tracing::debug!("variances_of(def={:?})", def);
let interner = DbInterner::new_with(db, None, None);
match def {
GenericDefId::FunctionId(_) => (),
GenericDefId::AdtId(adt) => {
if let AdtId::StructId(id) = adt {
let flags = &db.struct_signature(id).flags;
if flags.contains(StructFlags::IS_UNSAFE_CELL) {
return Some(Arc::from_iter(vec![Variance::Invariant; 1]));
return VariancesOf::new_from_iter(interner, [Variance::Invariant]);
} else if flags.contains(StructFlags::IS_PHANTOM_DATA) {
return Some(Arc::from_iter(vec![Variance::Covariant; 1]));
return VariancesOf::new_from_iter(interner, [Variance::Covariant]);
}
}
}
_ => return None,
_ => return VariancesOf::new_from_iter(interner, []),
}
let generics = generics(db, def);
let count = generics.len();
if count == 0 {
return None;
return VariancesOf::new_from_iter(interner, []);
}
let mut variances =
Context { generics, variances: vec![Variance::Bivariant; count], db }.solve();
@ -69,7 +71,7 @@ pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Ar
}
}
variances.is_empty().not().then(|| Arc::from_iter(variances))
VariancesOf::new_from_iter(interner, variances)
}
// pub(crate) fn variances_of_cycle_fn(
@ -81,130 +83,36 @@ pub(crate) fn variances_of(db: &dyn HirDatabase, def: GenericDefId) -> Option<Ar
// salsa::CycleRecoveryAction::Iterate
// }
fn glb(v1: Variance, v2: Variance) -> Variance {
// Greatest lower bound of the variance lattice as defined in The Paper:
//
// *
// - +
// o
match (v1, v2) {
(Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
(Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
(Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
(Variance::Covariant, Variance::Covariant) => Variance::Covariant,
(Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
(x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
}
}
pub(crate) fn variances_of_cycle_initial(
db: &dyn HirDatabase,
def: GenericDefId,
) -> Option<Arc<[Variance]>> {
) -> VariancesOf<'_> {
let interner = DbInterner::new_with(db, None, None);
let generics = generics(db, def);
let count = generics.len();
if count == 0 {
return None;
}
// FIXME(next-solver): Returns `Invariance` and not `Bivariance` here, see the comment in the main query.
Some(Arc::from(vec![Variance::Invariant; count]))
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum Variance {
Covariant, // T<A> <: T<B> iff A <: B -- e.g., function return type
Invariant, // T<A> <: T<B> iff B == A -- e.g., type of mutable cell
Contravariant, // T<A> <: T<B> iff B <: A -- e.g., function param type
Bivariant, // T<A> <: T<B> -- e.g., unused type parameter
}
impl fmt::Display for Variance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Variance::Covariant => write!(f, "covariant"),
Variance::Invariant => write!(f, "invariant"),
Variance::Contravariant => write!(f, "contravariant"),
Variance::Bivariant => write!(f, "bivariant"),
}
}
}
impl Variance {
/// `a.xform(b)` combines the variance of a context with the
/// variance of a type with the following meaning. If we are in a
/// context with variance `a`, and we encounter a type argument in
/// a position with variance `b`, then `a.xform(b)` is the new
/// variance with which the argument appears.
///
/// Example 1:
/// ```ignore (illustrative)
/// *mut Vec<i32>
/// ```
/// Here, the "ambient" variance starts as covariant. `*mut T` is
/// invariant with respect to `T`, so the variance in which the
/// `Vec<i32>` appears is `Covariant.xform(Invariant)`, which
/// yields `Invariant`. Now, the type `Vec<T>` is covariant with
/// respect to its type argument `T`, and hence the variance of
/// the `i32` here is `Invariant.xform(Covariant)`, which results
/// (again) in `Invariant`.
///
/// Example 2:
/// ```ignore (illustrative)
/// fn(*const Vec<i32>, *mut Vec<i32)
/// ```
/// The ambient variance is covariant. A `fn` type is
/// contravariant with respect to its parameters, so the variance
/// within which both pointer types appear is
/// `Covariant.xform(Contravariant)`, or `Contravariant`. `*const
/// T` is covariant with respect to `T`, so the variance within
/// which the first `Vec<i32>` appears is
/// `Contravariant.xform(Covariant)` or `Contravariant`. The same
/// is true for its `i32` argument. In the `*mut T` case, the
/// variance of `Vec<i32>` is `Contravariant.xform(Invariant)`,
/// and hence the outermost type is `Invariant` with respect to
/// `Vec<i32>` (and its `i32` argument).
///
/// Source: Figure 1 of "Taming the Wildcards:
/// Combining Definition- and Use-Site Variance" published in PLDI'11.
fn xform(self, v: Variance) -> Variance {
match (self, v) {
// Figure 1, column 1.
(Variance::Covariant, Variance::Covariant) => Variance::Covariant,
(Variance::Covariant, Variance::Contravariant) => Variance::Contravariant,
(Variance::Covariant, Variance::Invariant) => Variance::Invariant,
(Variance::Covariant, Variance::Bivariant) => Variance::Bivariant,
// Figure 1, column 2.
(Variance::Contravariant, Variance::Covariant) => Variance::Contravariant,
(Variance::Contravariant, Variance::Contravariant) => Variance::Covariant,
(Variance::Contravariant, Variance::Invariant) => Variance::Invariant,
(Variance::Contravariant, Variance::Bivariant) => Variance::Bivariant,
// Figure 1, column 3.
(Variance::Invariant, _) => Variance::Invariant,
// Figure 1, column 4.
(Variance::Bivariant, _) => Variance::Bivariant,
}
}
fn glb(self, v: Variance) -> Variance {
// Greatest lower bound of the variance lattice as
// defined in The Paper:
//
// *
// - +
// o
match (self, v) {
(Variance::Invariant, _) | (_, Variance::Invariant) => Variance::Invariant,
(Variance::Covariant, Variance::Contravariant) => Variance::Invariant,
(Variance::Contravariant, Variance::Covariant) => Variance::Invariant,
(Variance::Covariant, Variance::Covariant) => Variance::Covariant,
(Variance::Contravariant, Variance::Contravariant) => Variance::Contravariant,
(x, Variance::Bivariant) | (Variance::Bivariant, x) => x,
}
}
pub fn invariant(self) -> Self {
self.xform(Variance::Invariant)
}
pub fn covariant(self) -> Self {
self.xform(Variance::Covariant)
}
pub fn contravariant(self) -> Self {
self.xform(Variance::Contravariant)
}
VariancesOf::new_from_iter(interner, std::iter::repeat_n(Variance::Invariant, count))
}
struct Context<'db> {
@ -213,17 +121,16 @@ struct Context<'db> {
variances: Vec<Variance>,
}
impl Context<'_> {
impl<'db> Context<'db> {
fn solve(mut self) -> Vec<Variance> {
tracing::debug!("solve(generics={:?})", self.generics);
match self.generics.def() {
GenericDefId::AdtId(adt) => {
let db = self.db;
let mut add_constraints_from_variant = |variant| {
let subst = self.generics.placeholder_subst(db);
for (_, field) in db.field_types(variant).iter() {
for (_, field) in db.field_types_ns(variant).iter() {
self.add_constraints_from_ty(
&field.clone().substitute(Interner, &subst),
field.instantiate_identity(),
Variance::Covariant,
);
}
@ -239,16 +146,9 @@ impl Context<'_> {
}
}
GenericDefId::FunctionId(f) => {
let subst = self.generics.placeholder_subst(self.db);
let interner = DbInterner::new_with(self.db, None, None);
let args: crate::next_solver::GenericArgs<'_> = subst.to_nextsolver(interner);
let sig = self
.db
.callable_item_signature(f.into())
.instantiate(interner, args)
.skip_binder()
.to_chalk(interner);
self.add_constraints_from_sig(sig.params_and_return.iter(), Variance::Covariant);
let sig =
self.db.callable_item_signature(f.into()).instantiate_identity().skip_binder();
self.add_constraints_from_sig(sig.inputs_and_output.iter(), Variance::Covariant);
}
_ => {}
}
@ -276,122 +176,102 @@ impl Context<'_> {
/// Adds constraints appropriate for an instance of `ty` appearing
/// in a context with the generics defined in `generics` and
/// ambient variance `variance`
fn add_constraints_from_ty(&mut self, ty: &Ty, variance: Variance) {
fn add_constraints_from_ty(&mut self, ty: Ty<'db>, variance: Variance) {
tracing::debug!("add_constraints_from_ty(ty={:?}, variance={:?})", ty, variance);
match ty.kind(Interner) {
TyKind::Scalar(_) | TyKind::Never | TyKind::Str | TyKind::Foreign(..) => {
match ty.kind() {
TyKind::Int(_)
| TyKind::Uint(_)
| TyKind::Float(_)
| TyKind::Char
| TyKind::Bool
| TyKind::Never
| TyKind::Str
| TyKind::Foreign(..) => {
// leaf type -- noop
}
TyKind::FnDef(..) | TyKind::Coroutine(..) | TyKind::Closure(..) => {
TyKind::FnDef(..)
| TyKind::Coroutine(..)
| TyKind::CoroutineClosure(..)
| TyKind::Closure(..) => {
never!("Unexpected unnameable type in variance computation: {:?}", ty);
}
TyKind::Ref(mutbl, lifetime, ty) => {
TyKind::Ref(lifetime, ty, mutbl) => {
self.add_constraints_from_region(lifetime, variance);
self.add_constraints_from_mt(ty, *mutbl, variance);
self.add_constraints_from_mt(ty, mutbl, variance);
}
TyKind::Array(typ, len) => {
self.add_constraints_from_const(len, variance);
self.add_constraints_from_const(len);
self.add_constraints_from_ty(typ, variance);
}
TyKind::Slice(typ) => {
self.add_constraints_from_ty(typ, variance);
}
TyKind::Raw(mutbl, ty) => {
self.add_constraints_from_mt(ty, *mutbl, variance);
TyKind::RawPtr(ty, mutbl) => {
self.add_constraints_from_mt(ty, mutbl, variance);
}
TyKind::Tuple(_, subtys) => {
for subty in subtys.type_parameters(Interner) {
self.add_constraints_from_ty(&subty, variance);
TyKind::Tuple(subtys) => {
for subty in subtys {
self.add_constraints_from_ty(subty, variance);
}
}
TyKind::Adt(def, args) => {
self.add_constraints_from_args(def.0.into(), args.as_slice(Interner), variance);
self.add_constraints_from_args(def.def_id().0.into(), args, variance);
}
TyKind::Alias(AliasTy::Opaque(opaque)) => {
self.add_constraints_from_invariant_args(
opaque.substitution.as_slice(Interner),
variance,
);
TyKind::Alias(_, alias) => {
// FIXME: Probably not correct wrt. opaques.
self.add_constraints_from_invariant_args(alias.args);
}
TyKind::Alias(AliasTy::Projection(proj)) => {
self.add_constraints_from_invariant_args(
proj.substitution.as_slice(Interner),
variance,
);
}
// FIXME: check this
TyKind::AssociatedType(_, subst) => {
self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
}
// FIXME: check this
TyKind::OpaqueType(_, subst) => {
self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
}
TyKind::Dyn(it) => {
TyKind::Dynamic(bounds, region) => {
// The type `dyn Trait<T> +'a` is covariant w/r/t `'a`:
self.add_constraints_from_region(&it.lifetime, variance);
self.add_constraints_from_region(region, variance);
if let Some(trait_ref) = it.principal() {
// Trait are always invariant so we can take advantage of that.
self.add_constraints_from_invariant_args(
trait_ref
.map(|it| it.map(|it| it.substitution.clone()))
.substitute(
Interner,
&[GenericArg::new(
Interner,
chalk_ir::GenericArgData::Ty(TyKind::Error.intern(Interner)),
)],
)
.skip_binders()
.as_slice(Interner),
variance,
);
for bound in bounds {
match bound.skip_binder() {
ExistentialPredicate::Trait(trait_ref) => {
self.add_constraints_from_invariant_args(trait_ref.args)
}
ExistentialPredicate::Projection(projection) => {
self.add_constraints_from_invariant_args(projection.args);
match projection.term {
Term::Ty(ty) => {
self.add_constraints_from_ty(ty, Variance::Invariant)
}
Term::Const(konst) => self.add_constraints_from_const(konst),
}
}
ExistentialPredicate::AutoTrait(_) => {}
}
}
// FIXME
// for projection in data.projection_bounds() {
// match projection.skip_binder().term.unpack() {
// TyKind::TermKind::Ty(ty) => {
// self.add_constraints_from_ty( ty, self.invariant);
// }
// TyKind::TermKind::Const(c) => {
// self.add_constraints_from_const( c, self.invariant)
// }
// }
// }
}
// Chalk has no params, so use placeholders for now?
TyKind::Placeholder(index) => {
let idx = crate::from_placeholder_idx(self.db, *index).0;
let index = self.generics.type_or_const_param_idx(idx).unwrap();
self.constrain(index, variance);
TyKind::Param(param) => self.constrain(param.index as usize, variance),
TyKind::FnPtr(sig, _) => {
self.add_constraints_from_sig(sig.skip_binder().inputs_and_output.iter(), variance);
}
TyKind::Function(f) => {
self.add_constraints_from_sig(
f.substitution.0.iter(Interner).filter_map(move |p| p.ty(Interner)),
variance,
);
}
TyKind::Error => {
TyKind::Error(_) => {
// we encounter this when walking the trait references for object
// types, where we use Error as the Self type
}
TyKind::CoroutineWitness(..) | TyKind::BoundVar(..) | TyKind::InferenceVar(..) => {
TyKind::Bound(..) => {}
TyKind::CoroutineWitness(..)
| TyKind::Placeholder(..)
| TyKind::Infer(..)
| TyKind::UnsafeBinder(..)
| TyKind::Pat(..) => {
never!("unexpected type encountered in variance inference: {:?}", ty)
}
}
}
fn add_constraints_from_invariant_args(&mut self, args: &[GenericArg], variance: Variance) {
let variance_i = variance.invariant();
for k in args {
match k.data(Interner) {
GenericArgData::Lifetime(lt) => self.add_constraints_from_region(lt, variance_i),
GenericArgData::Ty(ty) => self.add_constraints_from_ty(ty, variance_i),
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance_i),
fn add_constraints_from_invariant_args(&mut self, args: GenericArgs<'db>) {
for k in args.iter() {
match k {
GenericArg::Lifetime(lt) => {
self.add_constraints_from_region(lt, Variance::Invariant)
}
GenericArg::Ty(ty) => self.add_constraints_from_ty(ty, Variance::Invariant),
GenericArg::Const(val) => self.add_constraints_from_const(val),
}
}
}
@ -401,51 +281,40 @@ impl Context<'_> {
fn add_constraints_from_args(
&mut self,
def_id: GenericDefId,
args: &[GenericArg],
args: GenericArgs<'db>,
variance: Variance,
) {
// We don't record `inferred_starts` entries for empty generics.
if args.is_empty() {
return;
}
let Some(variances) = self.db.variances_of(def_id) else {
return;
};
let variances = self.db.variances_of(def_id);
for (i, k) in args.iter().enumerate() {
match k.data(Interner) {
GenericArgData::Lifetime(lt) => {
self.add_constraints_from_region(lt, variance.xform(variances[i]))
}
GenericArgData::Ty(ty) => {
self.add_constraints_from_ty(ty, variance.xform(variances[i]))
}
GenericArgData::Const(val) => self.add_constraints_from_const(val, variance),
for (k, v) in args.iter().zip(variances) {
match k {
GenericArg::Lifetime(lt) => self.add_constraints_from_region(lt, variance.xform(v)),
GenericArg::Ty(ty) => self.add_constraints_from_ty(ty, variance.xform(v)),
GenericArg::Const(val) => self.add_constraints_from_const(val),
}
}
}
/// Adds constraints appropriate for a const expression `val`
/// in a context with ambient variance `variance`
fn add_constraints_from_const(&mut self, c: &Const, variance: Variance) {
match &c.data(Interner).value {
chalk_ir::ConstValue::Concrete(c) => {
if let ConstScalar::UnevaluatedConst(_, subst) = &c.interned {
self.add_constraints_from_invariant_args(subst.as_slice(Interner), variance);
}
}
fn add_constraints_from_const(&mut self, c: Const<'db>) {
match c.kind() {
ConstKind::Unevaluated(c) => self.add_constraints_from_invariant_args(c.args),
_ => {}
}
}
/// Adds constraints appropriate for a function with signature
/// `sig` appearing in a context with ambient variance `variance`
fn add_constraints_from_sig<'a>(
fn add_constraints_from_sig(
&mut self,
mut sig_tys: impl DoubleEndedIterator<Item = &'a Ty>,
mut sig_tys: impl DoubleEndedIterator<Item = Ty<'db>>,
variance: Variance,
) {
let contra = variance.contravariant();
let contra = variance.xform(Variance::Contravariant);
let Some(output) = sig_tys.next_back() else {
return never!("function signature has no return type");
};
@ -457,27 +326,26 @@ impl Context<'_> {
/// Adds constraints appropriate for a region appearing in a
/// context with ambient variance `variance`
fn add_constraints_from_region(&mut self, region: &Lifetime, variance: Variance) {
fn add_constraints_from_region(&mut self, region: Region<'db>, variance: Variance) {
tracing::debug!(
"add_constraints_from_region(region={:?}, variance={:?})",
region,
variance
);
match region.data(Interner) {
LifetimeData::Placeholder(index) => {
let idx = crate::lt_from_placeholder_idx(self.db, *index).0;
let inferred = self.generics.lifetime_idx(idx).unwrap();
self.constrain(inferred, variance);
}
LifetimeData::Static => {}
LifetimeData::BoundVar(..) => {
match region.kind() {
RegionKind::ReEarlyParam(param) => self.constrain(param.index as usize, variance),
RegionKind::ReStatic => {}
RegionKind::ReBound(..) => {
// Either a higher-ranked region inside of a type or a
// late-bound function parameter.
//
// We do not compute constraints for either of these.
}
LifetimeData::Error => {}
LifetimeData::Phantom(..) | LifetimeData::InferenceVar(..) | LifetimeData::Erased => {
RegionKind::ReError(_) => {}
RegionKind::ReLateParam(..)
| RegionKind::RePlaceholder(..)
| RegionKind::ReVar(..)
| RegionKind::ReErased => {
// We don't expect to see anything but 'static or bound
// regions when visiting member types or method types.
never!(
@ -491,11 +359,11 @@ impl Context<'_> {
/// Adds constraints appropriate for a mutability-type pair
/// appearing in a context with ambient variance `variance`
fn add_constraints_from_mt(&mut self, ty: &Ty, mt: Mutability, variance: Variance) {
fn add_constraints_from_mt(&mut self, ty: Ty<'db>, mt: Mutability, variance: Variance) {
self.add_constraints_from_ty(
ty,
match mt {
Mutability::Mut => variance.invariant(),
Mutability::Mut => Variance::Invariant,
Mutability::Not => variance,
},
);
@ -508,7 +376,7 @@ impl Context<'_> {
self.variances[index],
variance
);
self.variances[index] = self.variances[index].glb(variance);
self.variances[index] = glb(self.variances[index], variance);
}
}
@ -519,6 +387,7 @@ mod tests {
AdtId, GenericDefId, ModuleDefId, hir::generics::GenericParamDataRef, src::HasSource,
};
use itertools::Itertools;
use rustc_type_ir::{Variance, inherent::SliceLike};
use stdx::format_to;
use syntax::{AstNode, ast::HasName};
use test_fixture::WithFixture;
@ -1037,26 +906,21 @@ struct FixedPoint<T, U, V>(&'static FixedPoint<(), T, U>, V);
let loc = it.lookup(&db);
loc.source(&db).value.name().unwrap()
}
GenericDefId::TraitId(it) => {
let loc = it.lookup(&db);
loc.source(&db).value.name().unwrap()
}
GenericDefId::TypeAliasId(it) => {
let loc = it.lookup(&db);
loc.source(&db).value.name().unwrap()
}
GenericDefId::ImplId(_) => return None,
GenericDefId::ConstId(_) => return None,
GenericDefId::StaticId(_) => return None,
GenericDefId::TraitId(_)
| GenericDefId::TypeAliasId(_)
| GenericDefId::ImplId(_)
| GenericDefId::ConstId(_)
| GenericDefId::StaticId(_) => return None,
},
))
})
.sorted_by_key(|(_, n)| n.syntax().text_range().start());
let mut res = String::new();
for (def, name) in defs {
let Some(variances) = db.variances_of(def) else {
let variances = db.variances_of(def);
if variances.is_empty() {
continue;
};
}
format_to!(
res,
"{name}[{}]\n",
@ -1072,10 +936,16 @@ struct FixedPoint<T, U, V>(&'static FixedPoint<(), T, U>, V);
&lifetime_param_data.name
}
})
.zip_eq(&*variances)
.zip_eq(variances)
.format_with(", ", |(name, var), f| f(&format_args!(
"{}: {var}",
name.as_str()
"{}: {}",
name.as_str(),
match var {
Variance::Covariant => "covariant",
Variance::Invariant => "invariant",
Variance::Contravariant => "contravariant",
Variance::Bivariant => "bivariant",
},
)))
);
}

View file

@ -36,6 +36,7 @@ pub mod term_search;
mod display;
use std::{
fmt,
mem::discriminant,
ops::{ControlFlow, Not},
};
@ -160,7 +161,7 @@ pub use {
// FIXME: Properly encapsulate mir
hir_ty::mir,
hir_ty::{
CastError, FnAbi, PointerCast, Variance, attach_db, attach_db_allow_change,
CastError, FnAbi, PointerCast, attach_db, attach_db_allow_change,
consteval::ConstEvalError,
diagnostics::UnsafetyReason,
display::{ClosureStyle, DisplayTarget, HirDisplay, HirDisplayError, HirWrite},
@ -4110,7 +4111,39 @@ impl GenericParam {
GenericParam::ConstParam(_) => return None,
GenericParam::LifetimeParam(it) => generics.lifetime_idx(it.id)?,
};
db.variances_of(parent)?.get(index).copied()
db.variances_of(parent).get(index).map(Into::into)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Variance {
Bivariant,
Covariant,
Contravariant,
Invariant,
}
impl From<rustc_type_ir::Variance> for Variance {
#[inline]
fn from(value: rustc_type_ir::Variance) -> Self {
match value {
rustc_type_ir::Variance::Covariant => Variance::Covariant,
rustc_type_ir::Variance::Invariant => Variance::Invariant,
rustc_type_ir::Variance::Contravariant => Variance::Contravariant,
rustc_type_ir::Variance::Bivariant => Variance::Bivariant,
}
}
}
impl fmt::Display for Variance {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let description = match self {
Variance::Bivariant => "bivariant",
Variance::Covariant => "covariant",
Variance::Contravariant => "contravariant",
Variance::Invariant => "invariant",
};
f.pad(description)
}
}