@@ -159,6 +159,7 @@ def __init__(self):
159159 self .check_stmts = []
160160 self .given_stmts = []
161161 self .previous_stmts = []
162+ self .import_stmts = []
162163 self .prev_stmt_type = PrevStmtType .StmtExpr
163164 # the line number of test statement
164165 self .lineno = 0
@@ -174,11 +175,23 @@ def __init__(self):
174175 self .devices = None
175176 self .globs = {}
176177
178+ def write_imports (self ):
179+ import_str = ""
180+ for n in self .import_stmts :
181+ import_str += ExtractInlineTest .node_to_source_code (n ) + "\n "
182+ return import_str
183+
177184 def to_test (self ):
185+ prefix = "\n "
186+
187+ # for n in self.import_stmts:
188+ # import_str += ExtractInlineTest.node_to_source_code(n) + "\n"
189+
190+
178191 if self .prev_stmt_type == PrevStmtType .CondExpr :
179192 if self .assume_stmts == []:
180- return " \n " .join (
181- [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
193+ return prefix .join (
194+ + [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
182195 + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
183196 )
184197 else :
@@ -187,11 +200,11 @@ def to_test(self):
187200 )
188201 assume_statement = self .assume_stmts [0 ]
189202 assume_node = self .build_assume_node (assume_statement , body_nodes )
190- return " \n " .join (ExtractInlineTest .node_to_source_code (assume_node ))
203+ return prefix .join (ExtractInlineTest .node_to_source_code (assume_node ))
191204
192205 else :
193206 if self .assume_stmts is None or self .assume_stmts == []:
194- return " \n " .join (
207+ return prefix .join (
195208 [ExtractInlineTest .node_to_source_code (n ) for n in self .given_stmts ]
196209 + [ExtractInlineTest .node_to_source_code (n ) for n in self .previous_stmts ]
197210 + [ExtractInlineTest .node_to_source_code (n ) for n in self .check_stmts ]
@@ -202,7 +215,7 @@ def to_test(self):
202215 )
203216 assume_statement = self .assume_stmts [0 ]
204217 assume_node = self .build_assume_node (assume_statement , body_nodes )
205- return " \n " .join ([ExtractInlineTest .node_to_source_code (assume_node )])
218+ return prefix .join ([ExtractInlineTest .node_to_source_code (assume_node )])
206219
207220 def build_assume_node (self , assumption_node , body_nodes ):
208221 return ast .If (assumption_node , body_nodes , [])
@@ -252,7 +265,7 @@ class TimeoutException(Exception):
252265## InlineTest Parser
253266######################################################################
254267class InlinetestParser :
255- def parse (self , obj , globs : None ):
268+ def parse (self , obj , globs : None ):
256269 # obj = open(self.file_path, "r").read():
257270 if isinstance (obj , ModuleType ):
258271 tree = ast .parse (open (obj .__file__ , "r" ).read ())
@@ -297,7 +310,7 @@ class ExtractInlineTest(ast.NodeTransformer):
297310 arg_devices_str = "devices"
298311 diff_test_str = "diff_test"
299312 assume = "assume"
300- inline_module_imported = False
313+
301314 import_str = "import"
302315 from_str = "from"
303316 as_str = "as"
@@ -389,7 +402,7 @@ def collect_import_calls(self, node, import_calls: List[ast.Import], import_from
389402 import_calls .append (child )
390403 elif isinstance (child , ast .ImportFrom ):
391404 import_from_calls .append (child )
392-
405+
393406 def parse_constructor (self , node ):
394407 """
395408 Parse a constructor call.
@@ -558,7 +571,6 @@ class ConstrArgs(enum.Enum):
558571 getattr (arg , value_prop_name ))
559572
560573
561- ## Match implementation of above conditional tree; commented since Python < 3.10 does not support match
562574
563575 # match arg_idx:
564576 # case ConstrArgs.REPEATED:
@@ -581,7 +593,9 @@ class ConstrArgs(enum.Enum):
581593 raise MalformedException (
582594 f"inline test: { self .class_name_str } () accepts { NUM_OF_ARGUMENTS } arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float"
583595 )
584-
596+ #raise MalformedException("Argument " + str(index) + " incorrectly formatted. Argument should be a " + ConstrArgs.expected_ast_val_args[index].type())
597+
598+
585599 def parameterized_inline_tests_init (self , node : ast .List ):
586600 if not self .cur_inline_test .parameterized_inline_tests :
587601 self .cur_inline_test .parameterized_inline_tests = [InlineTest () for _ in range (len (node .elts ))]
@@ -1141,8 +1155,17 @@ def visit_Name(self, node):
11411155 self .cur_inline_test .previous_stmts = new_statements
11421156 self .cur_inline_test .check_stmts = comparisons
11431157
1144-
1158+ def parse_import (self , node ):
1159+ # TODO: Differentiate between import, from import, and import alias
1160+ import_node = ast .Import (
1161+ names = [
1162+ ast .alias (name = node )
1163+ ]
1164+ )
1165+ return import_node
11451166
1167+ def parse_import_from (self , node ):
1168+ pass
11461169
11471170 def build_fail (self ):
11481171 equal_node = ast .Compare (
@@ -1183,6 +1206,7 @@ def parse_group(self, node):
11831206 return stmt
11841207 else :
11851208 return node
1209+
11861210
11871211 def parse_parameterized_test (self ):
11881212 for index , parameterized_test in enumerate (self .cur_inline_test .parameterized_inline_tests ):
@@ -1219,24 +1243,11 @@ def parse_inline_test(self, node):
12191243 self .parse_assume (call )
12201244 inline_test_call_index += 1
12211245
1246+ # "given(a, 1)"
12221247 for call in inline_test_calls [inline_test_call_index :]:
1223- if isinstance (call .func , ast .Attribute ):
1224- if call .func .attr == self .given_str :
1225- self .parse_given (call )
1226- inline_test_call_index += 1
1227- elif call .func .attr == self .diff_given_str :
1228- self .parse_diff_given (call )
1229- inline_test_call_index += 1
1230-
1231- # match call.func.attr:
1232- # # "given(a, 1)"
1233- # case self.given_str:
1234- # self.parse_given(call)
1235- # inline_test_call_index += 1
1236- # # "diff_given(devices, ["cpu", "cuda"])"
1237- # case self.diff_given_str:
1238- # self.parse_diff_given(call)
1239- # inline_test_call_index += 1
1248+ if isinstance (call .func , ast .Attribute ) and call .func .attr == self .given_str :
1249+ self .parse_given (call )
1250+ inline_test_call_index += 1
12401251 else :
12411252 break
12421253
@@ -1245,6 +1256,7 @@ def parse_inline_test(self, node):
12451256 for import_stmt in import_from_calls :
12461257 self .cur_inline_test .import_stmts .append (import_stmt )
12471258
1259+
12481260 # "check_eq" or "check_true" or "check_false" or "check_neq"
12491261 for call in inline_test_calls [inline_test_call_index :]:
12501262 # "check_eq(a, 1)"
@@ -1307,6 +1319,7 @@ def node_to_source_code(node):
13071319## InlineTest Finder
13081320######################################################################
13091321class InlineTestFinder :
1322+ # Finder should NOT store any global variables
13101323 def __init__ (self , parser = InlinetestParser (), recurse = True , exclude_empty = True ):
13111324 self ._parser = parser
13121325 self ._recurse = recurse
@@ -1351,7 +1364,14 @@ def _is_routine(self, obj):
13511364 pass
13521365 return inspect .isroutine (maybe_routine )
13531366
1354- def find (self , obj , module = None , globs = None , extraglobs = None ):
1367+ # def find_imports(self, obj, module=None):
1368+ # if module is False:
1369+ # module = None
1370+ # elif module is None:
1371+ # module = inspect.getmodule(obj)
1372+
1373+
1374+ def find (self , obj , module = None , globs = None , extraglobs = None , imports = None ):
13551375 # Find the module that contains the given object (if obj is
13561376 # a module, then module=obj.).
13571377 if module is False :
@@ -1372,15 +1392,23 @@ def find(self, obj, module=None, globs=None, extraglobs=None):
13721392 if "__name__" not in globs :
13731393 globs ["__name__" ] = "__main__" # provide a default module name
13741394
1395+ # Find intersection between loaded modules and module imports
1396+ # if imports is None:
1397+ # imports = set(sys.modules) & set(globs)
1398+ # else:
1399+ # imports = imports.copy()
1400+
13751401 # Recursively explore `obj`, extracting InlineTests.
13761402 tests = []
1377- self ._find (tests , obj , module , globs , {})
1403+ self ._find (tests , obj , module , globs , imports , {})
13781404 return tests
13791405
1380- def _find (self , tests , obj , module , globs , seen ):
1406+ def _find (self , tests , obj , module , globs , imports , seen ):
13811407 if id (obj ) in seen :
13821408 return
13831409 seen [id (obj )] = 1
1410+
1411+
13841412 # Find a test for this object, and add it to the list of tests.
13851413 test = self ._parser .parse (obj , globs )
13861414 if test is not None :
@@ -1392,7 +1420,7 @@ def _find(self, tests, obj, module, globs, seen):
13921420
13931421 # Recurse to functions & classes.
13941422 if (self ._is_routine (val ) or inspect .isclass (val )) and self ._from_module (module , val ):
1395- self ._find (tests , val , module , globs , seen )
1423+ self ._find (tests , val , module , globs , imports , seen )
13961424
13971425 # Look for tests in a class's contained objects.
13981426 if inspect .isclass (obj ) and self ._recurse :
@@ -1406,17 +1434,18 @@ def _find(self, tests, obj, module, globs, seen):
14061434 module , val
14071435 ):
14081436 valname = "%s" % (valname )
1409- self ._find (tests , val , module , globs , seen )
1437+ self ._find (tests , val , module , globs , imports , seen )
14101438
14111439
14121440######################################################################
14131441## InlineTest Runner
14141442######################################################################
14151443class InlineTestRunner :
14161444 def run (self , test : InlineTest , out : List ) -> None :
1417- test_str = test .to_test ()
1445+ test_str = test .write_imports ()
1446+ test_str += test .to_test ()
14181447 print (test_str )
1419- tree = ast .parse (test . to_test () )
1448+ tree = ast .parse (test_str )
14201449 codeobj = compile (tree , filename = "<ast>" , mode = "exec" )
14211450 start_time = time .time ()
14221451 if test .timeout > 0 :
@@ -1553,6 +1582,10 @@ def collect(self) -> Iterable[InlinetestItem]:
15531582 group_tags = self .config .getoption ("inlinetest_group" , default = None )
15541583 order_tags = self .config .getoption ("inlinetest_order" , default = None )
15551584
1585+ # TODO: import all modules through the finder first before extracting inline tests
1586+ # - Create ast for all imports
1587+ # - If a function references an import, then include the imported library reference in the ast
1588+
15561589 for test_list in finder .find (module ):
15571590 # reorder the list if there are tests to be ordered
15581591 ordered_list = InlinetestModule .order_tests (test_list , order_tags )
0 commit comments