Skip to content

Commit 09cf067

Browse files
Record inferred type for type placeholders
This adds a type_of_type_placeholder arena to InferenceResult to record which type a given type placeholder gets inferred to.
1 parent 59dafb3 commit 09cf067

File tree

5 files changed

+128
-7
lines changed

5 files changed

+128
-7
lines changed

src/tools/rust-analyzer/crates/hir-def/src/hir/type_ref.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -195,12 +195,16 @@ impl TypeRef {
195195
TypeRef::Tuple(ThinVec::new())
196196
}
197197

198-
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(&TypeRef)) {
198+
pub fn walk(this: TypeRefId, map: &ExpressionStore, f: &mut impl FnMut(TypeRefId, &TypeRef)) {
199199
go(this, f, map);
200200

201-
fn go(type_ref: TypeRefId, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
202-
let type_ref = &map[type_ref];
203-
f(type_ref);
201+
fn go(
202+
type_ref_id: TypeRefId,
203+
f: &mut impl FnMut(TypeRefId, &TypeRef),
204+
map: &ExpressionStore,
205+
) {
206+
let type_ref = &map[type_ref_id];
207+
f(type_ref_id, type_ref);
204208
match type_ref {
205209
TypeRef::Fn(fn_) => {
206210
fn_.params.iter().for_each(|&(_, param_type)| go(param_type, f, map))
@@ -224,7 +228,7 @@ impl TypeRef {
224228
};
225229
}
226230

227-
fn go_path(path: &Path, f: &mut impl FnMut(&TypeRef), map: &ExpressionStore) {
231+
fn go_path(path: &Path, f: &mut impl FnMut(TypeRefId, &TypeRef), map: &ExpressionStore) {
228232
if let Some(type_ref) = path.type_anchor() {
229233
go(type_ref, f, map);
230234
}

src/tools/rust-analyzer/crates/hir-ty/src/infer.rs

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ use hir_def::{
4141
layout::Integer,
4242
resolver::{HasResolver, ResolveValueResult, Resolver, TypeNs, ValueNs},
4343
signatures::{ConstSignature, StaticSignature},
44-
type_ref::{ConstRef, LifetimeRefId, TypeRefId},
44+
type_ref::{ConstRef, LifetimeRefId, TypeRef, TypeRefId},
4545
};
4646
use hir_expand::{mod_path::ModPath, name::Name};
4747
use indexmap::IndexSet;
@@ -60,6 +60,7 @@ use triomphe::Arc;
6060

6161
use crate::{
6262
ImplTraitId, IncorrectGenericsLenKind, PathLoweringDiagnostic, TargetFeatures,
63+
collect_type_inference_vars,
6364
db::{HirDatabase, InternedClosureId, InternedOpaqueTyId},
6465
infer::{
6566
coerce::{CoerceMany, DynamicCoerceMany},
@@ -497,6 +498,7 @@ pub struct InferenceResult<'db> {
497498
/// unresolved or missing subpatterns or subpatterns of mismatched types.
498499
pub(crate) type_of_pat: ArenaMap<PatId, Ty<'db>>,
499500
pub(crate) type_of_binding: ArenaMap<BindingId, Ty<'db>>,
501+
pub(crate) type_of_type_placeholder: ArenaMap<TypeRefId, Ty<'db>>,
500502
pub(crate) type_of_opaque: FxHashMap<InternedOpaqueTyId, Ty<'db>>,
501503
pub(crate) type_mismatches: FxHashMap<ExprOrPatId, TypeMismatch<'db>>,
502504
/// Whether there are any type-mismatching errors in the result.
@@ -542,6 +544,7 @@ impl<'db> InferenceResult<'db> {
542544
type_of_expr: Default::default(),
543545
type_of_pat: Default::default(),
544546
type_of_binding: Default::default(),
547+
type_of_type_placeholder: Default::default(),
545548
type_of_opaque: Default::default(),
546549
type_mismatches: Default::default(),
547550
has_errors: Default::default(),
@@ -606,6 +609,9 @@ impl<'db> InferenceResult<'db> {
606609
_ => None,
607610
})
608611
}
612+
pub fn placeholder_types(&self) -> impl Iterator<Item = (TypeRefId, &Ty<'db>)> {
613+
self.type_of_type_placeholder.iter()
614+
}
609615
pub fn closure_info(&self, closure: InternedClosureId) -> &(Vec<CapturedItem<'db>>, FnTrait) {
610616
self.closure_info.get(&closure).unwrap()
611617
}
@@ -1014,6 +1020,7 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
10141020
type_of_expr,
10151021
type_of_pat,
10161022
type_of_binding,
1023+
type_of_type_placeholder,
10171024
type_of_opaque,
10181025
type_mismatches,
10191026
has_errors,
@@ -1046,6 +1053,11 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
10461053
*has_errors = *has_errors || ty.references_non_lt_error();
10471054
}
10481055
type_of_binding.shrink_to_fit();
1056+
for ty in type_of_type_placeholder.values_mut() {
1057+
*ty = table.resolve_completely(*ty);
1058+
*has_errors = *has_errors || ty.references_non_lt_error();
1059+
}
1060+
type_of_type_placeholder.shrink_to_fit();
10491061
type_of_opaque.shrink_to_fit();
10501062

10511063
*has_errors |= !type_mismatches.is_empty();
@@ -1285,6 +1297,10 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
12851297
self.result.type_of_pat.insert(pat, ty);
12861298
}
12871299

1300+
fn write_type_placeholder_ty(&mut self, type_ref: TypeRefId, ty: Ty<'db>) {
1301+
self.result.type_of_type_placeholder.insert(type_ref, ty);
1302+
}
1303+
12881304
fn write_binding_ty(&mut self, id: BindingId, ty: Ty<'db>) {
12891305
self.result.type_of_binding.insert(id, ty);
12901306
}
@@ -1333,7 +1349,27 @@ impl<'body, 'db> InferenceContext<'body, 'db> {
13331349
) -> Ty<'db> {
13341350
let ty = self
13351351
.with_ty_lowering(store, type_source, lifetime_elision, |ctx| ctx.lower_ty(type_ref));
1336-
self.process_user_written_ty(ty)
1352+
let ty = self.process_user_written_ty(ty);
1353+
1354+
// Record the association from placeholders' TypeRefId to type variables.
1355+
// We only record them if their number matches. This assumes TypeRef::walk and TypeVisitable process the items in the same order.
1356+
let type_variables = collect_type_inference_vars(&ty);
1357+
let mut placeholder_ids = vec![];
1358+
TypeRef::walk(type_ref, store, &mut |type_ref_id, type_ref| {
1359+
if matches!(type_ref, TypeRef::Placeholder) {
1360+
placeholder_ids.push(type_ref_id);
1361+
}
1362+
});
1363+
1364+
if placeholder_ids.len() == type_variables.len() {
1365+
for (placeholder_id, type_variable) in
1366+
placeholder_ids.into_iter().zip(type_variables.into_iter())
1367+
{
1368+
self.write_type_placeholder_ty(placeholder_id, type_variable);
1369+
}
1370+
}
1371+
1372+
ty
13371373
}
13381374

13391375
pub(crate) fn make_body_ty(&mut self, type_ref: TypeRefId) -> Ty<'db> {

src/tools/rust-analyzer/crates/hir-ty/src/lib.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -569,6 +569,35 @@ where
569569
Vec::from_iter(collector.params)
570570
}
571571

572+
struct TypeInferenceVarCollector<'db> {
573+
type_inference_vars: Vec<Ty<'db>>,
574+
}
575+
576+
impl<'db> rustc_type_ir::TypeVisitor<DbInterner<'db>> for TypeInferenceVarCollector<'db> {
577+
type Result = ();
578+
579+
fn visit_ty(&mut self, ty: Ty<'db>) -> Self::Result {
580+
use crate::rustc_type_ir::Flags;
581+
if ty.is_ty_var() {
582+
self.type_inference_vars.push(ty);
583+
} else if ty.flags().intersects(rustc_type_ir::TypeFlags::HAS_TY_INFER) {
584+
ty.super_visit_with(self);
585+
} else {
586+
// Fast path: don't visit inner types (e.g. generic arguments) when `flags` indicate
587+
// that there are no placeholders.
588+
}
589+
}
590+
}
591+
592+
pub fn collect_type_inference_vars<'db, T>(value: &T) -> Vec<Ty<'db>>
593+
where
594+
T: ?Sized + rustc_type_ir::TypeVisitable<DbInterner<'db>>,
595+
{
596+
let mut collector = TypeInferenceVarCollector { type_inference_vars: vec![] };
597+
value.visit_with(&mut collector);
598+
collector.type_inference_vars
599+
}
600+
572601
pub fn known_const_to_ast<'db>(
573602
konst: Const<'db>,
574603
db: &'db dyn HirDatabase,

src/tools/rust-analyzer/crates/hir-ty/src/tests.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ use hir_def::{
2323
item_scope::ItemScope,
2424
nameres::DefMap,
2525
src::HasSource,
26+
type_ref::TypeRefId,
2627
};
2728
use hir_expand::{FileRange, InFile, db::ExpandDatabase};
2829
use itertools::Itertools;
@@ -219,6 +220,24 @@ fn check_impl(
219220
}
220221
}
221222
}
223+
224+
for (type_ref, ty) in inference_result.placeholder_types() {
225+
let node = match type_node(&body_source_map, type_ref, &db) {
226+
Some(value) => value,
227+
None => continue,
228+
};
229+
let range = node.as_ref().original_file_range_rooted(&db);
230+
if let Some(expected) = types.remove(&range) {
231+
let actual = salsa::attach(&db, || {
232+
if display_source {
233+
ty.display_source_code(&db, def.module(&db), true).unwrap()
234+
} else {
235+
ty.display_test(&db, display_target).to_string()
236+
}
237+
});
238+
assert_eq!(actual, expected, "type annotation differs at {:#?}", range.range);
239+
}
240+
}
222241
}
223242

224243
let mut buf = String::new();
@@ -275,6 +294,20 @@ fn pat_node(
275294
})
276295
}
277296

297+
fn type_node(
298+
body_source_map: &BodySourceMap,
299+
type_ref: TypeRefId,
300+
db: &TestDB,
301+
) -> Option<InFile<SyntaxNode>> {
302+
Some(match body_source_map.type_syntax(type_ref) {
303+
Ok(sp) => {
304+
let root = db.parse_or_expand(sp.file_id);
305+
sp.map(|ptr| ptr.to_node(&root).syntax().clone())
306+
}
307+
Err(SyntheticSyntax) => return None,
308+
})
309+
}
310+
278311
fn infer(#[rust_analyzer::rust_fixture] ra_fixture: &str) -> String {
279312
infer_with_mismatches(ra_fixture, false)
280313
}

src/tools/rust-analyzer/crates/hir-ty/src/tests/display_source_code.rs

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -246,3 +246,22 @@ fn test() {
246246
"#,
247247
);
248248
}
249+
250+
#[test]
251+
fn type_placeholder_type() {
252+
check_types_source_code(
253+
r#"
254+
struct S<T>(T);
255+
fn test() {
256+
let f: S<_> = S(3);
257+
//^ i32
258+
let f: [_; _] = [4_u32, 5, 6];
259+
//^ u32
260+
let f: (_, _, _) = (1_u32, 1_i32, false);
261+
//^ u32
262+
//^ i32
263+
//^ bool
264+
}
265+
"#,
266+
);
267+
}

0 commit comments

Comments
 (0)