22
33from pandas .compat import zip
44from pandas .core .common import (isnull , _values_from_object , is_bool_dtype , is_list_like ,
5- is_categorical_dtype , is_object_dtype )
5+ is_categorical_dtype , is_object_dtype , take_1d )
66import pandas .compat as compat
77from pandas .core .base import AccessorProperty , NoNewAttributesMixin
88from pandas .util .decorators import Appender , deprecate_kwarg
@@ -1003,7 +1003,7 @@ def str_encode(arr, encoding, errors="strict"):
10031003
10041004def _noarg_wrapper (f , docstring = None , ** kargs ):
10051005 def wrapper (self ):
1006- result = _na_map (f , self .series , ** kargs )
1006+ result = _na_map (f , self ._data , ** kargs )
10071007 return self ._wrap_result (result )
10081008
10091009 wrapper .__name__ = f .__name__
@@ -1017,15 +1017,15 @@ def wrapper(self):
10171017
10181018def _pat_wrapper (f , flags = False , na = False , ** kwargs ):
10191019 def wrapper1 (self , pat ):
1020- result = f (self .series , pat )
1020+ result = f (self ._data , pat )
10211021 return self ._wrap_result (result )
10221022
10231023 def wrapper2 (self , pat , flags = 0 , ** kwargs ):
1024- result = f (self .series , pat , flags = flags , ** kwargs )
1024+ result = f (self ._data , pat , flags = flags , ** kwargs )
10251025 return self ._wrap_result (result )
10261026
10271027 def wrapper3 (self , pat , na = np .nan ):
1028- result = f (self .series , pat , na = na )
1028+ result = f (self ._data , pat , na = na )
10291029 return self ._wrap_result (result )
10301030
10311031 wrapper = wrapper3 if na else wrapper2 if flags else wrapper1
@@ -1059,8 +1059,11 @@ class StringMethods(NoNewAttributesMixin):
10591059 >>> s.str.replace('_', '')
10601060 """
10611061
1062- def __init__ (self , series ):
1063- self .series = series
1062+ def __init__ (self , data ):
1063+ self ._is_categorical = is_categorical_dtype (data )
1064+ self ._data = data .cat .categories if self ._is_categorical else data
1065+ # save orig to blow up categoricals to the right type
1066+ self ._orig = data
10641067 self ._freeze ()
10651068
10661069 def __getitem__ (self , key ):
@@ -1078,7 +1081,15 @@ def __iter__(self):
10781081 i += 1
10791082 g = self .get (i )
10801083
1081- def _wrap_result (self , result , ** kwargs ):
1084+ def _wrap_result (self , result , use_codes = True , name = None ):
1085+
1086+ # for category, we do the stuff on the categories, so blow it up
1087+ # to the full series again
1088+ # But for some operations, we have to do the stuff on the full values,
1089+ # so make it possible to skip this step as the method already did this before
1090+ # the transformation...
1091+ if use_codes and self ._is_categorical :
1092+ result = take_1d (result , self ._orig .cat .codes )
10821093
10831094 # leave as it is to keep extract and get_dummies results
10841095 # can be merged to _wrap_result_expand in v0.17
@@ -1088,29 +1099,34 @@ def _wrap_result(self, result, **kwargs):
10881099
10891100 if not hasattr (result , 'ndim' ):
10901101 return result
1091- name = kwargs . get ( ' name' ) or getattr (result , 'name' , None ) or self .series .name
1102+ name = name or getattr (result , 'name' , None ) or self ._orig .name
10921103
10931104 if result .ndim == 1 :
1094- if isinstance (self .series , Index ):
1105+ if isinstance (self ._orig , Index ):
10951106 # if result is a boolean np.array, return the np.array
10961107 # instead of wrapping it into a boolean Index (GH 8875)
10971108 if is_bool_dtype (result ):
10981109 return result
10991110 return Index (result , name = name )
1100- return Series (result , index = self .series .index , name = name )
1111+ return Series (result , index = self ._orig .index , name = name )
11011112 else :
11021113 assert result .ndim < 3
1103- return DataFrame (result , index = self .series .index )
1114+ return DataFrame (result , index = self ._orig .index )
11041115
11051116 def _wrap_result_expand (self , result , expand = False ):
11061117 if not isinstance (expand , bool ):
11071118 raise ValueError ("expand must be True or False" )
11081119
1120+ # for category, we do the stuff on the categories, so blow it up
1121+ # to the full series again
1122+ if self ._is_categorical :
1123+ result = take_1d (result , self ._orig .cat .codes )
1124+
11091125 from pandas .core .index import Index , MultiIndex
11101126 if not hasattr (result , 'ndim' ):
11111127 return result
11121128
1113- if isinstance (self .series , Index ):
1129+ if isinstance (self ._orig , Index ):
11141130 name = getattr (result , 'name' , None )
11151131 # if result is a boolean np.array, return the np.array
11161132 # instead of wrapping it into a boolean Index (GH 8875)
@@ -1123,36 +1139,38 @@ def _wrap_result_expand(self, result, expand=False):
11231139 else :
11241140 return Index (result , name = name )
11251141 else :
1126- index = self .series .index
1142+ index = self ._orig .index
11271143 if expand :
11281144 def cons_row (x ):
11291145 if is_list_like (x ):
11301146 return x
11311147 else :
11321148 return [ x ]
1133- cons = self .series ._constructor_expanddim
1149+ cons = self ._orig ._constructor_expanddim
11341150 data = [cons_row (x ) for x in result ]
11351151 return cons (data , index = index )
11361152 else :
11371153 name = getattr (result , 'name' , None )
1138- cons = self .series ._constructor
1154+ cons = self ._orig ._constructor
11391155 return cons (result , name = name , index = index )
11401156
11411157 @copy (str_cat )
11421158 def cat (self , others = None , sep = None , na_rep = None ):
1143- result = str_cat (self .series , others = others , sep = sep , na_rep = na_rep )
1144- return self ._wrap_result (result )
1159+ data = self ._orig if self ._is_categorical else self ._data
1160+ result = str_cat (data , others = others , sep = sep , na_rep = na_rep )
1161+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
1162+
11451163
11461164 @deprecate_kwarg ('return_type' , 'expand' ,
11471165 mapping = {'series' : False , 'frame' : True })
11481166 @copy (str_split )
11491167 def split (self , pat = None , n = - 1 , expand = False ):
1150- result = str_split (self .series , pat , n = n )
1168+ result = str_split (self ._data , pat , n = n )
11511169 return self ._wrap_result_expand (result , expand = expand )
11521170
11531171 @copy (str_rsplit )
11541172 def rsplit (self , pat = None , n = - 1 , expand = False ):
1155- result = str_rsplit (self .series , pat , n = n )
1173+ result = str_rsplit (self ._data , pat , n = n )
11561174 return self ._wrap_result_expand (result , expand = expand )
11571175
11581176 _shared_docs ['str_partition' ] = ("""
@@ -1203,53 +1221,53 @@ def rsplit(self, pat=None, n=-1, expand=False):
12031221 'also' : 'rpartition : Split the string at the last occurrence of `sep`' })
12041222 def partition (self , pat = ' ' , expand = True ):
12051223 f = lambda x : x .partition (pat )
1206- result = _na_map (f , self .series )
1224+ result = _na_map (f , self ._data )
12071225 return self ._wrap_result_expand (result , expand = expand )
12081226
12091227 @Appender (_shared_docs ['str_partition' ] % {'side' : 'last' ,
12101228 'return' : '3 elements containing two empty strings, followed by the string itself' ,
12111229 'also' : 'partition : Split the string at the first occurrence of `sep`' })
12121230 def rpartition (self , pat = ' ' , expand = True ):
12131231 f = lambda x : x .rpartition (pat )
1214- result = _na_map (f , self .series )
1232+ result = _na_map (f , self ._data )
12151233 return self ._wrap_result_expand (result , expand = expand )
12161234
12171235 @copy (str_get )
12181236 def get (self , i ):
1219- result = str_get (self .series , i )
1237+ result = str_get (self ._data , i )
12201238 return self ._wrap_result (result )
12211239
12221240 @copy (str_join )
12231241 def join (self , sep ):
1224- result = str_join (self .series , sep )
1242+ result = str_join (self ._data , sep )
12251243 return self ._wrap_result (result )
12261244
12271245 @copy (str_contains )
12281246 def contains (self , pat , case = True , flags = 0 , na = np .nan , regex = True ):
1229- result = str_contains (self .series , pat , case = case , flags = flags ,
1247+ result = str_contains (self ._data , pat , case = case , flags = flags ,
12301248 na = na , regex = regex )
12311249 return self ._wrap_result (result )
12321250
12331251 @copy (str_match )
12341252 def match (self , pat , case = True , flags = 0 , na = np .nan , as_indexer = False ):
1235- result = str_match (self .series , pat , case = case , flags = flags ,
1253+ result = str_match (self ._data , pat , case = case , flags = flags ,
12361254 na = na , as_indexer = as_indexer )
12371255 return self ._wrap_result (result )
12381256
12391257 @copy (str_replace )
12401258 def replace (self , pat , repl , n = - 1 , case = True , flags = 0 ):
1241- result = str_replace (self .series , pat , repl , n = n , case = case ,
1259+ result = str_replace (self ._data , pat , repl , n = n , case = case ,
12421260 flags = flags )
12431261 return self ._wrap_result (result )
12441262
12451263 @copy (str_repeat )
12461264 def repeat (self , repeats ):
1247- result = str_repeat (self .series , repeats )
1265+ result = str_repeat (self ._data , repeats )
12481266 return self ._wrap_result (result )
12491267
12501268 @copy (str_pad )
12511269 def pad (self , width , side = 'left' , fillchar = ' ' ):
1252- result = str_pad (self .series , width , side = side , fillchar = fillchar )
1270+ result = str_pad (self ._data , width , side = side , fillchar = fillchar )
12531271 return self ._wrap_result (result )
12541272
12551273 _shared_docs ['str_pad' ] = ("""
@@ -1297,27 +1315,27 @@ def zfill(self, width):
12971315 -------
12981316 filled : Series/Index of objects
12991317 """
1300- result = str_pad (self .series , width , side = 'left' , fillchar = '0' )
1318+ result = str_pad (self ._data , width , side = 'left' , fillchar = '0' )
13011319 return self ._wrap_result (result )
13021320
13031321 @copy (str_slice )
13041322 def slice (self , start = None , stop = None , step = None ):
1305- result = str_slice (self .series , start , stop , step )
1323+ result = str_slice (self ._data , start , stop , step )
13061324 return self ._wrap_result (result )
13071325
13081326 @copy (str_slice_replace )
13091327 def slice_replace (self , start = None , stop = None , repl = None ):
1310- result = str_slice_replace (self .series , start , stop , repl )
1328+ result = str_slice_replace (self ._data , start , stop , repl )
13111329 return self ._wrap_result (result )
13121330
13131331 @copy (str_decode )
13141332 def decode (self , encoding , errors = "strict" ):
1315- result = str_decode (self .series , encoding , errors )
1333+ result = str_decode (self ._data , encoding , errors )
13161334 return self ._wrap_result (result )
13171335
13181336 @copy (str_encode )
13191337 def encode (self , encoding , errors = "strict" ):
1320- result = str_encode (self .series , encoding , errors )
1338+ result = str_encode (self ._data , encoding , errors )
13211339 return self ._wrap_result (result )
13221340
13231341 _shared_docs ['str_strip' ] = ("""
@@ -1332,34 +1350,37 @@ def encode(self, encoding, errors="strict"):
13321350 @Appender (_shared_docs ['str_strip' ] % dict (side = 'left and right sides' ,
13331351 method = 'strip' ))
13341352 def strip (self , to_strip = None ):
1335- result = str_strip (self .series , to_strip , side = 'both' )
1353+ result = str_strip (self ._data , to_strip , side = 'both' )
13361354 return self ._wrap_result (result )
13371355
13381356 @Appender (_shared_docs ['str_strip' ] % dict (side = 'left side' ,
13391357 method = 'lstrip' ))
13401358 def lstrip (self , to_strip = None ):
1341- result = str_strip (self .series , to_strip , side = 'left' )
1359+ result = str_strip (self ._data , to_strip , side = 'left' )
13421360 return self ._wrap_result (result )
13431361
13441362 @Appender (_shared_docs ['str_strip' ] % dict (side = 'right side' ,
13451363 method = 'rstrip' ))
13461364 def rstrip (self , to_strip = None ):
1347- result = str_strip (self .series , to_strip , side = 'right' )
1365+ result = str_strip (self ._data , to_strip , side = 'right' )
13481366 return self ._wrap_result (result )
13491367
13501368 @copy (str_wrap )
13511369 def wrap (self , width , ** kwargs ):
1352- result = str_wrap (self .series , width , ** kwargs )
1370+ result = str_wrap (self ._data , width , ** kwargs )
13531371 return self ._wrap_result (result )
13541372
13551373 @copy (str_get_dummies )
13561374 def get_dummies (self , sep = '|' ):
1357- result = str_get_dummies (self .series , sep )
1358- return self ._wrap_result (result )
1375+ # we need to cast to Series of strings as only that has all
1376+ # methods available for making the dummies...
1377+ data = self ._orig .astype (str ) if self ._is_categorical else self ._data
1378+ result = str_get_dummies (data , sep )
1379+ return self ._wrap_result (result , use_codes = (not self ._is_categorical ))
13591380
13601381 @copy (str_translate )
13611382 def translate (self , table , deletechars = None ):
1362- result = str_translate (self .series , table , deletechars )
1383+ result = str_translate (self ._data , table , deletechars )
13631384 return self ._wrap_result (result )
13641385
13651386 count = _pat_wrapper (str_count , flags = True )
@@ -1369,7 +1390,7 @@ def translate(self, table, deletechars=None):
13691390
13701391 @copy (str_extract )
13711392 def extract (self , pat , flags = 0 ):
1372- result , name = str_extract (self .series , pat , flags = flags )
1393+ result , name = str_extract (self ._data , pat , flags = flags )
13731394 return self ._wrap_result (result , name = name )
13741395
13751396 _shared_docs ['find' ] = ("""
@@ -1398,13 +1419,13 @@ def extract(self, pat, flags=0):
13981419 @Appender (_shared_docs ['find' ] % dict (side = 'lowest' , method = 'find' ,
13991420 also = 'rfind : Return highest indexes in each strings' ))
14001421 def find (self , sub , start = 0 , end = None ):
1401- result = str_find (self .series , sub , start = start , end = end , side = 'left' )
1422+ result = str_find (self ._data , sub , start = start , end = end , side = 'left' )
14021423 return self ._wrap_result (result )
14031424
14041425 @Appender (_shared_docs ['find' ] % dict (side = 'highest' , method = 'rfind' ,
14051426 also = 'find : Return lowest indexes in each strings' ))
14061427 def rfind (self , sub , start = 0 , end = None ):
1407- result = str_find (self .series , sub , start = start , end = end , side = 'right' )
1428+ result = str_find (self ._data , sub , start = start , end = end , side = 'right' )
14081429 return self ._wrap_result (result )
14091430
14101431 def normalize (self , form ):
@@ -1423,7 +1444,7 @@ def normalize(self, form):
14231444 """
14241445 import unicodedata
14251446 f = lambda x : unicodedata .normalize (form , compat .u_safe (x ))
1426- result = _na_map (f , self .series )
1447+ result = _na_map (f , self ._data )
14271448 return self ._wrap_result (result )
14281449
14291450 _shared_docs ['index' ] = ("""
@@ -1453,13 +1474,13 @@ def normalize(self, form):
14531474 @Appender (_shared_docs ['index' ] % dict (side = 'lowest' , similar = 'find' , method = 'index' ,
14541475 also = 'rindex : Return highest indexes in each strings' ))
14551476 def index (self , sub , start = 0 , end = None ):
1456- result = str_index (self .series , sub , start = start , end = end , side = 'left' )
1477+ result = str_index (self ._data , sub , start = start , end = end , side = 'left' )
14571478 return self ._wrap_result (result )
14581479
14591480 @Appender (_shared_docs ['index' ] % dict (side = 'highest' , similar = 'rfind' , method = 'rindex' ,
14601481 also = 'index : Return lowest indexes in each strings' ))
14611482 def rindex (self , sub , start = 0 , end = None ):
1462- result = str_index (self .series , sub , start = start , end = end , side = 'right' )
1483+ result = str_index (self ._data , sub , start = start , end = end , side = 'right' )
14631484 return self ._wrap_result (result )
14641485
14651486 _shared_docs ['len' ] = ("""
@@ -1553,9 +1574,14 @@ class StringAccessorMixin(object):
15531574 def _make_str_accessor (self ):
15541575 from pandas .core .series import Series
15551576 from pandas .core .index import Index
1556- if isinstance (self , Series ) and not is_object_dtype (self .dtype ):
1557- # this really should exclude all series with any non-string values,
1558- # but that isn't practical for performance reasons until we have a
1577+ if isinstance (self , Series ) and not (
1578+ (is_categorical_dtype (self .dtype ) and
1579+ is_object_dtype (self .values .categories )) or
1580+ (is_object_dtype (self .dtype ))):
1581+ # it's neither a string series not a categorical series with strings
1582+ # inside the categories.
1583+ # this really should exclude all series with any non-string values (instead of test
1584+ # for object dtype), but that isn't practical for performance reasons until we have a
15591585 # str dtype (GH 9343)
15601586 raise AttributeError ("Can only use .str accessor with string "
15611587 "values, which use np.object_ dtype in "
0 commit comments