Skip to content

Commit 29f5b55

Browse files
committed
Readded Additional Functionality for Imports
1 parent a6f68fc commit 29f5b55

File tree

1 file changed

+68
-35
lines changed

1 file changed

+68
-35
lines changed

src/inline/plugin.py

Lines changed: 68 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(self):
159159
self.check_stmts = []
160160
self.given_stmts = []
161161
self.previous_stmts = []
162+
self.import_stmts = []
162163
self.prev_stmt_type = PrevStmtType.StmtExpr
163164
# the line number of test statement
164165
self.lineno = 0
@@ -174,11 +175,23 @@ def __init__(self):
174175
self.devices = None
175176
self.globs = {}
176177

178+
def write_imports(self):
179+
import_str = ""
180+
for n in self.import_stmts:
181+
import_str += ExtractInlineTest.node_to_source_code(n) + "\n"
182+
return import_str
183+
177184
def to_test(self):
185+
prefix = "\n"
186+
187+
# for n in self.import_stmts:
188+
# import_str += ExtractInlineTest.node_to_source_code(n) + "\n"
189+
190+
178191
if self.prev_stmt_type == PrevStmtType.CondExpr:
179192
if self.assume_stmts == []:
180-
return "\n".join(
181-
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
193+
return prefix.join(
194+
+ [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
182195
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
183196
)
184197
else:
@@ -187,11 +200,11 @@ def to_test(self):
187200
)
188201
assume_statement = self.assume_stmts[0]
189202
assume_node = self.build_assume_node(assume_statement, body_nodes)
190-
return "\n".join(ExtractInlineTest.node_to_source_code(assume_node))
203+
return prefix.join(ExtractInlineTest.node_to_source_code(assume_node))
191204

192205
else:
193206
if self.assume_stmts is None or self.assume_stmts == []:
194-
return "\n".join(
207+
return prefix.join(
195208
[ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts]
196209
+ [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts]
197210
+ [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts]
@@ -202,7 +215,7 @@ def to_test(self):
202215
)
203216
assume_statement = self.assume_stmts[0]
204217
assume_node = self.build_assume_node(assume_statement, body_nodes)
205-
return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)])
218+
return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)])
206219

207220
def build_assume_node(self, assumption_node, body_nodes):
208221
return ast.If(assumption_node, body_nodes, [])
@@ -252,7 +265,7 @@ class TimeoutException(Exception):
252265
## InlineTest Parser
253266
######################################################################
254267
class InlinetestParser:
255-
def parse(self, obj, globs: None):
268+
def parse(self, obj, globs: None):
256269
# obj = open(self.file_path, "r").read():
257270
if isinstance(obj, ModuleType):
258271
tree = ast.parse(open(obj.__file__, "r").read())
@@ -297,7 +310,7 @@ class ExtractInlineTest(ast.NodeTransformer):
297310
arg_devices_str = "devices"
298311
diff_test_str = "diff_test"
299312
assume = "assume"
300-
inline_module_imported = False
313+
301314
import_str = "import"
302315
from_str = "from"
303316
as_str = "as"
@@ -389,7 +402,7 @@ def collect_import_calls(self, node, import_calls: List[ast.Import], import_from
389402
import_calls.append(child)
390403
elif isinstance(child, ast.ImportFrom):
391404
import_from_calls.append(child)
392-
405+
393406
def parse_constructor(self, node):
394407
"""
395408
Parse a constructor call.
@@ -558,7 +571,6 @@ class ConstrArgs(enum.Enum):
558571
getattr(arg, value_prop_name))
559572

560573

561-
## Match implementation of above conditional tree; commented since Python < 3.10 does not support match
562574

563575
# match arg_idx:
564576
# case ConstrArgs.REPEATED:
@@ -581,7 +593,9 @@ class ConstrArgs(enum.Enum):
581593
raise MalformedException(
582594
f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float"
583595
)
584-
596+
#raise MalformedException("Argument " + str(index) + " incorrectly formatted. Argument should be a " + ConstrArgs.expected_ast_val_args[index].type())
597+
598+
585599
def parameterized_inline_tests_init(self, node: ast.List):
586600
if not self.cur_inline_test.parameterized_inline_tests:
587601
self.cur_inline_test.parameterized_inline_tests = [InlineTest() for _ in range(len(node.elts))]
@@ -1141,8 +1155,17 @@ def visit_Name(self, node):
11411155
self.cur_inline_test.previous_stmts = new_statements
11421156
self.cur_inline_test.check_stmts = comparisons
11431157

1144-
1158+
def parse_import(self, node):
1159+
# TODO: Differentiate between import, from import, and import alias
1160+
import_node = ast.Import(
1161+
names=[
1162+
ast.alias(name=node)
1163+
]
1164+
)
1165+
return import_node
11451166

1167+
def parse_import_from(self, node):
1168+
pass
11461169

11471170
def build_fail(self):
11481171
equal_node = ast.Compare(
@@ -1183,6 +1206,7 @@ def parse_group(self, node):
11831206
return stmt
11841207
else:
11851208
return node
1209+
11861210

11871211
def parse_parameterized_test(self):
11881212
for index, parameterized_test in enumerate(self.cur_inline_test.parameterized_inline_tests):
@@ -1219,24 +1243,11 @@ def parse_inline_test(self, node):
12191243
self.parse_assume(call)
12201244
inline_test_call_index += 1
12211245

1246+
# "given(a, 1)"
12221247
for call in inline_test_calls[inline_test_call_index:]:
1223-
if isinstance(call.func, ast.Attribute):
1224-
if call.func.attr == self.given_str:
1225-
self.parse_given(call)
1226-
inline_test_call_index += 1
1227-
elif call.func.attr == self.diff_given_str:
1228-
self.parse_diff_given(call)
1229-
inline_test_call_index += 1
1230-
1231-
# match call.func.attr:
1232-
# # "given(a, 1)"
1233-
# case self.given_str:
1234-
# self.parse_given(call)
1235-
# inline_test_call_index += 1
1236-
# # "diff_given(devices, ["cpu", "cuda"])"
1237-
# case self.diff_given_str:
1238-
# self.parse_diff_given(call)
1239-
# inline_test_call_index += 1
1248+
if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str:
1249+
self.parse_given(call)
1250+
inline_test_call_index += 1
12401251
else:
12411252
break
12421253

@@ -1245,6 +1256,7 @@ def parse_inline_test(self, node):
12451256
for import_stmt in import_from_calls:
12461257
self.cur_inline_test.import_stmts.append(import_stmt)
12471258

1259+
12481260
# "check_eq" or "check_true" or "check_false" or "check_neq"
12491261
for call in inline_test_calls[inline_test_call_index:]:
12501262
# "check_eq(a, 1)"
@@ -1307,6 +1319,7 @@ def node_to_source_code(node):
13071319
## InlineTest Finder
13081320
######################################################################
13091321
class InlineTestFinder:
1322+
# Finder should NOT store any global variables
13101323
def __init__(self, parser=InlinetestParser(), recurse=True, exclude_empty=True):
13111324
self._parser = parser
13121325
self._recurse = recurse
@@ -1351,7 +1364,14 @@ def _is_routine(self, obj):
13511364
pass
13521365
return inspect.isroutine(maybe_routine)
13531366

1354-
def find(self, obj, module=None, globs=None, extraglobs=None):
1367+
# def find_imports(self, obj, module=None):
1368+
# if module is False:
1369+
# module = None
1370+
# elif module is None:
1371+
# module = inspect.getmodule(obj)
1372+
1373+
1374+
def find(self, obj, module=None, globs=None, extraglobs=None, imports=None):
13551375
# Find the module that contains the given object (if obj is
13561376
# a module, then module=obj.).
13571377
if module is False:
@@ -1372,15 +1392,23 @@ def find(self, obj, module=None, globs=None, extraglobs=None):
13721392
if "__name__" not in globs:
13731393
globs["__name__"] = "__main__" # provide a default module name
13741394

1395+
# Find intersection between loaded modules and module imports
1396+
# if imports is None:
1397+
# imports = set(sys.modules) & set(globs)
1398+
# else:
1399+
# imports = imports.copy()
1400+
13751401
# Recursively explore `obj`, extracting InlineTests.
13761402
tests = []
1377-
self._find(tests, obj, module, globs, {})
1403+
self._find(tests, obj, module, globs, imports, {})
13781404
return tests
13791405

1380-
def _find(self, tests, obj, module, globs, seen):
1406+
def _find(self, tests, obj, module, globs, imports, seen):
13811407
if id(obj) in seen:
13821408
return
13831409
seen[id(obj)] = 1
1410+
1411+
13841412
# Find a test for this object, and add it to the list of tests.
13851413
test = self._parser.parse(obj, globs)
13861414
if test is not None:
@@ -1392,7 +1420,7 @@ def _find(self, tests, obj, module, globs, seen):
13921420

13931421
# Recurse to functions & classes.
13941422
if (self._is_routine(val) or inspect.isclass(val)) and self._from_module(module, val):
1395-
self._find(tests, val, module, globs, seen)
1423+
self._find(tests, val, module, globs, imports, seen)
13961424

13971425
# Look for tests in a class's contained objects.
13981426
if inspect.isclass(obj) and self._recurse:
@@ -1406,17 +1434,18 @@ def _find(self, tests, obj, module, globs, seen):
14061434
module, val
14071435
):
14081436
valname = "%s" % (valname)
1409-
self._find(tests, val, module, globs, seen)
1437+
self._find(tests, val, module, globs, imports, seen)
14101438

14111439

14121440
######################################################################
14131441
## InlineTest Runner
14141442
######################################################################
14151443
class InlineTestRunner:
14161444
def run(self, test: InlineTest, out: List) -> None:
1417-
test_str = test.to_test()
1445+
test_str = test.write_imports()
1446+
test_str += test.to_test()
14181447
print(test_str)
1419-
tree = ast.parse(test.to_test())
1448+
tree = ast.parse(test_str)
14201449
codeobj = compile(tree, filename="<ast>", mode="exec")
14211450
start_time = time.time()
14221451
if test.timeout > 0:
@@ -1553,6 +1582,10 @@ def collect(self) -> Iterable[InlinetestItem]:
15531582
group_tags = self.config.getoption("inlinetest_group", default=None)
15541583
order_tags = self.config.getoption("inlinetest_order", default=None)
15551584

1585+
# TODO: import all modules through the finder first before extracting inline tests
1586+
# - Create ast for all imports
1587+
# - If a function references an import, then include the imported library reference in the ast
1588+
15561589
for test_list in finder.find(module):
15571590
# reorder the list if there are tests to be ordered
15581591
ordered_list = InlinetestModule.order_tests(test_list, order_tags)

0 commit comments

Comments
 (0)