Skip to content

Commit 7fbcdca

Browse files
committed
Added Diff Given and Added Back Diff Test Functionality
1 parent 19fcf7c commit 7fbcdca

File tree

1 file changed

+253
-1
lines changed

1 file changed

+253
-1
lines changed

src/inline/plugin.py

Lines changed: 253 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -296,13 +296,16 @@ class ExtractInlineTest(ast.NodeTransformer):
296296
check_not_same = "check_not_same"
297297
fail_str = "fail"
298298
given_str = "given"
299+
diff_given_str = "diff_given"
299300
group_str = "Group"
300301
arg_test_name_str = "test_name"
301302
arg_parameterized_str = "parameterized"
302303
arg_repeated_str = "repeated"
303304
arg_tag_str = "tag"
304305
arg_disabled_str = "disabled"
305306
arg_timeout_str = "timeout"
307+
arg_devices_str = "devices"
308+
diff_test_str = "diff_test"
306309

307310
assume = "assume"
308311

@@ -596,6 +599,23 @@ def parse_given(self, node):
596599
else:
597600
raise MalformedException("inline test: invalid given(), expected 2 args")
598601

602+
def parse_diff_given(self, node):
603+
PROPERTY = 0
604+
VALUES = 1
605+
606+
if len(node.args) == 2:
607+
if self.cur_inline_test.parameterized:
608+
raise MalformedException("inline test: Parameterized inline tests currently do not support differential tests.")
609+
else:
610+
devices = []
611+
for elt in node.args[VALUES].elts:
612+
if elt.value not in {"cpu", "cuda", "mps"}:
613+
raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']")
614+
devices.append(elt.value)
615+
setattr(self.cur_inline_test, node.args[PROPERTY].id, devices)
616+
else:
617+
raise MalformedException("inline test: invalid diff_given(), expected 2 args")
618+
599619
def parse_assume(self, node):
600620
if len(node.args) == 1:
601621
if self.cur_inline_test.parameterized:
@@ -930,6 +950,229 @@ def parse_fail(self, node):
930950
else:
931951
raise MalformedException("inline test: fail() does not expect any arguments")
932952

953+
def parse_diff_test(self, node):
954+
if not self.cur_inline_test.devices:
955+
raise MalformedException("diff_test can only be used with the 'devices' parameter.")
956+
957+
if len(node.args) != 1:
958+
raise MalformedException("diff_test() requires exactly 1 argument.")
959+
960+
output_node = self.parse_group(node.args[0])
961+
962+
# Get the original operation
963+
original_op = None
964+
for stmt in self.cur_inline_test.previous_stmts:
965+
if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id:
966+
original_op = stmt.value
967+
break
968+
969+
if not original_op:
970+
raise MalformedException("Could not find original operation for diff_test")
971+
972+
# Create our new statements
973+
new_statements = []
974+
device_outputs = []
975+
976+
# Import necessary modules for seed setting - Always add these
977+
# Import random module
978+
import_random = ast.ImportFrom(
979+
module='random',
980+
names=[ast.alias(name='seed', asname=None)],
981+
level=0
982+
)
983+
new_statements.append(import_random)
984+
985+
# Import numpy.random
986+
import_np = ast.ImportFrom(
987+
module='numpy',
988+
names=[ast.alias(name='random', asname='np_random')],
989+
level=0
990+
)
991+
new_statements.append(import_np)
992+
993+
# Create seed function - Always add this
994+
seed_func_def = ast.FunctionDef(
995+
name='set_random_seed',
996+
args=ast.arguments(
997+
posonlyargs=[],
998+
args=[ast.arg(arg='seed_value', annotation=None)],
999+
kwonlyargs=[],
1000+
kw_defaults=[],
1001+
defaults=[]
1002+
),
1003+
body=[
1004+
ast.Expr(
1005+
value=ast.Call(
1006+
func=ast.Name(id='seed', ctx=ast.Load()),
1007+
args=[ast.Name(id='seed_value', ctx=ast.Load())],
1008+
keywords=[]
1009+
)
1010+
),
1011+
ast.Expr(
1012+
value=ast.Call(
1013+
func=ast.Attribute(
1014+
value=ast.Name(id='torch', ctx=ast.Load()),
1015+
attr='manual_seed'
1016+
),
1017+
args=[ast.Name(id='seed_value', ctx=ast.Load())],
1018+
keywords=[]
1019+
)
1020+
),
1021+
ast.Expr(
1022+
value=ast.Call(
1023+
func=ast.Attribute(
1024+
value=ast.Name(id='np_random', ctx=ast.Load()),
1025+
attr='seed'
1026+
),
1027+
args=[ast.Name(id='seed_value', ctx=ast.Load())],
1028+
keywords=[]
1029+
)
1030+
)
1031+
],
1032+
decorator_list=[],
1033+
returns=None
1034+
)
1035+
new_statements.append(seed_func_def)
1036+
1037+
# Process input tensors
1038+
for given_stmt in self.cur_inline_test.given_stmts:
1039+
input_var = given_stmt.targets[0].id
1040+
ref_var = f"{input_var}_ref"
1041+
1042+
# Always clone inputs for in-place operations
1043+
new_statements.append(
1044+
ast.Assign(
1045+
targets=[ast.Name(id=ref_var, ctx=ast.Store())],
1046+
value=ast.Call(
1047+
func=ast.Attribute(
1048+
value=given_stmt.value,
1049+
attr="clone"
1050+
),
1051+
args=[],
1052+
keywords=[]
1053+
)
1054+
)
1055+
)
1056+
1057+
# Create device-specific versions
1058+
for device in self.cur_inline_test.devices:
1059+
device_var = f"{input_var}_{device}"
1060+
1061+
new_statements.append(
1062+
ast.Assign(
1063+
targets=[ast.Name(id=device_var, ctx=ast.Store())],
1064+
value=ast.Call(
1065+
func=ast.Attribute(
1066+
value=ast.Name(id=ref_var, ctx=ast.Load()),
1067+
attr="to"
1068+
),
1069+
args=[ast.Constant(value=device)],
1070+
keywords=[]
1071+
)
1072+
)
1073+
)
1074+
1075+
# Create device-specific operations
1076+
device_input_map = {device: {} for device in self.cur_inline_test.devices}
1077+
for device in self.cur_inline_test.devices:
1078+
for given_stmt in self.cur_inline_test.given_stmts:
1079+
input_var = given_stmt.targets[0].id
1080+
device_input_map[device][input_var] = f"{input_var}_{device}"
1081+
1082+
# Always set seed before each device operation - no condition check
1083+
new_statements.append(
1084+
ast.Expr(
1085+
value=ast.Call(
1086+
func=ast.Name(id='set_random_seed', ctx=ast.Load()),
1087+
args=[ast.Constant(value=42)], # Use constant seed 42
1088+
keywords=[]
1089+
)
1090+
)
1091+
)
1092+
1093+
device_op = copy.deepcopy(original_op)
1094+
1095+
# Replace input references
1096+
class ReplaceInputs(ast.NodeTransformer):
1097+
def visit_Name(self, node):
1098+
if node.id in device_input_map[device]:
1099+
return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx)
1100+
return node
1101+
1102+
device_op = ReplaceInputs().visit(device_op)
1103+
device_output = f"output_{device}"
1104+
1105+
new_statements.append(
1106+
ast.Assign(
1107+
targets=[ast.Name(id=device_output, ctx=ast.Store())],
1108+
value=device_op
1109+
)
1110+
)
1111+
device_outputs.append(device_output)
1112+
1113+
# Standard comparison method for all operations - no condition check
1114+
comparisons = []
1115+
for i in range(len(device_outputs) - 1):
1116+
dev1 = device_outputs[i]
1117+
dev2 = device_outputs[i + 1]
1118+
1119+
dev1_cpu = f"{dev1}_cpu"
1120+
dev2_cpu = f"{dev2}_cpu"
1121+
1122+
# Move outputs back to CPU for comparison
1123+
new_statements.append(
1124+
ast.Assign(
1125+
targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())],
1126+
value=ast.Call(
1127+
func=ast.Attribute(
1128+
value=ast.Name(id=dev1, ctx=ast.Load()),
1129+
attr="to"
1130+
),
1131+
args=[ast.Constant(value="cpu")],
1132+
keywords=[]
1133+
)
1134+
)
1135+
)
1136+
1137+
new_statements.append(
1138+
ast.Assign(
1139+
targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())],
1140+
value=ast.Call(
1141+
func=ast.Attribute(
1142+
value=ast.Name(id=dev2, ctx=ast.Load()),
1143+
attr="to"
1144+
),
1145+
args=[ast.Constant(value="cpu")],
1146+
keywords=[]
1147+
)
1148+
)
1149+
)
1150+
1151+
# Standard allclose comparison
1152+
comparison = self.build_assert_eq(
1153+
ast.Call(
1154+
func=ast.Attribute(
1155+
value=ast.Name(id=dev1_cpu, ctx=ast.Load()),
1156+
attr="allclose"
1157+
),
1158+
args=[
1159+
ast.Name(id=dev2_cpu, ctx=ast.Load())
1160+
],
1161+
keywords=[
1162+
ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)),
1163+
ast.keyword(arg="atol", value=ast.Constant(value=1e-4)),
1164+
ast.keyword(arg="equal_nan", value=ast.Constant(value=True))
1165+
]
1166+
),
1167+
ast.Constant(value=True)
1168+
)
1169+
comparisons.append(comparison)
1170+
1171+
# Replace statements
1172+
self.cur_inline_test.previous_stmts = new_statements
1173+
self.cur_inline_test.check_stmts = comparisons
1174+
1175+
9331176
def parse_group(self, node):
9341177
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == self.group_str:
9351178
# node type is ast.Call, node.func type is ast.Name
@@ -988,6 +1231,9 @@ def parse_inline_test(self, node):
9881231
if isinstance(call.func, ast.Attribute) and call.func.attr == self.assume:
9891232
self.parse_assume(call)
9901233
inline_test_call_index += 1
1234+
elif call.func.attr == self.diff_given_str:
1235+
self.parse_diff_given(call)
1236+
inline_test_call_index += 1
9911237

9921238
for call in inline_test_calls[inline_test_call_index:]:
9931239
if isinstance(call.func, ast.Attribute):
@@ -1027,9 +1273,15 @@ def parse_inline_test(self, node):
10271273
self.parse_check_not_same(call)
10281274
elif call.func.attr == self.fail_str:
10291275
self.parse_fail(call)
1276+
elif call.func.attr == self.diff_test_str:
1277+
self.parse_diff_test(call)
10301278
elif call.func.attr == self.given_str:
10311279
raise MalformedException(
1032-
f"inline test: given() must be called before check_eq()/check_true()/check_false()"
1280+
f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()"
1281+
)
1282+
elif call.func.attr == self.diff_given_str:
1283+
raise MalformedException(
1284+
f"inline test: diff_given() must be called before check_eq()/check_true()/check_false()/diff_test()"
10331285
)
10341286
else:
10351287
raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}")

0 commit comments

Comments
 (0)