diff options
| -rw-r--r-- | Cargo.lock | 1 | ||||
| -rw-r--r-- | crates/hir-ty/Cargo.toml | 1 | ||||
| -rw-r--r-- | crates/hir-ty/src/infer.rs | 2 | ||||
| -rw-r--r-- | crates/hir-ty/src/infer/unify.rs | 85 | ||||
| -rw-r--r-- | crates/hir-ty/src/tests/traits.rs | 65 |
5 files changed, 139 insertions, 15 deletions
diff --git a/Cargo.lock b/Cargo.lock index 4c255b12b7f..59cd66756cd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -557,6 +557,7 @@ version = "0.0.0" dependencies = [ "arrayvec", "base-db", + "bitflags", "chalk-derive", "chalk-ir", "chalk-recursive", diff --git a/crates/hir-ty/Cargo.toml b/crates/hir-ty/Cargo.toml index c72199c37fe..ae837ac6dce 100644 --- a/crates/hir-ty/Cargo.toml +++ b/crates/hir-ty/Cargo.toml @@ -13,6 +13,7 @@ doctest = false cov-mark = "2.0.0-pre.1" itertools = "0.10.5" arrayvec = "0.7.2" +bitflags = "1.3.2" smallvec = "1.10.0" ena = "0.14.0" tracing = "0.1.35" diff --git a/crates/hir-ty/src/infer.rs b/crates/hir-ty/src/infer.rs index 7b54886d53f..6b59f1c20da 100644 --- a/crates/hir-ty/src/infer.rs +++ b/crates/hir-ty/src/infer.rs @@ -512,6 +512,8 @@ impl<'a> InferenceContext<'a> { fn resolve_all(self) -> InferenceResult { let InferenceContext { mut table, mut result, .. } = self; + table.fallback_if_possible(); + // FIXME resolve obligations as well (use Guidance if necessary) table.resolve_obligations_as_possible(); diff --git a/crates/hir-ty/src/infer/unify.rs b/crates/hir-ty/src/infer/unify.rs index 12f45f00f9c..e7ddd1591fe 100644 --- a/crates/hir-ty/src/infer/unify.rs +++ b/crates/hir-ty/src/infer/unify.rs @@ -1,6 +1,6 @@ //! Unification and canonicalization logic. -use std::{fmt, mem, sync::Arc}; +use std::{fmt, iter, mem, sync::Arc}; use chalk_ir::{ cast::Cast, fold::TypeFoldable, interner::HasInterner, zip::Zip, CanonicalVarKind, FloatTy, @@ -128,9 +128,13 @@ pub(crate) fn unify( )) } -#[derive(Copy, Clone, Debug)] -pub(crate) struct TypeVariableData { - diverging: bool, +bitflags::bitflags! { + #[derive(Default)] + pub(crate) struct TypeVariableFlags: u8 { + const DIVERGING = 1 << 0; + const INTEGER = 1 << 1; + const FLOAT = 1 << 2; + } } type ChalkInferenceTable = chalk_solve::infer::InferenceTable<Interner>; @@ -140,14 +144,14 @@ pub(crate) struct InferenceTable<'a> { pub(crate) db: &'a dyn HirDatabase, pub(crate) trait_env: Arc<TraitEnvironment>, var_unification_table: ChalkInferenceTable, - type_variable_table: Vec<TypeVariableData>, + type_variable_table: Vec<TypeVariableFlags>, pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>, } pub(crate) struct InferenceTableSnapshot { var_table_snapshot: chalk_solve::infer::InferenceSnapshot<Interner>, pending_obligations: Vec<Canonicalized<InEnvironment<Goal>>>, - type_variable_table_snapshot: Vec<TypeVariableData>, + type_variable_table_snapshot: Vec<TypeVariableFlags>, } impl<'a> InferenceTable<'a> { @@ -169,19 +173,19 @@ impl<'a> InferenceTable<'a> { /// result. pub(super) fn propagate_diverging_flag(&mut self) { for i in 0..self.type_variable_table.len() { - if !self.type_variable_table[i].diverging { + if !self.type_variable_table[i].contains(TypeVariableFlags::DIVERGING) { continue; } let v = InferenceVar::from(i as u32); let root = self.var_unification_table.inference_var_root(v); if let Some(data) = self.type_variable_table.get_mut(root.index() as usize) { - data.diverging = true; + *data |= TypeVariableFlags::DIVERGING; } } } pub(super) fn set_diverging(&mut self, iv: InferenceVar, diverging: bool) { - self.type_variable_table[iv.index() as usize].diverging = diverging; + self.type_variable_table[iv.index() as usize].set(TypeVariableFlags::DIVERGING, diverging); } fn fallback_value(&self, iv: InferenceVar, kind: TyVariableKind) -> Ty { @@ -189,7 +193,7 @@ impl<'a> InferenceTable<'a> { _ if self .type_variable_table .get(iv.index() as usize) - .map_or(false, |data| data.diverging) => + .map_or(false, |data| data.contains(TypeVariableFlags::DIVERGING)) => { TyKind::Never } @@ -247,10 +251,8 @@ impl<'a> InferenceTable<'a> { } fn extend_type_variable_table(&mut self, to_index: usize) { - self.type_variable_table.extend( - (0..1 + to_index - self.type_variable_table.len()) - .map(|_| TypeVariableData { diverging: false }), - ); + let count = to_index - self.type_variable_table.len() + 1; + self.type_variable_table.extend(iter::repeat(TypeVariableFlags::default()).take(count)); } fn new_var(&mut self, kind: TyVariableKind, diverging: bool) -> Ty { @@ -258,7 +260,15 @@ impl<'a> InferenceTable<'a> { // Chalk might have created some type variables for its own purposes that we don't know about... self.extend_type_variable_table(var.index() as usize); assert_eq!(var.index() as usize, self.type_variable_table.len() - 1); - self.type_variable_table[var.index() as usize].diverging = diverging; + let flags = self.type_variable_table.get_mut(var.index() as usize).unwrap(); + if diverging { + *flags |= TypeVariableFlags::DIVERGING; + } + if matches!(kind, TyVariableKind::Integer) { + *flags |= TypeVariableFlags::INTEGER; + } else if matches!(kind, TyVariableKind::Float) { + *flags |= TypeVariableFlags::FLOAT; + } var.to_ty_with_kind(Interner, kind) } @@ -340,6 +350,51 @@ impl<'a> InferenceTable<'a> { self.resolve_with_fallback(t, &|_, _, d, _| d) } + /// Apply a fallback to unresolved scalar types. Integer type variables and float type + /// variables are replaced with i32 and f64, respectively. + /// + /// This method is only intended to be called just before returning inference results (i.e. in + /// `InferenceContext::resolve_all()`). + /// + /// FIXME: This method currently doesn't apply fallback to unconstrained general type variables + /// whereas rustc replaces them with `()` or `!`. + pub(super) fn fallback_if_possible(&mut self) { + let int_fallback = TyKind::Scalar(Scalar::Int(IntTy::I32)).intern(Interner); + let float_fallback = TyKind::Scalar(Scalar::Float(FloatTy::F64)).intern(Interner); + + let scalar_vars: Vec<_> = self + .type_variable_table + .iter() + .enumerate() + .filter_map(|(index, flags)| { + let kind = if flags.contains(TypeVariableFlags::INTEGER) { + TyVariableKind::Integer + } else if flags.contains(TypeVariableFlags::FLOAT) { + TyVariableKind::Float + } else { + return None; + }; + + // FIXME: This is not really the nicest way to get `InferenceVar`s. Can we get them + // without directly constructing them from `index`? + let var = InferenceVar::from(index as u32).to_ty(Interner, kind); + Some(var) + }) + .collect(); + + for var in scalar_vars { + let maybe_resolved = self.resolve_ty_shallow(&var); + if let TyKind::InferenceVar(_, kind) = maybe_resolved.kind(Interner) { + let fallback = match kind { + TyVariableKind::Integer => &int_fallback, + TyVariableKind::Float => &float_fallback, + TyVariableKind::General => unreachable!(), + }; + self.unify(&var, fallback); + } + } + } + /// Unify two relatable values (e.g. `Ty`) and register new trait goals that arise from that. pub(crate) fn unify<T: ?Sized + Zip<Interner>>(&mut self, ty1: &T, ty2: &T) -> bool { let result = match self.try_unify(ty1, ty2) { diff --git a/crates/hir-ty/src/tests/traits.rs b/crates/hir-ty/src/tests/traits.rs index a9fd01ee011..d01fe063285 100644 --- a/crates/hir-ty/src/tests/traits.rs +++ b/crates/hir-ty/src/tests/traits.rs @@ -4100,3 +4100,68 @@ where "#, ); } + +#[test] +fn bin_op_with_scalar_fallback() { + // Extra impls are significant so that chalk doesn't give us definite guidances. + check_types( + r#" +//- minicore: add +use core::ops::Add; + +struct Vec2<T>(T, T); + +impl Add for Vec2<i32> { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { loop {} } +} +impl Add for Vec2<u32> { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { loop {} } +} +impl Add for Vec2<f32> { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { loop {} } +} +impl Add for Vec2<f64> { + type Output = Self; + fn add(self, rhs: Self) -> Self::Output { loop {} } +} + +fn test() { + let a = Vec2(1, 2); + let b = Vec2(3, 4); + let c = a + b; + //^ Vec2<i32> + let a = Vec2(1., 2.); + let b = Vec2(3., 4.); + let c = a + b; + //^ Vec2<f64> +} +"#, + ); +} + +#[test] +fn trait_method_with_scalar_fallback() { + check_types( + r#" +trait Trait { + type Output; + fn foo(&self) -> Self::Output; +} +impl<T> Trait for T { + type Output = T; + fn foo(&self) -> Self::Output { loop {} } +} +fn test() { + let a = 42; + let b = a.foo(); + //^ i32 + let a = 3.14; + let b = a.foo(); + //^ f64 +} +"#, + ); +} |
