diff --git a/src/ast/helpers.rs b/src/ast/helpers.rs index bdb82800696ac..1d93c47519a7f 100644 --- a/src/ast/helpers.rs +++ b/src/ast/helpers.rs @@ -1,6 +1,6 @@ use once_cell::sync::Lazy; use regex::Regex; -use rustpython_ast::{Excepthandler, ExcepthandlerKind, Expr, ExprKind, StmtKind}; +use rustpython_ast::{Excepthandler, ExcepthandlerKind, Expr, ExprKind, Location, StmtKind}; use crate::python::typing; @@ -131,3 +131,15 @@ pub fn is_super_call_with_arguments(func: &Expr, args: &[Expr]) -> bool { false } } + +/// Convert a location within a file (relative to `base`) to an absolute position. +pub fn to_absolute(relative: &Location, base: &Location) -> Location { + if relative.row() == 1 { + Location::new( + relative.row() + base.row() - 1, + relative.column() + base.column() - 1, + ) + } else { + Location::new(relative.row() + base.row() - 1, relative.column()) + } +} diff --git a/src/autofix/fixes.rs b/src/autofix/fixes.rs deleted file mode 100644 index acf0245faf9b5..0000000000000 --- a/src/autofix/fixes.rs +++ /dev/null @@ -1,371 +0,0 @@ -use anyhow::Result; -use itertools::Itertools; -use libcst_native::ImportNames::Aliases; -use libcst_native::NameOrAttribute::N; -use libcst_native::{Codegen, Expression, SmallStatement, Statement}; -use rustpython_parser::ast::{ExcepthandlerKind, Expr, Keyword, Location, Stmt, StmtKind}; -use rustpython_parser::lexer; -use rustpython_parser::token::Tok; - -use crate::ast::operations::SourceCodeLocator; -use crate::ast::types::Range; -use crate::autofix::Fix; - -/// Convert a location within a file (relative to `base`) to an absolute position. -fn to_absolute(relative: &Location, base: &Location) -> Location { - if relative.row() == 1 { - Location::new( - relative.row() + base.row() - 1, - relative.column() + base.column() - 1, - ) - } else { - Location::new(relative.row() + base.row() - 1, relative.column()) - } -} - -/// Generate a fix to remove a base from a ClassDef statement. -pub fn remove_class_def_base( - locator: &mut SourceCodeLocator, - stmt_at: &Location, - expr_at: Location, - bases: &[Expr], - keywords: &[Keyword], -) -> Option { - let content = locator.slice_source_code_at(stmt_at); - - // Case 1: `object` is the only base. - if bases.len() == 1 && keywords.is_empty() { - let mut fix_start = None; - let mut fix_end = None; - let mut count: usize = 0; - for (start, tok, end) in lexer::make_tokenizer(content).flatten() { - if matches!(tok, Tok::Lpar) { - if count == 0 { - fix_start = Some(to_absolute(&start, stmt_at)); - } - count += 1; - } - - if matches!(tok, Tok::Rpar) { - count -= 1; - if count == 0 { - fix_end = Some(to_absolute(&end, stmt_at)); - break; - } - } - } - - return match (fix_start, fix_end) { - (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), - _ => None, - }; - } - - if bases - .iter() - .map(|node| node.location) - .chain(keywords.iter().map(|node| node.location)) - .any(|location| location > expr_at) - { - // Case 2: `object` is _not_ the last node. - let mut fix_start: Option = None; - let mut fix_end: Option = None; - let mut seen_comma = false; - for (start, tok, end) in lexer::make_tokenizer(content).flatten() { - let start = to_absolute(&start, stmt_at); - if seen_comma { - if matches!(tok, Tok::Newline) { - fix_end = Some(end); - } else { - fix_end = Some(start); - } - break; - } - if start == expr_at { - fix_start = Some(start); - } - if fix_start.is_some() && matches!(tok, Tok::Comma) { - seen_comma = true; - } - } - - match (fix_start, fix_end) { - (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), - _ => None, - } - } else { - // Case 3: `object` is the last node, so we have to find the last token that isn't a comma. - let mut fix_start: Option = None; - let mut fix_end: Option = None; - for (start, tok, end) in lexer::make_tokenizer(content).flatten() { - let start = to_absolute(&start, stmt_at); - let end = to_absolute(&end, stmt_at); - if start == expr_at { - fix_end = Some(end); - break; - } - if matches!(tok, Tok::Comma) { - fix_start = Some(start); - } - } - - match (fix_start, fix_end) { - (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), - _ => None, - } - } -} - -pub fn remove_super_arguments(locator: &mut SourceCodeLocator, expr: &Expr) -> Option { - let range = Range::from_located(expr); - let contents = locator.slice_source_code_range(&range); - - let mut tree = match libcst_native::parse_module(contents, None) { - Ok(m) => m, - Err(_) => return None, - }; - - if let Some(Statement::Simple(body)) = tree.body.first_mut() { - if let Some(SmallStatement::Expr(body)) = body.body.first_mut() { - if let Expression::Call(body) = &mut body.value { - body.args = vec![]; - body.whitespace_before_args = Default::default(); - body.whitespace_after_func = Default::default(); - - let mut state = Default::default(); - tree.codegen(&mut state); - - return Some(Fix::replacement( - state.to_string(), - range.location, - range.end_location, - )); - } - } - } - - None -} - -/// Determine if a body contains only a single statement, taking into account deleted. -fn has_single_child(body: &[Stmt], deleted: &[&Stmt]) -> bool { - body.iter().filter(|child| !deleted.contains(child)).count() == 1 -} - -/// Determine if a child is the only statement in its body. -fn is_lone_child(child: &Stmt, parent: &Stmt, deleted: &[&Stmt]) -> Result { - match &parent.node { - StmtKind::FunctionDef { body, .. } - | StmtKind::AsyncFunctionDef { body, .. } - | StmtKind::ClassDef { body, .. } - | StmtKind::With { body, .. } - | StmtKind::AsyncWith { body, .. } => { - if body.iter().contains(child) { - Ok(has_single_child(body, deleted)) - } else { - Err(anyhow::anyhow!("Unable to find child in parent body.")) - } - } - StmtKind::For { body, orelse, .. } - | StmtKind::AsyncFor { body, orelse, .. } - | StmtKind::While { body, orelse, .. } - | StmtKind::If { body, orelse, .. } => { - if body.iter().contains(child) { - Ok(has_single_child(body, deleted)) - } else if orelse.iter().contains(child) { - Ok(has_single_child(orelse, deleted)) - } else { - Err(anyhow::anyhow!("Unable to find child in parent body.")) - } - } - StmtKind::Try { - body, - handlers, - orelse, - finalbody, - } => { - if body.iter().contains(child) { - Ok(has_single_child(body, deleted)) - } else if orelse.iter().contains(child) { - Ok(has_single_child(orelse, deleted)) - } else if finalbody.iter().contains(child) { - Ok(has_single_child(finalbody, deleted)) - } else if let Some(body) = handlers.iter().find_map(|handler| match &handler.node { - ExcepthandlerKind::ExceptHandler { body, .. } => { - if body.iter().contains(child) { - Some(body) - } else { - None - } - } - }) { - Ok(has_single_child(body, deleted)) - } else { - Err(anyhow::anyhow!("Unable to find child in parent body.")) - } - } - _ => Err(anyhow::anyhow!("Unable to find child in parent body.")), - } -} - -pub fn remove_stmt(stmt: &Stmt, parent: Option<&Stmt>, deleted: &[&Stmt]) -> Result { - if parent - .map(|parent| is_lone_child(stmt, parent, deleted)) - .map_or(Ok(None), |v| v.map(Some))? - .unwrap_or_default() - { - // If removing this node would lead to an invalid syntax tree, replace - // it with a `pass`. - Ok(Fix::replacement( - "pass".to_string(), - stmt.location, - stmt.end_location.unwrap(), - )) - } else { - // Otherwise, nuke the entire line. - // TODO(charlie): This logic assumes that there are no multi-statement physical lines. - Ok(Fix::deletion( - Location::new(stmt.location.row(), 1), - Location::new(stmt.end_location.unwrap().row() + 1, 1), - )) - } -} - -/// Generate a Fix to remove any unused imports from an `import` statement. -pub fn remove_unused_imports( - locator: &mut SourceCodeLocator, - full_names: &[&str], - stmt: &Stmt, - parent: Option<&Stmt>, - deleted: &[&Stmt], -) -> Result { - let mut tree = match libcst_native::parse_module( - locator.slice_source_code_range(&Range::from_located(stmt)), - None, - ) { - Ok(m) => m, - Err(_) => return Err(anyhow::anyhow!("Failed to extract CST from source.")), - }; - - let body = if let Some(Statement::Simple(body)) = tree.body.first_mut() { - body - } else { - return Err(anyhow::anyhow!("Expected node to be: Statement::Simple.")); - }; - let body = if let Some(SmallStatement::Import(body)) = body.body.first_mut() { - body - } else { - return Err(anyhow::anyhow!( - "Expected node to be: SmallStatement::ImportFrom." - )); - }; - let aliases = &mut body.names; - - // Preserve the trailing comma (or not) from the last entry. - let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone()); - - // Identify unused imports from within the `import from`. - let mut removable = vec![]; - for (index, alias) in aliases.iter().enumerate() { - if let N(import_name) = &alias.name { - if full_names.contains(&import_name.value) { - removable.push(index); - } - } - } - // TODO(charlie): This is quadratic. - for index in removable.iter().rev() { - aliases.remove(*index); - } - - if let Some(alias) = aliases.last_mut() { - alias.comma = trailing_comma; - } - - if aliases.is_empty() { - remove_stmt(stmt, parent, deleted) - } else { - let mut state = Default::default(); - tree.codegen(&mut state); - - Ok(Fix::replacement( - state.to_string(), - stmt.location, - stmt.end_location.unwrap(), - )) - } -} - -/// Generate a Fix to remove any unused imports from an `import from` statement. -pub fn remove_unused_import_froms( - locator: &mut SourceCodeLocator, - full_names: &[&str], - stmt: &Stmt, - parent: Option<&Stmt>, - deleted: &[&Stmt], -) -> Result { - let mut tree = match libcst_native::parse_module( - locator.slice_source_code_range(&Range::from_located(stmt)), - None, - ) { - Ok(m) => m, - Err(_) => return Err(anyhow::anyhow!("Failed to extract CST from source.")), - }; - - let body = if let Some(Statement::Simple(body)) = tree.body.first_mut() { - body - } else { - return Err(anyhow::anyhow!("Expected node to be: Statement::Simple.")); - }; - let body = if let Some(SmallStatement::ImportFrom(body)) = body.body.first_mut() { - body - } else { - return Err(anyhow::anyhow!( - "Expected node to be: SmallStatement::ImportFrom." - )); - }; - let aliases = if let Aliases(aliases) = &mut body.names { - aliases - } else { - return Err(anyhow::anyhow!("Expected node to be: Aliases.")); - }; - - // Preserve the trailing comma (or not) from the last entry. - let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone()); - - // Identify unused imports from within the `import from`. - let mut removable = vec![]; - for (index, alias) in aliases.iter().enumerate() { - if let N(name) = &alias.name { - let import_name = if let Some(N(module_name)) = &body.module { - format!("{}.{}", module_name.value, name.value) - } else { - name.value.to_string() - }; - if full_names.contains(&import_name.as_str()) { - removable.push(index); - } - } - } - // TODO(charlie): This is quadratic. - for index in removable.iter().rev() { - aliases.remove(*index); - } - - if let Some(alias) = aliases.last_mut() { - alias.comma = trailing_comma; - } - - if aliases.is_empty() { - remove_stmt(stmt, parent, deleted) - } else { - let mut state = Default::default(); - tree.codegen(&mut state); - - Ok(Fix::replacement( - state.to_string(), - stmt.location, - stmt.end_location.unwrap(), - )) - } -} diff --git a/src/autofix/helpers.rs b/src/autofix/helpers.rs new file mode 100644 index 0000000000000..ab8f54daf842c --- /dev/null +++ b/src/autofix/helpers.rs @@ -0,0 +1,89 @@ +use anyhow::Result; +use itertools::Itertools; +use rustpython_parser::ast::{ExcepthandlerKind, Location, Stmt, StmtKind}; + +use crate::autofix::Fix; + +/// Determine if a body contains only a single statement, taking into account deleted. +fn has_single_child(body: &[Stmt], deleted: &[&Stmt]) -> bool { + body.iter().filter(|child| !deleted.contains(child)).count() == 1 +} + +/// Determine if a child is the only statement in its body. +fn is_lone_child(child: &Stmt, parent: &Stmt, deleted: &[&Stmt]) -> Result { + match &parent.node { + StmtKind::FunctionDef { body, .. } + | StmtKind::AsyncFunctionDef { body, .. } + | StmtKind::ClassDef { body, .. } + | StmtKind::With { body, .. } + | StmtKind::AsyncWith { body, .. } => { + if body.iter().contains(child) { + Ok(has_single_child(body, deleted)) + } else { + Err(anyhow::anyhow!("Unable to find child in parent body.")) + } + } + StmtKind::For { body, orelse, .. } + | StmtKind::AsyncFor { body, orelse, .. } + | StmtKind::While { body, orelse, .. } + | StmtKind::If { body, orelse, .. } => { + if body.iter().contains(child) { + Ok(has_single_child(body, deleted)) + } else if orelse.iter().contains(child) { + Ok(has_single_child(orelse, deleted)) + } else { + Err(anyhow::anyhow!("Unable to find child in parent body.")) + } + } + StmtKind::Try { + body, + handlers, + orelse, + finalbody, + } => { + if body.iter().contains(child) { + Ok(has_single_child(body, deleted)) + } else if orelse.iter().contains(child) { + Ok(has_single_child(orelse, deleted)) + } else if finalbody.iter().contains(child) { + Ok(has_single_child(finalbody, deleted)) + } else if let Some(body) = handlers.iter().find_map(|handler| match &handler.node { + ExcepthandlerKind::ExceptHandler { body, .. } => { + if body.iter().contains(child) { + Some(body) + } else { + None + } + } + }) { + Ok(has_single_child(body, deleted)) + } else { + Err(anyhow::anyhow!("Unable to find child in parent body.")) + } + } + _ => Err(anyhow::anyhow!("Unable to find child in parent body.")), + } +} + +pub fn remove_stmt(stmt: &Stmt, parent: Option<&Stmt>, deleted: &[&Stmt]) -> Result { + if parent + .map(|parent| is_lone_child(stmt, parent, deleted)) + .map_or(Ok(None), |v| v.map(Some))? + .unwrap_or_default() + { + // If removing this node would lead to an invalid syntax tree, replace + // it with a `pass`. + Ok(Fix::replacement( + "pass".to_string(), + stmt.location, + stmt.end_location.unwrap(), + )) + } else { + // Otherwise, nuke the entire line. + // TODO(charlie): This logic assumes that there are no multi-statement physical lines. + Ok(Fix::deletion( + Location::new(stmt.location.row(), 1), + Location::new(stmt.end_location.unwrap().row() + 1, 1), + )) + } +} diff --git a/src/autofix/mod.rs b/src/autofix/mod.rs index 9b6df783bbf80..b9d5a92efb2f8 100644 --- a/src/autofix/mod.rs +++ b/src/autofix/mod.rs @@ -2,7 +2,7 @@ use rustpython_ast::Location; use serde::{Deserialize, Serialize}; pub mod fixer; -pub mod fixes; +pub mod helpers; #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)] pub struct Patch { diff --git a/src/check_ast.rs b/src/check_ast.rs index 5a436472b6f96..5ad6d85837d02 100644 --- a/src/check_ast.rs +++ b/src/check_ast.rs @@ -18,7 +18,7 @@ use crate::ast::types::{ }; use crate::ast::visitor::{walk_excepthandler, Visitor}; use crate::ast::{helpers, operations, visitor}; -use crate::autofix::{fixer, fixes}; +use crate::autofix::fixer; use crate::checks::{Check, CheckCode, CheckKind}; use crate::docstrings::definition::{Definition, DefinitionKind, Documentable}; use crate::python::builtins::{BUILTINS, MAGIC_GLOBALS}; @@ -1955,8 +1955,8 @@ impl<'a> Checker<'a> { .collect(); let removal_fn = match kind { - ImportKind::Import => fixes::remove_unused_imports, - ImportKind::ImportFrom => fixes::remove_unused_import_froms, + ImportKind::Import => pyflakes::fixes::remove_unused_imports, + ImportKind::ImportFrom => pyflakes::fixes::remove_unused_import_froms, }; match removal_fn(&mut self.locator, &full_names, child, parent, &deleted) { diff --git a/src/flake8_print/plugins/print_call.rs b/src/flake8_print/plugins/print_call.rs index 691922d867f01..6f70d377e39b1 100644 --- a/src/flake8_print/plugins/print_call.rs +++ b/src/flake8_print/plugins/print_call.rs @@ -1,7 +1,7 @@ use log::error; use rustpython_ast::{Expr, Stmt, StmtKind}; -use crate::autofix::{fixer, fixes}; +use crate::autofix::{fixer, helpers}; use crate::check_ast::Checker; use crate::checks::CheckCode; use crate::flake8_print::checks; @@ -25,7 +25,7 @@ pub fn print_call(checker: &mut Checker, expr: &Expr, func: &Expr) { .map(|index| checker.parents[*index]) .collect(); - match fixes::remove_stmt( + match helpers::remove_stmt( checker.parents[context.defined_by], context.defined_in.map(|index| checker.parents[index]), &deleted, diff --git a/src/pyflakes/fixes.rs b/src/pyflakes/fixes.rs new file mode 100644 index 0000000000000..08af6293ecef6 --- /dev/null +++ b/src/pyflakes/fixes.rs @@ -0,0 +1,147 @@ +use libcst_native::ImportNames::Aliases; +use libcst_native::NameOrAttribute::N; +use libcst_native::{Codegen, SmallStatement, Statement}; +use rustpython_ast::Stmt; + +use crate::ast::operations::SourceCodeLocator; +use crate::ast::types::Range; +use crate::autofix::{helpers, Fix}; + +/// Generate a Fix to remove any unused imports from an `import` statement. +pub fn remove_unused_imports( + locator: &mut SourceCodeLocator, + full_names: &[&str], + stmt: &Stmt, + parent: Option<&Stmt>, + deleted: &[&Stmt], +) -> anyhow::Result { + let mut tree = match libcst_native::parse_module( + locator.slice_source_code_range(&Range::from_located(stmt)), + None, + ) { + Ok(m) => m, + Err(_) => return Err(anyhow::anyhow!("Failed to extract CST from source.")), + }; + + let body = if let Some(Statement::Simple(body)) = tree.body.first_mut() { + body + } else { + return Err(anyhow::anyhow!("Expected node to be: Statement::Simple.")); + }; + let body = if let Some(SmallStatement::Import(body)) = body.body.first_mut() { + body + } else { + return Err(anyhow::anyhow!( + "Expected node to be: SmallStatement::ImportFrom." + )); + }; + let aliases = &mut body.names; + + // Preserve the trailing comma (or not) from the last entry. + let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone()); + + // Identify unused imports from within the `import from`. + let mut removable = vec![]; + for (index, alias) in aliases.iter().enumerate() { + if let N(import_name) = &alias.name { + if full_names.contains(&import_name.value) { + removable.push(index); + } + } + } + // TODO(charlie): This is quadratic. + for index in removable.iter().rev() { + aliases.remove(*index); + } + + if let Some(alias) = aliases.last_mut() { + alias.comma = trailing_comma; + } + + if aliases.is_empty() { + helpers::remove_stmt(stmt, parent, deleted) + } else { + let mut state = Default::default(); + tree.codegen(&mut state); + + Ok(Fix::replacement( + state.to_string(), + stmt.location, + stmt.end_location.unwrap(), + )) + } +} + +/// Generate a Fix to remove any unused imports from an `import from` statement. +pub fn remove_unused_import_froms( + locator: &mut SourceCodeLocator, + full_names: &[&str], + stmt: &Stmt, + parent: Option<&Stmt>, + deleted: &[&Stmt], +) -> anyhow::Result { + let mut tree = match libcst_native::parse_module( + locator.slice_source_code_range(&Range::from_located(stmt)), + None, + ) { + Ok(m) => m, + Err(_) => return Err(anyhow::anyhow!("Failed to extract CST from source.")), + }; + + let body = if let Some(Statement::Simple(body)) = tree.body.first_mut() { + body + } else { + return Err(anyhow::anyhow!("Expected node to be: Statement::Simple.")); + }; + let body = if let Some(SmallStatement::ImportFrom(body)) = body.body.first_mut() { + body + } else { + return Err(anyhow::anyhow!( + "Expected node to be: SmallStatement::ImportFrom." + )); + }; + let aliases = if let Aliases(aliases) = &mut body.names { + aliases + } else { + return Err(anyhow::anyhow!("Expected node to be: Aliases.")); + }; + + // Preserve the trailing comma (or not) from the last entry. + let trailing_comma = aliases.last().and_then(|alias| alias.comma.clone()); + + // Identify unused imports from within the `import from`. + let mut removable = vec![]; + for (index, alias) in aliases.iter().enumerate() { + if let N(name) = &alias.name { + let import_name = if let Some(N(module_name)) = &body.module { + format!("{}.{}", module_name.value, name.value) + } else { + name.value.to_string() + }; + if full_names.contains(&import_name.as_str()) { + removable.push(index); + } + } + } + // TODO(charlie): This is quadratic. + for index in removable.iter().rev() { + aliases.remove(*index); + } + + if let Some(alias) = aliases.last_mut() { + alias.comma = trailing_comma; + } + + if aliases.is_empty() { + helpers::remove_stmt(stmt, parent, deleted) + } else { + let mut state = Default::default(); + tree.codegen(&mut state); + + Ok(Fix::replacement( + state.to_string(), + stmt.location, + stmt.end_location.unwrap(), + )) + } +} diff --git a/src/pyflakes/mod.rs b/src/pyflakes/mod.rs index 6cab51a666c73..e96a4069b2a51 100644 --- a/src/pyflakes/mod.rs +++ b/src/pyflakes/mod.rs @@ -1,2 +1,3 @@ pub mod checks; +pub mod fixes; pub mod plugins; diff --git a/src/pyupgrade/fixes.rs b/src/pyupgrade/fixes.rs new file mode 100644 index 0000000000000..3c9922d210c07 --- /dev/null +++ b/src/pyupgrade/fixes.rs @@ -0,0 +1,133 @@ +use libcst_native::{Codegen, Expression, SmallStatement, Statement}; +use rustpython_ast::{Expr, Keyword, Location}; +use rustpython_parser::lexer; +use rustpython_parser::lexer::Tok; + +use crate::ast::helpers; +use crate::ast::operations::SourceCodeLocator; +use crate::ast::types::Range; +use crate::autofix::Fix; + +/// Generate a fix to remove a base from a ClassDef statement. +pub fn remove_class_def_base( + locator: &mut SourceCodeLocator, + stmt_at: &Location, + expr_at: Location, + bases: &[Expr], + keywords: &[Keyword], +) -> Option { + let content = locator.slice_source_code_at(stmt_at); + + // Case 1: `object` is the only base. + if bases.len() == 1 && keywords.is_empty() { + let mut fix_start = None; + let mut fix_end = None; + let mut count: usize = 0; + for (start, tok, end) in lexer::make_tokenizer(content).flatten() { + if matches!(tok, Tok::Lpar) { + if count == 0 { + fix_start = Some(helpers::to_absolute(&start, stmt_at)); + } + count += 1; + } + + if matches!(tok, Tok::Rpar) { + count -= 1; + if count == 0 { + fix_end = Some(helpers::to_absolute(&end, stmt_at)); + break; + } + } + } + + return match (fix_start, fix_end) { + (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), + _ => None, + }; + } + + if bases + .iter() + .map(|node| node.location) + .chain(keywords.iter().map(|node| node.location)) + .any(|location| location > expr_at) + { + // Case 2: `object` is _not_ the last node. + let mut fix_start: Option = None; + let mut fix_end: Option = None; + let mut seen_comma = false; + for (start, tok, end) in lexer::make_tokenizer(content).flatten() { + let start = helpers::to_absolute(&start, stmt_at); + if seen_comma { + if matches!(tok, Tok::Newline) { + fix_end = Some(end); + } else { + fix_end = Some(start); + } + break; + } + if start == expr_at { + fix_start = Some(start); + } + if fix_start.is_some() && matches!(tok, Tok::Comma) { + seen_comma = true; + } + } + + match (fix_start, fix_end) { + (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), + _ => None, + } + } else { + // Case 3: `object` is the last node, so we have to find the last token that isn't a comma. + let mut fix_start: Option = None; + let mut fix_end: Option = None; + for (start, tok, end) in lexer::make_tokenizer(content).flatten() { + let start = helpers::to_absolute(&start, stmt_at); + let end = helpers::to_absolute(&end, stmt_at); + if start == expr_at { + fix_end = Some(end); + break; + } + if matches!(tok, Tok::Comma) { + fix_start = Some(start); + } + } + + match (fix_start, fix_end) { + (Some(start), Some(end)) => Some(Fix::replacement("".to_string(), start, end)), + _ => None, + } + } +} + +pub fn remove_super_arguments(locator: &mut SourceCodeLocator, expr: &Expr) -> Option { + let range = Range::from_located(expr); + let contents = locator.slice_source_code_range(&range); + + let mut tree = match libcst_native::parse_module(contents, None) { + Ok(m) => m, + Err(_) => return None, + }; + + if let Some(Statement::Simple(body)) = tree.body.first_mut() { + if let Some(SmallStatement::Expr(body)) = body.body.first_mut() { + if let Expression::Call(body) = &mut body.value { + body.args = vec![]; + body.whitespace_before_args = Default::default(); + body.whitespace_after_func = Default::default(); + + let mut state = Default::default(); + tree.codegen(&mut state); + + return Some(Fix::replacement( + state.to_string(), + range.location, + range.end_location, + )); + } + } + } + + None +} diff --git a/src/pyupgrade/mod.rs b/src/pyupgrade/mod.rs index 320be4a3ab59d..1ed9ca7914bc8 100644 --- a/src/pyupgrade/mod.rs +++ b/src/pyupgrade/mod.rs @@ -1,3 +1,4 @@ mod checks; +pub mod fixes; pub mod plugins; pub mod types; diff --git a/src/pyupgrade/plugins/super_call_with_parameters.rs b/src/pyupgrade/plugins/super_call_with_parameters.rs index 29d584384c786..5f38e8a515c95 100644 --- a/src/pyupgrade/plugins/super_call_with_parameters.rs +++ b/src/pyupgrade/plugins/super_call_with_parameters.rs @@ -1,8 +1,9 @@ use rustpython_ast::{Expr, Stmt}; use crate::ast::helpers; -use crate::autofix::{fixer, fixes}; +use crate::autofix::fixer; use crate::check_ast::Checker; +use crate::pyupgrade; use crate::pyupgrade::checks; pub fn super_call_with_parameters(checker: &mut Checker, expr: &Expr, func: &Expr, args: &[Expr]) { @@ -17,7 +18,9 @@ pub fn super_call_with_parameters(checker: &mut Checker, expr: &Expr, func: &Exp .collect(); if let Some(mut check) = checks::super_args(scope, &parents, expr, func, args) { if matches!(checker.autofix, fixer::Mode::Generate | fixer::Mode::Apply) { - if let Some(fix) = fixes::remove_super_arguments(&mut checker.locator, expr) { + if let Some(fix) = + pyupgrade::fixes::remove_super_arguments(&mut checker.locator, expr) + { check.amend(fix); } } diff --git a/src/pyupgrade/plugins/useless_metaclass_type.rs b/src/pyupgrade/plugins/useless_metaclass_type.rs index 7354507e16d9a..88d992da40ace 100644 --- a/src/pyupgrade/plugins/useless_metaclass_type.rs +++ b/src/pyupgrade/plugins/useless_metaclass_type.rs @@ -2,7 +2,7 @@ use log::error; use rustpython_ast::{Expr, Stmt}; use crate::ast::types::{CheckLocator, Range}; -use crate::autofix::{fixer, fixes}; +use crate::autofix::{fixer, helpers}; use crate::check_ast::Checker; use crate::pyupgrade::checks; @@ -20,7 +20,7 @@ pub fn useless_metaclass_type(checker: &mut Checker, stmt: &Stmt, value: &Expr, .map(|index| checker.parents[*index]) .collect(); - match fixes::remove_stmt( + match helpers::remove_stmt( checker.parents[context.defined_by], context.defined_in.map(|index| checker.parents[index]), &deleted, diff --git a/src/pyupgrade/plugins/useless_object_inheritance.rs b/src/pyupgrade/plugins/useless_object_inheritance.rs index c8e928f2a0722..ed31fae3503a4 100644 --- a/src/pyupgrade/plugins/useless_object_inheritance.rs +++ b/src/pyupgrade/plugins/useless_object_inheritance.rs @@ -1,7 +1,8 @@ use rustpython_ast::{Expr, Keyword, Stmt}; -use crate::autofix::{fixer, fixes}; +use crate::autofix::fixer; use crate::check_ast::Checker; +use crate::pyupgrade; use crate::pyupgrade::checks; pub fn useless_object_inheritance( @@ -14,7 +15,7 @@ pub fn useless_object_inheritance( let scope = checker.current_scope(); if let Some(mut check) = checks::useless_object_inheritance(name, bases, scope) { if matches!(checker.autofix, fixer::Mode::Generate | fixer::Mode::Apply) { - if let Some(fix) = fixes::remove_class_def_base( + if let Some(fix) = pyupgrade::fixes::remove_class_def_base( &mut checker.locator, &stmt.location, check.location,