Skip to content

Commit d746a00

Browse files
authored
Fix train (#55)
* Update to log each separate lora size independently and correctly identify which checkpoints to load based on lora size. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Store best f1 scores for each layer in central kv store and save best performing model per layer. Add flag to resume training only from best performing lora sizes for each layer. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Bugfixes and move save final predictor into layerwise trainer to avoid saving after early exit due to errors. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Remove restart_if_missing and add documentation for load_best_only Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent af8d8c5 commit d746a00

File tree

2 files changed

+193
-84
lines changed

2 files changed

+193
-84
lines changed

src/trainer.py

Lines changed: 179 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -248,12 +248,14 @@ def __init__(
248248
hidden_size: int,
249249
intermediate_size: int,
250250
lora_size: int,
251+
lora_pct: float,
251252
device: torch.device,
252253
):
253254
self.device = device
254255
self.layer_idx = layer_idx
255256
self.hidden_size = hidden_size
256257
self.intermediate_size = intermediate_size
258+
self.lora_pct = lora_pct
257259

258260
# Initialize predictors for each layer
259261
self.predictor = FastLoRAProjection(
@@ -400,7 +402,7 @@ def train_layer(
400402
num_training_steps=total_steps,
401403
)
402404

403-
best_f1 = 0.0
405+
404406
global_step = 0
405407
start_epoch = 0
406408

@@ -417,17 +419,22 @@ def train_layer(
417419
)
418420
global_step = checkpoint_data["global_step"]
419421
start_epoch = checkpoint_data["epoch"]
420-
best_f1 = checkpoint_data["best_f1"]
422+
#best_f1 = checkpoint_data["best_f1"]
421423
logger.info(
422-
f"Resumed training from step {global_step}, epoch {start_epoch}, best_f1: {best_f1:.4f}"
424+
f"Resumed training from step {global_step}, epoch {start_epoch}"
423425
)
424-
except Exception as e:
426+
except FileNotFoundError as e:
425427
logger.warning(
426428
f"Failed to load checkpoint: {e}. Starting fresh training."
427429
)
428430
global_step = 0
429431
start_epoch = 0
430-
best_f1 = 0.0
432+
#best_f1 = 0.0
433+
except AssertionError as e:
434+
logger.warning(
435+
f"Failed to load checkpoint: {e}. Skipping training."
436+
)
437+
return None
431438
else:
432439
logger.info("No checkpoint found. Starting fresh training.")
433440

@@ -482,11 +489,11 @@ def train_layer(
482489

483490
wandb.log(
484491
{
485-
f"layer_{self.layer_idx}/train_loss": loss.item(),
486-
f"layer_{self.layer_idx}/learning_rate": scheduler.get_last_lr()[
492+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/train_loss": loss.item(),
493+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/learning_rate": scheduler.get_last_lr()[
487494
0
488495
],
489-
f"layer_{self.layer_idx}/gradient_norm": grad_norm,
496+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/gradient_norm": grad_norm,
490497
"step": global_step,
491498
}
492499
)
@@ -499,7 +506,6 @@ def train_layer(
499506
epoch,
500507
optimizer,
501508
scheduler,
502-
best_f1,
503509
loss.item(),
504510
)
505511

@@ -509,16 +515,16 @@ def train_layer(
509515
if use_wandb:
510516
wandb.log(
511517
{
512-
f"layer_{self.layer_idx}/eval_gt_sparsity": eval_metrics[
518+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/eval_gt_sparsity": eval_metrics[
513519
"gt_sparsity"
514520
],
515-
f"layer_{self.layer_idx}/eval_pred_sparsity": eval_metrics[
521+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/eval_pred_sparsity": eval_metrics[
516522
"pred_sparsity"
517523
],
518-
f"layer_{self.layer_idx}/eval_accuracy": eval_metrics[
524+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/eval_accuracy": eval_metrics[
519525
"accuracy"
520526
],
521-
f"layer_{self.layer_idx}/eval_precision": eval_metrics[
527+
f"layer_{self.layer_idx}_lora_{self.lora_pct:.1f}%/eval_precision": eval_metrics[
522528
"precision"
523529
],
524530
f"layer_{self.layer_idx}/eval_recall": eval_metrics["recall"],
@@ -527,11 +533,19 @@ def train_layer(
527533
}
528534
)
529535

530-
# Save best model
531-
if eval_metrics["f1"] > best_f1:
532-
best_f1 = eval_metrics["f1"]
533-
# Save the best model immediately
534-
if save_dir:
536+
# Save best model among all lora models for the given layer
537+
if save_dir:
538+
f1_path = os.path.join(save_dir, f"f1_store.pt")
539+
f1_store = torch.load(f1_path)
540+
541+
best_f1 = f1_store[self.layer_idx] if self.layer_idx in f1_store else 0.0
542+
543+
if eval_metrics["f1"] > best_f1:
544+
best_f1 = eval_metrics["f1"]
545+
f1_store[self.layer_idx] = best_f1
546+
torch.save(f1_store, f1_path)
547+
548+
# Save best model
535549
best_model_name = f"best_predictor_layer_{self.layer_idx}"
536550
self.save_predictor(save_dir, name=best_model_name)
537551
logger.info(f"Saved new best model: {best_model_name}")
@@ -548,6 +562,14 @@ def train_layer(
548562
# Close progress bar
549563
progress_bar.close()
550564

565+
# Save final predictor for this layer and LoRA size
566+
if save_dir:
567+
model_name = (
568+
f"final_predictor_layer_{self.layer_idx}_lora_{self.lora_pct:.1f}pct"
569+
)
570+
self.save_predictor(save_dir, name=model_name)
571+
logger.info(f"Saved final predictor: {model_name}")
572+
551573
return self.predictor # type: ignore
552574

553575
def save_predictor(self, save_dir: str, name: str = "predictor"):
@@ -566,7 +588,6 @@ def save_checkpoint(
566588
epoch: int,
567589
optimizer: torch.optim.Optimizer,
568590
scheduler,
569-
best_f1: float,
570591
loss: float,
571592
):
572593
"""Save training checkpoint with full state."""
@@ -578,9 +599,9 @@ def save_checkpoint(
578599
"predictor_state_dict": self.predictor.state_dict(),
579600
"optimizer_state_dict": optimizer.state_dict(),
580601
"scheduler_state_dict": scheduler.state_dict(),
581-
"best_f1": best_f1,
582602
"loss": loss,
583603
"layer_idx": self.layer_idx,
604+
"lora_pct": self.lora_pct,
584605
"hidden_size": self.hidden_size,
585606
"intermediate_size": self.intermediate_size,
586607
}
@@ -608,6 +629,9 @@ def load_checkpoint(
608629
logger.info(f"Loading checkpoint from {checkpoint_path}")
609630
checkpoint = torch.load(checkpoint_path, map_location=self.device)
610631

632+
if checkpoint["lora_pct"] != self.lora_pct:
633+
raise AssertionError(f"Mismatched LoRA size found: expected {self.lora_pct}% but found {checkpoint['lora_pct']}%.")
634+
611635
# Load predictor state
612636
self.predictor.load_state_dict(checkpoint["predictor_state_dict"])
613637

@@ -619,7 +643,6 @@ def load_checkpoint(
619643
return {
620644
"global_step": checkpoint["global_step"],
621645
"epoch": checkpoint["epoch"],
622-
"best_f1": checkpoint["best_f1"],
623646
"loss": checkpoint["loss"],
624647
}
625648

@@ -771,6 +794,7 @@ def train_all_layers(
771794
save_interval: int = 1000,
772795
resume_from_checkpoint: bool = False,
773796
checkpoint_path: Optional[str] = None,
797+
load_best_only: bool = False,
774798
seed: int = 42,
775799
):
776800
"""Train predictors for all specified layers."""
@@ -795,60 +819,32 @@ def train_all_layers(
795819
[train_size, val_size],
796820
generator=torch.Generator().manual_seed(seed),
797821
)
798-
# Train each layer with each LoRA size (hyperparameter grid)
799-
for lora_size, lora_pct in zip(self.lora_sizes, self.lora_size_percentages):
800-
logger.info(f"Training with LoRA size {lora_size} ({lora_pct:.1f}%)")
801-
802-
for layer_idx in self.layer_indices:
803-
logger.info(
804-
f"Starting training for layer {layer_idx} with LoRA size {lora_size}"
805-
)
806-
807-
# Get or create trainer for this layer and LoRA size
808-
trainer_key = (layer_idx, lora_size)
809-
if trainer_key not in self.layer_trainers:
810-
self.layer_trainers[trainer_key] = LayerwisePredictorTrainer(
811-
layer_idx=layer_idx,
812-
hidden_size=self.hidden_size,
813-
intermediate_size=self.intermediate_size,
814-
lora_size=lora_size,
815-
device=self.device,
816-
)
817-
818-
trainer = self.layer_trainers[trainer_key]
819-
820-
# Switch shared dataset to current layer
821-
logger.info(f"Switching shared dataset to layer {layer_idx}")
822-
self.shared_dataset.set_layer(layer_idx)
823-
824-
logger.info(
825-
f"Layer {layer_idx}, LoRA {lora_pct:.1f}%: Using {len(train_dataset)} training samples, {len(val_dataset)} validation samples"
826-
)
827822

828-
# Determine checkpoint path for this layer if resuming
829-
layer_checkpoint_path = None
830-
if resume_from_checkpoint:
831-
if checkpoint_path:
832-
# If specific checkpoint path provided, use it only for the matching layer
833-
if f"layer_{layer_idx}" in checkpoint_path:
834-
layer_checkpoint_path = checkpoint_path
835-
else:
836-
# Look for latest checkpoint for this layer
837-
layer_checkpoint_path = (
838-
None # Let trainer find latest automatically
839-
)
840-
841-
# Update wandb to include LoRA size if using wandb
842-
if use_wandb:
843-
wandb.log(
844-
{
845-
f"layer_{layer_idx}_lora_{lora_pct:.1f}%/lora_size": lora_size,
846-
f"layer_{layer_idx}_lora_{lora_pct:.1f}%/lora_pct": lora_pct,
847-
}
848-
)
849-
850-
# Train predictor for this layer
851-
trainer.train_layer(
823+
f1_path = os.path.join(save_dir, "f1_store.pt") # This should be replaced by something more sophisticated like an LMDB
824+
if not os.path.exists(f1_path):
825+
torch.save({
826+
layer_idx: 0.0 for layer_idx in self.layer_indices
827+
}, f1_path)
828+
829+
for layer_idx in self.layer_indices:
830+
logger.info(f"Training layer {layer_idx}...")
831+
832+
# Train only from best lora predictor
833+
if load_best_only:
834+
best_checkpoint_path = "best_predictor_layer_{self.layer_idx}.pt"
835+
if not os.path.exists(best_checkpoint_path):
836+
logger.warning(f"Best checkpoint for layer {layer_idx} not found. Skipping to next layer.")
837+
continue
838+
best_checkpoint = torch.load(best_checkpoint_path)
839+
lora_pct = best_checkpoint["lora_pct"]
840+
del best_checkpoint
841+
842+
lora_size = self.intermediate_size * lora_pct / 100
843+
844+
self._train_layer(
845+
layer_idx=layer_idx,
846+
lora_pct=lora_pct,
847+
lora_size=lora_size,
852848
train_dataset=train_dataset,
853849
val_dataset=val_dataset,
854850
num_epochs=num_epochs,
@@ -858,23 +854,122 @@ def train_all_layers(
858854
save_dir=save_dir,
859855
save_interval=save_interval,
860856
resume_from_checkpoint=resume_from_checkpoint,
861-
checkpoint_path=layer_checkpoint_path,
857+
checkpoint_path=checkpoint_path
862858
)
859+
else:
860+
# Train each layer with each LoRA size (hyperparameter grid)
861+
for lora_size, lora_pct in zip(self.lora_sizes, self.lora_size_percentages):
862+
logger.info(f"Training with LoRA size {lora_size} ({lora_pct:.1f}%)")
863863

864-
# Save final predictor for this layer and LoRA size
865-
if save_dir:
866-
model_name = (
867-
f"final_predictor_layer_{layer_idx}_lora_{lora_pct:.1f}pct"
864+
self._train_layer(
865+
layer_idx=layer_idx,
866+
lora_pct=lora_pct,
867+
lora_size=lora_size,
868+
train_dataset=train_dataset,
869+
val_dataset=val_dataset,
870+
num_epochs=num_epochs,
871+
batch_size=batch_size,
872+
learning_rate=learning_rate,
873+
use_wandb=use_wandb,
874+
save_dir=save_dir,
875+
save_interval=save_interval,
876+
resume_from_checkpoint=resume_from_checkpoint,
877+
checkpoint_path=checkpoint_path
868878
)
869-
trainer.save_predictor(save_dir, name=model_name)
870-
logger.info(f"Saved final predictor: {model_name}")
871879

872-
logger.info(
873-
f"Completed training for layer {layer_idx} with LoRA size {lora_size}"
880+
logger.info(
881+
f"Completed all training - {len(self.layer_indices)} layers × {len(self.lora_sizes)} LoRA sizes = {len(self.layer_indices) * len(self.lora_sizes)} total experiments"
882+
)
883+
884+
def _train_layer(
885+
self,
886+
layer_idx: int,
887+
lora_pct: float,
888+
lora_size: int,
889+
train_dataset,
890+
val_dataset,
891+
num_epochs: int,
892+
batch_size: int,
893+
learning_rate: float,
894+
use_wandb: bool = False,
895+
save_dir: Optional[str] = None,
896+
save_interval: int = 1000,
897+
resume_from_checkpoint: bool = False,
898+
checkpoint_path: Optional[str] = None
899+
):
900+
final_checkpoint = (
901+
f"final_predictor_layer_{layer_idx}_lora_{lora_pct:.1f}pct"
902+
)
903+
if os.path.exists(final_checkpoint):
904+
logger.info(
905+
f"Final checkpoint for layer {layer_idx} with LoRA size {lora_size} found. Skipping training..."
906+
)
907+
return
908+
909+
logger.info(
910+
f"Starting training for layer {layer_idx} with LoRA size {lora_size}"
911+
)
912+
913+
# Get or create trainer for this layer and LoRA size
914+
trainer_key = (layer_idx, lora_size)
915+
if trainer_key not in self.layer_trainers:
916+
self.layer_trainers[trainer_key] = LayerwisePredictorTrainer(
917+
layer_idx=layer_idx,
918+
hidden_size=self.hidden_size,
919+
intermediate_size=self.intermediate_size,
920+
lora_size=lora_size,
921+
lora_pct=lora_pct,
922+
device=self.device,
923+
)
924+
925+
trainer = self.layer_trainers[trainer_key]
926+
927+
# Switch shared dataset to current layer
928+
logger.info(f"Switching shared dataset to layer {layer_idx}")
929+
self.shared_dataset.set_layer(layer_idx)
930+
931+
logger.info(
932+
f"Layer {layer_idx}, LoRA {lora_pct:.1f}%: Using {len(train_dataset)} training samples, {len(val_dataset)} validation samples"
933+
)
934+
935+
# Determine checkpoint path for this layer if resuming
936+
layer_checkpoint_path = None
937+
if resume_from_checkpoint:
938+
if checkpoint_path:
939+
# If specific checkpoint path provided, use it only for the matching layer
940+
if f"layer_{layer_idx}" in checkpoint_path:
941+
layer_checkpoint_path = checkpoint_path
942+
else:
943+
# Look for latest checkpoint for this layer
944+
layer_checkpoint_path = (
945+
None # Let trainer find latest automatically
874946
)
875947

948+
# Update wandb to include LoRA size if using wandb
949+
if use_wandb:
950+
wandb.log(
951+
{
952+
f"layer_{layer_idx}_lora_{lora_pct:.1f}%/lora_size": lora_size,
953+
f"layer_{layer_idx}_lora_{lora_pct:.1f}%/lora_pct": lora_pct,
954+
}
955+
)
956+
957+
# Train predictor for this layer
958+
trainer.train_layer(
959+
train_dataset=train_dataset,
960+
val_dataset=val_dataset,
961+
num_epochs=num_epochs,
962+
batch_size=batch_size,
963+
learning_rate=learning_rate,
964+
use_wandb=use_wandb,
965+
save_dir=save_dir,
966+
save_interval=save_interval,
967+
resume_from_checkpoint=resume_from_checkpoint,
968+
checkpoint_path=layer_checkpoint_path
969+
)
970+
876971
logger.info(
877-
f"Completed all training - {len(self.layer_indices)} layers × {len(self.lora_sizes)} LoRA sizes = {len(self.layer_indices) * len(self.lora_sizes)} total experiments"
972+
f"Completed training for layer {layer_idx} with LoRA size {lora_size}"
878973
)
879974

880975
def save_all_predictors(self, save_dir: str):

0 commit comments

Comments
 (0)