Skip to content

Commit 66b6974

Browse files
committed
fixed pandas backports tests
1 parent 25128e3 commit 66b6974

File tree

1 file changed

+26
-4
lines changed

1 file changed

+26
-4
lines changed

tests/unit/test_pandas_backports.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,46 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import builtins
1516
import unittest.mock as mock
1617

1718
import db_dtypes.pandas_backports as pandas_backports
1819

20+
REAL_IMPORT = builtins.__import__
21+
22+
23+
def _import_side_effect(module_name, result_return=None, result_raise=None):
24+
"""
25+
Builds a side-effect for mocking the import function.
26+
If the imported package matches `name`, it will return or raise based on
27+
arguments. Otherwise, it will default to regular import behaviour
28+
"""
29+
def _impl(name, *args, **kwargs):
30+
if name == module_name:
31+
if result_raise:
32+
raise result_raise
33+
else:
34+
return result_return
35+
else:
36+
return REAL_IMPORT(name, *args, **kwargs)
37+
return _impl
38+
1939

2040
@mock.patch("builtins.__import__")
2141
def test_import_default_module_found(mock_import):
2242
mock_module = mock.MagicMock()
23-
mock_module.OpsMixin = "OpsMixin_from_module" # Simulate successful import
24-
mock_import.return_value = mock_module
43+
mock_module.OpsMixin = "OpsMixin_from_module"
44+
45+
mock_import.side_effect = _import_side_effect("module_name", mock_module)
2546

2647
default_class = type("OpsMixin", (), {}) # Dummy class
2748
result = pandas_backports.import_default("module_name", default=default_class)
2849
assert result == "OpsMixin_from_module"
2950

3051

3152
@mock.patch("builtins.__import__")
32-
def test_import_default_module_not_found(mock_import):
33-
mock_import.side_effect = ModuleNotFoundError
53+
def test_import_default_module_not_foundX(mock_import):
54+
mock_import.side_effect = _import_side_effect("module_name", result_raise=ModuleNotFoundError)
3455

3556
default_class = type("OpsMixin", (), {}) # Dummy class
3657
result = pandas_backports.import_default("module_name", default=default_class)
@@ -48,6 +69,7 @@ def test_import_default_force_true(mock_import):
4869
result = pandas_backports.import_default(
4970
"any_module_name", force=True, default=default_class
5071
)
72+
assert mock_import.call_count == 0
5173

5274
# Assert that the returned value is the default class itself
5375
assert result is default_class

0 commit comments

Comments
 (0)