11# pyright: reportMissingTypeArgument=true, reportMissingParameterType=true
22import dataclasses
33import enum
4+ import functools
45import inspect
56import itertools
67import json
78import re
89from typing import (
910 Any ,
10- Callable ,
1111 Dict ,
1212 Iterable ,
1313 List ,
1414 Literal ,
1515 Mapping ,
1616 Optional ,
17- Protocol ,
1817 Sequence ,
1918 Set ,
2019 Tuple ,
2524 get_args ,
2625 get_origin ,
2726 get_type_hints ,
28- runtime_checkable ,
2927)
3028
3129__all__ = [
4846__to_snake_case_cache : Dict [str , str ] = {}
4947
5048
49+ @functools .lru_cache (maxsize = 2048 )
5150def to_snake_case (s : str ) -> str :
5251 result = __to_snake_case_cache .get (s , __not_valid )
5352 if result is __not_valid :
@@ -66,6 +65,7 @@ def to_snake_case(s: str) -> str:
6665__to_snake_camel_cache : Dict [str , str ] = {}
6766
6867
68+ @functools .lru_cache (maxsize = 2048 )
6969def to_camel_case (s : str ) -> str :
7070 result = __to_snake_camel_cache .get (s , __not_valid )
7171 if result is __not_valid :
@@ -91,20 +91,6 @@ def _decode_case(cls, s: str) -> str:
9191 return to_snake_case (s )
9292
9393
94- @runtime_checkable
95- class HasCaseEncoder (Protocol ):
96- @classmethod
97- def _encode_case (cls , s : str ) -> str : # pragma: no cover
98- ...
99-
100-
101- @runtime_checkable
102- class HasCaseDecoder (Protocol ):
103- @classmethod
104- def _decode_case (cls , s : str ) -> str : # pragma: no cover
105- ...
106-
107-
10894_T = TypeVar ("_T" )
10995
11096
@@ -118,21 +104,13 @@ def _decode_case(cls, s: str) -> str:
118104 return s
119105
120106
121- __default_config = DefaultConfig ()
122-
123-
124- def __get_config (obj : Any , entry_protocol : Type [_T ]) -> _T :
125- if isinstance (obj , entry_protocol ):
126- return obj
127- return cast (_T , __default_config )
128-
129-
130107def encode_case (obj : Any , field : dataclasses .Field ) -> str : # type: ignore
131108 alias = field .metadata .get ("alias" , None )
132109 if alias :
133110 return str (alias )
134-
135- return __get_config (obj , HasCaseEncoder )._encode_case (field .name ) # type: ignore
111+ if hasattr (obj , "_encode_case" ):
112+ return str (obj ._encode_case (field .name ))
113+ return field .name
136114
137115
138116def decode_case (type : Type [_T ], name : str ) -> str :
@@ -144,7 +122,10 @@ def decode_case(type: Type[_T], name: str) -> str:
144122 if field :
145123 return field .name
146124
147- return __get_config (type , HasCaseDecoder )._decode_case (name ) # type: ignore
125+ if hasattr (type , "_decode_case" ):
126+ return str (type ._decode_case (name )) # type: ignore[attr-defined]
127+
128+ return name
148129
149130
150131def __default (o : Any ) -> Any :
@@ -365,42 +346,34 @@ def as_dict(
365346 value : Any ,
366347 * ,
367348 remove_defaults : bool = False ,
368- dict_factory : Callable [[Any ], Dict [str , Any ]] = dict ,
369349 encode : bool = True ,
370350) -> Dict [str , Any ]:
371351 if not dataclasses .is_dataclass (value ):
372352 raise TypeError ("as_dict() should be called on dataclass instances" )
373353
374- return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , dict_factory , encode ))
354+ return cast (Dict [str , Any ], _as_dict_inner (value , remove_defaults , encode ))
375355
376356
377357def _as_dict_inner (
378358 value : Any ,
379359 remove_defaults : bool ,
380- dict_factory : Callable [[Any ], Dict [str , Any ]],
381360 encode : bool = True ,
382361) -> Any :
383362 if dataclasses .is_dataclass (value ):
384- result = []
385- for f in dataclasses .fields (value ):
386- v = _as_dict_inner (getattr (value , f .name ), remove_defaults , dict_factory )
387-
388- if remove_defaults and v == f .default :
389- continue
390- result .append ((encode_case (value , f ) if encode else f .name , v ))
391- return dict_factory (result )
363+ return {
364+ encode_case (value , f ) if encode else f .name : _as_dict_inner (getattr (value , f .name ), remove_defaults )
365+ for f in dataclasses .fields (value )
366+ if not remove_defaults or getattr (value , f .name ) != f .default
367+ }
392368
393369 if isinstance (value , tuple ) and hasattr (value , "_fields" ):
394- return type ( value )( * [_as_dict_inner (v , remove_defaults , dict_factory ) for v in value ])
370+ return [_as_dict_inner (v , remove_defaults ) for v in value ]
395371
396372 if isinstance (value , (list , tuple )):
397- return type ( value )( _as_dict_inner (v , remove_defaults , dict_factory ) for v in value )
373+ return [ _as_dict_inner (v , remove_defaults ) for v in value ]
398374
399375 if isinstance (value , dict ):
400- return type (value )(
401- (_as_dict_inner (k , remove_defaults , dict_factory ), _as_dict_inner (v , remove_defaults , dict_factory ))
402- for k , v in value .items ()
403- )
376+ return {_as_dict_inner (k , remove_defaults ): _as_dict_inner (v , remove_defaults ) for k , v in value .items ()}
404377
405378 return value
406379
0 commit comments