Skip to content

Commit d6802e0

Browse files
committed
Added Import Changes
1 parent 59c2ff1 commit d6802e0

File tree

2 files changed

+101
-24
lines changed

2 files changed

+101
-24
lines changed

src/inline/plugin.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -296,11 +296,17 @@ class ExtractInlineTest(ast.NodeTransformer):
296296
arg_timeout_str = "timeout"
297297

298298
assume = "assume"
299+
inline_module_imported = False
300+
import_str = "import"
301+
from_str = "from"
302+
as_str = "as"
303+
299304
inline_module_imported = False
300305

301306
def __init__(self):
302307
self.cur_inline_test = InlineTest()
303308
self.inline_test_list = []
309+
self.import_list = []
304310

305311
def is_inline_test_class(self, node):
306312
if isinstance(node, ast.Call):
@@ -350,16 +356,39 @@ def find_previous_stmt(self, node):
350356
return prev_stmt_node
351357
return self.find_condition_stmt(prev_stmt_node)
352358

353-
def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]):
359+
def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call], import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
354360
"""
355361
collect all function calls in the node
356362
"""
357363
if isinstance(node, ast.Attribute):
358-
self.collect_inline_test_calls(node.value, inline_test_calls)
364+
self.collect_inline_test_calls(node.value, inline_test_calls, import_calls, import_from_calls)
359365
elif isinstance(node, ast.Call):
360366
inline_test_calls.append(node)
361-
self.collect_inline_test_calls(node.func, inline_test_calls)
367+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
368+
elif isinstance(node, ast.Import):
369+
import_calls.append(node)
370+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
371+
elif isinstance(node, ast.ImportFrom):
372+
import_from_calls.append(node)
373+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
374+
375+
def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
376+
"""
377+
collect all import calls in the node (should be done first)
378+
"""
362379

380+
while not isinstance(node, ast.Module) and node.parent != None:
381+
node = node.parent
382+
383+
if not isinstance(node, ast.Module):
384+
return
385+
386+
for child in node.children:
387+
if isinstance(child, ast.Import):
388+
import_calls.append(child)
389+
elif isinstance(child, ast.ImportFrom):
390+
import_from_calls.append(child)
391+
363392
def parse_constructor(self, node):
364393
"""
365394
Parse a constructor call.
@@ -931,8 +960,13 @@ def parse_parameterized_test(self):
931960
parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index)
932961

933962
def parse_inline_test(self, node):
934-
inline_test_calls = []
935-
self.collect_inline_test_calls(node, inline_test_calls)
963+
import_calls = []
964+
import_from_calls = []
965+
inline_test_calls = []
966+
967+
self.collect_inline_test_calls(node, inline_test_calls, import_calls, import_from_calls)
968+
self.collect_import_calls(node, import_calls, import_from_calls)
969+
936970
inline_test_calls.reverse()
937971

938972
if len(inline_test_calls) <= 1:
@@ -953,14 +987,32 @@ def parse_inline_test(self, node):
953987
self.parse_assume(call)
954988
inline_test_call_index += 1
955989

956-
# "given(a, 1)"
957990
for call in inline_test_calls[inline_test_call_index:]:
958-
if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str:
959-
self.parse_given(call)
960-
inline_test_call_index += 1
991+
if isinstance(call.func, ast.Attribute):
992+
if call.func.attr == self.given_str:
993+
self.parse_given(call)
994+
inline_test_call_index += 1
995+
elif call.func.attr == self.diff_given_str:
996+
self.parse_diff_given(call)
997+
inline_test_call_index += 1
998+
999+
# match call.func.attr:
1000+
# # "given(a, 1)"
1001+
# case self.given_str:
1002+
# self.parse_given(call)
1003+
# inline_test_call_index += 1
1004+
# # "diff_given(devices, ["cpu", "cuda"])"
1005+
# case self.diff_given_str:
1006+
# self.parse_diff_given(call)
1007+
# inline_test_call_index += 1
9611008
else:
9621009
break
9631010

1011+
for import_stmt in import_calls:
1012+
self.cur_inline_test.import_stmts.append(import_stmt)
1013+
for import_stmt in import_from_calls:
1014+
self.cur_inline_test.import_stmts.append(import_stmt)
1015+
9641016
# "check_eq" or "check_true" or "check_false" or "check_neq"
9651017
for call in inline_test_calls[inline_test_call_index:]:
9661018
# "check_eq(a, 1)"

tests/test_plugin.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,48 @@
22
from _pytest.pytester import Pytester
33
import pytest
44

5+
# For testing in Spyder only
6+
if __name__ == "__main__":
7+
pytest.main(['-v', '-s'])
8+
59

610
# pytest -p pytester
711
class TestInlinetests:
12+
def test_inline_detects_imports(self, pytester: Pytester):
13+
checkfile = pytester.makepyfile(
14+
"""
15+
from inline import itest
16+
import datetime
17+
18+
def m(a):
19+
b = a + datetime.timedelta(days=365)
20+
itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366))
21+
"""
22+
)
23+
for x in (pytester.path, checkfile):
24+
items, reprec = pytester.inline_genitems(x)
25+
assert len(items) == 1
26+
res = pytester.runpytest()
27+
assert res.ret != 1
28+
29+
# def test_inline_detects_from_imports(self, pytester: Pytester):
30+
# checkfile = pytester.makepyfile(
31+
# """
32+
# from inline import itest
33+
# import numpy as np
34+
# from scipy import stats as st
35+
36+
# def m(n, p):
37+
# b = st.binom(n, p)
38+
# itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p)
39+
# """
40+
# )
41+
# for x in (pytester.path, checkfile):
42+
# items, reprec = pytester.inline_genitems(x)
43+
# assert len(items) == 1
44+
# res = pytester.runpytest()
45+
# assert res.ret == 0
46+
847
def test_inline_parser(self, pytester: Pytester):
948
checkfile = pytester.makepyfile(
1049
"""
@@ -31,6 +70,7 @@ def m(a):
3170
items, reprec = pytester.inline_genitems(x)
3271
assert len(items) == 0
3372

73+
3474
def test_inline_malformed_given(self, pytester: Pytester):
3575
checkfile = pytester.makepyfile(
3676
"""
@@ -118,21 +158,6 @@ def m(a):
118158
res = pytester.runpytest()
119159
assert res.ret == 0
120160

121-
def test_check_eq_parameterized_tests(self, pytester: Pytester):
122-
checkfile = pytester.makepyfile(
123-
"""
124-
from inline import itest
125-
def m(a):
126-
a = a + 1
127-
itest(parameterized=True).given(a, [2, 3]).check_eq(a, [3, 4])
128-
"""
129-
)
130-
for x in (pytester.path, checkfile):
131-
items, reprec = pytester.inline_genitems(x)
132-
assert len(items) == 2
133-
res = pytester.runpytest()
134-
assert res.ret == 0
135-
136161
def test_malformed_check_eq_parameterized_tests(self, pytester: Pytester):
137162
checkfile = pytester.makepyfile(
138163
"""

0 commit comments

Comments
 (0)