Skip to content

Commit b77797c

Browse files
authored
ENH: Index[complex] (#45256)
1 parent e681fcd commit b77797c

File tree

16 files changed

+99
-36
lines changed

16 files changed

+99
-36
lines changed

pandas/_libs/index.pyi

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class IndexEngine:
2929

3030
class Float64Engine(IndexEngine): ...
3131
class Float32Engine(IndexEngine): ...
32+
class Complex128Engine(IndexEngine): ...
33+
class Complex64Engine(IndexEngine): ...
3234
class Int64Engine(IndexEngine): ...
3335
class Int32Engine(IndexEngine): ...
3436
class Int16Engine(IndexEngine): ...

pandas/_libs/index_class_helper.pxi.in

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dtypes = [('Float64', 'float64'),
2121
('UInt32', 'uint32'),
2222
('UInt16', 'uint16'),
2323
('UInt8', 'uint8'),
24+
('Complex64', 'complex64'),
25+
('Complex128', 'complex128'),
2426
]
2527
}}
2628

@@ -33,18 +35,25 @@ cdef class {{name}}Engine(IndexEngine):
3335
return _hash.{{name}}HashTable(n)
3436

3537
cdef _check_type(self, object val):
36-
{{if name not in {'Float64', 'Float32'} }}
38+
{{if name not in {'Float64', 'Float32', 'Complex64', 'Complex128'} }}
3739
if not util.is_integer_object(val):
3840
raise KeyError(val)
3941
{{if name.startswith("U")}}
4042
if val < 0:
4143
# cannot have negative values with unsigned int dtype
4244
raise KeyError(val)
4345
{{endif}}
44-
{{else}}
46+
{{elif name not in {'Complex64', 'Complex128'} }}
4547
if not util.is_integer_object(val) and not util.is_float_object(val):
4648
# in particular catch bool and avoid casting True -> 1.0
4749
raise KeyError(val)
50+
{{else}}
51+
if (not util.is_integer_object(val)
52+
and not util.is_float_object(val)
53+
and not util.is_complex_object(val)
54+
):
55+
# in particular catch bool and avoid casting True -> 1.0
56+
raise KeyError(val)
4857
{{endif}}
4958

5059

pandas/conftest.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -539,6 +539,8 @@ def _create_mi_with_dt64tz_level():
539539
"uint": tm.makeUIntIndex(100),
540540
"range": tm.makeRangeIndex(100),
541541
"float": tm.makeFloatIndex(100),
542+
"complex64": tm.makeFloatIndex(100).astype("complex64"),
543+
"complex128": tm.makeFloatIndex(100).astype("complex128"),
542544
"num_int64": tm.makeNumericIndex(100, dtype="int64"),
543545
"num_int32": tm.makeNumericIndex(100, dtype="int32"),
544546
"num_int16": tm.makeNumericIndex(100, dtype="int16"),

pandas/core/indexes/base.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,6 +487,8 @@ def __new__(
487487
if data.dtype.kind in ["i", "u", "f"]:
488488
# maybe coerce to a sub-class
489489
arr = data
490+
elif data.dtype.kind in ["c"]:
491+
arr = np.asarray(data)
490492
else:
491493
arr = com.asarray_tuplesafe(data, dtype=_dtype_obj)
492494

@@ -614,7 +616,9 @@ def _dtype_to_subclass(cls, dtype: DtypeObj):
614616
# NB: assuming away MultiIndex
615617
return Index
616618

617-
elif issubclass(dtype.type, (str, bool, np.bool_)):
619+
elif issubclass(
620+
dtype.type, (str, bool, np.bool_, complex, np.complex64, np.complex128)
621+
):
618622
return Index
619623

620624
raise NotImplementedError(dtype)
@@ -858,6 +862,11 @@ def _engine(
858862
# TODO(ExtensionIndex): use libindex.ExtensionEngine(self._values)
859863
return libindex.ObjectEngine(self._get_engine_target())
860864

865+
elif self.values.dtype == np.complex64:
866+
return libindex.Complex64Engine(self._get_engine_target())
867+
elif self.values.dtype == np.complex128:
868+
return libindex.Complex128Engine(self._get_engine_target())
869+
861870
# to avoid a reference cycle, bind `target_values` to a local variable, so
862871
# `self` is not passed into the lambda.
863872
target_values = self._get_engine_target()
@@ -5980,8 +5989,6 @@ def _find_common_type_compat(self, target) -> DtypeObj:
59805989
# FIXME: find_common_type incorrect with Categorical GH#38240
59815990
# FIXME: some cases where float64 cast can be lossy?
59825991
dtype = np.dtype(np.float64)
5983-
if dtype.kind == "c":
5984-
dtype = _dtype_obj
59855992
return dtype
59865993

59875994
@final
@@ -7120,7 +7127,7 @@ def _maybe_cast_data_without_dtype(
71207127
FutureWarning,
71217128
stacklevel=3,
71227129
)
7123-
if result.dtype.kind in ["b", "c"]:
7130+
if result.dtype.kind in ["b"]:
71247131
return subarr
71257132
result = ensure_wrapped_if_datetimelike(result)
71267133
return result

pandas/core/indexes/numeric.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,8 @@ def _can_hold_na(self) -> bool: # type: ignore[override]
114114
np.dtype(np.uint64): libindex.UInt64Engine,
115115
np.dtype(np.float32): libindex.Float32Engine,
116116
np.dtype(np.float64): libindex.Float64Engine,
117+
np.dtype(np.complex64): libindex.Complex64Engine,
118+
np.dtype(np.complex128): libindex.Complex128Engine,
117119
}
118120

119121
@property
@@ -128,6 +130,7 @@ def inferred_type(self) -> str:
128130
"i": "integer",
129131
"u": "integer",
130132
"f": "floating",
133+
"c": "complex",
131134
}[self.dtype.kind]
132135

133136
def __new__(cls, data=None, dtype: Dtype | None = None, copy=False, name=None):

pandas/tests/arrays/categorical/test_constructors.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,6 @@ def test_construction_with_ordered(self, ordered):
676676
cat = Categorical([0, 1, 2], ordered=ordered)
677677
assert cat.ordered == bool(ordered)
678678

679-
@pytest.mark.xfail(reason="Imaginary values not supported in Categorical")
680679
def test_constructor_imaginary(self):
681680
values = [1, 2, 3 + 1j]
682681
c1 = Categorical(values)

pandas/tests/base/test_misc.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,14 +137,19 @@ def test_memory_usage_components_narrow_series(dtype):
137137
assert total_usage == non_index_usage + index_usage
138138

139139

140-
def test_searchsorted(index_or_series_obj):
140+
def test_searchsorted(index_or_series_obj, request):
141141
# numpy.searchsorted calls obj.searchsorted under the hood.
142142
# See gh-12238
143143
obj = index_or_series_obj
144144

145145
if isinstance(obj, pd.MultiIndex):
146146
# See gh-14833
147147
pytest.skip("np.searchsorted doesn't work on pd.MultiIndex")
148+
if obj.dtype.kind == "c" and isinstance(obj, Index):
149+
# TODO: Should Series cases also raise? Looks like they use numpy
150+
# comparison semantics https://github.com/numpy/numpy/issues/15981
151+
mark = pytest.mark.xfail(reason="complex objects are not comparable")
152+
request.node.add_marker(mark)
148153

149154
max_obj = max(obj, default=0)
150155
index = np.searchsorted(obj, max_obj)

pandas/tests/groupby/test_groupby.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1054,15 +1054,14 @@ def test_groupby_complex_numbers():
10541054
)
10551055
expected = DataFrame(
10561056
np.array([1, 1, 1], dtype=np.int64),
1057-
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], dtype="object", name="b"),
1057+
index=Index([(1 + 1j), (1 + 2j), (1 + 0j)], name="b"),
10581058
columns=Index(["a"], dtype="object"),
10591059
)
10601060
result = df.groupby("b", sort=False).count()
10611061
tm.assert_frame_equal(result, expected)
10621062

10631063
# Sorted by the magnitude of the complex numbers
1064-
# Complex Index dtype is cast to object
1065-
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], dtype="object", name="b")
1064+
expected.index = Index([(1 + 0j), (1 + 1j), (1 + 2j)], name="b")
10661065
result = df.groupby("b", sort=True).count()
10671066
tm.assert_frame_equal(result, expected)
10681067

pandas/tests/indexes/multi/test_setops.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -525,11 +525,17 @@ def test_union_nan_got_duplicated():
525525
tm.assert_index_equal(result, mi2)
526526

527527

528-
def test_union_duplicates(index):
528+
def test_union_duplicates(index, request):
529529
# GH#38977
530530
if index.empty or isinstance(index, (IntervalIndex, CategoricalIndex)):
531531
# No duplicates in empty indexes
532532
return
533+
if index.dtype.kind == "c":
534+
mark = pytest.mark.xfail(
535+
reason="sort_values() call raises bc complex objects are not comparable"
536+
)
537+
request.node.add_marker(mark)
538+
533539
values = index.unique().values.tolist()
534540
mi1 = MultiIndex.from_arrays([values, [1] * len(values)])
535541
mi2 = MultiIndex.from_arrays([[values[0]] + values, [1] * (len(values) + 1)])

pandas/tests/indexes/test_any_index.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,14 @@ def test_mutability(index):
4646
index[0] = index[0]
4747

4848

49-
def test_map_identity_mapping(index):
49+
def test_map_identity_mapping(index, request):
5050
# GH#12766
51+
if index.dtype == np.complex64:
52+
mark = pytest.mark.xfail(
53+
reason="maybe_downcast_to_dtype doesn't handle complex"
54+
)
55+
request.node.add_marker(mark)
56+
5157
result = index.map(lambda x: x)
5258
tm.assert_index_equal(result, index, exact="equiv")
5359

0 commit comments

Comments
 (0)