Skip to content

Commit 665b0cd

Browse files
authored
Merge pull request #154 from jakkdl/91x_comprehensions
2 parents b74a671 + a928956 commit 665b0cd

File tree

4 files changed

+171
-3
lines changed

4 files changed

+171
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
# Changelog
22
*[CalVer, YY.month.patch](https://calver.org/)*
33

4+
## Future
5+
- TRIO91X now supports comprehensions
6+
47
## 23.2.5
58
- Fix false alarms for `@pytest.fixture`-decorated functions in TRIO101, TRIO910 and TRIO911
69

flake8_trio/visitors/visitor91x.py

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(self, *args: Any, **kwargs: Any):
9797
self.safe_decorator = False
9898
self.async_function = False
9999
self.uncheckpointed_statements: set[Statement] = set()
100+
self.comp_unknown = False
100101

101102
self.loop_state = LoopState()
102103
self.try_state = TryState()
@@ -522,3 +523,70 @@ def leave_BooleanOperation_right(self, node: cst.BooleanOperation):
522523
self.uncheckpointed_statements.update(
523524
self.outer[node]["uncheckpointed_statements"]
524525
)
526+
527+
# comprehensions are simpler than loops, since they cannot contain yields
528+
# or many other complicated statements, but their subfields are not in the order
529+
# they're logically executed, so we manually visit each field in execution order,
530+
# as long as the effect of the statement is not known. Once we know the comprehension
531+
# will checkpoint, we stop visiting, or once we are no longer guaranteed to execute
532+
# code deeper in the comprehension.
533+
# Functions return `False` so libcst doesn't iterate subnodes [again].
534+
def visit_ListComp(self, node: cst.DictComp | cst.SetComp | cst.ListComp):
535+
if not self.async_function or not self.uncheckpointed_statements:
536+
return False
537+
538+
outer = self.comp_unknown
539+
self.comp_unknown = True
540+
541+
# visit `for` and `if`s
542+
node.for_in.visit(self)
543+
544+
# if still unknown, visit the expression
545+
if self.comp_unknown and self.uncheckpointed_statements:
546+
if isinstance(node, cst.DictComp):
547+
node.key.visit(self)
548+
node.value.visit(self)
549+
else:
550+
node.elt.visit(self)
551+
552+
self.comp_unknown = outer
553+
return False
554+
555+
visit_SetComp = visit_ListComp
556+
visit_DictComp = visit_ListComp
557+
558+
def visit_CompFor(self, node: cst.CompFor):
559+
# should only ever be visited manually, when inside an async function
560+
assert self.async_function
561+
562+
if not self.uncheckpointed_statements:
563+
return False
564+
565+
# if async comprehension, checkpoint
566+
if node.asynchronous:
567+
self.uncheckpointed_statements = set()
568+
self.comp_unknown = False
569+
return False
570+
571+
# visit the iter call, which might have await's
572+
node.iter.visit(self)
573+
574+
# stop checking if the loop is not guaranteed to execute
575+
if not iter_guaranteed_once_cst(node.iter):
576+
self.comp_unknown = False
577+
578+
# only the first `if` is guaranteed to execute
579+
# and if there's any, don't check inner loop
580+
elif node.ifs:
581+
self.comp_unknown = False
582+
node.ifs[0].visit(self)
583+
elif node.inner_for_in:
584+
# visit nested loops (and ifs), if any
585+
node.inner_for_in.visit(self)
586+
587+
return False
588+
589+
# We don't have any logic on if generators are guaranteed to unroll, so always
590+
# ignore their content
591+
def visit_GeneratorExp(self, node: cst.GeneratorExp):
592+
return False

tests/eval_files/trio910.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -543,10 +543,9 @@ async def foo_comprehension_1():
543543

544544

545545
# should error
546-
async def foo_comprehension_2():
546+
async def foo_comprehension_2(): # error: 0, "exit", Statement("function definition", lineno)
547547
[await foo() for x in range(10) if bar()]
548548

549549

550-
# should not error, see https://github.com/Zac-HD/flake8-trio/issues/144
551-
async def foo_comprehension_3(): # error: 0, "exit", Statement("function definition", lineno)
550+
async def foo_comprehension_3():
552551
[... async for x in bar()]

tests/eval_files/trio911.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ async def foo() -> Any:
1212
await foo()
1313

1414

15+
def bar(*args) -> Any:
16+
...
17+
18+
1519
async def foo_yield_1():
1620
await foo()
1721
yield 5
@@ -886,3 +890,97 @@ async def foo_test():
886890
@pytest.fixture()
887891
async def foo_test2():
888892
yield
893+
894+
895+
async def comprehensions():
896+
# guaranteed iteration with await in test
897+
[... for x in range(10) if await foo()]
898+
yield # safe
899+
900+
# guaranteed iteration and await in value, but test is not guaranteed
901+
[await foo() for x in range(10) if bar()]
902+
yield # error: 4, "yield", Stmt("yield", line-4)
903+
904+
# guaranteed iteration and await in value
905+
[await foo() for x in range(10)]
906+
yield # safe
907+
908+
# not guaranteed to iter
909+
[await foo() for x in bar()]
910+
yield # error: 4, "yield", Stmt("yield", line-4)
911+
912+
# await statement in loop expression
913+
[... for x in bar(await foo())]
914+
yield
915+
916+
# set comprehensions use same logic as list
917+
{await foo() for x in range(10)}
918+
yield # safe
919+
920+
{await foo() for x in bar()}
921+
yield # error: 4, "yield", Stmt("yield", line-3)
922+
923+
# dict comprehensions use same logic as list
924+
{await foo(): 5 for x in bar()}
925+
yield # error: 4, "yield", Stmt("yield", line-4)
926+
927+
# other than `await` can be in both key&val
928+
{await foo(): 5 for x in range(10)}
929+
yield
930+
931+
{5: await foo() for x in range(10)}
932+
yield
933+
934+
# generator expressions are never treated as safe
935+
(await foo() for x in range(10))
936+
yield # error: 4, "yield", Stmt("yield", line-4)
937+
938+
(await foo() for x in bar() if await foo())
939+
yield # error: 4, "yield", Stmt("yield", line-3)
940+
941+
# async for always safe
942+
[... async for x in bar()]
943+
yield # safe
944+
{... async for x in bar()}
945+
yield # safe
946+
{...: ... async for x in bar()}
947+
yield # safe
948+
949+
# other than in generator expression
950+
(... async for x in bar())
951+
yield # error: 4, "yield", Stmt("yield", line-4)
952+
953+
# multiple loops
954+
[... for x in range(10) for y in range(10) if await foo()]
955+
yield
956+
[... for x in range(10) for y in bar() if await foo()]
957+
yield # error: 4, "yield", Stmt("yield", line-2)
958+
[... for x in bar() for y in range(10) if await foo()]
959+
yield # error: 4, "yield", Stmt("yield", line-2)
960+
961+
[await foo() for x in range(10) for y in range(10)]
962+
yield
963+
[await foo() for x in range(10) for y in bar()]
964+
yield # error: 4, "yield", Stmt("yield", line-2)
965+
[await foo() for x in bar() for y in range(10)]
966+
yield # error: 4, "yield", Stmt("yield", line-2)
967+
968+
# trip loops!
969+
[... for x in range(10) for y in range(10) async for z in bar()]
970+
yield
971+
[... for x in range(10) for y in range(10) for z in range(10)]
972+
yield # error: 4, "yield", Stmt("yield", line-2)
973+
974+
# multiple ifs
975+
[... for x in range(10) for y in range(10) if await foo() if await foo()]
976+
yield
977+
978+
[... for x in range(10) for y in bar() if await foo() if await foo()]
979+
yield # error: 4, "yield", Stmt("yield", line-3)
980+
981+
# nested comprehensions
982+
[[await foo() for x in range(10)] for y in range(10)]
983+
yield
984+
985+
# code coverage: inner comprehension with no checkpointed statements
986+
[... for x in [await foo()] for y in x]

0 commit comments

Comments
 (0)