8989 "TrinoQuery" ,
9090 "TrinoRequest" ,
9191 "PROXIES" ,
92- "SpooledData " ,
92+ "DecodableSegment " ,
9393 "SpooledSegment" ,
9494 "InlineSegment" ,
9595 "Segment"
@@ -920,16 +920,16 @@ def fetch(self) -> List[Union[List[Any]], Any]:
920920 if isinstance (status .rows , dict ):
921921 # spooling protocol
922922 rows = cast (_SpooledProtocolResponseTO , rows )
923- segments = self ._to_segments (rows )
923+ spooled = self ._to_segments (rows )
924924 if self ._fetch_mode == "segments" :
925- return segments
926- return list (SegmentIterator (segments , self ._row_mapper ))
925+ return spooled
926+ return list (SegmentIterator (spooled , self ._row_mapper ))
927927 elif isinstance (status .rows , list ):
928928 return self ._row_mapper .map (rows )
929929 else :
930930 raise ValueError (f"Unexpected type: { type (status .rows )} " )
931931
932- def _to_segments (self , rows : _SpooledProtocolResponseTO ) -> SpooledData :
932+ def _to_segments (self , rows : _SpooledProtocolResponseTO ) -> List [ DecodableSegment ] :
933933 encoding = rows ["encoding" ]
934934 metadata = rows ["metadata" ] if "metadata" in rows else None
935935 segments = []
@@ -944,7 +944,7 @@ def _to_segments(self, rows: _SpooledProtocolResponseTO) -> SpooledData:
944944 else :
945945 raise ValueError (f"Unsupported segment type: { segment_type } " )
946946
947- return SpooledData ( encoding , metadata , segments )
947+ return list ( map ( lambda segment : DecodableSegment ( encoding , metadata , segment ), segments ) )
948948
949949 def cancel (self ) -> None :
950950 """Cancel the current query"""
@@ -1164,46 +1164,44 @@ def __repr__(self):
11641164 )
11651165
11661166
1167- class SpooledData :
1167+ class DecodableSegment :
11681168 """
11691169 Represents a collection of spooled segments of data, with an encoding format.
11701170
11711171 Attributes:
11721172 encoding (str): The encoding format of the spooled data.
1173- metadata (_SegmentMetadataTO): Metadata for all segments
1174- segments (List[ Segment] ): The list of segments in the spooled data.
1173+ metadata (_SegmentMetadataTO): Metadata for all segments in the query
1174+ segment ( Segment): The spooled segment data
11751175 """
1176- def __init__ (self , encoding : str , metadata : _SegmentMetadataTO , segments : List [ Segment ] ) -> None :
1176+ def __init__ (self , encoding : str , metadata : _SegmentMetadataTO , segment : Segment ) -> None :
11771177 self ._encoding = encoding
11781178 self ._metadata = metadata
1179- self ._segments = segments
1180- self ._segments_iterator = iter (segments )
1179+ self ._segment = segment
11811180
11821181 @property
11831182 def encoding (self ):
11841183 return self ._encoding
11851184
11861185 @property
1187- def segments (self ):
1188- return self ._segments
1189-
1190- def __iter__ (self ) -> Iterator [Tuple ["SpooledData" , "Segment" ]]:
1191- return self
1186+ def segment (self ):
1187+ return self ._segment
11921188
1193- def __next__ (self ) -> Tuple ["SpooledData" , "Segment" ]:
1194- return self , next (self ._segments_iterator )
1189+ @property
1190+ def metadata (self ):
1191+ return self ._metadata
11951192
11961193 def __repr__ (self ):
1197- return (f"SpooledData (encoding={ self ._encoding } , metadata={ self ._metadata } , segments= { list ( self ._segments ) } )" )
1194+ return (f"DecodableSegment (encoding={ self ._encoding } , metadata={ self ._metadata } , segment= { self ._segment } )" )
11981195
11991196
12001197class SegmentIterator :
1201- def __init__ (self , spooled_data : SpooledData , mapper : RowMapper ) -> None :
1202- self ._segments = iter (spooled_data ._segments )
1203- self ._decoder = SegmentDecoder (CompressedQueryDataDecoderFactory (mapper ).create (spooled_data .encoding ))
1198+ def __init__ (self , segments : Union [DecodableSegment , List [DecodableSegment ]], mapper : RowMapper ) -> None :
1199+ self ._segments = iter (segments if isinstance (segments , List ) else [segments ])
1200+ self ._mapper = mapper
1201+ self ._decoder = None
12041202 self ._rows : Iterator [List [List [Any ]]] = iter ([])
12051203 self ._finished = False
1206- self ._current_segment : Optional [Segment ] = None
1204+ self ._current_segment : Optional [DecodableSegment ] = None
12071205
12081206 def __iter__ (self ) -> Iterator [List [Any ]]:
12091207 return self
@@ -1214,16 +1212,22 @@ def __next__(self) -> List[Any]:
12141212 try :
12151213 return next (self ._rows )
12161214 except StopIteration :
1217- if self ._current_segment and isinstance (self ._current_segment , SpooledSegment ):
1218- self ._current_segment .acknowledge ()
12191215 if self ._finished :
12201216 raise StopIteration
12211217 self ._load_next_segment ()
12221218
12231219 def _load_next_segment (self ):
12241220 try :
1225- self ._current_segment = segment = next (self ._segments )
1226- self ._rows = iter (self ._decoder .decode (segment ))
1221+ if self ._current_segment :
1222+ segment = self ._current_segment .segment
1223+ if isinstance (segment , SpooledSegment ):
1224+ segment .acknowledge ()
1225+
1226+ self ._current_segment = next (self ._segments )
1227+ if self ._decoder is None :
1228+ self ._decoder = SegmentDecoder (CompressedQueryDataDecoderFactory (self ._mapper )
1229+ .create (self ._current_segment .encoding ))
1230+ self ._rows = iter (self ._decoder .decode (self ._current_segment .segment ))
12271231 except StopIteration :
12281232 self ._finished = True
12291233
0 commit comments