@@ -297,11 +297,17 @@ class ExtractInlineTest(ast.NodeTransformer):
297297 arg_devices_str = "devices"
298298 diff_test_str = "diff_test"
299299 assume = "assume"
300+ inline_module_imported = False
301+ import_str = "import"
302+ from_str = "from"
303+ as_str = "as"
304+
300305 inline_module_imported = False
301306
302307 def __init__ (self ):
303308 self .cur_inline_test = InlineTest ()
304309 self .inline_test_list = []
310+ self .import_list = []
305311
306312 def is_inline_test_class (self , node ):
307313 if isinstance (node , ast .Call ):
@@ -351,16 +357,39 @@ def find_previous_stmt(self, node):
351357 return prev_stmt_node
352358 return self .find_condition_stmt (prev_stmt_node )
353359
354- def collect_inline_test_calls (self , node , inline_test_calls : List [ast .Call ]):
360+ def collect_inline_test_calls (self , node , inline_test_calls : List [ast .Call ], import_calls : List [ ast . Import ], import_from_calls : List [ ast . ImportFrom ] ):
355361 """
356362 collect all function calls in the node
357363 """
358364 if isinstance (node , ast .Attribute ):
359- self .collect_inline_test_calls (node .value , inline_test_calls )
365+ self .collect_inline_test_calls (node .value , inline_test_calls , import_calls , import_from_calls )
360366 elif isinstance (node , ast .Call ):
361367 inline_test_calls .append (node )
362- self .collect_inline_test_calls (node .func , inline_test_calls )
368+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
369+ elif isinstance (node , ast .Import ):
370+ import_calls .append (node )
371+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
372+ elif isinstance (node , ast .ImportFrom ):
373+ import_from_calls .append (node )
374+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
375+
376+ def collect_import_calls (self , node , import_calls : List [ast .Import ], import_from_calls : List [ast .ImportFrom ]):
377+ """
378+ collect all import calls in the node (should be done first)
379+ """
363380
381+ while not isinstance (node , ast .Module ) and node .parent != None :
382+ node = node .parent
383+
384+ if not isinstance (node , ast .Module ):
385+ return
386+
387+ for child in node .children :
388+ if isinstance (child , ast .Import ):
389+ import_calls .append (child )
390+ elif isinstance (child , ast .ImportFrom ):
391+ import_from_calls .append (child )
392+
364393 def parse_constructor (self , node ):
365394 """
366395 Parse a constructor call.
@@ -1163,8 +1192,13 @@ def parse_parameterized_test(self):
11631192 parameterized_test .test_name = self .cur_inline_test .test_name + "_" + str (index )
11641193
11651194 def parse_inline_test (self , node ):
1166- inline_test_calls = []
1167- self .collect_inline_test_calls (node , inline_test_calls )
1195+ import_calls = []
1196+ import_from_calls = []
1197+ inline_test_calls = []
1198+
1199+ self .collect_inline_test_calls (node , inline_test_calls , import_calls , import_from_calls )
1200+ self .collect_import_calls (node , import_calls , import_from_calls )
1201+
11681202 inline_test_calls .reverse ()
11691203
11701204 if len (inline_test_calls ) <= 1 :
@@ -1185,14 +1219,32 @@ def parse_inline_test(self, node):
11851219 self .parse_assume (call )
11861220 inline_test_call_index += 1
11871221
1188- # "given(a, 1)"
11891222 for call in inline_test_calls [inline_test_call_index :]:
1190- if isinstance (call .func , ast .Attribute ) and call .func .attr == self .given_str :
1191- self .parse_given (call )
1192- inline_test_call_index += 1
1223+ if isinstance (call .func , ast .Attribute ):
1224+ if call .func .attr == self .given_str :
1225+ self .parse_given (call )
1226+ inline_test_call_index += 1
1227+ elif call .func .attr == self .diff_given_str :
1228+ self .parse_diff_given (call )
1229+ inline_test_call_index += 1
1230+
1231+ # match call.func.attr:
1232+ # # "given(a, 1)"
1233+ # case self.given_str:
1234+ # self.parse_given(call)
1235+ # inline_test_call_index += 1
1236+ # # "diff_given(devices, ["cpu", "cuda"])"
1237+ # case self.diff_given_str:
1238+ # self.parse_diff_given(call)
1239+ # inline_test_call_index += 1
11931240 else :
11941241 break
11951242
1243+ for import_stmt in import_calls :
1244+ self .cur_inline_test .import_stmts .append (import_stmt )
1245+ for import_stmt in import_from_calls :
1246+ self .cur_inline_test .import_stmts .append (import_stmt )
1247+
11961248 # "check_eq" or "check_true" or "check_false" or "check_neq"
11971249 for call in inline_test_calls [inline_test_call_index :]:
11981250 # "check_eq(a, 1)"
0 commit comments