From d77d3234ce861106cbf66738aafa4fafc6bf7db6 Mon Sep 17 00:00:00 2001 From: rainy-me Date: Sat, 25 Dec 2021 09:05:56 +0900 Subject: [PATCH] refactor: avoid filter map next with find map separate traversal --- crates/hir_ty/src/diagnostics/expr.rs | 143 +++++++++++++++----------- 1 file changed, 84 insertions(+), 59 deletions(-) diff --git a/crates/hir_ty/src/diagnostics/expr.rs b/crates/hir_ty/src/diagnostics/expr.rs index a8c4026e31..b7d765c59b 100644 --- a/crates/hir_ty/src/diagnostics/expr.rs +++ b/crates/hir_ty/src/diagnostics/expr.rs @@ -81,9 +81,8 @@ impl ExprValidator { } fn validate_body(&mut self, db: &dyn HirDatabase) { - self.check_for_filter_map_next(db); - let body = db.body(self.owner); + let mut filter_map_next_checker = None; for (id, expr) in body.exprs.iter() { if let Some((variant, missed_fields, true)) = @@ -101,7 +100,7 @@ impl ExprValidator { self.validate_match(id, *expr, arms, db, self.infer.clone()); } Expr::Call { .. } | Expr::MethodCall { .. } => { - self.validate_call(db, id, expr); + self.validate_call(db, id, expr, &mut filter_map_next_checker); } _ => {} } @@ -143,58 +142,13 @@ impl ExprValidator { }); } - fn check_for_filter_map_next(&mut self, db: &dyn HirDatabase) { - // Find the FunctionIds for Iterator::filter_map and Iterator::next - let iterator_path = path![core::iter::Iterator]; - let resolver = self.owner.resolver(db.upcast()); - let iterator_trait_id = match resolver.resolve_known_trait(db.upcast(), &iterator_path) { - Some(id) => id, - None => return, - }; - let iterator_trait_items = &db.trait_data(iterator_trait_id).items; - let filter_map_function_id = - match iterator_trait_items.iter().find(|item| item.0 == name![filter_map]) { - Some((_, AssocItemId::FunctionId(id))) => id, - _ => return, - }; - let next_function_id = match iterator_trait_items.iter().find(|item| item.0 == name![next]) - { - Some((_, AssocItemId::FunctionId(id))) => id, - _ => return, - }; - - // Search function body for instances of .filter_map(..).next() - let body = db.body(self.owner); - let mut prev = None; - for (id, expr) in body.exprs.iter() { - if let Expr::MethodCall { receiver, .. } = expr { - let function_id = match self.infer.method_resolution(id) { - Some((id, _)) => id, - None => continue, - }; - - if function_id == *filter_map_function_id { - prev = Some(id); - continue; - } - - if function_id == *next_function_id { - if let Some(filter_map_id) = prev { - if *receiver == filter_map_id { - self.diagnostics.push( - BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap { - method_call_expr: id, - }, - ); - } - } - } - } - prev = None; - } - } - - fn validate_call(&mut self, db: &dyn HirDatabase, call_id: ExprId, expr: &Expr) { + fn validate_call( + &mut self, + db: &dyn HirDatabase, + call_id: ExprId, + expr: &Expr, + filter_map_next_checker: &mut Option, + ) { // Check that the number of arguments matches the number of parameters. // FIXME: Due to shortcomings in the current type system implementation, only emit this @@ -214,6 +168,24 @@ impl ExprValidator { (sig, args.len()) } Expr::MethodCall { receiver, args, .. } => { + let (callee, subst) = match self.infer.method_resolution(call_id) { + Some(it) => it, + None => return, + }; + + if filter_map_next_checker + .get_or_insert_with(|| { + FilterMapNextChecker::new(&self.owner.resolver(db.upcast()), db) + }) + .check(call_id, receiver, &callee) + .is_some() + { + self.diagnostics.push( + BodyValidationDiagnostic::ReplaceFilterMapNextWithFindMap { + method_call_expr: call_id, + }, + ); + } let receiver = &self.infer.type_of_expr[*receiver]; if receiver.strip_references().is_unknown() { // if the receiver is of unknown type, it's very likely we @@ -222,10 +194,6 @@ impl ExprValidator { return; } - let (callee, subst) = match self.infer.method_resolution(call_id) { - Some(it) => it, - None => return, - }; let sig = db.callable_item_signature(callee.into()).substitute(Interner, &subst); (sig, args.len() + 1) @@ -424,6 +392,63 @@ impl ExprValidator { } } +struct FilterMapNextChecker { + filter_map_function_id: Option, + next_function_id: Option, + prev_filter_map_expr_id: Option, +} + +impl FilterMapNextChecker { + fn new(resolver: &hir_def::resolver::Resolver, db: &dyn HirDatabase) -> Self { + // Find and store the FunctionIds for Iterator::filter_map and Iterator::next + let iterator_path = path![core::iter::Iterator]; + let mut filter_map_function_id = None; + let mut next_function_id = None; + + if let Some(iterator_trait_id) = resolver.resolve_known_trait(db.upcast(), &iterator_path) { + let iterator_trait_items = &db.trait_data(iterator_trait_id).items; + for item in iterator_trait_items.iter() { + if let (name, AssocItemId::FunctionId(id)) = item { + if *name == name![filter_map] { + filter_map_function_id = Some(*id); + } + if *name == name![next] { + next_function_id = Some(*id); + } + } + if filter_map_function_id.is_some() && next_function_id.is_some() { + break; + } + } + } + Self { filter_map_function_id, next_function_id, prev_filter_map_expr_id: None } + } + + // check for instances of .filter_map(..).next() + fn check( + &mut self, + current_expr_id: ExprId, + receiver_expr_id: &ExprId, + function_id: &hir_def::FunctionId, + ) -> Option<()> { + if *function_id == self.filter_map_function_id? { + self.prev_filter_map_expr_id = Some(current_expr_id); + return None; + } + + if *function_id == self.next_function_id? { + if let Some(prev_filter_map_expr_id) = self.prev_filter_map_expr_id { + if *receiver_expr_id == prev_filter_map_expr_id { + return Some(()); + } + } + } + + self.prev_filter_map_expr_id = None; + None + } +} + pub fn record_literal_missing_fields( db: &dyn HirDatabase, infer: &InferenceResult,