Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 261 additions & 4 deletions src/inline/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,13 +296,16 @@ class ExtractInlineTest(ast.NodeTransformer):
check_not_same = "check_not_same"
fail_str = "fail"
given_str = "given"
diff_given_str = "diff_given"
group_str = "Group"
arg_test_name_str = "test_name"
arg_parameterized_str = "parameterized"
arg_repeated_str = "repeated"
arg_tag_str = "tag"
arg_disabled_str = "disabled"
arg_timeout_str = "timeout"
arg_devices_str = "devices"
diff_test_str = "diff_test"

assume = "assume"

Expand Down Expand Up @@ -596,6 +599,30 @@ def parse_given(self, node):
else:
raise MalformedException("inline test: invalid given(), expected 2 args")

def parse_diff_given(self, node):
PROPERTY = 0
VALUES = 1

if sys.version_info >= (3, 8, 0):
attr_name = "value"
else:
attr_name = "s"


if len(node.args) == 2:
if self.cur_inline_test.parameterized:
raise MalformedException("inline test: Parameterized inline tests currently do not support differential tests.")
else:
devices = []
for elt in node.args[VALUES].elts:
value = getattr(elt, attr_name)
if value not in {"cpu", "cuda", "mps"}:
raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']")
devices.append(value)
setattr(self.cur_inline_test, node.args[PROPERTY].id, devices)
else:
raise MalformedException("inline test: invalid diff_given(), expected 2 args")

def parse_assume(self, node):
if len(node.args) == 1:
if self.cur_inline_test.parameterized:
Expand Down Expand Up @@ -930,6 +957,229 @@ def parse_fail(self, node):
else:
raise MalformedException("inline test: fail() does not expect any arguments")

def parse_diff_test(self, node):
if not self.cur_inline_test.devices:
raise MalformedException("diff_test can only be used with the 'devices' parameter.")

if len(node.args) != 1:
raise MalformedException("diff_test() requires exactly 1 argument.")

output_node = self.parse_group(node.args[0])

# Get the original operation
original_op = None
for stmt in self.cur_inline_test.previous_stmts:
if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id:
original_op = stmt.value
break

if not original_op:
raise MalformedException("Could not find original operation for diff_test")

# Create our new statements
new_statements = []
device_outputs = []

# Import necessary modules for seed setting - Always add these
# Import random module
import_random = ast.ImportFrom(
module='random',
names=[ast.alias(name='seed', asname=None)],
level=0
)
new_statements.append(import_random)

# Import numpy.random
import_np = ast.ImportFrom(
module='numpy',
names=[ast.alias(name='random', asname='np_random')],
level=0
)
new_statements.append(import_np)

# Create seed function - Always add this
seed_func_def = ast.FunctionDef(
name='set_random_seed',
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg='seed_value', annotation=None)],
kwonlyargs=[],
kw_defaults=[],
defaults=[]
),
body=[
ast.Expr(
value=ast.Call(
func=ast.Name(id='seed', ctx=ast.Load()),
args=[ast.Name(id='seed_value', ctx=ast.Load())],
keywords=[]
)
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id='torch', ctx=ast.Load()),
attr='manual_seed'
),
args=[ast.Name(id='seed_value', ctx=ast.Load())],
keywords=[]
)
),
ast.Expr(
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id='np_random', ctx=ast.Load()),
attr='seed'
),
args=[ast.Name(id='seed_value', ctx=ast.Load())],
keywords=[]
)
)
],
decorator_list=[],
returns=None
)
new_statements.append(seed_func_def)

# Process input tensors
for given_stmt in self.cur_inline_test.given_stmts:
input_var = given_stmt.targets[0].id
ref_var = f"{input_var}_ref"

# Always clone inputs for in-place operations
new_statements.append(
ast.Assign(
targets=[ast.Name(id=ref_var, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=given_stmt.value,
attr="clone"
),
args=[],
keywords=[]
)
)
)

# Create device-specific versions
for device in self.cur_inline_test.devices:
device_var = f"{input_var}_{device}"

new_statements.append(
ast.Assign(
targets=[ast.Name(id=device_var, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=ref_var, ctx=ast.Load()),
attr="to"
),
args=[ast.Constant(value=device)],
keywords=[]
)
)
)

# Create device-specific operations
device_input_map = {device: {} for device in self.cur_inline_test.devices}
for device in self.cur_inline_test.devices:
for given_stmt in self.cur_inline_test.given_stmts:
input_var = given_stmt.targets[0].id
device_input_map[device][input_var] = f"{input_var}_{device}"

# Always set seed before each device operation - no condition check
new_statements.append(
ast.Expr(
value=ast.Call(
func=ast.Name(id='set_random_seed', ctx=ast.Load()),
args=[ast.Constant(value=42)], # Use constant seed 42
keywords=[]
)
)
)

device_op = copy.deepcopy(original_op)

# Replace input references
class ReplaceInputs(ast.NodeTransformer):
def visit_Name(self, node):
if node.id in device_input_map[device]:
return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx)
return node

device_op = ReplaceInputs().visit(device_op)
device_output = f"output_{device}"

new_statements.append(
ast.Assign(
targets=[ast.Name(id=device_output, ctx=ast.Store())],
value=device_op
)
)
device_outputs.append(device_output)

# Standard comparison method for all operations - no condition check
comparisons = []
for i in range(len(device_outputs) - 1):
dev1 = device_outputs[i]
dev2 = device_outputs[i + 1]

dev1_cpu = f"{dev1}_cpu"
dev2_cpu = f"{dev2}_cpu"

# Move outputs back to CPU for comparison
new_statements.append(
ast.Assign(
targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=dev1, ctx=ast.Load()),
attr="to"
),
args=[ast.Constant(value="cpu")],
keywords=[]
)
)
)

new_statements.append(
ast.Assign(
targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())],
value=ast.Call(
func=ast.Attribute(
value=ast.Name(id=dev2, ctx=ast.Load()),
attr="to"
),
args=[ast.Constant(value="cpu")],
keywords=[]
)
)
)

# Standard allclose comparison
comparison = self.build_assert_eq(
ast.Call(
func=ast.Attribute(
value=ast.Name(id=dev1_cpu, ctx=ast.Load()),
attr="allclose"
),
args=[
ast.Name(id=dev2_cpu, ctx=ast.Load())
],
keywords=[
ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)),
ast.keyword(arg="atol", value=ast.Constant(value=1e-4)),
ast.keyword(arg="equal_nan", value=ast.Constant(value=True))
]
),
ast.Constant(value=True)
)
comparisons.append(comparison)

# Replace statements
self.cur_inline_test.previous_stmts = new_statements
self.cur_inline_test.check_stmts = comparisons


def parse_group(self, node):
if isinstance(node, ast.Call) and isinstance(node.func, ast.Name) and node.func.id == self.group_str:
# node type is ast.Call, node.func type is ast.Name
Expand Down Expand Up @@ -994,9 +1244,10 @@ def parse_inline_test(self, node):
if call.func.attr == self.given_str:
self.parse_given(call)
inline_test_call_index += 1
else:
break

elif call.func.attr == self.diff_given_str:
self.parse_diff_given(call)
inline_test_call_index += 1

for import_stmt in import_calls:
self.cur_inline_test.import_stmts.append(import_stmt)
for import_stmt in import_from_calls:
Expand Down Expand Up @@ -1027,9 +1278,15 @@ def parse_inline_test(self, node):
self.parse_check_not_same(call)
elif call.func.attr == self.fail_str:
self.parse_fail(call)
elif call.func.attr == self.diff_test_str:
self.parse_diff_test(call)
elif call.func.attr == self.given_str:
raise MalformedException(
f"inline test: given() must be called before check_eq()/check_true()/check_false()"
f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()"
)
elif call.func.attr == self.diff_given_str:
raise MalformedException(
f"inline test: diff_given() must be called before check_eq()/check_true()/check_false()/diff_test()"
)
else:
raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}")
Expand Down
Loading
Loading