1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15- """Manages the list of fine-tuned adapters loaded on top of the base model for serving.
15+ """Manages the list of fine-tuned adapters loaded on top of the base model
16+ for serving.
1617"""
1718
1819import logging
1920import dataclasses
2021
2122import jax
2223import jax .numpy as jnp
23- from flax import struct
2424import time
2525import asyncio
2626import functools
@@ -123,7 +123,10 @@ def __init__(
123123
124124 # --- Unsafe Internal methods which assumes that lock is held ---
125125 def _unsafe_transfer_to_hbm (self , adapter_id : str ):
126- """Internal: Transfers an adapter from CPU RAM to HBM. Assumes lock is held."""
126+ """
127+ Internal: Transfers an adapter from CPU RAM to HBM.
128+ Assumes lock is held.
129+ """
127130 if adapter_id not in self .loaded_adapters_cpu :
128131 raise ValueError (f"Adapter '{ adapter_id } ' not loaded in CPU RAM." )
129132
@@ -137,7 +140,7 @@ def _unsafe_transfer_to_hbm(self, adapter_id: str):
137140 )
138141
139142 # Move from CPU RAM to HBM
140- logging .info (f "Transferring { adapter_id } from CPU to HBM." )
143+ logging .info ("Transferring %s from CPU to HBM." , adapter_id )
141144 self .loaded_adapters_hbm [adapter_id ] = _as_jnp_array (
142145 self .loaded_adapters_cpu [adapter_id ]
143146 ) # Convert to JAX array
@@ -152,7 +155,9 @@ def _unsafe_transfer_to_hbm(self, adapter_id: str):
152155 metadata .last_accessed = time .time () # Update time on transfer
153156
154157 def _unsafe_transfer_to_cpu (self , adapter_id : str ):
155- """Internal: Transfers an adapter from HBM to CPU RAM. Assumes lock is held."""
158+ """
159+ Internal: Transfers an adapter from HBM to CPU RAM. Assumes lock is held.
160+ """
156161
157162 if adapter_id not in self .loaded_adapters_hbm :
158163 raise ValueError (f"Adapter '{ adapter_id } ' not loaded in HBM." )
@@ -167,7 +172,7 @@ def _unsafe_transfer_to_cpu(self, adapter_id: str):
167172 )
168173
169174 # Move from HBM to CPU RAM
170- logging .info (f "Transferring { adapter_id } from HBM to CPU." )
175+ logging .info ("Transferring %s from HBM to CPU." , adapter_id )
171176 self .loaded_adapters_cpu [adapter_id ] = _as_np_array (
172177 self .loaded_adapters_hbm [adapter_id ]
173178 )
@@ -189,7 +194,7 @@ def _unsafe_unload_adapter(self, adapter_id: str):
189194 if metadata .status == AdapterStatus .UNLOADED :
190195 return
191196
192- logging .info (f "Unloading adapter { adapter_id } ." )
197+ logging .info ("Unloading adapter %s." , adapter_id )
193198 if metadata .status == AdapterStatus .LOADED_HBM :
194199 del self .loaded_adapters_hbm [adapter_id ]
195200 self .current_hbm_usage -= metadata .size_hbm
@@ -210,14 +215,13 @@ async def register_adapter(
210215 adapter_path : str | None = None ,
211216 adapter_config : Dict [str , Any ] | None = None ,
212217 ):
213- """Registers a new LoRA adatper."""
214218 """
215- Registers a LoRA adapter with the TensorStore. This also loads the adapter;
219+ Registers a LoRA adapter with the TensorStore. This also loads the adapter;
216220 IF called without adapter_config. Because in this case, it needs
217221 to get adapter_config from the engine's load_single_adapter() call, which
218222 also provides the adapter_params. So in that case it is beneficial to load
219- the adapter to HBM. This call path is expected only from the direct inference
220- request.
223+ the adapter to HBM. This call path is expected only from the direct
224+ inference request.
221225 OTHERWISE, it simply adds metadata about the adapter to the registry.
222226
223227 Args:
@@ -229,7 +233,7 @@ async def register_adapter(
229233 ValueError: If an adapter with the same ID is already registered.
230234 """
231235 if adapter_id in self .adapter_registry :
232- logging .warning (f "Adapter with ID '{ adapter_id } ' already registered." )
236+ logging .warning ("Adapter with ID '%s ' already registered." , adapter_id )
233237 return
234238
235239 if adapter_path is None :
@@ -249,7 +253,7 @@ async def register_adapter(
249253 async with self .lock :
250254 # Double check registration inside lock
251255 if adapter_id in self .adapter_registry :
252- logging .warning (f "Adapter '{ adapter_id } ' registered concurrently." )
256+ logging .warning ("Adapter '%s ' registered concurrently." , adapter_id )
253257 return
254258
255259 self .adapter_registry [adapter_id ] = AdapterMetadata (
@@ -331,21 +335,23 @@ async def load_adapter(
331335 if metadata .status == AdapterStatus .LOADING :
332336 # Wait untill loading is done.
333337 logging .info (
334- f"Adapter { adapter_id } is already loading by another task, waiting..."
338+ "Adapter %s is already loading by another task, waiting..." ,
339+ adapter_id ,
335340 )
336341
337342 # Get the event created by the first loading task
338343 event_to_wait_on = metadata .loading_event
339344 if event_to_wait_on is None :
340345 # Should not happen if status is LOADING, indicates inconsistency
341346 raise RuntimeError (
342- f"Inconsistent state: Adapter { adapter_id } is LOADING but has no event."
347+ f"Inconsistent state: Adapter { adapter_id } is LOADING "
348+ f"but has no event."
343349 )
344350
345- logging .info (f "Adapter { adapter_id } is loading, will wait." )
351+ logging .info ("Adapter %s is loading, will wait." , adapter_id )
346352
347353 if metadata .status == AdapterStatus .UNLOADED : # Check if it was UNLOADED
348- logging .info (f "Beginning load for adapter { adapter_id } ..." )
354+ logging .info ("Beginning load for adapter %s ..." , adapter_id )
349355
350356 metadata .loading_event = (
351357 asyncio .Event ()
@@ -357,8 +363,9 @@ async def load_adapter(
357363 if event_to_wait_on :
358364 await event_to_wait_on .wait ()
359365 # After waiting, the original loader finished (or failed).
360- # Re-call load_adapter to ensure desired state (HBM/CPU) and update timestamp.
361- logging .info (f"Finished waiting for { adapter_id } . Re-checking state." )
366+ # Re-call load_adapter to ensure desired state (HBM/CPU) and
367+ # update timestamp.
368+ logging .info ("Finished waiting for %s. Re-checking state." , adapter_id )
362369 await self .load_adapter (adapter_id , adapter_weights , to_hbm )
363370 return # Recursive call handled the final state
364371
@@ -370,7 +377,8 @@ async def load_adapter(
370377
371378 # TODO: Compare performance improvements
372379 # Option 1: Low performant (Run blocking I/O on main thread)
373- # adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path)
380+ # adapter_weights, adapter_config = self.engine.load_single_adapter(
381+ # adapter_path)
374382
375383 # Option 2: Better performant
376384 # Run blocking I/O in executor
@@ -379,6 +387,7 @@ async def load_adapter(
379387 None ,
380388 functools .partial (self .engine .load_single_adapter , adapter_path ),
381389 )
390+ del adapter_config
382391
383392 if adapter_weights is None :
384393 raise ValueError (f"Failed to load adapter_weights from { adapter_path } ." )
@@ -395,14 +404,18 @@ async def load_adapter(
395404 # If status changed while loading (e.g., unloaded), abort
396405 if metadata .status != AdapterStatus .LOADING :
397406 logging .warning (
398- f"Load cancelled for { adapter_id } , status changed to { metadata .status } "
407+ "Load cancelled for %s, status changed to %s" ,
408+ adapter_id ,
409+ metadata .status ,
399410 )
400411 return
401412
402- # Get size of unified_lora_params when they are saved in HBM as JAX array
413+ # Get size of unified_lora_params when they are saved in
414+ # HBM as JAX array
403415 adapter_size_hbm = _get_size_of_pytree (adapter_weights_as_jnp_array )
404416
405- # Get size of unified_lora_params when they are saved in CPU RAM as NumPy array
417+ # Get size of unified_lora_params when they are saved in
418+ # CPU RAM as NumPy array
406419 adapter_size_cpu = _get_size_of_pytree (adapter_weights_as_np_array )
407420
408421 metadata .size_hbm = adapter_size_hbm
@@ -445,7 +458,7 @@ async def load_adapter(
445458 metadata .last_accessed = time .time ()
446459 load_successful = True
447460
448- except Exception as e :
461+ except Exception as e : # pylint: disable=broad-exception-caught
449462 async with self .lock :
450463 metadata = self .adapter_registry [adapter_id ]
451464 metadata .status = AdapterStatus .UNLOADED # Mark as unloaded on error
@@ -515,18 +528,21 @@ async def get_lora_weights(
515528 async with self .lock :
516529 self ._unsafe_transfer_to_cpu (adapter_id )
517530
518- # Now all required adapters should be loaded in correct memory (HBM or CPU), get them
531+ # Now all required adapters should be loaded in correct memory (HBM or CPU),
532+ # so get them
519533 adapter_params = None
520534 if to_hbm :
521535 if adapter_id not in self .loaded_adapters_hbm :
522536 raise RuntimeError (
523- f"Adapter { adapter_id } should be in HBM but wasn't found after loading."
537+ f"Adapter { adapter_id } should be in HBM "
538+ f"but wasn't found after loading."
524539 )
525540 adapter_params = self .loaded_adapters_hbm [adapter_id ]
526541 else :
527542 if adapter_id not in self .loaded_adapters_cpu :
528543 raise RuntimeError (
529- f"Adapter { adapter_id } should be in CPU but wasn't found after loading."
544+ f"Adapter { adapter_id } should be in CPU "
545+ f"but wasn't found after loading."
530546 )
531547 adapter_params = self .loaded_adapters_cpu [adapter_id ]
532548
@@ -550,7 +566,8 @@ async def unload_adapter(self, adapter_id: str):
550566 metadata = self .adapter_registry [adapter_id ]
551567 if metadata .status == AdapterStatus .LOADING :
552568 raise RuntimeError (
553- f"Inconsistent state: Adapter { adapter_id } is LOADING after just finishing one."
569+ f"Inconsistent state: Adapter { adapter_id } is LOADING after "
570+ f"just finishing one."
554571 )
555572
556573 self ._unsafe_unload_adapter (adapter_id )
0 commit comments