Skip to content

Commit ceedfa9

Browse files
authored
REF: Share Block.setitem (#45348)
1 parent 7abddcd commit ceedfa9

File tree

1 file changed

+91
-65
lines changed

1 file changed

+91
-65
lines changed

pandas/core/internals/blocks.py

Lines changed: 91 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -885,6 +885,12 @@ def _maybe_squeeze_arg(self, arg: np.ndarray) -> np.ndarray:
885885
"""
886886
return arg
887887

888+
def _unwrap_setitem_indexer(self, indexer):
889+
"""
890+
For compatibility with 1D-only ExtensionArrays.
891+
"""
892+
return indexer
893+
888894
def setitem(self, indexer, value):
889895
"""
890896
Attempt self.values[indexer] = value, possibly creating a new array.
@@ -1357,6 +1363,45 @@ class EABackedBlock(Block):
13571363

13581364
values: ExtensionArray
13591365

1366+
def setitem(self, indexer, value):
1367+
"""
1368+
Attempt self.values[indexer] = value, possibly creating a new array.
1369+
1370+
This differs from Block.setitem by not allowing setitem to change
1371+
the dtype of the Block.
1372+
1373+
Parameters
1374+
----------
1375+
indexer : tuple, list-like, array-like, slice, int
1376+
The subset of self.values to set
1377+
value : object
1378+
The value being set
1379+
1380+
Returns
1381+
-------
1382+
Block
1383+
1384+
Notes
1385+
-----
1386+
`indexer` is a direct slice/positional indexer. `value` must
1387+
be a compatible shape.
1388+
"""
1389+
if not self._can_hold_element(value):
1390+
# see TestSetitemFloatIntervalWithIntIntervalValues
1391+
nb = self.coerce_to_target_dtype(value)
1392+
return nb.setitem(indexer, value)
1393+
1394+
indexer = self._unwrap_setitem_indexer(indexer)
1395+
value = self._maybe_squeeze_arg(value)
1396+
1397+
values = self.values
1398+
if values.ndim == 2:
1399+
# TODO: string[pyarrow] tests break if we transpose unconditionally
1400+
values = values.T
1401+
check_setitem_lengths(indexer, value, values)
1402+
values[indexer] = value
1403+
return self
1404+
13601405
def where(self, other, cond) -> list[Block]:
13611406
arr = self.values.T
13621407

@@ -1556,75 +1601,68 @@ def _maybe_squeeze_arg(self, arg):
15561601
If necessary, squeeze a (N, 1) ndarray to (N,)
15571602
"""
15581603
# e.g. if we are passed a 2D mask for putmask
1559-
if isinstance(arg, np.ndarray) and arg.ndim == self.values.ndim + 1:
1604+
if (
1605+
isinstance(arg, (np.ndarray, ExtensionArray))
1606+
and arg.ndim == self.values.ndim + 1
1607+
):
15601608
# TODO(EA2D): unnecessary with 2D EAs
15611609
assert arg.shape[1] == 1
1562-
arg = arg[:, 0]
1563-
return arg
1564-
1565-
@property
1566-
def is_view(self) -> bool:
1567-
"""Extension arrays are never treated as views."""
1568-
return False
1610+
# error: No overload variant of "__getitem__" of "ExtensionArray"
1611+
# matches argument type "Tuple[slice, int]"
1612+
arg = arg[:, 0] # type:ignore[call-overload]
1613+
elif isinstance(arg, ABCDataFrame):
1614+
# 2022-01-06 only reached for setitem
1615+
# TODO: should we avoid getting here with DataFrame?
1616+
assert arg.shape[1] == 1
1617+
arg = arg._ixs(0, axis=1)._values
15691618

1570-
@cache_readonly
1571-
def is_numeric(self):
1572-
return self.values.dtype._is_numeric
1619+
return arg
15731620

1574-
def setitem(self, indexer, value):
1621+
def _unwrap_setitem_indexer(self, indexer):
15751622
"""
1576-
Attempt self.values[indexer] = value, possibly creating a new array.
1577-
1578-
This differs from Block.setitem by not allowing setitem to change
1579-
the dtype of the Block.
1580-
1581-
Parameters
1582-
----------
1583-
indexer : tuple, list-like, array-like, slice, int
1584-
The subset of self.values to set
1585-
value : object
1586-
The value being set
1587-
1588-
Returns
1589-
-------
1590-
Block
1623+
Adapt a 2D-indexer to our 1D values.
15911624
1592-
Notes
1593-
-----
1594-
`indexer` is a direct slice/positional indexer. `value` must
1595-
be a compatible shape.
1625+
This is intended for 'setitem', not 'iget' or '_slice'.
15961626
"""
1597-
if not self._can_hold_element(value):
1598-
# see TestSetitemFloatIntervalWithIntIntervalValues
1599-
return self.coerce_to_target_dtype(value).setitem(indexer, value)
1627+
# TODO: ATM this doesn't work for iget/_slice, can we change that?
16001628

16011629
if isinstance(indexer, tuple):
16021630
# TODO(EA2D): not needed with 2D EAs
1603-
# we are always 1-D
1604-
indexer = indexer[0]
1605-
if isinstance(indexer, np.ndarray) and indexer.ndim == 2:
1606-
# GH#44703
1607-
if indexer.shape[1] != 1:
1631+
# Should never have length > 2. Caller is responsible for checking.
1632+
# Length 1 is reached vis setitem_single_block and setitem_single_column
1633+
# each of which pass indexer=(pi,)
1634+
if len(indexer) == 2:
1635+
1636+
if all(isinstance(x, np.ndarray) and x.ndim == 2 for x in indexer):
1637+
# GH#44703 went through indexing.maybe_convert_ix
1638+
first, second = indexer
1639+
if not (
1640+
second.size == 1 and (second == 0).all() and first.shape[1] == 1
1641+
):
1642+
raise NotImplementedError(
1643+
"This should not be reached. Please report a bug at "
1644+
"github.com/pandas-dev/pandas/"
1645+
)
1646+
indexer = first[:, 0]
1647+
1648+
elif lib.is_integer(indexer[1]) and indexer[1] == 0:
1649+
# reached via setitem_single_block passing the whole indexer
1650+
indexer = indexer[0]
1651+
else:
16081652
raise NotImplementedError(
16091653
"This should not be reached. Please report a bug at "
16101654
"github.com/pandas-dev/pandas/"
16111655
)
1612-
indexer = indexer[:, 0]
1656+
return indexer
16131657

1614-
# TODO(EA2D): not needed with 2D EAS
1615-
if isinstance(value, (np.ndarray, ExtensionArray)) and value.ndim == 2:
1616-
assert value.shape[1] == 1
1617-
# error: No overload variant of "__getitem__" of "ExtensionArray"
1618-
# matches argument type "Tuple[slice, int]"
1619-
value = value[:, 0] # type: ignore[call-overload]
1620-
elif isinstance(value, ABCDataFrame):
1621-
# TODO: should we avoid getting here with DataFrame?
1622-
assert value.shape[1] == 1
1623-
value = value._ixs(0, axis=1)._values
1658+
@property
1659+
def is_view(self) -> bool:
1660+
"""Extension arrays are never treated as views."""
1661+
return False
16241662

1625-
check_setitem_lengths(indexer, value, self.values)
1626-
self.values[indexer] = value
1627-
return self
1663+
@cache_readonly
1664+
def is_numeric(self):
1665+
return self.values.dtype._is_numeric
16281666

16291667
def take_nd(
16301668
self,
@@ -1802,18 +1840,6 @@ def is_view(self) -> bool:
18021840
# check the ndarray values of the DatetimeIndex values
18031841
return self.values._ndarray.base is not None
18041842

1805-
def setitem(self, indexer, value):
1806-
if not self._can_hold_element(value):
1807-
return self.coerce_to_target_dtype(value).setitem(indexer, value)
1808-
1809-
values = self.values
1810-
if self.ndim > 1:
1811-
# Dont transpose with ndim=1 bc we would fail to invalidate
1812-
# arr.freq
1813-
values = values.T
1814-
values[indexer] = value
1815-
return self
1816-
18171843
def diff(self, n: int, axis: int = 0) -> list[Block]:
18181844
"""
18191845
1st discrete difference.

0 commit comments

Comments
 (0)