diff --git a/demo/example.py b/demo/example.py index 63f1894..c43a024 100644 --- a/demo/example.py +++ b/demo/example.py @@ -17,7 +17,7 @@ def get_assignment_map_from_checkpoint(tvars, init_checkpoint): # inline test itest().given(name, "a:0").check_eq(m.group(1), "a") # a failing inline test - # itest().given(name, "a:0").check_eq(m.group(1), "aaa") + # itest().given(name, "a:0").check_eq(m.group(1), "aaa") if m is not None: name = m.group(1) name_to_variable[name] = var diff --git a/meeting-notes/10-31 Meeting Notes.txt b/meeting-notes/10-31 Meeting Notes.txt new file mode 100644 index 0000000..97140a4 --- /dev/null +++ b/meeting-notes/10-31 Meeting Notes.txt @@ -0,0 +1,27 @@ +- Are tests run directly after parsing, or are they amended to a file first? +-- Ad hoc file + +- Need more information on the control flow of the module? +-- Will ask +-- Entry point is VisitExpr +--- Traverses ast tree, if current node is expr, it triggers the function +--- If the expression is an inline test class, it starts with itest() and will trigger past inline test + +- Where does module determine that a line in the module is an inline test? +-- Will collect all calls and will keep collect all relevant st nodes in file; at some point, will try to pass nodes into inline tests + +- What is the root node of the ast in collect_inline_test_calls()? +-- Inline test expression +-- Inline tests are generated separately with each individual itest() call + +# Ideas +- Find should return list of tests AND imports +- Add import_list to ExtractInlineTest class; would search for imports in the module at the same time as inline tests +- Furthermore, collect import calls while collecting ineline_test_calls in "parse_inline_test" +- From research, found that you could get list of imports loaded by module by getting intersection between system modules and module imports, though that would be redundant since we can just access the source code directly +- Can define strings for "import", "from", and "as"; would like to know if ast's group code by line + + +Import statements may be in classes or functions + +Add references to imported libraries to ensure they are imported properly \ No newline at end of file diff --git a/meeting-notes/11-14 Meeting Notes.txt b/meeting-notes/11-14 Meeting Notes.txt new file mode 100644 index 0000000..74c122d --- /dev/null +++ b/meeting-notes/11-14 Meeting Notes.txt @@ -0,0 +1,16 @@ +Split progress so far, and push each as a separate request: +- Constructor +- Imports +- Diff Given + +-Start another fork on the original pytest-inline repo +-Put constructor changes onto new fork +-Push request +-Repeat for other two components + +Diff Given: +- Multiple scenes to worry about +- May need to up the best version, then develop off that + +Clean up as separate requests, push them, then add more concrete tests for more inline tests +- More tests from constructor to inline \ No newline at end of file diff --git a/meeting-notes/11-7 Meeting Notes.txt b/meeting-notes/11-7 Meeting Notes.txt new file mode 100644 index 0000000..9231a28 --- /dev/null +++ b/meeting-notes/11-7 Meeting Notes.txt @@ -0,0 +1,29 @@ +Cases where libraries are not installed: +- Throw bad error message +- Should be the developer's duty to set up their environment properly with the libraries they want to use +- Add to pytest configuration file + +Devices would eventually be moved out of constructor +For differential tests, we would introduce diff_input function + +diff_given(): +- Parameters: +- Specific +-- Passes in differential test inputs +-- Want to eliminate the devices constructor; devices would only support specific inputs +-- Supports any differential input variable to be given +--- Example: testing on devices, written as diff_given(devices, ["cuda", "cpu"]) +--- Reference how given is written right now +-- Would be from the inline test call itself +-- Devices is only accepted first argument in diff_given +--- Should be equivalent to what it was +--- Can try to extend for any first argument +-- Can be used together with given() +--- Some variables shouldn't be given multiple values; not changing the values of multiple inputs +--- Use given() to specific input value for some variables +--- Use diff_given() for inputs that need to be varied +-- Refer to project proposal for format +-- Order can be varied + +Fork original repo, then copy changes over + diff --git a/pyproject.toml b/pyproject.toml index 95f8b65..c59ccc3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,4 +88,4 @@ select = [ ] [tool.ruff.lint.isort] -known-first-party = ["inline"] +known-first-party = ["inline"] \ No newline at end of file diff --git a/src/inline/inline.py b/src/inline/inline.py index 5766542..17acf7b 100644 --- a/src/inline/inline.py +++ b/src/inline/inline.py @@ -10,6 +10,7 @@ def __init__( tag: List = [], disabled: bool = False, timeout: float = -1.0, + devices: List = None, ): """ Initialize Inline object with test name / parametrized flag @@ -20,6 +21,8 @@ def __init__( :param tag: tags to group tests :param disabled: whether the test is disabled :param timeout: seconds to timeout the test, must be a float + :param devices: list of devices to run differential testing on (e.g., ["cpu", "cuda", "mps"]) + if None, differential testing is disabled """ def given(self, variable, value): @@ -32,6 +35,16 @@ def given(self, variable, value): """ return self + def diff_given(self, variable, value): + """ + Set value to a variable for differential testing. + + :param variable: a variable name + :param value: a value that will be assigned to the variable + :returns: Inline object + """ + return self + def check_eq(self, actual_value, expected_value): """ Assert whether two values equal @@ -42,6 +55,19 @@ def check_eq(self, actual_value, expected_value): :raises: AssertionError """ return self + + def diff_test(self, outputs): + """ + Assert whether outputs are consistent across different devices. + This method compares the outputs from different devices specified in the constructor. + + :param outputs: a dictionary mapping device names to their outputs, or a single output value + if a single value is provided, the test will run the computation on all devices + and compare against this reference value + :returns: Inline object + :raises: AssertionError if outputs differ across devices + """ + return self def check_neq(self, actual_value, expected_value): """ diff --git a/src/inline/plugin.py b/src/inline/plugin.py index 11c0774..df98084 100644 --- a/src/inline/plugin.py +++ b/src/inline/plugin.py @@ -159,6 +159,7 @@ def __init__(self): self.check_stmts = [] self.given_stmts = [] self.previous_stmts = [] + self.import_stmts = [] self.prev_stmt_type = PrevStmtType.StmtExpr # the line number of test statement self.lineno = 0 @@ -171,12 +172,25 @@ def __init__(self): self.tag = [] self.disabled = False self.timeout = -1.0 + self.devices = None self.globs = {} + def write_imports(self): + import_str = "" + for n in self.import_stmts: + import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + return import_str + def to_test(self): + prefix = "\n" + + # for n in self.import_stmts: + # import_str += ExtractInlineTest.node_to_source_code(n) + "\n" + + if self.prev_stmt_type == PrevStmtType.CondExpr: if self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] ) @@ -186,11 +200,11 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join(ExtractInlineTest.node_to_source_code(assume_node)) + return prefix.join(ExtractInlineTest.node_to_source_code(assume_node)) else: if self.assume_stmts is None or self.assume_stmts == []: - return "\n".join( + return prefix.join( [ExtractInlineTest.node_to_source_code(n) for n in self.given_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.previous_stmts] + [ExtractInlineTest.node_to_source_code(n) for n in self.check_stmts] @@ -201,7 +215,7 @@ def to_test(self): ) assume_statement = self.assume_stmts[0] assume_node = self.build_assume_node(assume_statement, body_nodes) - return "\n".join([ExtractInlineTest.node_to_source_code(assume_node)]) + return prefix.join([ExtractInlineTest.node_to_source_code(assume_node)]) def build_assume_node(self, assumption_node, body_nodes): return ast.If(assumption_node, body_nodes, []) @@ -251,7 +265,7 @@ class TimeoutException(Exception): ## InlineTest Parser ###################################################################### class InlinetestParser: - def parse(self, obj, globs: None): + def parse(self, obj, globs: None): # obj = open(self.file_path, "r").read(): if isinstance(obj, ModuleType): tree = ast.parse(open(obj.__file__, "r").read()) @@ -286,6 +300,7 @@ class ExtractInlineTest(ast.NodeTransformer): check_not_same = "check_not_same" fail_str = "fail" given_str = "given" + diff_given_str = "diff_given" group_str = "Group" arg_test_name_str = "test_name" arg_parameterized_str = "parameterized" @@ -293,12 +308,20 @@ class ExtractInlineTest(ast.NodeTransformer): arg_tag_str = "tag" arg_disabled_str = "disabled" arg_timeout_str = "timeout" + arg_devices_str = "devices" + diff_test_str = "diff_test" assume = "assume" + + import_str = "import" + from_str = "from" + as_str = "as" + inline_module_imported = False def __init__(self): self.cur_inline_test = InlineTest() self.inline_test_list = [] + self.import_list = [] def is_inline_test_class(self, node): if isinstance(node, ast.Call): @@ -348,206 +371,230 @@ def find_previous_stmt(self, node): return prev_stmt_node return self.find_condition_stmt(prev_stmt_node) - def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call]): + def collect_inline_test_calls(self, node, inline_test_calls: List[ast.Call], import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): """ collect all function calls in the node """ if isinstance(node, ast.Attribute): - self.collect_inline_test_calls(node.value, inline_test_calls) + self.collect_inline_test_calls(node.value, inline_test_calls, import_calls, import_from_calls) elif isinstance(node, ast.Call): inline_test_calls.append(node) - self.collect_inline_test_calls(node.func, inline_test_calls) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + elif isinstance(node, ast.Import): + import_calls.append(node) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + elif isinstance(node, ast.ImportFrom): + import_from_calls.append(node) + self.collect_inline_test_calls(node.func, inline_test_calls, import_calls, import_from_calls) + + def collect_import_calls(self, node, import_calls: List[ast.Import], import_from_calls: List[ast.ImportFrom]): + """ + collect all import calls in the node (should be done first) + """ + + while not isinstance(node, ast.Module) and node.parent != None: + node = node.parent + + if not isinstance(node, ast.Module): + return + + for child in node.children: + if isinstance(child, ast.Import): + import_calls.append(child) + elif isinstance(child, ast.ImportFrom): + import_from_calls.append(child) def parse_constructor(self, node): """ Parse a constructor call. """ - NUM_OF_ARGUMENTS = 6 + + # Argument Order: + # 0) test_name (str) + # 1) parameterized (bool) + # 2) repeated (positive integer) + # 3) tag (str) + # 4) disabled (bool) + # 5) timeout (positive float) + # 6) devices (str array) + + + + keyword_idxs = { + self.arg_test_name_str : 0, + self.arg_parameterized_str : 1, + self.arg_repeated_str : 2, + self.arg_tag_str : 3, + self.arg_disabled_str : 4, + self.arg_timeout_str : 5, + self.arg_devices_str : 6 + } + + NUM_OF_ARGUMENTS = 7 if len(node.args) + len(node.keywords) <= NUM_OF_ARGUMENTS: # positional arguments - if sys.version_info >= (3, 8, 0): - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Constant) and isinstance(arg.value, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.value - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Constant) and isinstance(arg.value, int): + self.parse_constructor_args(node.args) + + #keyword arguments + keyword_args = [] + + #create list with 7 null values (for each position) + for i in range(0, NUM_OF_ARGUMENTS): + keyword_args.append(None) + + for keyword in node.keywords: + keyword_args[keyword_idxs[keyword.arg]] = keyword.value + self.parse_constructor_args(keyword_args) + + + if not self.cur_inline_test.test_name: + # by default, use lineno as test name + self.cur_inline_test.test_name = f"line{node.lineno}" + # set the line number + self.cur_inline_test.lineno = node.lineno + + def parse_constructor_args(self, args): + class ConstrArgs(enum.Enum): + TEST_NAME = 0 + PARAMETERIZED = 1 + REPEATED = 2 + TAG_STR = 3 + DISABLED = 4 + TIMEOUT = 5 + DEVICES = 6 + + property_names = { + ConstrArgs.TEST_NAME : "test_name", + ConstrArgs.PARAMETERIZED : "parameterized", + ConstrArgs.REPEATED : "repeated", + ConstrArgs.TAG_STR : "tag", + ConstrArgs.DISABLED : "disabled", + ConstrArgs.TIMEOUT : "timeout", + ConstrArgs.DEVICES : "devices" + } + + pre_38_val_names = { + ConstrArgs.TEST_NAME : "s", + ConstrArgs.PARAMETERIZED : "value", + ConstrArgs.REPEATED : "n", + ConstrArgs.TAG_STR : "s", + ConstrArgs.DISABLED : "value", + ConstrArgs.TIMEOUT : "n", + ConstrArgs.DEVICES : "" + } + + pre_38_expec_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Str, + ConstrArgs.PARAMETERIZED : ast.NameConstant, + ConstrArgs.REPEATED : ast.Num, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.NameConstant, + ConstrArgs.TIMEOUT : ast.Num, + } + + expected_ast_arg_type = { + ConstrArgs.TEST_NAME : ast.Constant, + ConstrArgs.PARAMETERIZED : ast.Constant, + ConstrArgs.REPEATED : ast.Constant, + ConstrArgs.TAG_STR : ast.List, + ConstrArgs.DISABLED : ast.Constant, + ConstrArgs.TIMEOUT : ast.Constant + } + + expected_ast_val_args = { + ConstrArgs.TEST_NAME : [str], + ConstrArgs.PARAMETERIZED : [bool], + ConstrArgs.REPEATED : [int], + ConstrArgs.TAG_STR : [None], + ConstrArgs.DISABLED : [bool], + ConstrArgs.TIMEOUT : [float, int], + ConstrArgs.DEVICES : [str] + } + + NUM_OF_ARGUMENTS = 7 + + # Arguments organized by expected ast type, value type, and index in that order + for index, arg in enumerate(args): + # Skips over null arguments; needed for keywords + if arg == None: + continue + + # Devices are not referenced in versions before 3.8; all other arguments can be from any version + if index == ConstrArgs.DEVICES and isinstance(arg, ast.List): + devices = [] + for elt in arg.elts: + if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + raise MalformedException("devices can only be List of string") + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + self.cur_inline_test.devices = devices + # Assumes version is past 3.8, no explicit references to ast.Constant before 3.8 + else: + corr_arg_type = False + corr_val_type = False + value_prop_name = "" + arg_idx = ConstrArgs(index) + + if sys.version_info >= (3, 8, 0) and isinstance(arg, expected_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = "value" + elif sys.version_info < (3, 8, 0) and isinstance(arg, pre_38_expec_ast_arg_type[arg_idx]): + corr_arg_type = True + value_prop_name = pre_38_val_names[arg_idx] + + # Verifies value types; skipped for ast node types with no nested values + for arg_type in expected_ast_val_args[arg_idx]: + if arg_type == None: + corr_val_type = True + break + if isinstance(arg.value, arg_type): + corr_val_type = True + break + + if corr_val_type and corr_arg_type: + # Accounts for additional checks for REPEATED and TAG_STR arguments + if arg_idx == ConstrArgs.REPEATED: if arg.value <= 0: raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.value - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) - self.cur_inline_test.tag = tags - elif index == 4 and isinstance(arg, ast.Constant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - elif ( - index == 5 - and isinstance(arg, ast.Constant) - and (isinstance(arg.value, float) or isinstance(arg.value, int)) - ): - self.cur_inline_test.timeout = arg.value - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, str) - ): - self.cur_inline_test.test_name = keyword.value.value - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, int) - ): - if keyword.value.value <= 0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.value - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): + self.cur_inline_test.repeated = getattr(arg, value_prop_name) + elif arg_idx == ConstrArgs.TAG_STR: tags = [] - for elt in keyword.value.elts: + for elt in arg.elts: if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): raise MalformedException(f"tag can only be List of string") - tags.append(elt.value) + tags.append(getattr(elt, value_prop_name)) self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.Constant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Constant) - and (isinstance(keyword.value.value, float) or isinstance(keyword.value.value, int)) - ): - if keyword.value.value <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.value + # For non-special cases, set the attribute defined by the dictionary else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - for index, arg in enumerate(node.args): - # check if "test_name" is a string - if index == 0 and isinstance(arg, ast.Str) and isinstance(arg.s, str): - # get the test name if exists - self.cur_inline_test.test_name = arg.s - # check if "parameterized" is a boolean - elif index == 1 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.parameterized = arg.value - # check if "repeated" is a positive integer - elif index == 2 and isinstance(arg, ast.Num) and isinstance(arg.n, int): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = arg.n - # check if "tag" is a list of string - elif index == 3 and isinstance(arg.value, ast.List): - tags = [] - for elt in arg.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif index == 4 and isinstance(arg, ast.NameConstant) and isinstance(arg.value, bool): - self.cur_inline_test.disabled = arg.value - # check if "timeout" is a positive int - elif ( - index == 5 and isinstance(arg, ast.Num) and (isinstance(arg.n, float) or isinstance(arg.n, int)) - ): - if arg.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = arg.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive intege, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - # keyword arguments - for keyword in node.keywords: - # check if "test_name" is a string - if ( - keyword.arg == self.arg_test_name_str - and isinstance(keyword.value, ast.Str) - and isinstance(keyword.value.s, str) - ): - self.cur_inline_test.test_name = keyword.value.s - # check if "parameterized" is a boolean - elif ( - keyword.arg == self.arg_parameterized_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.parameterized = keyword.value.value - # check if "repeated" is a positive integer - elif ( - keyword.arg == self.arg_repeated_str - and isinstance(keyword.value, ast.Num) - and isinstance(keyword.value.n, int) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") - self.cur_inline_test.repeated = keyword.value.n - # check if "tag" is a list of string - elif keyword.arg == self.arg_tag_str and isinstance(keyword.value, ast.List): - tags = [] - for elt in keyword.value.elts: - if not (isinstance(elt, ast.Str) and isinstance(elt.s, str)): - raise MalformedException(f"tag can only be List of string") - tags.append(elt.s) - self.cur_inline_test.tag = tags - # check if "disabled" is a boolean - elif ( - keyword.arg == self.arg_disabled_str - and isinstance(keyword.value, ast.NameConstant) - and isinstance(keyword.value.value, bool) - ): - self.cur_inline_test.disabled = keyword.value.value - # check if "timeout" is a positive float - elif ( - keyword.arg == self.arg_timeout_str - and isinstance(keyword.value, ast.Num) - and (isinstance(keyword.value.n, float) or isinstance(keyword.value.n, int)) - ): - if keyword.value.n <= 0.0: - raise MalformedException(f"inline test: {self.arg_timeout_str} must be greater than 0") - self.cur_inline_test.timeout = keyword.value.n - else: - raise MalformedException( - f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" - ) - else: - raise MalformedException(f"inline test: invalid {self.class_name_str}(), expected at most 3 args") - - if not self.cur_inline_test.test_name: - # by default, use lineno as test name - self.cur_inline_test.test_name = f"line{node.lineno}" - # set the line number - self.cur_inline_test.lineno = node.lineno + setattr(self.cur_inline_test, + property_names[arg_idx], + getattr(arg, value_prop_name)) + + + + # match arg_idx: + # case ConstrArgs.REPEATED: + # if arg.value <= 0: + # raise MalformedException(f"inline test: {self.arg_repeated_str} must be greater than 0") + # self.cur_inline_test.repeated = getattr(arg, value_prop_name) + # case ConstrArgs.TAG_STR: + # tags = [] + # for elt in arg.elts: + # if not (isinstance(elt, ast.Constant) and isinstance(elt.value, str)): + # raise MalformedException(f"tag can only be List of string") + # tags.append(getattr(elt, value_prop_name)) + # self.cur_inline_test.tag = tags + # # For non-special cases, set the attribute defined by the dictionary + # case _: + # setattr(self.cur_inline_test, + # property_names[arg_idx], + # getattr(arg, value_prop_name)) + else: + raise MalformedException( + f"inline test: {self.class_name_str}() accepts {NUM_OF_ARGUMENTS} arguments. 'test_name' must be a string constant, 'parameterized' must be a boolean constant, 'repeated' must be a positive integer, 'tag' must be a list of string, 'timeout' must be a positive float" + ) + #raise MalformedException("Argument " + str(index) + " incorrectly formatted. Argument should be a " + ConstrArgs.expected_ast_val_args[index].type()) def parameterized_inline_tests_init(self, node: ast.List): if not self.cur_inline_test.parameterized_inline_tests: @@ -568,6 +615,23 @@ def parse_given(self, node): else: raise MalformedException("inline test: invalid given(), expected 2 args") + def parse_diff_given(self, node): + PROPERTY = 0 + VALUES = 1 + + if len(node.args) == 2: + if self.cur_inline_test.parameterized: + raise MalformedException("inline test: diff_given() does not currently support parameterized inline tests.") + else: + devices = [] + for elt in node.args[VALUES].elts: + if elt.value not in {"cpu", "cuda", "mps"}: + raise MalformedException(f"Invalid device: {elt.value}. Must be one of ['cpu', 'cuda', 'mps']") + devices.append(elt.value) + setattr(self.cur_inline_test, node.args[PROPERTY].id, devices) + else: + raise MalformedException("inline test: invalid diff_given(), expected 2 args") + def parse_assume(self, node): if len(node.args) == 1: if self.cur_inline_test.parameterized: @@ -885,6 +949,240 @@ def parse_check_not_same(self, node): self.cur_inline_test.check_stmts.append(assert_node) else: raise MalformedException("inline test: invalid check_not_same(), expected 2 args") + + def parse_diff_test(self, node): + if not self.cur_inline_test.devices: + raise MalformedException("diff_test can only be used with the 'devices' parameter.") + + if len(node.args) != 1: + raise MalformedException("diff_test() requires exactly 1 argument.") + + output_node = self.parse_group(node.args[0]) + + # Get the original operation + original_op = None + for stmt in self.cur_inline_test.previous_stmts: + if isinstance(stmt, ast.Assign) and stmt.targets[0].id == output_node.id: + original_op = stmt.value + break + + if not original_op: + raise MalformedException("Could not find original operation for diff_test") + + # Create our new statements + new_statements = [] + device_outputs = [] + + # Import necessary modules for seed setting - Always add these + # Import random module + import_random = ast.ImportFrom( + module='random', + names=[ast.alias(name='seed', asname=None)], + level=0 + ) + new_statements.append(import_random) + + # Import numpy.random + import_np = ast.ImportFrom( + module='numpy', + names=[ast.alias(name='random', asname='np_random')], + level=0 + ) + new_statements.append(import_np) + + # Create seed function - Always add this + seed_func_def = ast.FunctionDef( + name='set_random_seed', + args=ast.arguments( + posonlyargs=[], + args=[ast.arg(arg='seed_value', annotation=None)], + kwonlyargs=[], + kw_defaults=[], + defaults=[] + ), + body=[ + ast.Expr( + value=ast.Call( + func=ast.Name(id='seed', ctx=ast.Load()), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='torch', ctx=ast.Load()), + attr='manual_seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ), + ast.Expr( + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id='np_random', ctx=ast.Load()), + attr='seed' + ), + args=[ast.Name(id='seed_value', ctx=ast.Load())], + keywords=[] + ) + ) + ], + decorator_list=[], + returns=None + ) + new_statements.append(seed_func_def) + + # Process input tensors + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + ref_var = f"{input_var}_ref" + + # Always clone inputs for in-place operations + new_statements.append( + ast.Assign( + targets=[ast.Name(id=ref_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=given_stmt.value, + attr="clone" + ), + args=[], + keywords=[] + ) + ) + ) + + # Create device-specific versions + for device in self.cur_inline_test.devices: + device_var = f"{input_var}_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_var, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=ref_var, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value=device)], + keywords=[] + ) + ) + ) + + # Create device-specific operations + device_input_map = {device: {} for device in self.cur_inline_test.devices} + for device in self.cur_inline_test.devices: + for given_stmt in self.cur_inline_test.given_stmts: + input_var = given_stmt.targets[0].id + device_input_map[device][input_var] = f"{input_var}_{device}" + + # Always set seed before each device operation - no condition check + new_statements.append( + ast.Expr( + value=ast.Call( + func=ast.Name(id='set_random_seed', ctx=ast.Load()), + args=[ast.Constant(value=42)], # Use constant seed 42 + keywords=[] + ) + ) + ) + + device_op = copy.deepcopy(original_op) + + # Replace input references + class ReplaceInputs(ast.NodeTransformer): + def visit_Name(self, node): + if node.id in device_input_map[device]: + return ast.Name(id=device_input_map[device][node.id], ctx=node.ctx) + return node + + device_op = ReplaceInputs().visit(device_op) + device_output = f"output_{device}" + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=device_output, ctx=ast.Store())], + value=device_op + ) + ) + device_outputs.append(device_output) + + # Standard comparison method for all operations - no condition check + comparisons = [] + for i in range(len(device_outputs) - 1): + dev1 = device_outputs[i] + dev2 = device_outputs[i + 1] + + dev1_cpu = f"{dev1}_cpu" + dev2_cpu = f"{dev2}_cpu" + + # Move outputs back to CPU for comparison + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev1_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + new_statements.append( + ast.Assign( + targets=[ast.Name(id=dev2_cpu, ctx=ast.Store())], + value=ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev2, ctx=ast.Load()), + attr="to" + ), + args=[ast.Constant(value="cpu")], + keywords=[] + ) + ) + ) + + # Standard allclose comparison + comparison = self.build_assert_eq( + ast.Call( + func=ast.Attribute( + value=ast.Name(id=dev1_cpu, ctx=ast.Load()), + attr="allclose" + ), + args=[ + ast.Name(id=dev2_cpu, ctx=ast.Load()) + ], + keywords=[ + ast.keyword(arg="rtol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="atol", value=ast.Constant(value=1e-4)), + ast.keyword(arg="equal_nan", value=ast.Constant(value=True)) + ] + ), + ast.Constant(value=True) + ) + comparisons.append(comparison) + + # Replace statements + self.cur_inline_test.previous_stmts = new_statements + self.cur_inline_test.check_stmts = comparisons + + def parse_import(self, node): + # TODO: Differentiate between import, from import, and import alias + import_node = ast.Import( + names=[ + ast.alias(name=node) + ] + ) + return import_node + + def parse_import_from(self, node): + pass def build_fail(self): equal_node = ast.Compare( @@ -925,6 +1223,7 @@ def parse_group(self, node): return stmt else: return node + def parse_parameterized_test(self): for index, parameterized_test in enumerate(self.cur_inline_test.parameterized_inline_tests): @@ -934,8 +1233,13 @@ def parse_parameterized_test(self): parameterized_test.test_name = self.cur_inline_test.test_name + "_" + str(index) def parse_inline_test(self, node): - inline_test_calls = [] - self.collect_inline_test_calls(node, inline_test_calls) + import_calls = [] + import_from_calls = [] + inline_test_calls = [] + + self.collect_inline_test_calls(node, inline_test_calls, import_calls, import_from_calls) + self.collect_import_calls(node, import_calls, import_from_calls) + inline_test_calls.reverse() if len(inline_test_calls) <= 1: @@ -956,14 +1260,33 @@ def parse_inline_test(self, node): self.parse_assume(call) inline_test_call_index += 1 - # "given(a, 1)" for call in inline_test_calls[inline_test_call_index:]: - if isinstance(call.func, ast.Attribute) and call.func.attr == self.given_str: - self.parse_given(call) - inline_test_call_index += 1 + if isinstance(call.func, ast.Attribute): + if call.func.attr == self.given_str: + self.parse_given(call) + inline_test_call_index += 1 + elif call.func.attr == self.diff_given_str: + self.parse_diff_given(call) + inline_test_call_index += 1 + + # match call.func.attr: + # # "given(a, 1)" + # case self.given_str: + # self.parse_given(call) + # inline_test_call_index += 1 + # # "diff_given(devices, ["cpu", "cuda"])" + # case self.diff_given_str: + # self.parse_diff_given(call) + # inline_test_call_index += 1 else: break + for import_stmt in import_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + for import_stmt in import_from_calls: + self.cur_inline_test.import_stmts.append(import_stmt) + + # "check_eq" or "check_true" or "check_false" or "check_neq" for call in inline_test_calls[inline_test_call_index:]: # "check_eq(a, 1)" @@ -986,11 +1309,13 @@ def parse_inline_test(self, node): self.parse_check_same(call) elif call.func.attr == self.check_not_same: self.parse_check_not_same(call) + elif call.func.attr == self.diff_test_str: + self.parse_diff_test(call) elif call.func.attr == self.fail_str: self.parse_fail(call) elif call.func.attr == self.given_str: raise MalformedException( - f"inline test: given() must be called before check_eq()/check_true()/check_false()" + f"inline test: given() must be called before check_eq()/check_true()/check_false()/diff_test()" ) else: raise MalformedException(f"inline test: invalid function call {self.node_to_source_code(call.func)}") @@ -1024,6 +1349,7 @@ def node_to_source_code(node): ## InlineTest Finder ###################################################################### class InlineTestFinder: + # Finder should NOT store any global variables def __init__(self, parser=InlinetestParser(), recurse=True, exclude_empty=True): self._parser = parser self._recurse = recurse @@ -1068,7 +1394,14 @@ def _is_routine(self, obj): pass return inspect.isroutine(maybe_routine) - def find(self, obj, module=None, globs=None, extraglobs=None): + # def find_imports(self, obj, module=None): + # if module is False: + # module = None + # elif module is None: + # module = inspect.getmodule(obj) + + + def find(self, obj, module=None, globs=None, extraglobs=None, imports=None): # Find the module that contains the given object (if obj is # a module, then module=obj.). if module is False: @@ -1089,15 +1422,23 @@ def find(self, obj, module=None, globs=None, extraglobs=None): if "__name__" not in globs: globs["__name__"] = "__main__" # provide a default module name + # Find intersection between loaded modules and module imports + # if imports is None: + # imports = set(sys.modules) & set(globs) + # else: + # imports = imports.copy() + # Recursively explore `obj`, extracting InlineTests. tests = [] - self._find(tests, obj, module, globs, {}) + self._find(tests, obj, module, globs, imports, {}) return tests - def _find(self, tests, obj, module, globs, seen): + def _find(self, tests, obj, module, globs, imports, seen): if id(obj) in seen: return seen[id(obj)] = 1 + + # Find a test for this object, and add it to the list of tests. test = self._parser.parse(obj, globs) if test is not None: @@ -1109,7 +1450,7 @@ def _find(self, tests, obj, module, globs, seen): # Recurse to functions & classes. if (self._is_routine(val) or inspect.isclass(val)) and self._from_module(module, val): - self._find(tests, val, module, globs, seen) + self._find(tests, val, module, globs, imports, seen) # Look for tests in a class's contained objects. if inspect.isclass(obj) and self._recurse: @@ -1123,7 +1464,7 @@ def _find(self, tests, obj, module, globs, seen): module, val ): valname = "%s" % (valname) - self._find(tests, val, module, globs, seen) + self._find(tests, val, module, globs, imports, seen) ###################################################################### @@ -1131,7 +1472,10 @@ def _find(self, tests, obj, module, globs, seen): ###################################################################### class InlineTestRunner: def run(self, test: InlineTest, out: List) -> None: - tree = ast.parse(test.to_test()) + test_str = test.write_imports() + test_str += test.to_test() + print(test_str) + tree = ast.parse(test_str) codeobj = compile(tree, filename="", mode="exec") start_time = time.time() if test.timeout > 0: @@ -1268,6 +1612,10 @@ def collect(self) -> Iterable[InlinetestItem]: group_tags = self.config.getoption("inlinetest_group", default=None) order_tags = self.config.getoption("inlinetest_order", default=None) + # TODO: import all modules through the finder first before extracting inline tests + # - Create ast for all imports + # - If a function references an import, then include the imported library reference in the ast + for test_list in finder.find(module): # reorder the list if there are tests to be ordered ordered_list = InlinetestModule.order_tests(test_list, order_tags) diff --git a/tests/test_plugin.py b/tests/test_plugin.py index 40c3096..8fe88a6 100644 --- a/tests/test_plugin.py +++ b/tests/test_plugin.py @@ -2,9 +2,65 @@ from _pytest.pytester import Pytester import pytest +# For testing in Spyder only +if __name__ == "__main__": + pytest.main(['-v', '-s']) + # pytest -p pytester class TestInlinetests: + def test_inline_diff_given(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + + def m(a): + a = a + 1 + itest().diff_given(devices, ["cpu", "cuda"]).given(a, 1).check_eq(a, 2) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + + def test_inline_detects_imports(self, pytester: Pytester): + checkfile = pytester.makepyfile( + """ + from inline import itest + import datetime + + def m(a): + b = a + datetime.timedelta(days=365) + itest().given(a, datetime.timedelta(days=1)).check_eq(b, datetime.timedelta(days=366)) + """ + ) + for x in (pytester.path, checkfile): + items, reprec = pytester.inline_genitems(x) + assert len(items) == 1 + res = pytester.runpytest() + assert res.ret != 1 + + # def test_inline_detects_from_imports(self, pytester: Pytester): + # checkfile = pytester.makepyfile( + # """ + # from inline import itest + # import numpy as np + # from scipy import stats as st + + # def m(n, p): + # b = st.binom(n, p) + # itest().given(n, 100).given(p, 0.5).check_eq(b.mean(), n * p) + # """ + # ) + # for x in (pytester.path, checkfile): + # items, reprec = pytester.inline_genitems(x) + # assert len(items) == 1 + # res = pytester.runpytest() + # assert res.ret == 0 + def test_inline_parser(self, pytester: Pytester): checkfile = pytester.makepyfile( """ @@ -31,6 +87,7 @@ def m(a): items, reprec = pytester.inline_genitems(x) assert len(items) == 0 + def test_inline_malformed_given(self, pytester: Pytester): checkfile = pytester.makepyfile( """ @@ -118,21 +175,6 @@ def m(a): res = pytester.runpytest() assert res.ret == 0 - def test_check_eq_parameterized_tests(self, pytester: Pytester): - checkfile = pytester.makepyfile( - """ - from inline import itest - def m(a): - a = a + 1 - itest(parameterized=True).given(a, [2, 3]).check_eq(a, [3, 4]) - """ - ) - for x in (pytester.path, checkfile): - items, reprec = pytester.inline_genitems(x) - assert len(items) == 2 - res = pytester.runpytest() - assert res.ret == 0 - def test_malformed_check_eq_parameterized_tests(self, pytester: Pytester): checkfile = pytester.makepyfile( """