From 70758e711db00326b41998cb98eb5a8c5bb69fb8 Mon Sep 17 00:00:00 2001 From: Kasperi Apell Date: Tue, 31 Dec 2024 20:43:51 +0200 Subject: [PATCH] [flake8-bugbear] Catch yield in subexpressions (B901) (#14453) Currently, the B901 rule misses yield expressions that are not top-of-tree, for example as in def f(): x = yield print(x) return 42 This commit refactors the rule to find such yield expressions. Assignments are traversed and identifiers bound to yield or yield from expressions are tracked, so that if those variables are later returned (which is valid), the rule is not triggered. The assignment traversal part is inspired by the match_value and match_target functions from src/analyze/typing.rs in the ruff_python_semantic crate. The relevant issue is #14453. --- .../test/fixtures/flake8_bugbear/B901.py | 86 ++++++++++- .../rules/return_in_generator.rs | 142 ++++++++++++++++-- ...__flake8_bugbear__tests__B901_B901.py.snap | 58 ++++++- 3 files changed, 264 insertions(+), 22 deletions(-) diff --git a/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B901.py b/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B901.py index 42fdda60d7c25e..5a7c79afa45990 100644 --- a/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B901.py +++ b/crates/ruff_linter/resources/test/fixtures/flake8_bugbear/B901.py @@ -1,6 +1,6 @@ """ Should emit: -B901 - on lines 9, 36 +B901 - on lines 9, 17, 25, 30, 35, 42, 48, 53 """ @@ -13,6 +13,46 @@ def broken(): yield 1 +def broken2(): + return [3, 2, 1] + + yield from not_broken() + + +def broken3(): + x = yield + print(x) + return 42 + + +def broken4(): + (yield from range(5)) + return 10 + + +def broken5(): + x, y = ((yield from []), 7) + return y + + +def broken6(): + x = y = z = yield from [] + w, z = ("a", 10) + x + return z + + +def broken7(): + x = yield from [] + x = 5 + return x + + +def broken8(): + ((x, y), z) = ((a, b), c) = (((yield 2), 3), 4) + return b + + def not_broken(): if True: return @@ -32,12 +72,6 @@ def not_broken3(): yield from not_broken() -def broken2(): - return [3, 2, 1] - - yield from not_broken() - - async def not_broken4(): import asyncio @@ -72,7 +106,43 @@ def inner(ex): return x +def not_broken9(): + x = None + + def inner(): + return (yield from []) + + x = inner() + return x + + +def not_broken10(): + x, y = ((yield from []), 7) + return x + + +def not_broken11(): + x = y = z = yield from [] + return z + + +def not_broken12(): + x = yield + print(x) + return x + + +def not_broken13(): + (x, y), z, w = ((0, (yield)), 1, 2) + return y + + +def not_broken14(): + (x, y) = (z, w) = ((yield 5), 7) + return z + + class NotBroken9(object): def __await__(self): yield from function() - return 42 + return 42 \ No newline at end of file diff --git a/crates/ruff_linter/src/rules/flake8_bugbear/rules/return_in_generator.rs b/crates/ruff_linter/src/rules/flake8_bugbear/rules/return_in_generator.rs index e04ce11eef2520..24164dd046775a 100644 --- a/crates/ruff_linter/src/rules/flake8_bugbear/rules/return_in_generator.rs +++ b/crates/ruff_linter/src/rules/flake8_bugbear/rules/return_in_generator.rs @@ -1,9 +1,9 @@ +use std::collections::HashMap; + use ruff_diagnostics::Diagnostic; use ruff_diagnostics::Violation; use ruff_macros::{derive_message_formats, ViolationMetadata}; -use ruff_python_ast::statement_visitor; -use ruff_python_ast::statement_visitor::StatementVisitor; -use ruff_python_ast::{self as ast, Expr, Stmt, StmtFunctionDef}; +use ruff_python_ast::{self as ast, visitor::Visitor, Expr, Stmt}; use ruff_text_size::TextRange; use crate::checkers::ast::Checker; @@ -91,13 +91,13 @@ impl Violation for ReturnInGenerator { } /// B901 -pub(crate) fn return_in_generator(checker: &mut Checker, function_def: &StmtFunctionDef) { +pub(crate) fn return_in_generator(checker: &mut Checker, function_def: &ast::StmtFunctionDef) { if function_def.name.id == "__await__" { return; } let mut visitor = ReturnInGeneratorVisitor::default(); - visitor.visit_body(&function_def.body); + ast::statement_visitor::StatementVisitor::visit_body(&mut visitor, &function_def.body); if visitor.has_yield { if let Some(return_) = visitor.return_ { @@ -108,31 +108,155 @@ pub(crate) fn return_in_generator(checker: &mut Checker, function_def: &StmtFunc } } +enum BindState { + Stored, + Reassigned, +} + #[derive(Default)] struct ReturnInGeneratorVisitor { return_: Option, has_yield: bool, + yield_expr_names: HashMap, + yield_on_last_visit: bool, } -impl StatementVisitor<'_> for ReturnInGeneratorVisitor { +impl ast::statement_visitor::StatementVisitor<'_> for ReturnInGeneratorVisitor { fn visit_stmt(&mut self, stmt: &Stmt) { match stmt { Stmt::Expr(ast::StmtExpr { value, .. }) => match **value { Expr::Yield(_) | Expr::YieldFrom(_) => { self.has_yield = true; } - _ => {} + _ => { + self.visit_expr(value); + } }, Stmt::FunctionDef(_) => { // Do not recurse into nested functions; they're evaluated separately. } + Stmt::Assign(ast::StmtAssign { targets, value, .. }) => { + for target in targets { + self.discover_yield_assignments(target, value); + } + } + Stmt::AnnAssign(ast::StmtAnnAssign { + target, + value: Some(value), + .. + }) => { + self.yield_on_last_visit = false; + self.visit_expr(value); + self.evaluate_target(target); + } Stmt::Return(ast::StmtReturn { - value: Some(_), + value: Some(value), range, }) => { + if let Expr::Name(ast::ExprName { ref id, .. }) = **value { + if !matches!( + self.yield_expr_names.get(id.as_str()), + Some(BindState::Reassigned) | None + ) { + return; + } + } self.return_ = Some(*range); } - _ => statement_visitor::walk_stmt(self, stmt), + _ => ast::statement_visitor::walk_stmt(self, stmt), + } + } +} + +impl Visitor<'_> for ReturnInGeneratorVisitor { + fn visit_expr(&mut self, expr: &Expr) { + match expr { + Expr::Yield(_) | Expr::YieldFrom(_) => { + self.has_yield = true; + self.yield_on_last_visit = true; + } + Expr::Lambda(_) | Expr::Call(_) => {} + _ => ast::visitor::walk_expr(self, expr), + } + } +} + +impl ReturnInGeneratorVisitor { + /// Determine if a target is bound to a yield or a yield from expression and, + /// if so, track that target + fn evaluate_target(&mut self, target: &Expr) { + if let Expr::Name(ast::ExprName { ref id, .. }) = *target { + if self.yield_on_last_visit { + match self.yield_expr_names.get(id.as_str()) { + Some(BindState::Reassigned) => {} + _ => { + self.yield_expr_names + .insert(id.to_string(), BindState::Stored); + } + } + } else { + if let Some(BindState::Stored) = self.yield_expr_names.get(id.as_str()) { + self.yield_expr_names + .insert(id.to_string(), BindState::Reassigned); + } + } + } + } + + /// Given a target and a value, track any identifiers that are bound to + /// yield or yield from expressions + fn discover_yield_assignments(&mut self, target: &Expr, value: &Expr) { + match target { + Expr::Name(_) => { + self.yield_on_last_visit = false; + self.visit_expr(value); + self.evaluate_target(target); + } + Expr::Tuple(ast::ExprTuple { elts: tar_elts, .. }) + | Expr::List(ast::ExprList { elts: tar_elts, .. }) => match value { + Expr::Tuple(ast::ExprTuple { elts: val_elts, .. }) + | Expr::List(ast::ExprList { elts: val_elts, .. }) + | Expr::Set(ast::ExprSet { elts: val_elts, .. }) => { + self.discover_yield_container_assignments(tar_elts, val_elts); + } + Expr::Yield(_) | Expr::YieldFrom(_) => { + self.has_yield = true; + self.yield_on_last_visit = true; + self.evaluate_target(target); + } + _ => {} + }, + _ => {} + } + } + + fn discover_yield_container_assignments(&mut self, targets: &[Expr], values: &[Expr]) { + for (target, value) in targets.iter().zip(values) { + match target { + Expr::Tuple(ast::ExprTuple { elts: tar_elts, .. }) + | Expr::List(ast::ExprList { elts: tar_elts, .. }) + | Expr::Set(ast::ExprSet { elts: tar_elts, .. }) => { + match value { + Expr::Tuple(ast::ExprTuple { elts: val_elts, .. }) + | Expr::List(ast::ExprList { elts: val_elts, .. }) + | Expr::Set(ast::ExprSet { elts: val_elts, .. }) => { + self.discover_yield_container_assignments(tar_elts, val_elts); + } + Expr::Yield(_) | Expr::YieldFrom(_) => { + self.has_yield = true; + self.yield_on_last_visit = true; + self.evaluate_target(target); + } + _ => {} + }; + } + Expr::Name(_) => { + self.yield_on_last_visit = false; + self.visit_expr(value); + self.evaluate_target(target); + } + _ => {} + } } } } diff --git a/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B901_B901.py.snap b/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B901_B901.py.snap index 7538d0c1a405af..52da58d6f74aa3 100644 --- a/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B901_B901.py.snap +++ b/crates/ruff_linter/src/rules/flake8_bugbear/snapshots/ruff_linter__rules__flake8_bugbear__tests__B901_B901.py.snap @@ -12,11 +12,59 @@ B901.py:9:9: B901 Using `yield` and `return {value}` in a generator function can 11 | yield 3 | -B901.py:36:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior +B901.py:17:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior | -35 | def broken2(): -36 | return [3, 2, 1] +16 | def broken2(): +17 | return [3, 2, 1] | ^^^^^^^^^^^^^^^^ B901 -37 | -38 | yield from not_broken() +18 | +19 | yield from not_broken() + | + +B901.py:25:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +23 | x = yield +24 | print(x) +25 | return 42 + | ^^^^^^^^^ B901 + | + +B901.py:30:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +28 | def broken4(): +29 | (yield from range(5)) +30 | return 10 + | ^^^^^^^^^ B901 + | + +B901.py:35:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +33 | def broken5(): +34 | x, y = ((yield from []), 7) +35 | return y + | ^^^^^^^^ B901 + | + +B901.py:42:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +40 | w, z = ("a", 10) +41 | x +42 | return z + | ^^^^^^^^ B901 + | + +B901.py:48:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +46 | x = yield from [] +47 | x = 5 +48 | return x + | ^^^^^^^^ B901 + | + +B901.py:53:5: B901 Using `yield` and `return {value}` in a generator function can lead to confusing behavior + | +51 | def broken8(): +52 | ((x, y), z) = ((a, b), c) = (((yield 2), 3), 4) +53 | return b + | ^^^^^^^^ B901 |