Skip to content

Commit 76f5363

Browse files
authored
Merge pull request #149 from jakkdl/91x_autofix
implement autofix for TRIO100
2 parents 7c293f7 + 77c61dd commit 76f5363

File tree

13 files changed

+368
-32
lines changed

13 files changed

+368
-32
lines changed

flake8_trio/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,10 @@ def main():
101101
cwd=root,
102102
).stdout.splitlines()
103103
except (subprocess.SubprocessError, FileNotFoundError):
104-
print("Doesn't seem to be a git repo; pass filenames to format.")
104+
print(
105+
"Doesn't seem to be a git repo; pass filenames to format.",
106+
file=sys.stderr,
107+
)
105108
sys.exit(1)
106109
all_filenames = [
107110
os.path.join(root, f) for f in all_filenames if _should_format(f)
@@ -110,6 +113,9 @@ def main():
110113
plugin = Plugin.from_filename(file)
111114
for error in sorted(plugin.run()):
112115
print(f"{file}:{error}")
116+
if plugin.options.autofix:
117+
with open(file, "w") as file:
118+
file.write(plugin.module.code)
113119

114120

115121
class Plugin:
@@ -122,7 +128,7 @@ def __init__(self, tree: ast.AST, lines: Sequence[str]):
122128
self._tree = tree
123129
source = "".join(lines)
124130

125-
self._module: cst.Module = cst_parse_module_native(source)
131+
self.module: cst.Module = cst_parse_module_native(source)
126132

127133
@classmethod
128134
def from_filename(cls, filename: str | PathLike[str]) -> Plugin: # pragma: no cover
@@ -137,12 +143,14 @@ def from_source(cls, source: str) -> Plugin:
137143
plugin = Plugin.__new__(cls)
138144
super(Plugin, plugin).__init__()
139145
plugin._tree = ast.parse(source)
140-
plugin._module = cst_parse_module_native(source)
146+
plugin.module = cst_parse_module_native(source)
141147
return plugin
142148

143149
def run(self) -> Iterable[Error]:
144150
yield from Flake8TrioRunner.run(self._tree, self.options)
145-
yield from Flake8TrioRunner_cst(self.options).run(self._module)
151+
cst_runner = Flake8TrioRunner_cst(self.options, self.module)
152+
yield from cst_runner.run()
153+
self.module = cst_runner.module
146154

147155
@staticmethod
148156
def add_options(option_manager: OptionManager | ArgumentParser):
@@ -157,6 +165,7 @@ def add_options(option_manager: OptionManager | ArgumentParser):
157165
add_argument = functools.partial(
158166
option_manager.add_option, parse_from_config=True
159167
)
168+
add_argument("--autofix", action="store_true", required=False)
160169

161170
add_argument(
162171
"--no-checkpoint-warning-decorators",

flake8_trio/runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,20 +100,20 @@ def visit(self, node: ast.AST):
100100

101101

102102
class Flake8TrioRunner_cst:
103-
def __init__(self, options: Namespace):
103+
def __init__(self, options: Namespace, module: Module):
104104
super().__init__()
105105
self.state = SharedState(options)
106106
self.options = options
107107
self.visitors: tuple[Flake8TrioVisitor_cst, ...] = tuple(
108108
v(self.state) for v in ERROR_CLASSES_CST if self.selected(v.error_codes)
109109
)
110+
self.module = module
110111

111-
def run(self, module: Module) -> Iterable[Error]:
112+
def run(self) -> Iterable[Error]:
112113
if not self.visitors:
113114
return
114-
wrapper = cst.MetadataWrapper(module)
115115
for v in self.visitors:
116-
_ = wrapper.visit(v)
116+
self.module = cst.MetadataWrapper(self.module).visit(v)
117117
yield from self.state.problems
118118

119119
def selected(self, error_codes: dict[str, str]) -> bool:

flake8_trio/visitors/helpers.py

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
import ast
99
from fnmatch import fnmatch
10-
from typing import TYPE_CHECKING, NamedTuple, TypeVar
10+
from typing import TYPE_CHECKING, NamedTuple, TypeVar, cast
1111

1212
import libcst as cst
1313
import libcst.matchers as m
@@ -341,3 +341,62 @@ def func_has_decorator(func: cst.FunctionDef, *names: str) -> bool:
341341
),
342342
)
343343
)
344+
345+
346+
def get_comments(node: cst.CSTNode | Iterable[cst.CSTNode]) -> Iterator[cst.EmptyLine]:
347+
# pyright can't use hasattr to narrow the type, so need a bunch of casts
348+
if hasattr(node, "__iter__"):
349+
for n in cast("Iterable[cst.CSTNode]", node):
350+
yield from get_comments(n)
351+
return
352+
yield from (
353+
cst.EmptyLine(comment=ensure_type(c, cst.Comment))
354+
for c in m.findall(cast("cst.CSTNode", node), m.Comment())
355+
)
356+
return
357+
358+
359+
# used in TRIO100
360+
def flatten_preserving_comments(node: cst.BaseCompoundStatement):
361+
# add leading lines (comments and empty lines) for the node to be removed
362+
new_leading_lines = list(node.leading_lines)
363+
364+
# add other comments belonging to the node as empty lines with comments
365+
for attr in "lpar", "items", "rpar":
366+
# pragma, since this is currently only used to flatten `With` statements
367+
if comment_nodes := getattr(node, attr, None): # pragma: no cover
368+
new_leading_lines.extend(get_comments(comment_nodes))
369+
370+
# node.body is a BaseSuite, whose subclasses are SimpleStatementSuite
371+
# and IndentedBlock
372+
if isinstance(node.body, cst.SimpleStatementSuite):
373+
# `with ...: pass;pass;pass` -> pass;pass;pass
374+
return cst.SimpleStatementLine(node.body.body, leading_lines=new_leading_lines)
375+
376+
assert isinstance(node.body, cst.IndentedBlock)
377+
nodes = list(node.body.body)
378+
379+
# nodes[0] is a BaseStatement, whose subclasses are SimpleStatementLine
380+
# and BaseCompoundStatement - both of which has leading_lines
381+
assert isinstance(nodes[0], (cst.SimpleStatementLine, cst.BaseCompoundStatement))
382+
383+
# add body header comment - i.e. comments on the same/last line of the statement
384+
if node.body.header and node.body.header.comment:
385+
new_leading_lines.append(
386+
cst.EmptyLine(indent=True, comment=node.body.header.comment)
387+
)
388+
# add the leading lines of the first node
389+
new_leading_lines.extend(nodes[0].leading_lines)
390+
# update the first node with all the above constructed lines
391+
nodes[0] = nodes[0].with_changes(leading_lines=new_leading_lines)
392+
393+
# if there's comments in the footer of the indented block, add a pass
394+
# statement with the comments as leading lines
395+
if node.body.footer:
396+
nodes.append(
397+
cst.SimpleStatementLine(
398+
[cst.Pass()],
399+
node.body.footer,
400+
)
401+
)
402+
return cst.FlattenSentinel(nodes)

flake8_trio/visitors/visitor100.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
import libcst.matchers as m
1414

1515
from .flake8triovisitor import Flake8TrioVisitor_cst
16-
from .helpers import AttributeCall, error_class_cst, with_has_call
16+
from .helpers import (
17+
AttributeCall,
18+
error_class_cst,
19+
flatten_preserving_comments,
20+
with_has_call,
21+
)
1722

1823

1924
@error_class_cst
@@ -46,12 +51,16 @@ def visit_With(self, node: cst.With) -> None:
4651
else:
4752
self.has_checkpoint_stack.append(True)
4853

49-
def leave_With(self, original_node: cst.With, updated_node: cst.With) -> cst.With:
54+
def leave_With(
55+
self, original_node: cst.With, updated_node: cst.With
56+
) -> cst.BaseStatement | cst.FlattenSentinel[cst.BaseStatement]:
5057
if not self.has_checkpoint_stack.pop():
5158
for res in self.node_dict[original_node]:
5259
self.error(res.node, res.base, res.function)
53-
# if: autofixing is enabled for this code
54-
# then: remove the with and pop out it's body
60+
61+
if self.options.autofix and len(updated_node.items) == 1:
62+
return flatten_preserving_comments(updated_node)
63+
5564
return updated_node
5665

5766
def visit_For(self, node: cst.For):

flake8_trio/visitors/visitors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def visit_With(self, node: ast.With | ast.AsyncWith):
9191
nursery = get_matching_call(item.context_expr, "open_nursery")
9292

9393
# `isinstance(..., ast.Call)` is done in get_matching_call
94-
body_call = cast(ast.Call, node.body[0].value)
94+
body_call = cast("ast.Call", node.body[0].value)
9595

9696
if (
9797
nursery is not None

tests/autofix_files/trio100.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# type: ignore
2+
3+
import trio
4+
5+
# error: 5, "trio", "move_on_after"
6+
...
7+
8+
9+
async def function_name():
10+
# fmt: off
11+
...; ...; ...
12+
# fmt: on
13+
# error: 15, "trio", "fail_after"
14+
...
15+
# error: 15, "trio", "fail_at"
16+
...
17+
# error: 15, "trio", "move_on_after"
18+
...
19+
# error: 15, "trio", "move_on_at"
20+
...
21+
# error: 15, "trio", "CancelScope"
22+
...
23+
24+
with trio.move_on_after(10):
25+
await trio.sleep(1)
26+
27+
with trio.move_on_after(10):
28+
await trio.sleep(1)
29+
print("hello")
30+
31+
with trio.move_on_after(10):
32+
while True:
33+
await trio.sleep(1)
34+
print("hello")
35+
36+
with open("filename") as _:
37+
...
38+
39+
# error: 9, "trio", "fail_after"
40+
...
41+
42+
send_channel, receive_channel = trio.open_memory_channel(0)
43+
async with trio.fail_after(10):
44+
async with send_channel:
45+
...
46+
47+
async with trio.fail_after(10):
48+
async for _ in receive_channel:
49+
...
50+
51+
# error: 15, "trio", "fail_after"
52+
for _ in receive_channel:
53+
...
54+
55+
# fix missed alarm when function is defined inside the with scope
56+
# error: 9, "trio", "move_on_after"
57+
58+
async def foo():
59+
await trio.sleep(1)
60+
61+
# error: 9, "trio", "move_on_after"
62+
if ...:
63+
64+
async def foo():
65+
if ...:
66+
await trio.sleep(1)
67+
68+
async with random_ignored_library.fail_after(10):
69+
...
70+
71+
72+
async def function_name2():
73+
with (
74+
open("") as _,
75+
trio.fail_after(10), # error: 8, "trio", "fail_after"
76+
):
77+
...
78+
79+
with (
80+
trio.fail_after(5), # error: 8, "trio", "fail_after"
81+
open("") as _,
82+
trio.move_on_after(5), # error: 8, "trio", "move_on_after"
83+
):
84+
...
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import trio
2+
3+
# a
4+
# b
5+
# error: 5, "trio", "move_on_after"
6+
# c
7+
# d
8+
print(1) # e
9+
# f
10+
# g
11+
print(2) # h
12+
# i
13+
# j
14+
print(3) # k
15+
# l
16+
# m
17+
pass
18+
# n
19+
20+
# error: 5, "trio", "move_on_after"
21+
...
22+
23+
24+
# a
25+
# b
26+
# fmt: off
27+
...;...;...
28+
# fmt: on
29+
# c
30+
# d
31+
32+
# Doesn't autofix With's with multiple withitems
33+
with (
34+
trio.move_on_after(10), # error: 4, "trio", "move_on_after"
35+
open("") as f,
36+
):
37+
...
38+
39+
40+
# multiline with, despite only being one statement
41+
# a
42+
# b
43+
# c
44+
# error: 4, "trio", "move_on_after"
45+
# d
46+
# e
47+
# f
48+
# g
49+
# h
50+
# this comment is kept
51+
...
52+
53+
# fmt: off
54+
# a
55+
# b
56+
# error: 4, "trio", "move_on_after"
57+
# c
58+
...; ...; ...
59+
# fmt: on

tests/conftest.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ def pytest_addoption(parser: pytest.Parser):
99
parser.addoption(
1010
"--runfuzz", action="store_true", default=False, help="run fuzz tests"
1111
)
12+
parser.addoption(
13+
"--generate-autofix",
14+
action="store_true",
15+
default=False,
16+
help="generate autofix file content",
17+
)
1218
parser.addoption(
1319
"--enable-visitor-codes-regex",
1420
default=".*",
@@ -32,6 +38,11 @@ def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item
3238
item.add_marker(skip_fuzz)
3339

3440

41+
@pytest.fixture()
42+
def generate_autofix(request: pytest.FixtureRequest):
43+
return request.config.getoption("generate_autofix")
44+
45+
3546
@pytest.fixture()
3647
def enable_visitor_codes_regex(request: pytest.FixtureRequest):
3748
return request.config.getoption("--enable-visitor-codes-regex")

0 commit comments

Comments
 (0)