diff --git a/src/qcodes/utils/function_helpers.py b/src/qcodes/utils/function_helpers.py index 1fca19de39ed..d754646a7ef6 100644 --- a/src/qcodes/utils/function_helpers.py +++ b/src/qcodes/utils/function_helpers.py @@ -1,5 +1,5 @@ from asyncio import iscoroutinefunction -from inspect import signature +from inspect import CO_VARARGS, signature def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: @@ -29,6 +29,26 @@ def is_function(f: object, arg_count: int, coroutine: bool = False) -> bool: # otherwise the user should make an explicit function. return arg_count == 1 + if func_code := getattr(f, "__code__", None): + # handle objects like functools.partial(f, ...) + func_defaults = getattr(f, "__defaults__", None) + number_of_defaults = len(func_defaults) if func_defaults is not None else 0 + + if getattr(f, "__self__", None) is not None: + # bound method + min_positional = func_code.co_argcount - 1 - number_of_defaults + max_positional = func_code.co_argcount - 1 + else: + min_positional = func_code.co_argcount - number_of_defaults + max_positional = func_code.co_argcount + + if func_code.co_flags & CO_VARARGS: + # we have *args + max_positional = 10e10 + + ev = min_positional <= arg_count <= max_positional + return ev + try: sig = signature(f) except ValueError: diff --git a/tests/utils/test_isfunction.py b/tests/utils/test_isfunction.py index fc2a3dc368be..8da545c8db79 100644 --- a/tests/utils/test_isfunction.py +++ b/tests/utils/test_isfunction.py @@ -1,3 +1,4 @@ +from functools import partial from typing import NoReturn import pytest @@ -36,6 +37,32 @@ def f2(a: object, b: object) -> NoReturn: is_function(f0, -1) +def test_function_partial() -> None: + def f0(one_arg: int) -> int: + return one_arg + + f = partial(f0, 1) + assert is_function(f, 0) + assert not is_function(f, 1) + + +def test_function_varargs() -> None: + def f(*args) -> None: + return None + + assert is_function(f, 0) + assert is_function(f, 1) + assert is_function(f, 100) + + def g(a, b=1, *args) -> None: + return None + + assert not is_function(g, 0) + assert is_function(g, 1) + assert is_function(g, 2) + assert is_function(g, 100) + + class AClass: def method_a(self) -> NoReturn: raise RuntimeError("function should not get called")