Skip to content

Commit f6f4481

Browse files
committed
add args
1 parent 112b998 commit f6f4481

28 files changed

+358
-172
lines changed

flake8_trio/__init__.py

Lines changed: 98 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
import functools
1616
import keyword
1717
import os
18-
import re
1918
import subprocess
2019
import sys
2120
import tokenize
@@ -24,8 +23,9 @@
2423

2524
import libcst as cst
2625

26+
from .base import Options
2727
from .runner import Flake8TrioRunner, Flake8TrioRunner_cst
28-
from .visitors import default_disabled_error_codes
28+
from .visitors import ERROR_CLASSES, ERROR_CLASSES_CST, default_disabled_error_codes
2929

3030
if TYPE_CHECKING:
3131
from collections.abc import Iterable, Sequence
@@ -76,12 +76,6 @@ def cst_parse_module_native(source: str) -> cst.Module:
7676

7777
def main() -> int:
7878
parser = ArgumentParser(prog="flake8_trio")
79-
parser.add_argument(
80-
nargs="*",
81-
metavar="file",
82-
dest="files",
83-
help="Files(s) to format, instead of autodetection.",
84-
)
8579
Plugin.add_options(parser)
8680
args = parser.parse_args()
8781
Plugin.parse_options(args)
@@ -115,7 +109,7 @@ def main() -> int:
115109
for error in sorted(plugin.run()):
116110
print(f"{file}:{error}")
117111
any_error = True
118-
if plugin.options.autofix:
112+
if plugin.options.autofix_codes:
119113
with open(file, "w") as file:
120114
file.write(plugin.module.code)
121115
return 1 if any_error else 0
@@ -124,7 +118,13 @@ def main() -> int:
124118
class Plugin:
125119
name = __name__
126120
version = __version__
127-
options: Namespace = Namespace()
121+
standalone = True
122+
_options: Options | None = None
123+
124+
@property
125+
def options(self) -> Options:
126+
assert self._options is not None
127+
return self._options
128128

129129
def __init__(self, tree: ast.AST, lines: Sequence[str]):
130130
super().__init__()
@@ -158,18 +158,64 @@ def run(self) -> Iterable[Error]:
158158
@staticmethod
159159
def add_options(option_manager: OptionManager | ArgumentParser):
160160
if isinstance(option_manager, ArgumentParser):
161-
# TODO: disable TRIO9xx calls by default
162-
# if run as standalone
161+
Plugin.standalone = True
163162
add_argument = option_manager.add_argument
163+
add_argument(
164+
nargs="*",
165+
metavar="file",
166+
dest="files",
167+
help="Files(s) to format, instead of autodetection.",
168+
)
164169
else: # if run as a flake8 plugin
170+
Plugin.standalone = False
165171
# Disable TRIO9xx calls by default
166172
option_manager.extend_default_ignore(default_disabled_error_codes)
167173
# add parameter to parse from flake8 config
168174
add_argument = functools.partial( # type: ignore
169175
option_manager.add_option, parse_from_config=True
170176
)
171-
add_argument("--autofix", action="store_true", required=False)
172177

178+
add_argument(
179+
"--enable",
180+
type=comma_separated_list,
181+
default="TRIO",
182+
required=False,
183+
help=(
184+
"Comma-separated list of error codes to enable, similar to flake8"
185+
" --select but is additionally more performant as it will disable"
186+
" non-enabled visitors from running instead of just silencing their"
187+
" errors."
188+
),
189+
)
190+
add_argument(
191+
"--disable",
192+
type=comma_separated_list,
193+
default="TRIO9" if Plugin.standalone else "",
194+
required=False,
195+
help=(
196+
"Comma-separated list of error codes to disable, similar to flake8"
197+
" --ignore but is additionally more performant as it will disable"
198+
" non-enabled visitors from running instead of just silencing their"
199+
" errors."
200+
),
201+
)
202+
add_argument(
203+
"--autofix",
204+
type=comma_separated_list,
205+
default="",
206+
required=False,
207+
help=(
208+
"Comma-separated list of error-codes to enable autofixing for"
209+
"if implemented. Requires running as a standalone program."
210+
),
211+
)
212+
add_argument(
213+
"--error-on-autofix",
214+
action="store_true",
215+
required=False,
216+
default=False,
217+
help="Whether to also print an error message for autofixed errors",
218+
)
173219
add_argument(
174220
"--no-checkpoint-warning-decorators",
175221
default="asynccontextmanager",
@@ -208,19 +254,6 @@ def add_options(option_manager: OptionManager | ArgumentParser):
208254
"suggesting it be replaced with {value}"
209255
),
210256
)
211-
add_argument(
212-
"--enable-visitor-codes-regex",
213-
type=re.compile, # type: ignore[arg-type]
214-
default=".*",
215-
required=False,
216-
help=(
217-
"Regex string of visitors to enable. Can be used to disable broken "
218-
"visitors, or instead of --select/--disable to select error codes "
219-
"in a way that is more performant. If a visitor raises multiple codes "
220-
"it will not be disabled unless all codes are disabled, but it will "
221-
"not report codes matching this regex."
222-
),
223-
)
224257
add_argument(
225258
"--anyio",
226259
# action=store_true + parse_from_config does seem to work here, despite
@@ -237,7 +270,45 @@ def add_options(option_manager: OptionManager | ArgumentParser):
237270

238271
@staticmethod
239272
def parse_options(options: Namespace):
240-
Plugin.options = options
273+
def get_matching_codes(
274+
patterns: Iterable[str], codes: Iterable[str]
275+
) -> Iterable[str]:
276+
for pattern in patterns:
277+
for code in codes:
278+
if code.lower().startswith(pattern.lower()):
279+
yield code
280+
281+
all_codes: set[str] = {
282+
err_code
283+
for err_class in (*ERROR_CLASSES, *ERROR_CLASSES_CST)
284+
for err_code in err_class.error_codes.keys() # type: ignore[attr-defined]
285+
if len(err_code) == 7 # exclude e.g. TRIO103_anyio_trio
286+
}
287+
288+
if options.autofix and not Plugin.standalone:
289+
print("Cannot autofix when run as a flake8 plugin.", file=sys.stderr)
290+
sys.exit(1)
291+
autofix_codes = set(get_matching_codes(options.autofix, all_codes))
292+
293+
# enable codes
294+
enabled_codes = set(get_matching_codes(options.enable, all_codes))
295+
296+
# disable codes
297+
enabled_codes -= set(get_matching_codes(options.disable, enabled_codes))
298+
299+
# if disable has default value, re-enable explicitly enabled codes
300+
if options.disable == ["TRIO9"]:
301+
enabled_codes.update(code for code in options.enable if len(code) == 7)
302+
303+
Plugin._options = Options(
304+
enabled_codes=enabled_codes,
305+
autofix_codes=autofix_codes,
306+
error_on_autofix=options.error_on_autofix,
307+
no_checkpoint_warning_decorators=options.no_checkpoint_warning_decorators,
308+
startable_in_context_manager=options.startable_in_context_manager,
309+
trio200_blocking_calls=options.trio200_blocking_calls,
310+
anyio=options.anyio,
311+
)
241312

242313

243314
def comma_separated_list(raw_value: str) -> list[str]:

flake8_trio/base.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,25 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, NamedTuple
5+
from dataclasses import dataclass
6+
from typing import TYPE_CHECKING, Any, NamedTuple
7+
8+
if TYPE_CHECKING:
9+
from collections.abc import Collection
10+
11+
12+
@dataclass
13+
class Options:
14+
# error codes to give errors for
15+
enabled_codes: set[str]
16+
# error codes to autofix
17+
autofix_codes: set[str]
18+
# whether to print an error message even when autofixed
19+
error_on_autofix: bool
20+
no_checkpoint_warning_decorators: Collection[str]
21+
startable_in_context_manager: Collection[str]
22+
trio200_blocking_calls: dict[str, str]
23+
anyio: bool
624

725

826
class Statement(NamedTuple):

flake8_trio/runner.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from __future__ import annotations
88

99
import ast
10-
import re
1110
from dataclasses import dataclass, field
1211
from typing import TYPE_CHECKING
1312

@@ -21,44 +20,49 @@
2120
)
2221

2322
if TYPE_CHECKING:
24-
from argparse import Namespace
2523
from collections.abc import Iterable
2624

2725
from libcst import Module
2826

29-
from .base import Error
27+
from .base import Error, Options
3028
from .visitors.flake8triovisitor import Flake8TrioVisitor, Flake8TrioVisitor_cst
3129

3230

3331
@dataclass
3432
class SharedState:
35-
options: Namespace
33+
options: Options
3634
problems: list[Error] = field(default_factory=list)
3735
library: tuple[str, ...] = ()
3836
typed_calls: dict[str, str] = field(default_factory=dict)
3937
variables: dict[str, str] = field(default_factory=dict)
4038

4139

42-
class Flake8TrioRunner(ast.NodeVisitor):
43-
def __init__(self, options: Namespace):
40+
class __CommonRunner:
41+
"""Common functionality used in both runners."""
42+
43+
def __init__(self, options: Options):
4444
super().__init__()
4545
self.state = SharedState(options)
4646

47+
def selected(self, error_codes: dict[str, str]) -> bool:
48+
enabled_or_autofix = (
49+
self.state.options.enabled_codes | self.state.options.autofix_codes
50+
)
51+
return bool(set(error_codes) & enabled_or_autofix)
52+
53+
54+
class Flake8TrioRunner(ast.NodeVisitor, __CommonRunner):
55+
def __init__(self, options: Options):
56+
super().__init__(options)
4757
# utility visitors that need to run before the error-checking visitors
4858
self.utility_visitors = {v(self.state) for v in utility_visitors}
4959

5060
self.visitors = {
5161
v(self.state) for v in ERROR_CLASSES if self.selected(v.error_codes)
5262
}
5363

54-
def selected(self, error_codes: dict[str, str]) -> bool:
55-
return any(
56-
re.match(self.state.options.enable_visitor_codes_regex, code)
57-
for code in error_codes
58-
)
59-
6064
@classmethod
61-
def run(cls, tree: ast.AST, options: Namespace) -> Iterable[Error]:
65+
def run(cls, tree: ast.AST, options: Options) -> Iterable[Error]:
6266
runner = cls(options)
6367
runner.visit(tree)
6468
yield from runner.state.problems
@@ -104,10 +108,9 @@ def visit(self, node: ast.AST):
104108
subclass.set_state(subclass.outer.pop(node, {}))
105109

106110

107-
class Flake8TrioRunner_cst:
108-
def __init__(self, options: Namespace, module: Module):
109-
super().__init__()
110-
self.state = SharedState(options)
111+
class Flake8TrioRunner_cst(__CommonRunner):
112+
def __init__(self, options: Options, module: Module):
113+
super().__init__(options)
111114
self.options = options
112115

113116
# Could possibly enable/disable utility visitors here, if visitors declared
@@ -127,9 +130,3 @@ def run(self) -> Iterable[Error]:
127130
for v in (*self.utility_visitors, *self.visitors):
128131
self.module = cst.MetadataWrapper(self.module).visit(v)
129132
yield from self.state.problems
130-
131-
def selected(self, error_codes: dict[str, str]) -> bool:
132-
return any(
133-
re.match(self.state.options.enable_visitor_codes_regex, code)
134-
for code in error_codes
135-
)

flake8_trio/visitors/flake8triovisitor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from __future__ import annotations
44

55
import ast
6-
import re
76
from abc import ABC
87
from typing import TYPE_CHECKING, Any, Union
98

@@ -99,7 +98,7 @@ def error(
9998
), "No error code defined, but class has multiple codes"
10099
error_code = next(iter(self.error_codes))
101100
# don't emit an error if this code is disabled in a multi-code visitor
102-
elif not re.match(self.options.enable_visitor_codes_regex, error_code):
101+
elif error_code[:7] not in self.options.enabled_codes:
103102
return
104103

105104
self.__state.problems.append(
@@ -190,9 +189,7 @@ def set_state(self, attrs: dict[str, Any], copy: bool = False):
190189
def save_state(self, node: cst.CSTNode, *attrs: str, copy: bool = False):
191190
state = self.get_state(*attrs, copy=copy)
192191
if node in self.outer:
193-
# not currently used, and not gonna bother adding dedicated test
194-
# visitors atm
195-
self.outer[node].update(state) # pragma: no cover
192+
self.outer[node].update(state)
196193
else:
197194
self.outer[node] = state
198195

@@ -211,10 +208,9 @@ def error(
211208
), "No error code defined, but class has multiple codes"
212209
error_code = next(iter(self.error_codes))
213210
# don't emit an error if this code is disabled in a multi-code visitor
214-
elif not re.match(
215-
self.options.enable_visitor_codes_regex, error_code
216-
): # pragma: no cover
217-
return
211+
# TODO: write test for only one of 910/911 enabled/autofixed
212+
elif error_code[:7] not in self.options.enabled_codes:
213+
return # pragma: no cover
218214
pos = self.get_metadata(PositionProvider, node).start
219215

220216
self.__state.problems.append(
@@ -228,6 +224,12 @@ def error(
228224
)
229225
)
230226

227+
def should_autofix(self, code: str | None = None):
228+
if code is None:
229+
assert len(self.error_codes) == 1
230+
code = next(iter(self.error_codes))
231+
return code in self.options.autofix_codes
232+
231233
@property
232234
def library(self) -> tuple[str, ...]:
233235
return self.__state.library if self.__state.library else ("trio",)

0 commit comments

Comments
 (0)