Skip to content

Commit 164ecca

Browse files
authored
Fix LayerAdaptiveDataset to work with load_full_dataset (#57)
* Fix LayerAdaptiveDataset to properly reload the full dataset for the new layer when load_full_dataset=True Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Included a few other small fixes to trainer. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Add check to ensure dataset isn't loaded twice for layer 0. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent d746a00 commit 164ecca

File tree

1 file changed

+14
-4
lines changed

1 file changed

+14
-4
lines changed

src/trainer.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,16 @@ def get_cache_stats(self) -> Dict[str, Any]:
237237
"max_cache_size": _chunk_cache.max_size,
238238
"cache_type": "chunk_cache",
239239
}
240+
241+
def set_layer_idx(self, layer_idx):
242+
if layer_idx == self.layer_idx:
243+
return
244+
self.layer_idx = layer_idx
245+
if self.load_full_dataset:
246+
logger.info("Layer index changed with load_full_dataset=True. Reloading full dataset for new layer index...")
247+
self.clear_cache()
248+
self._load_full_data()
249+
logger.info(f"Full dataset loaded into memory for layer {layer_idx}")
240250

241251

242252
class LayerwisePredictorTrainer:
@@ -304,8 +314,8 @@ def evaluate_predictor(self, dataloader: DataLoader) -> Dict[str, float]:
304314
tp = (pred_mask * gt_mask).sum().item()
305315
fp = (pred_mask * (~gt_mask)).sum().item()
306316
fn = ((~pred_mask) * gt_mask).sum().item()
307-
total_gt_sparsity += gt_mask.sum() / gt_mask.numel()
308-
total_pred_sparsity += pred_mask.sum() / pred_mask.numel()
317+
total_gt_sparsity += 1 - (gt_mask.sum() / gt_mask.numel())
318+
total_pred_sparsity += 1 - (pred_mask.sum() / pred_mask.numel())
309319
precision = tp / (tp + fp)
310320
recall = tp / (tp + fn)
311321
f1 = 2 * precision * recall / (precision + recall)
@@ -710,7 +720,7 @@ def __init__(self, base_dataset: StreamingSparsityDataset):
710720
def set_layer(self, layer_idx: int):
711721
"""Switch to a different layer for data access."""
712722
self.current_layer_idx = layer_idx
713-
self.base_dataset.layer_idx = layer_idx
723+
self.base_dataset.set_layer_idx(layer_idx)
714724

715725
def __len__(self):
716726
return len(self.base_dataset)
@@ -898,7 +908,7 @@ def _train_layer(
898908
checkpoint_path: Optional[str] = None
899909
):
900910
final_checkpoint = (
901-
f"final_predictor_layer_{layer_idx}_lora_{lora_pct:.1f}pct"
911+
f"final_predictor_layer_{layer_idx}_lora_{lora_pct:.1f}pct.pt"
902912
)
903913
if os.path.exists(final_checkpoint):
904914
logger.info(

0 commit comments

Comments
 (0)