Auto merge of #13894 - lowr:patch/fallback-before-final-obligation-resolution, r=lnicola
Apply fallback before final obligation resolution Fixes #13249 Fixes #13518 We've been applying fallback to type variables independently even when there are some unresolved obligations that associate them. This PR applies fallback to unresolved scalar type variables before the final attempt of resolving obligations, which enables us to infer more. Unlike rustc, which has separate storages for each kind of type variables, we currently don't have a way to retrieve only integer/float type variables without folding/visiting every single type we've inferred. I've repurposed `TypeVariableData` as bitflags that also hold the kind of the type variable it's referring to so that we can "reconstruct" scalar type variables from their indices. This PR increases the number of ??ty for rust-analyzer repo not because we regress and fail to infer the existing code but because we fail to infer the new code. It seems we have problems inferring some functions bitflags produces.
This commit is contained in:
commit
7449f9fa10
5 changed files with 139 additions and 15 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
|
@ -557,6 +557,7 @@ version = "0.0.0"
|
|||
dependencies = [
|
||||
"arrayvec",
|
||||
"base-db",
|
||||
"bitflags",
|
||||
"chalk-derive",
|
||||
"chalk-ir",
|
||||
"chalk-recursive",
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ doctest = false
|
|||
cov-mark = "2.0.0-pre.1"
|
||||
itertools = "0.10.5"
|
||||
arrayvec = "0.7.2"
|
||||
bitflags = "1.3.2"
|
||||
smallvec = "1.10.0"
|
||||
ena = "0.14.0"
|
||||
tracing = "0.1.35"
|
||||
|
|
|
|||
|
|
@ -512,6 +512,8 @@ impl<'a> InferenceContext<'a> {
|
|||
fn resolve_all(self) -> InferenceResult {
|
||||
let InferenceContext { mut table, mut result, .. } = self;
|
||||
|
||||
table.fallback_if_possible();
|
||||
|
||||
// FIXME resolve obligations as well (use Guidance if necessary)
|
||||
table.resolve_obligations_as_possible();
|
||||
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
//! Unification and canonicalization logic.
|
||||
|
||||
use std::{fmt, mem, sync::Arc};
|
||||
use std::{fmt, iter, mem, sync::Arc};
|
||||
|
||||
use chalk_ir::{
|
||||
cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy,
|
||||
|
|
@ -128,9 +128,13 @@ pub(crate) fn unify(
|
|||
))
|
||||
}
|
||||
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub(crate) struct TypeVariableData {
|
||||
diverging: bool,
|
||||
bitflags::bitflags! {
|
||||
#[derive(Default)]
|
||||
pub(crate) struct TypeVariableFlags: u8 {
|
||||
const DIVERGING = 1 << 0;
|
||||
const INTEGER = 1 << 1;
|
||||
const FLOAT = 1 << 2;
|
||||
}
|
||||
}
|
||||
|
||||
type ChalkInferenceTable = chalk_solve::infer::InferenceTable<Interner>;
|
||||
|
|
@ -140,14 +144,14 @@ pub(crate) struct InferenceTable<'a> {
|
|||
pub(crate) db: &'a dyn HirDatabase,
|
||||
pub(crate) trait_env: Arc<TraitEnvironment>,
|
||||
var_unification_table: ChalkInferenceTable,
|
||||
type_variable_table: Vec<TypeVariableData>,
|
||||
type_variable_table: Vec<TypeVariableFlags>,
|
||||
pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>,
|
||||
}
|
||||
|
||||
pub(crate) struct InferenceTableSnapshot {
|
||||
var_table_snapshot: chalk_solve::infer::InferenceSnapshot<Interner>,
|
||||
pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>,
|
||||
type_variable_table_snapshot: Vec<TypeVariableData>,
|
||||
type_variable_table_snapshot: Vec<TypeVariableFlags>,
|
||||
}
|
||||
|
||||
impl<'a> InferenceTable<'a> {
|
||||
|
|
@ -169,19 +173,19 @@ impl<'a> InferenceTable<'a> {
|
|||
/// result.
|
||||
pub(super) fn propagate_diverging_flag(&mut self) {
|
||||
for i in 0..self.type_variable_table.len() {
|
||||
if !self.type_variable_table[i].diverging {
|
||||
if !self.type_variable_table[i].contains(TypeVariableFlags::DIVERGING) {
|
||||
continue;
|
||||
}
|
||||
let v = InferenceVar::from(i as u32);
|
||||
let root = self.var_unification_table.inference_var_root(v);
|
||||
if let Some(data) = self.type_variable_table.get_mut(root.index() as usize) {
|
||||
data.diverging = true;
|
||||
*data |= TypeVariableFlags::DIVERGING;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn set_diverging(&mut self, iv: InferenceVar, diverging: bool) {
|
||||
self.type_variable_table[iv.index() as usize].diverging = diverging;
|
||||
self.type_variable_table[iv.index() as usize].set(TypeVariableFlags::DIVERGING, diverging);
|
||||
}
|
||||
|
||||
fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty {
|
||||
|
|
@ -189,7 +193,7 @@ impl<'a> InferenceTable<'a> {
|
|||
_ if self
|
||||
.type_variable_table
|
||||
.get(iv.index() as usize)
|
||||
.map_or(false, |data| data.diverging) =>
|
||||
.map_or(false, |data| data.contains(TypeVariableFlags::DIVERGING)) =>
|
||||
{
|
||||
TyKind::Never
|
||||
}
|
||||
|
|
@ -247,10 +251,8 @@ impl<'a> InferenceTable<'a> {
|
|||
}
|
||||
|
||||
fn extend_type_variable_table(&mut self, to_index: usize) {
|
||||
self.type_variable_table.extend(
|
||||
(0..1 + to_index - self.type_variable_table.len())
|
||||
.map(|_| TypeVariableData { diverging: false }),
|
||||
);
|
||||
let count = to_index - self.type_variable_table.len() + 1;
|
||||
self.type_variable_table.extend(iter::repeat(TypeVariableFlags::default()).take(count));
|
||||
}
|
||||
|
||||
fn new_var(&mut self, kind: TyVariableKind, diverging: bool) -> Ty {
|
||||
|
|
@ -258,7 +260,15 @@ impl<'a> InferenceTable<'a> {
|
|||
// Chalk might have created some type variables for its own purposes that we don't know about...
|
||||
self.extend_type_variable_table(var.index() as usize);
|
||||
assert_eq!(var.index() as usize, self.type_variable_table.len() - 1);
|
||||
self.type_variable_table[var.index() as usize].diverging = diverging;
|
||||
let flags = self.type_variable_table.get_mut(var.index() as usize).unwrap();
|
||||
if diverging {
|
||||
*flags |= TypeVariableFlags::DIVERGING;
|
||||
}
|
||||
if matches!(kind, TyVariableKind::Integer) {
|
||||
*flags |= TypeVariableFlags::INTEGER;
|
||||
} else if matches!(kind, TyVariableKind::Float) {
|
||||
*flags |= TypeVariableFlags::FLOAT;
|
||||
}
|
||||
var.to_ty_with_kind(Interner, kind)
|
||||
}
|
||||
|
||||
|
|
@ -340,6 +350,51 @@ impl<'a> InferenceTable<'a> {
|
|||
self.resolve_with_fallback(t, &|_, _, d, _| d)
|
||||
}
|
||||
|
||||
/// Apply a fallback to unresolved scalar types. Integer type variables and float type
|
||||
/// variables are replaced with i32 and f64, respectively.
|
||||
///
|
||||
/// This method is only intended to be called just before returning inference results (i.e. in
|
||||
/// `InferenceContext::resolve_all()`).
|
||||
///
|
||||
/// FIXME: This method currently doesn't apply fallback to unconstrained general type variables
|
||||
/// whereas rustc replaces them with `()` or `!`.
|
||||
pub(super) fn fallback_if_possible(&mut self) {
|
||||
let int_fallback = TyKind::Scalar(Scalar::Int(IntTy::I32)).intern(Interner);
|
||||
let float_fallback = TyKind::Scalar(Scalar::Float(FloatTy::F64)).intern(Interner);
|
||||
|
||||
let scalar_vars: Vec<_> = self
|
||||
.type_variable_table
|
||||
.iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, flags)| {
|
||||
let kind = if flags.contains(TypeVariableFlags::INTEGER) {
|
||||
TyVariableKind::Integer
|
||||
} else if flags.contains(TypeVariableFlags::FLOAT) {
|
||||
TyVariableKind::Float
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// FIXME: This is not really the nicest way to get `InferenceVar`s. Can we get them
|
||||
// without directly constructing them from `index`?
|
||||
let var = InferenceVar::from(index as u32).to_ty(Interner, kind);
|
||||
Some(var)
|
||||
})
|
||||
.collect();
|
||||
|
||||
for var in scalar_vars {
|
||||
let maybe_resolved = self.resolve_ty_shallow(&var);
|
||||
if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) {
|
||||
let fallback = match kind {
|
||||
TyVariableKind::Integer => &int_fallback,
|
||||
TyVariableKind::Float => &float_fallback,
|
||||
TyVariableKind::General => unreachable!(),
|
||||
};
|
||||
self.unify(&var, fallback);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that.
|
||||
pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool {
|
||||
let result = match self.try_unify(ty1, ty2) {
|
||||
|
|
|
|||
|
|
@ -4100,3 +4100,68 @@ where
|
|||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bin_op_with_scalar_fallback() {
|
||||
// Extra impls are significant so that chalk doesn't give us definite guidances.
|
||||
check_types(
|
||||
r#"
|
||||
//- minicore: add
|
||||
use core::ops::Add;
|
||||
|
||||
struct Vec2<T>(T, T);
|
||||
|
||||
impl Add for Vec2<i32> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output { loop {} }
|
||||
}
|
||||
impl Add for Vec2<u32> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output { loop {} }
|
||||
}
|
||||
impl Add for Vec2<f32> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output { loop {} }
|
||||
}
|
||||
impl Add for Vec2<f64> {
|
||||
type Output = Self;
|
||||
fn add(self, rhs: Self) -> Self::Output { loop {} }
|
||||
}
|
||||
|
||||
fn test() {
|
||||
let a = Vec2(1, 2);
|
||||
let b = Vec2(3, 4);
|
||||
let c = a + b;
|
||||
//^ Vec2<i32>
|
||||
let a = Vec2(1., 2.);
|
||||
let b = Vec2(3., 4.);
|
||||
let c = a + b;
|
||||
//^ Vec2<f64>
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn trait_method_with_scalar_fallback() {
|
||||
check_types(
|
||||
r#"
|
||||
trait Trait {
|
||||
type Output;
|
||||
fn foo(&self) -> Self::Output;
|
||||
}
|
||||
impl<T> Trait for T {
|
||||
type Output = T;
|
||||
fn foo(&self) -> Self::Output { loop {} }
|
||||
}
|
||||
fn test() {
|
||||
let a = 42;
|
||||
let b = a.foo();
|
||||
//^ i32
|
||||
let a = 3.14;
|
||||
let b = a.foo();
|
||||
//^ f64
|
||||
}
|
||||
"#,
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue