@@ -296,11 +296,17 @@ class ExtractInlineTest(ast.NodeTransformer):
296296 arg_timeout_str = "timeout"
297297
298298 assume = "assume"
299+ inline_module_imported = False
300+ import_str = "import"
301+ from_str = "from"
302+ as_str = "as"
303+
299304 inline_module_imported = False
300305
301306 def __init__ (self ):
302307 self .cur_inline_test = InlineTest ()
303308 self .inline_test_list = []
309+ self .import_list = []
304310
305311 def is_inline_test_class (self , node ):
306312 if isinstance (node , ast .Call ):
@@ -350,16 +356,39 @@ def find_previous_stmt(self, node):
350356 return prev_stmt_node
351357 return self .find_condition_stmt (prev_stmt_node )
352358
353- def collect_inline_test_calls (self , node , inline_test_calls : List [ast .Call ]):
359+ def collect_inline_test_calls (self , node , inline_test_calls : List [ast .Call ], import_calls : List [ ast . Import ], import_from_calls : List [ ast . ImportFrom ] ):
354360 """
355361 collect all function calls in the node
356362 """
357363 if isinstance (node , ast .Attribute ):
358- self .collect_inline_test_calls (node .value , inline_test_calls )
364+ self .collect_inline_test_calls (node .value , inline_test_calls , import_calls , import_from_calls )
359365 elif isinstance (node , ast .Call ):
360366 inline_test_calls .append (node )
361- self .collect_inline_test_calls (node .func , inline_test_calls )
367+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
368+ elif isinstance (node , ast .Import ):
369+ import_calls .append (node )
370+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
371+ elif isinstance (node , ast .ImportFrom ):
372+ import_from_calls .append (node )
373+ self .collect_inline_test_calls (node .func , inline_test_calls , import_calls , import_from_calls )
374+
375+ def collect_import_calls (self , node , import_calls : List [ast .Import ], import_from_calls : List [ast .ImportFrom ]):
376+ """
377+ collect all import calls in the node (should be done first)
378+ """
362379
380+ while not isinstance (node , ast .Module ) and node .parent != None :
381+ node = node .parent
382+
383+ if not isinstance (node , ast .Module ):
384+ return
385+
386+ for child in node .children :
387+ if isinstance (child , ast .Import ):
388+ import_calls .append (child )
389+ elif isinstance (child , ast .ImportFrom ):
390+ import_from_calls .append (child )
391+
363392 def parse_constructor (self , node ):
364393 """
365394 Parse a constructor call.
@@ -931,8 +960,13 @@ def parse_parameterized_test(self):
931960 parameterized_test .test_name = self .cur_inline_test .test_name + "_" + str (index )
932961
933962 def parse_inline_test (self , node ):
934- inline_test_calls = []
935- self .collect_inline_test_calls (node , inline_test_calls )
963+ import_calls = []
964+ import_from_calls = []
965+ inline_test_calls = []
966+
967+ self .collect_inline_test_calls (node , inline_test_calls , import_calls , import_from_calls )
968+ self .collect_import_calls (node , import_calls , import_from_calls )
969+
936970 inline_test_calls .reverse ()
937971
938972 if len (inline_test_calls ) <= 1 :
@@ -953,14 +987,32 @@ def parse_inline_test(self, node):
953987 self .parse_assume (call )
954988 inline_test_call_index += 1
955989
956- # "given(a, 1)"
957990 for call in inline_test_calls [inline_test_call_index :]:
958- if isinstance (call .func , ast .Attribute ) and call .func .attr == self .given_str :
959- self .parse_given (call )
960- inline_test_call_index += 1
991+ if isinstance (call .func , ast .Attribute ):
992+ if call .func .attr == self .given_str :
993+ self .parse_given (call )
994+ inline_test_call_index += 1
995+ elif call .func .attr == self .diff_given_str :
996+ self .parse_diff_given (call )
997+ inline_test_call_index += 1
998+
999+ # match call.func.attr:
1000+ # # "given(a, 1)"
1001+ # case self.given_str:
1002+ # self.parse_given(call)
1003+ # inline_test_call_index += 1
1004+ # # "diff_given(devices, ["cpu", "cuda"])"
1005+ # case self.diff_given_str:
1006+ # self.parse_diff_given(call)
1007+ # inline_test_call_index += 1
9611008 else :
9621009 break
9631010
1011+ for import_stmt in import_calls :
1012+ self .cur_inline_test .import_stmts .append (import_stmt )
1013+ for import_stmt in import_from_calls :
1014+ self .cur_inline_test .import_stmts .append (import_stmt )
1015+
9641016 # "check_eq" or "check_true" or "check_false" or "check_neq"
9651017 for call in inline_test_calls [inline_test_call_index :]:
9661018 # "check_eq(a, 1)"
0 commit comments