Skip to content

Commit 7ca7c58

Browse files
author
jetstream authors
committed
Merge pull request #249 from AI-Hypercomputer:amangu-lora
PiperOrigin-RevId: 748716588
2 parents 5020585 + bf3483f commit 7ca7c58

File tree

5 files changed

+157
-77
lines changed

5 files changed

+157
-77
lines changed

jetstream/core/lora/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

jetstream/core/lora/adapter_tensorstore.py

Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
"""Manages the list of fine-tuned adapters loaded on top of the base model for serving.
15+
"""Manages the list of fine-tuned adapters loaded on top of the base model
16+
for serving.
1617
"""
1718

1819
import logging
1920
import dataclasses
2021

2122
import jax
2223
import jax.numpy as jnp
23-
from flax import struct
2424
import time
2525
import asyncio
2626
import functools
@@ -123,7 +123,10 @@ def __init__(
123123

124124
# --- Unsafe Internal methods which assumes that lock is held ---
125125
def _unsafe_transfer_to_hbm(self, adapter_id: str):
126-
"""Internal: Transfers an adapter from CPU RAM to HBM. Assumes lock is held."""
126+
"""
127+
Internal: Transfers an adapter from CPU RAM to HBM.
128+
Assumes lock is held.
129+
"""
127130
if adapter_id not in self.loaded_adapters_cpu:
128131
raise ValueError(f"Adapter '{adapter_id}' not loaded in CPU RAM.")
129132

@@ -137,7 +140,7 @@ def _unsafe_transfer_to_hbm(self, adapter_id: str):
137140
)
138141

139142
# Move from CPU RAM to HBM
140-
logging.info(f"Transferring {adapter_id} from CPU to HBM.")
143+
logging.info("Transferring %s from CPU to HBM.", adapter_id)
141144
self.loaded_adapters_hbm[adapter_id] = _as_jnp_array(
142145
self.loaded_adapters_cpu[adapter_id]
143146
) # Convert to JAX array
@@ -152,7 +155,9 @@ def _unsafe_transfer_to_hbm(self, adapter_id: str):
152155
metadata.last_accessed = time.time() # Update time on transfer
153156

154157
def _unsafe_transfer_to_cpu(self, adapter_id: str):
155-
"""Internal: Transfers an adapter from HBM to CPU RAM. Assumes lock is held."""
158+
"""
159+
Internal: Transfers an adapter from HBM to CPU RAM. Assumes lock is held.
160+
"""
156161

157162
if adapter_id not in self.loaded_adapters_hbm:
158163
raise ValueError(f"Adapter '{adapter_id}' not loaded in HBM.")
@@ -167,7 +172,7 @@ def _unsafe_transfer_to_cpu(self, adapter_id: str):
167172
)
168173

169174
# Move from HBM to CPU RAM
170-
logging.info(f"Transferring {adapter_id} from HBM to CPU.")
175+
logging.info("Transferring %s from HBM to CPU.", adapter_id)
171176
self.loaded_adapters_cpu[adapter_id] = _as_np_array(
172177
self.loaded_adapters_hbm[adapter_id]
173178
)
@@ -189,7 +194,7 @@ def _unsafe_unload_adapter(self, adapter_id: str):
189194
if metadata.status == AdapterStatus.UNLOADED:
190195
return
191196

192-
logging.info(f"Unloading adapter {adapter_id}.")
197+
logging.info("Unloading adapter %s.", adapter_id)
193198
if metadata.status == AdapterStatus.LOADED_HBM:
194199
del self.loaded_adapters_hbm[adapter_id]
195200
self.current_hbm_usage -= metadata.size_hbm
@@ -210,14 +215,13 @@ async def register_adapter(
210215
adapter_path: str | None = None,
211216
adapter_config: Dict[str, Any] | None = None,
212217
):
213-
"""Registers a new LoRA adatper."""
214218
"""
215-
Registers a LoRA adapter with the TensorStore. This also loads the adapter;
219+
Registers a LoRA adapter with the TensorStore. This also loads the adapter;
216220
IF called without adapter_config. Because in this case, it needs
217221
to get adapter_config from the engine's load_single_adapter() call, which
218222
also provides the adapter_params. So in that case it is beneficial to load
219-
the adapter to HBM. This call path is expected only from the direct inference
220-
request.
223+
the adapter to HBM. This call path is expected only from the direct
224+
inference request.
221225
OTHERWISE, it simply adds metadata about the adapter to the registry.
222226
223227
Args:
@@ -229,7 +233,7 @@ async def register_adapter(
229233
ValueError: If an adapter with the same ID is already registered.
230234
"""
231235
if adapter_id in self.adapter_registry:
232-
logging.warning(f"Adapter with ID '{adapter_id}' already registered.")
236+
logging.warning("Adapter with ID '%s' already registered.", adapter_id)
233237
return
234238

235239
if adapter_path is None:
@@ -249,7 +253,7 @@ async def register_adapter(
249253
async with self.lock:
250254
# Double check registration inside lock
251255
if adapter_id in self.adapter_registry:
252-
logging.warning(f"Adapter '{adapter_id}' registered concurrently.")
256+
logging.warning("Adapter '%s' registered concurrently.", adapter_id)
253257
return
254258

255259
self.adapter_registry[adapter_id] = AdapterMetadata(
@@ -331,21 +335,23 @@ async def load_adapter(
331335
if metadata.status == AdapterStatus.LOADING:
332336
# Wait untill loading is done.
333337
logging.info(
334-
f"Adapter {adapter_id} is already loading by another task, waiting..."
338+
"Adapter %s is already loading by another task, waiting...",
339+
adapter_id,
335340
)
336341

337342
# Get the event created by the first loading task
338343
event_to_wait_on = metadata.loading_event
339344
if event_to_wait_on is None:
340345
# Should not happen if status is LOADING, indicates inconsistency
341346
raise RuntimeError(
342-
f"Inconsistent state: Adapter {adapter_id} is LOADING but has no event."
347+
f"Inconsistent state: Adapter {adapter_id} is LOADING "
348+
f"but has no event."
343349
)
344350

345-
logging.info(f"Adapter {adapter_id} is loading, will wait.")
351+
logging.info("Adapter %s is loading, will wait.", adapter_id)
346352

347353
if metadata.status == AdapterStatus.UNLOADED: # Check if it was UNLOADED
348-
logging.info(f"Beginning load for adapter {adapter_id}...")
354+
logging.info("Beginning load for adapter %s...", adapter_id)
349355

350356
metadata.loading_event = (
351357
asyncio.Event()
@@ -357,8 +363,9 @@ async def load_adapter(
357363
if event_to_wait_on:
358364
await event_to_wait_on.wait()
359365
# After waiting, the original loader finished (or failed).
360-
# Re-call load_adapter to ensure desired state (HBM/CPU) and update timestamp.
361-
logging.info(f"Finished waiting for {adapter_id}. Re-checking state.")
366+
# Re-call load_adapter to ensure desired state (HBM/CPU) and
367+
# update timestamp.
368+
logging.info("Finished waiting for %s. Re-checking state.", adapter_id)
362369
await self.load_adapter(adapter_id, adapter_weights, to_hbm)
363370
return # Recursive call handled the final state
364371

@@ -370,7 +377,8 @@ async def load_adapter(
370377

371378
# TODO: Compare performance improvements
372379
# Option 1: Low performant (Run blocking I/O on main thread)
373-
# adapter_weights, adapter_config = self.engine.load_single_adapter(adapter_path)
380+
# adapter_weights, adapter_config = self.engine.load_single_adapter(
381+
# adapter_path)
374382

375383
# Option 2: Better performant
376384
# Run blocking I/O in executor
@@ -379,6 +387,7 @@ async def load_adapter(
379387
None,
380388
functools.partial(self.engine.load_single_adapter, adapter_path),
381389
)
390+
del adapter_config
382391

383392
if adapter_weights is None:
384393
raise ValueError(f"Failed to load adapter_weights from {adapter_path}.")
@@ -395,14 +404,18 @@ async def load_adapter(
395404
# If status changed while loading (e.g., unloaded), abort
396405
if metadata.status != AdapterStatus.LOADING:
397406
logging.warning(
398-
f"Load cancelled for {adapter_id}, status changed to {metadata.status}"
407+
"Load cancelled for %s, status changed to %s",
408+
adapter_id,
409+
metadata.status,
399410
)
400411
return
401412

402-
# Get size of unified_lora_params when they are saved in HBM as JAX array
413+
# Get size of unified_lora_params when they are saved in
414+
# HBM as JAX array
403415
adapter_size_hbm = _get_size_of_pytree(adapter_weights_as_jnp_array)
404416

405-
# Get size of unified_lora_params when they are saved in CPU RAM as NumPy array
417+
# Get size of unified_lora_params when they are saved in
418+
# CPU RAM as NumPy array
406419
adapter_size_cpu = _get_size_of_pytree(adapter_weights_as_np_array)
407420

408421
metadata.size_hbm = adapter_size_hbm
@@ -445,7 +458,7 @@ async def load_adapter(
445458
metadata.last_accessed = time.time()
446459
load_successful = True
447460

448-
except Exception as e:
461+
except Exception as e: # pylint: disable=broad-exception-caught
449462
async with self.lock:
450463
metadata = self.adapter_registry[adapter_id]
451464
metadata.status = AdapterStatus.UNLOADED # Mark as unloaded on error
@@ -515,18 +528,21 @@ async def get_lora_weights(
515528
async with self.lock:
516529
self._unsafe_transfer_to_cpu(adapter_id)
517530

518-
# Now all required adapters should be loaded in correct memory (HBM or CPU), get them
531+
# Now all required adapters should be loaded in correct memory (HBM or CPU),
532+
# so get them
519533
adapter_params = None
520534
if to_hbm:
521535
if adapter_id not in self.loaded_adapters_hbm:
522536
raise RuntimeError(
523-
f"Adapter {adapter_id} should be in HBM but wasn't found after loading."
537+
f"Adapter {adapter_id} should be in HBM "
538+
f"but wasn't found after loading."
524539
)
525540
adapter_params = self.loaded_adapters_hbm[adapter_id]
526541
else:
527542
if adapter_id not in self.loaded_adapters_cpu:
528543
raise RuntimeError(
529-
f"Adapter {adapter_id} should be in CPU but wasn't found after loading."
544+
f"Adapter {adapter_id} should be in CPU "
545+
f"but wasn't found after loading."
530546
)
531547
adapter_params = self.loaded_adapters_cpu[adapter_id]
532548

@@ -550,7 +566,8 @@ async def unload_adapter(self, adapter_id: str):
550566
metadata = self.adapter_registry[adapter_id]
551567
if metadata.status == AdapterStatus.LOADING:
552568
raise RuntimeError(
553-
f"Inconsistent state: Adapter {adapter_id} is LOADING after just finishing one."
569+
f"Inconsistent state: Adapter {adapter_id} is LOADING after "
570+
f"just finishing one."
554571
)
555572

556573
self._unsafe_unload_adapter(adapter_id)
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2024 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)