@@ -296,13 +296,16 @@ class ExtractInlineTest(ast.NodeTransformer):
296296 check_not_same = "check_not_same"
297297 fail_str = "fail"
298298 given_str = "given"
299+ diff_given_str = "diff_given"
299300 group_str = "Group"
300301 arg_test_name_str = "test_name"
301302 arg_parameterized_str = "parameterized"
302303 arg_repeated_str = "repeated"
303304 arg_tag_str = "tag"
304305 arg_disabled_str = "disabled"
305306 arg_timeout_str = "timeout"
307+ arg_devices_str = "devices"
308+ diff_test_str = "diff_test"
306309
307310 assume = "assume"
308311
@@ -596,6 +599,23 @@ def parse_given(self, node):
596599 else :
597600 raise MalformedException ("inline test: invalid given(), expected 2 args" )
598601
602+ def parse_diff_given (self , node ):
603+ PROPERTY = 0
604+ VALUES = 1
605+
606+ if len (node .args ) == 2 :
607+ if self .cur_inline_test .parameterized :
608+ raise MalformedException ("inline test: Parameterized inline tests currently do not support differential tests." )
609+ else :
610+ devices = []
611+ for elt in node .args [VALUES ].elts :
612+ if elt .value not in {"cpu" , "cuda" , "mps" }:
613+ raise MalformedException (f"Invalid device: { elt .value } . Must be one of ['cpu', 'cuda', 'mps']" )
614+ devices .append (elt .value )
615+ setattr (self .cur_inline_test , node .args [PROPERTY ].id , devices )
616+ else :
617+ raise MalformedException ("inline test: invalid diff_given(), expected 2 args" )
618+
599619 def parse_assume (self , node ):
600620 if len (node .args ) == 1 :
601621 if self .cur_inline_test .parameterized :
@@ -930,6 +950,229 @@ def parse_fail(self, node):
930950 else :
931951 raise MalformedException ("inline test: fail() does not expect any arguments" )
932952
953+ def parse_diff_test (self , node ):
954+ if not self .cur_inline_test .devices :
955+ raise MalformedException ("diff_test can only be used with the 'devices' parameter." )
956+
957+ if len (node .args ) != 1 :
958+ raise MalformedException ("diff_test() requires exactly 1 argument." )
959+
960+ output_node = self .parse_group (node .args [0 ])
961+
962+ # Get the original operation
963+ original_op = None
964+ for stmt in self .cur_inline_test .previous_stmts :
965+ if isinstance (stmt , ast .Assign ) and stmt .targets [0 ].id == output_node .id :
966+ original_op = stmt .value
967+ break
968+
969+ if not original_op :
970+ raise MalformedException ("Could not find original operation for diff_test" )
971+
972+ # Create our new statements
973+ new_statements = []
974+ device_outputs = []
975+
976+ # Import necessary modules for seed setting - Always add these
977+ # Import random module
978+ import_random = ast .ImportFrom (
979+ module = 'random' ,
980+ names = [ast .alias (name = 'seed' , asname = None )],
981+ level = 0
982+ )
983+ new_statements .append (import_random )
984+
985+ # Import numpy.random
986+ import_np = ast .ImportFrom (
987+ module = 'numpy' ,
988+ names = [ast .alias (name = 'random' , asname = 'np_random' )],
989+ level = 0
990+ )
991+ new_statements .append (import_np )
992+
993+ # Create seed function - Always add this
994+ seed_func_def = ast .FunctionDef (
995+ name = 'set_random_seed' ,
996+ args = ast .arguments (
997+ posonlyargs = [],
998+ args = [ast .arg (arg = 'seed_value' , annotation = None )],
999+ kwonlyargs = [],
1000+ kw_defaults = [],
1001+ defaults = []
1002+ ),
1003+ body = [
1004+ ast .Expr (
1005+ value = ast .Call (
1006+ func = ast .Name (id = 'seed' , ctx = ast .Load ()),
1007+ args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
1008+ keywords = []
1009+ )
1010+ ),
1011+ ast .Expr (
1012+ value = ast .Call (
1013+ func = ast .Attribute (
1014+ value = ast .Name (id = 'torch' , ctx = ast .Load ()),
1015+ attr = 'manual_seed'
1016+ ),
1017+ args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
1018+ keywords = []
1019+ )
1020+ ),
1021+ ast .Expr (
1022+ value = ast .Call (
1023+ func = ast .Attribute (
1024+ value = ast .Name (id = 'np_random' , ctx = ast .Load ()),
1025+ attr = 'seed'
1026+ ),
1027+ args = [ast .Name (id = 'seed_value' , ctx = ast .Load ())],
1028+ keywords = []
1029+ )
1030+ )
1031+ ],
1032+ decorator_list = [],
1033+ returns = None
1034+ )
1035+ new_statements .append (seed_func_def )
1036+
1037+ # Process input tensors
1038+ for given_stmt in self .cur_inline_test .given_stmts :
1039+ input_var = given_stmt .targets [0 ].id
1040+ ref_var = f"{ input_var } _ref"
1041+
1042+ # Always clone inputs for in-place operations
1043+ new_statements .append (
1044+ ast .Assign (
1045+ targets = [ast .Name (id = ref_var , ctx = ast .Store ())],
1046+ value = ast .Call (
1047+ func = ast .Attribute (
1048+ value = given_stmt .value ,
1049+ attr = "clone"
1050+ ),
1051+ args = [],
1052+ keywords = []
1053+ )
1054+ )
1055+ )
1056+
1057+ # Create device-specific versions
1058+ for device in self .cur_inline_test .devices :
1059+ device_var = f"{ input_var } _{ device } "
1060+
1061+ new_statements .append (
1062+ ast .Assign (
1063+ targets = [ast .Name (id = device_var , ctx = ast .Store ())],
1064+ value = ast .Call (
1065+ func = ast .Attribute (
1066+ value = ast .Name (id = ref_var , ctx = ast .Load ()),
1067+ attr = "to"
1068+ ),
1069+ args = [ast .Constant (value = device )],
1070+ keywords = []
1071+ )
1072+ )
1073+ )
1074+
1075+ # Create device-specific operations
1076+ device_input_map = {device : {} for device in self .cur_inline_test .devices }
1077+ for device in self .cur_inline_test .devices :
1078+ for given_stmt in self .cur_inline_test .given_stmts :
1079+ input_var = given_stmt .targets [0 ].id
1080+ device_input_map [device ][input_var ] = f"{ input_var } _{ device } "
1081+
1082+ # Always set seed before each device operation - no condition check
1083+ new_statements .append (
1084+ ast .Expr (
1085+ value = ast .Call (
1086+ func = ast .Name (id = 'set_random_seed' , ctx = ast .Load ()),
1087+ args = [ast .Constant (value = 42 )], # Use constant seed 42
1088+ keywords = []
1089+ )
1090+ )
1091+ )
1092+
1093+ device_op = copy .deepcopy (original_op )
1094+
1095+ # Replace input references
1096+ class ReplaceInputs (ast .NodeTransformer ):
1097+ def visit_Name (self , node ):
1098+ if node .id in device_input_map [device ]:
1099+ return ast .Name (id = device_input_map [device ][node .id ], ctx = node .ctx )
1100+ return node
1101+
1102+ device_op = ReplaceInputs ().visit (device_op )
1103+ device_output = f"output_{ device } "
1104+
1105+ new_statements .append (
1106+ ast .Assign (
1107+ targets = [ast .Name (id = device_output , ctx = ast .Store ())],
1108+ value = device_op
1109+ )
1110+ )
1111+ device_outputs .append (device_output )
1112+
1113+ # Standard comparison method for all operations - no condition check
1114+ comparisons = []
1115+ for i in range (len (device_outputs ) - 1 ):
1116+ dev1 = device_outputs [i ]
1117+ dev2 = device_outputs [i + 1 ]
1118+
1119+ dev1_cpu = f"{ dev1 } _cpu"
1120+ dev2_cpu = f"{ dev2 } _cpu"
1121+
1122+ # Move outputs back to CPU for comparison
1123+ new_statements .append (
1124+ ast .Assign (
1125+ targets = [ast .Name (id = dev1_cpu , ctx = ast .Store ())],
1126+ value = ast .Call (
1127+ func = ast .Attribute (
1128+ value = ast .Name (id = dev1 , ctx = ast .Load ()),
1129+ attr = "to"
1130+ ),
1131+ args = [ast .Constant (value = "cpu" )],
1132+ keywords = []
1133+ )
1134+ )
1135+ )
1136+
1137+ new_statements .append (
1138+ ast .Assign (
1139+ targets = [ast .Name (id = dev2_cpu , ctx = ast .Store ())],
1140+ value = ast .Call (
1141+ func = ast .Attribute (
1142+ value = ast .Name (id = dev2 , ctx = ast .Load ()),
1143+ attr = "to"
1144+ ),
1145+ args = [ast .Constant (value = "cpu" )],
1146+ keywords = []
1147+ )
1148+ )
1149+ )
1150+
1151+ # Standard allclose comparison
1152+ comparison = self .build_assert_eq (
1153+ ast .Call (
1154+ func = ast .Attribute (
1155+ value = ast .Name (id = dev1_cpu , ctx = ast .Load ()),
1156+ attr = "allclose"
1157+ ),
1158+ args = [
1159+ ast .Name (id = dev2_cpu , ctx = ast .Load ())
1160+ ],
1161+ keywords = [
1162+ ast .keyword (arg = "rtol" , value = ast .Constant (value = 1e-4 )),
1163+ ast .keyword (arg = "atol" , value = ast .Constant (value = 1e-4 )),
1164+ ast .keyword (arg = "equal_nan" , value = ast .Constant (value = True ))
1165+ ]
1166+ ),
1167+ ast .Constant (value = True )
1168+ )
1169+ comparisons .append (comparison )
1170+
1171+ # Replace statements
1172+ self .cur_inline_test .previous_stmts = new_statements
1173+ self .cur_inline_test .check_stmts = comparisons
1174+
1175+
9331176 def parse_group (self , node ):
9341177 if isinstance (node , ast .Call ) and isinstance (node .func , ast .Name ) and node .func .id == self .group_str :
9351178 # node type is ast.Call, node.func type is ast.Name
@@ -988,6 +1231,9 @@ def parse_inline_test(self, node):
9881231 if isinstance (call .func , ast .Attribute ) and call .func .attr == self .assume :
9891232 self .parse_assume (call )
9901233 inline_test_call_index += 1
1234+ elif call .func .attr == self .diff_given_str :
1235+ self .parse_diff_given (call )
1236+ inline_test_call_index += 1
9911237
9921238 for call in inline_test_calls [inline_test_call_index :]:
9931239 if isinstance (call .func , ast .Attribute ):
@@ -1027,9 +1273,15 @@ def parse_inline_test(self, node):
10271273 self .parse_check_not_same (call )
10281274 elif call .func .attr == self .fail_str :
10291275 self .parse_fail (call )
1276+ elif call .func .attr == self .diff_test_str :
1277+ self .parse_diff_test (call )
10301278 elif call .func .attr == self .given_str :
10311279 raise MalformedException (
1032- f"inline test: given() must be called before check_eq()/check_true()/check_false()"
1280+ f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()"
1281+ )
1282+ elif call .func .attr == self .diff_given_str :
1283+ raise MalformedException (
1284+ f"inline test: diff_given() must be called before check_eq()/check_true()/check_false()/diff_test()"
10331285 )
10341286 else :
10351287 raise MalformedException (f"inline test: invalid function call { self .node_to_source_code (call .func )} " )
0 commit comments