Skip to content

Commit a6f68fc

Browse files
committed
Added Import Changes
1 parent 2e15c5c commit a6f68fc

File tree

2 files changed

+101
-24
lines changed

2 files changed

+101
-24
lines changed

src/inline/plugin.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -297,11 +297,17 @@ class ExtractInlineTest(ast.NodeTransformer):
297297
arg_devices_str = "devices"
298298
diff_test_str = "diff_test"
299299
assume = "assume"
300+
inline_module_imported = False
301+
import_str = "import"
302+
from_str = "from"
303+
as_str = "as"
304+
300305
inline_module_imported = False
301306

302307
def __init__(self):
303308
self.cur_inline_test = InlineTest()
304309
self.inline_test_list = []
310+
self.import_list = []
305311

306312
def is_inline_test_class(self, node):
307313
if isinstance(node, ast.Call):
@@ -351,16 +357,39 @@ def find_previous_stmt(self, node):
351357
return prev_stmt_node
352358
return self.find_condition_stmt(prev_stmt_node)
353359

354-
def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]):
360+
def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call], import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
355361
"""
356362
collect all function calls in the node
357363
"""
358364
if isinstance(node, ast.Attribute):
359-
self.collect_inline_test_calls(node.value, inline_test_calls)
365+
self.collect_inline_test_calls(node.value, inline_test_calls, import_calls, import_from_calls)
360366
elif isinstance(node, ast.Call):
361367
inline_test_calls.append(node)
362-
self.collect_inline_test_calls(node.func, inline_test_calls)
368+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
369+
elif isinstance(node, ast.Import):
370+
import_calls.append(node)
371+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
372+
elif isinstance(node, ast.ImportFrom):
373+
import_from_calls.append(node)
374+
self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls)
375+
376+
def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]):
377+
"""
378+
collect all import calls in the node (should be done first)
379+
"""
363380

381+
while not isinstance(node, ast.Module) and node.parent != None:
382+
node = node.parent
383+
384+
if not isinstance(node, ast.Module):
385+
return
386+
387+
for child in node.children:
388+
if isinstance(child, ast.Import):
389+
import_calls.append(child)
390+
elif isinstance(child, ast.ImportFrom):
391+
import_from_calls.append(child)
392+
364393
def parse_constructor(self, node):
365394
"""
366395
Parse a constructor call.
@@ -1163,8 +1192,13 @@ def parse_parameterized_test(self):
11631192
parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index)
11641193

11651194
def parse_inline_test(self, node):
1166-
inline_test_calls = []
1167-
self.collect_inline_test_calls(node, inline_test_calls)
1195+
import_calls = []
1196+
import_from_calls = []
1197+
inline_test_calls = []
1198+
1199+
self.collect_inline_test_calls(node, inline_test_calls, import_calls, import_from_calls)
1200+
self.collect_import_calls(node, import_calls, import_from_calls)
1201+
11681202
inline_test_calls.reverse()
11691203

11701204
if len(inline_test_calls) <= 1:
@@ -1185,14 +1219,32 @@ def parse_inline_test(self, node):
11851219
self.parse_assume(call)
11861220
inline_test_call_index += 1
11871221

1188-
# "given(a, 1)"
11891222
for call in inline_test_calls[inline_test_call_index:]:
1190-
if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str:
1191-
self.parse_given(call)
1192-
inline_test_call_index += 1
1223+
if isinstance(call.func, ast.Attribute):
1224+
if call.func.attr == self.given_str:
1225+
self.parse_given(call)
1226+
inline_test_call_index += 1
1227+
elif call.func.attr == self.diff_given_str:
1228+
self.parse_diff_given(call)
1229+
inline_test_call_index += 1
1230+
1231+
# match call.func.attr:
1232+
# # "given(a, 1)"
1233+
# case self.given_str:
1234+
# self.parse_given(call)
1235+
# inline_test_call_index += 1
1236+
# # "diff_given(devices, ["cpu", "cuda"])"
1237+
# case self.diff_given_str:
1238+
# self.parse_diff_given(call)
1239+
# inline_test_call_index += 1
11931240
else:
11941241
break
11951242

1243+
for import_stmt in import_calls:
1244+
self.cur_inline_test.import_stmts.append(import_stmt)
1245+
for import_stmt in import_from_calls:
1246+
self.cur_inline_test.import_stmts.append(import_stmt)
1247+
11961248
# "check_eq" or "check_true" or "check_false" or "check_neq"
11971249
for call in inline_test_calls[inline_test_call_index:]:
11981250
# "check_eq(a, 1)"

tests/test_plugin.py

Lines changed: 40 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,48 @@
22
from _pytest.pytester import Pytester
33
import pytest
44

5+
# For testing in Spyder only
6+
if __name__ == "__main__":
7+
pytest.main(['-v', '-s'])
8+
59

610
# pytest -p pytester
711
class TestInlinetests:
12+
def test_inline_detects_imports(self, pytester: Pytester):
13+
checkfile = pytester.makepyfile(
14+
"""
15+
from inline import itest
16+
import datetime
17+
18+
def m(a):
19+
b = a + datetime.timedelta(days=365)
20+
itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366))
21+
"""
22+
)
23+
for x in (pytester.path, checkfile):
24+
items, reprec = pytester.inline_genitems(x)
25+
assert len(items) == 1
26+
res = pytester.runpytest()
27+
assert res.ret != 1
28+
29+
# def test_inline_detects_from_imports(self, pytester: Pytester):
30+
# checkfile = pytester.makepyfile(
31+
# """
32+
# from inline import itest
33+
# import numpy as np
34+
# from scipy import stats as st
35+
36+
# def m(n, p):
37+
# b = st.binom(n, p)
38+
# itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p)
39+
# """
40+
# )
41+
# for x in (pytester.path, checkfile):
42+
# items, reprec = pytester.inline_genitems(x)
43+
# assert len(items) == 1
44+
# res = pytester.runpytest()
45+
# assert res.ret == 0
46+
847
def test_inline_parser(self, pytester: Pytester):
948
checkfile = pytester.makepyfile(
1049
"""
@@ -31,6 +70,7 @@ def m(a):
3170
items, reprec = pytester.inline_genitems(x)
3271
assert len(items) == 0
3372

73+
3474
def test_inline_malformed_given(self, pytester: Pytester):
3575
checkfile = pytester.makepyfile(
3676
"""
@@ -118,21 +158,6 @@ def m(a):
118158
res = pytester.runpytest()
119159
assert res.ret == 0
120160

121-
def test_check_eq_parameterized_tests(self, pytester: Pytester):
122-
checkfile = pytester.makepyfile(
123-
"""
124-
from inline import itest
125-
def m(a):
126-
a = a + 1
127-
itest(parameterized=True).given(a, [2, 3]).check_eq(a, [3, 4])
128-
"""
129-
)
130-
for x in (pytester.path, checkfile):
131-
items, reprec = pytester.inline_genitems(x)
132-
assert len(items) == 2
133-
res = pytester.runpytest()
134-
assert res.ret == 0
135-
136161
def test_malformed_check_eq_parameterized_tests(self, pytester: Pytester):
137162
checkfile = pytester.makepyfile(
138163
"""

0 commit comments

Comments
 (0)