11from __future__ import annotations
22
33import itertools
4- from typing import TYPE_CHECKING
4+ from typing import (
5+ TYPE_CHECKING ,
6+ cast ,
7+ )
58import warnings
69
710import numpy as np
@@ -452,7 +455,7 @@ def _unstack_multiple(data, clocs, fill_value=None):
452455 return unstacked
453456
454457
455- def unstack (obj , level , fill_value = None ):
458+ def unstack (obj : Series | DataFrame , level , fill_value = None ):
456459
457460 if isinstance (level , (tuple , list )):
458461 if len (level ) != 1 :
@@ -489,19 +492,20 @@ def unstack(obj, level, fill_value=None):
489492 )
490493
491494
492- def _unstack_frame (obj , level , fill_value = None ):
495+ def _unstack_frame (obj : DataFrame , level , fill_value = None ):
496+ assert isinstance (obj .index , MultiIndex ) # checked by caller
497+ unstacker = _Unstacker (obj .index , level = level , constructor = obj ._constructor )
498+
493499 if not obj ._can_fast_transpose :
494- unstacker = _Unstacker (obj .index , level = level )
495500 mgr = obj ._mgr .unstack (unstacker , fill_value = fill_value )
496501 return obj ._constructor (mgr )
497502 else :
498- unstacker = _Unstacker (obj .index , level = level , constructor = obj ._constructor )
499503 return unstacker .get_result (
500504 obj ._values , value_columns = obj .columns , fill_value = fill_value
501505 )
502506
503507
504- def _unstack_extension_series (series , level , fill_value ):
508+ def _unstack_extension_series (series : Series , level , fill_value ) -> DataFrame :
505509 """
506510 Unstack an ExtensionArray-backed Series.
507511
@@ -534,14 +538,14 @@ def _unstack_extension_series(series, level, fill_value):
534538 return result
535539
536540
537- def stack (frame , level = - 1 , dropna = True ):
541+ def stack (frame : DataFrame , level = - 1 , dropna : bool = True ):
538542 """
539543 Convert DataFrame to Series with multi-level Index. Columns become the
540544 second level of the resulting hierarchical index
541545
542546 Returns
543547 -------
544- stacked : Series
548+ stacked : Series or DataFrame
545549 """
546550
547551 def factorize (index ):
@@ -676,8 +680,10 @@ def _stack_multi_column_index(columns: MultiIndex) -> MultiIndex:
676680 )
677681
678682
679- def _stack_multi_columns (frame , level_num = - 1 , dropna = True ):
680- def _convert_level_number (level_num : int , columns ):
683+ def _stack_multi_columns (
684+ frame : DataFrame , level_num : int = - 1 , dropna : bool = True
685+ ) -> DataFrame :
686+ def _convert_level_number (level_num : int , columns : Index ):
681687 """
682688 Logic for converting the level number to something we can safely pass
683689 to swaplevel.
@@ -690,32 +696,36 @@ def _convert_level_number(level_num: int, columns):
690696
691697 return level_num
692698
693- this = frame .copy ()
699+ this = frame .copy (deep = False )
700+ mi_cols = this .columns # cast(MultiIndex, this.columns)
701+ assert isinstance (mi_cols , MultiIndex ) # caller is responsible
694702
695703 # this makes life much simpler
696- if level_num != frame . columns .nlevels - 1 :
704+ if level_num != mi_cols .nlevels - 1 :
697705 # roll levels to put selected level at end
698- roll_columns = this . columns
699- for i in range (level_num , frame . columns .nlevels - 1 ):
706+ roll_columns = mi_cols
707+ for i in range (level_num , mi_cols .nlevels - 1 ):
700708 # Need to check if the ints conflict with level names
701709 lev1 = _convert_level_number (i , roll_columns )
702710 lev2 = _convert_level_number (i + 1 , roll_columns )
703711 roll_columns = roll_columns .swaplevel (lev1 , lev2 )
704- this .columns = roll_columns
712+ this .columns = mi_cols = roll_columns
705713
706- if not this . columns ._is_lexsorted ():
714+ if not mi_cols ._is_lexsorted ():
707715 # Workaround the edge case where 0 is one of the column names,
708716 # which interferes with trying to sort based on the first
709717 # level
710- level_to_sort = _convert_level_number (0 , this . columns )
718+ level_to_sort = _convert_level_number (0 , mi_cols )
711719 this = this .sort_index (level = level_to_sort , axis = 1 )
720+ mi_cols = this .columns
712721
713- new_columns = _stack_multi_column_index (this .columns )
722+ mi_cols = cast (MultiIndex , mi_cols )
723+ new_columns = _stack_multi_column_index (mi_cols )
714724
715725 # time to ravel the values
716726 new_data = {}
717- level_vals = this . columns .levels [- 1 ]
718- level_codes = sorted (set (this . columns .codes [- 1 ]))
727+ level_vals = mi_cols .levels [- 1 ]
728+ level_codes = sorted (set (mi_cols .codes [- 1 ]))
719729 level_vals_nan = level_vals .insert (len (level_vals ), None )
720730
721731 level_vals_used = np .take (level_vals_nan , level_codes )
0 commit comments