@@ -96,8 +96,8 @@ class BaseXMLFormatter:
9696 def __init__ (
9797 self ,
9898 frame : DataFrame ,
99- path_or_buffer : FilePath | WriteBuffer [bytes ] | None = None ,
100- index : bool | None = True ,
99+ path_or_buffer : FilePath | WriteBuffer [bytes ] | WriteBuffer [ str ] | None = None ,
100+ index : bool = True ,
101101 root_name : str | None = "data" ,
102102 row_name : str | None = "row" ,
103103 na_rep : str | None = None ,
@@ -108,7 +108,7 @@ def __init__(
108108 encoding : str = "utf-8" ,
109109 xml_declaration : bool | None = True ,
110110 pretty_print : bool | None = True ,
111- stylesheet : FilePath | ReadBuffer [str ] | None = None ,
111+ stylesheet : FilePath | ReadBuffer [str ] | ReadBuffer [ bytes ] | None = None ,
112112 compression : CompressionOptions = "infer" ,
113113 storage_options : StorageOptions = None ,
114114 ) -> None :
@@ -132,6 +132,11 @@ def __init__(
132132 self .orig_cols = self .frame .columns .tolist ()
133133 self .frame_dicts = self .process_dataframe ()
134134
135+ self .validate_columns ()
136+ self .validate_encoding ()
137+ self .prefix_uri = self .get_prefix_uri ()
138+ self .handle_indexes ()
139+
135140 def build_tree (self ) -> bytes :
136141 """
137142 Build tree from data.
@@ -189,8 +194,8 @@ def process_dataframe(self) -> dict[int | str, dict[str, Any]]:
189194 if self .index :
190195 df = df .reset_index ()
191196
192- if self .na_rep :
193- df = df .replace ({ None : self .na_rep , float ( "nan" ): self . na_rep } )
197+ if self .na_rep is not None :
198+ df = df .fillna ( self .na_rep )
194199
195200 return df .to_dict (orient = "index" )
196201
@@ -247,17 +252,37 @@ def other_namespaces(self) -> dict:
247252
248253 return nmsp_dict
249254
250- def build_attribs (self ) -> None :
255+ def build_attribs (self , d : dict [ str , Any ], elem_row : Any ) -> Any :
251256 """
252257 Create attributes of row.
253258
254259 This method adds attributes using attr_cols to row element and
255260 works with tuples for multindex or hierarchical columns.
256261 """
257262
258- raise AbstractMethodError (self )
263+ if not self .attr_cols :
264+ return elem_row
265+
266+ for col in self .attr_cols :
267+ attr_name = self ._get_flat_col_name (col )
268+ try :
269+ if not isna (d [col ]):
270+ elem_row .attrib [attr_name ] = str (d [col ])
271+ except KeyError :
272+ raise KeyError (f"no valid column, { col } " )
273+ return elem_row
274+
275+ def _get_flat_col_name (self , col : str | tuple ) -> str :
276+ flat_col = col
277+ if isinstance (col , tuple ):
278+ flat_col = (
279+ "" .join ([str (c ) for c in col ]).strip ()
280+ if "" in col
281+ else "_" .join ([str (c ) for c in col ]).strip ()
282+ )
283+ return f"{ self .prefix_uri } { flat_col } "
259284
260- def build_elems (self ) -> None :
285+ def build_elems (self , d : dict [ str , Any ], elem_row : Any ) -> None :
261286 """
262287 Create child elements of row.
263288
@@ -267,6 +292,19 @@ def build_elems(self) -> None:
267292
268293 raise AbstractMethodError (self )
269294
295+ def _build_elems (self , sub_element_cls , d : dict [str , Any ], elem_row : Any ) -> None :
296+
297+ if not self .elem_cols :
298+ return
299+
300+ for col in self .elem_cols :
301+ elem_name = self ._get_flat_col_name (col )
302+ try :
303+ val = None if isna (d [col ]) or d [col ] == "" else str (d [col ])
304+ sub_element_cls (elem_row , elem_name ).text = val
305+ except KeyError :
306+ raise KeyError (f"no valid column, { col } " )
307+
270308 def write_output (self ) -> str | None :
271309 xml_doc = self .build_tree ()
272310
@@ -291,14 +329,6 @@ class EtreeXMLFormatter(BaseXMLFormatter):
291329 modules: `xml.etree.ElementTree` and `xml.dom.minidom`.
292330 """
293331
294- def __init__ (self , * args , ** kwargs ) -> None :
295- super ().__init__ (* args , ** kwargs )
296-
297- self .validate_columns ()
298- self .validate_encoding ()
299- self .handle_indexes ()
300- self .prefix_uri = self .get_prefix_uri ()
301-
302332 def build_tree (self ) -> bytes :
303333 from xml .etree .ElementTree import (
304334 Element ,
@@ -311,16 +341,15 @@ def build_tree(self) -> bytes:
311341 )
312342
313343 for d in self .frame_dicts .values ():
314- self .d = d
315- self .elem_row = SubElement (self .root , f"{ self .prefix_uri } { self .row_name } " )
344+ elem_row = SubElement (self .root , f"{ self .prefix_uri } { self .row_name } " )
316345
317346 if not self .attr_cols and not self .elem_cols :
318- self .elem_cols = list (self . d .keys ())
319- self .build_elems ()
347+ self .elem_cols = list (d .keys ())
348+ self .build_elems (d , elem_row )
320349
321350 else :
322- self .build_attribs ()
323- self .build_elems ()
351+ elem_row = self .build_attribs (d , elem_row )
352+ self .build_elems (d , elem_row )
324353
325354 self .out_xml = tostring (self .root , method = "xml" , encoding = self .encoding )
326355
@@ -357,56 +386,10 @@ def get_prefix_uri(self) -> str:
357386
358387 return uri
359388
360- def build_attribs (self ) -> None :
361- if not self .attr_cols :
362- return
363-
364- for col in self .attr_cols :
365- flat_col = col
366- if isinstance (col , tuple ):
367- flat_col = (
368- "" .join ([str (c ) for c in col ]).strip ()
369- if "" in col
370- else "_" .join ([str (c ) for c in col ]).strip ()
371- )
372-
373- attr_name = f"{ self .prefix_uri } { flat_col } "
374- try :
375- val = (
376- None
377- if self .d [col ] is None or self .d [col ] != self .d [col ]
378- else str (self .d [col ])
379- )
380- if val is not None :
381- self .elem_row .attrib [attr_name ] = val
382- except KeyError :
383- raise KeyError (f"no valid column, { col } " )
384-
385- def build_elems (self ) -> None :
389+ def build_elems (self , d : dict [str , Any ], elem_row : Any ) -> None :
386390 from xml .etree .ElementTree import SubElement
387391
388- if not self .elem_cols :
389- return
390-
391- for col in self .elem_cols :
392- flat_col = col
393- if isinstance (col , tuple ):
394- flat_col = (
395- "" .join ([str (c ) for c in col ]).strip ()
396- if "" in col
397- else "_" .join ([str (c ) for c in col ]).strip ()
398- )
399-
400- elem_name = f"{ self .prefix_uri } { flat_col } "
401- try :
402- val = (
403- None
404- if self .d [col ] in [None , "" ] or self .d [col ] != self .d [col ]
405- else str (self .d [col ])
406- )
407- SubElement (self .elem_row , elem_name ).text = val
408- except KeyError :
409- raise KeyError (f"no valid column, { col } " )
392+ self ._build_elems (SubElement , d , elem_row )
410393
411394 def prettify_tree (self ) -> bytes :
412395 """
@@ -458,12 +441,7 @@ class LxmlXMLFormatter(BaseXMLFormatter):
458441 def __init__ (self , * args , ** kwargs ) -> None :
459442 super ().__init__ (* args , ** kwargs )
460443
461- self .validate_columns ()
462- self .validate_encoding ()
463- self .prefix_uri = self .get_prefix_uri ()
464-
465444 self .convert_empty_str_key ()
466- self .handle_indexes ()
467445
468446 def build_tree (self ) -> bytes :
469447 """
@@ -481,16 +459,15 @@ def build_tree(self) -> bytes:
481459 self .root = Element (f"{ self .prefix_uri } { self .root_name } " , nsmap = self .namespaces )
482460
483461 for d in self .frame_dicts .values ():
484- self .d = d
485- self .elem_row = SubElement (self .root , f"{ self .prefix_uri } { self .row_name } " )
462+ elem_row = SubElement (self .root , f"{ self .prefix_uri } { self .row_name } " )
486463
487464 if not self .attr_cols and not self .elem_cols :
488- self .elem_cols = list (self . d .keys ())
489- self .build_elems ()
465+ self .elem_cols = list (d .keys ())
466+ self .build_elems (d , elem_row )
490467
491468 else :
492- self .build_attribs ()
493- self .build_elems ()
469+ elem_row = self .build_attribs (d , elem_row )
470+ self .build_elems (d , elem_row )
494471
495472 self .out_xml = tostring (
496473 self .root ,
@@ -529,54 +506,10 @@ def get_prefix_uri(self) -> str:
529506
530507 return uri
531508
532- def build_attribs (self ) -> None :
533- if not self .attr_cols :
534- return
535-
536- for col in self .attr_cols :
537- flat_col = col
538- if isinstance (col , tuple ):
539- flat_col = (
540- "" .join ([str (c ) for c in col ]).strip ()
541- if "" in col
542- else "_" .join ([str (c ) for c in col ]).strip ()
543- )
544-
545- attr_name = f"{ self .prefix_uri } { flat_col } "
546- try :
547- val = (
548- None
549- if self .d [col ] is None or self .d [col ] != self .d [col ]
550- else str (self .d [col ])
551- )
552- if val is not None :
553- self .elem_row .attrib [attr_name ] = val
554- except KeyError :
555- raise KeyError (f"no valid column, { col } " )
556-
557- def build_elems (self ) -> None :
509+ def build_elems (self , d : dict [str , Any ], elem_row : Any ) -> None :
558510 from lxml .etree import SubElement
559511
560- if not self .elem_cols :
561- return
562-
563- for col in self .elem_cols :
564- flat_col = col
565- if isinstance (col , tuple ):
566- flat_col = (
567- "" .join ([str (c ) for c in col ]).strip ()
568- if "" in col
569- else "_" .join ([str (c ) for c in col ]).strip ()
570- )
571-
572- elem_name = f"{ self .prefix_uri } { flat_col } "
573- try :
574- val = (
575- None if isna (self .d [col ]) or self .d [col ] == "" else str (self .d [col ])
576- )
577- SubElement (self .elem_row , elem_name ).text = val
578- except KeyError :
579- raise KeyError (f"no valid column, { col } " )
512+ self ._build_elems (SubElement , d , elem_row )
580513
581514 def transform_doc (self ) -> bytes :
582515 """
0 commit comments