@@ -248,12 +248,14 @@ def __init__(
248248 hidden_size : int ,
249249 intermediate_size : int ,
250250 lora_size : int ,
251+ lora_pct : float ,
251252 device : torch .device ,
252253 ):
253254 self .device = device
254255 self .layer_idx = layer_idx
255256 self .hidden_size = hidden_size
256257 self .intermediate_size = intermediate_size
258+ self .lora_pct = lora_pct
257259
258260 # Initialize predictors for each layer
259261 self .predictor = FastLoRAProjection (
@@ -400,7 +402,7 @@ def train_layer(
400402 num_training_steps = total_steps ,
401403 )
402404
403- best_f1 = 0.0
405+
404406 global_step = 0
405407 start_epoch = 0
406408
@@ -417,17 +419,22 @@ def train_layer(
417419 )
418420 global_step = checkpoint_data ["global_step" ]
419421 start_epoch = checkpoint_data ["epoch" ]
420- best_f1 = checkpoint_data ["best_f1" ]
422+ # best_f1 = checkpoint_data["best_f1"]
421423 logger .info (
422- f"Resumed training from step { global_step } , epoch { start_epoch } , best_f1: { best_f1 :.4f } "
424+ f"Resumed training from step { global_step } , epoch { start_epoch } "
423425 )
424- except Exception as e :
426+ except FileNotFoundError as e :
425427 logger .warning (
426428 f"Failed to load checkpoint: { e } . Starting fresh training."
427429 )
428430 global_step = 0
429431 start_epoch = 0
430- best_f1 = 0.0
432+ #best_f1 = 0.0
433+ except AssertionError as e :
434+ logger .warning (
435+ f"Failed to load checkpoint: { e } . Skipping training."
436+ )
437+ return None
431438 else :
432439 logger .info ("No checkpoint found. Starting fresh training." )
433440
@@ -482,11 +489,11 @@ def train_layer(
482489
483490 wandb .log (
484491 {
485- f"layer_{ self .layer_idx } /train_loss" : loss .item (),
486- f"layer_{ self .layer_idx } /learning_rate" : scheduler .get_last_lr ()[
492+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /train_loss" : loss .item (),
493+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /learning_rate" : scheduler .get_last_lr ()[
487494 0
488495 ],
489- f"layer_{ self .layer_idx } /gradient_norm" : grad_norm ,
496+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /gradient_norm" : grad_norm ,
490497 "step" : global_step ,
491498 }
492499 )
@@ -499,7 +506,6 @@ def train_layer(
499506 epoch ,
500507 optimizer ,
501508 scheduler ,
502- best_f1 ,
503509 loss .item (),
504510 )
505511
@@ -509,16 +515,16 @@ def train_layer(
509515 if use_wandb :
510516 wandb .log (
511517 {
512- f"layer_{ self .layer_idx } /eval_gt_sparsity" : eval_metrics [
518+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /eval_gt_sparsity" : eval_metrics [
513519 "gt_sparsity"
514520 ],
515- f"layer_{ self .layer_idx } /eval_pred_sparsity" : eval_metrics [
521+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /eval_pred_sparsity" : eval_metrics [
516522 "pred_sparsity"
517523 ],
518- f"layer_{ self .layer_idx } /eval_accuracy" : eval_metrics [
524+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /eval_accuracy" : eval_metrics [
519525 "accuracy"
520526 ],
521- f"layer_{ self .layer_idx } /eval_precision" : eval_metrics [
527+ f"layer_{ self .layer_idx } _lora_ { self . lora_pct :.1f } % /eval_precision" : eval_metrics [
522528 "precision"
523529 ],
524530 f"layer_{ self .layer_idx } /eval_recall" : eval_metrics ["recall" ],
@@ -527,11 +533,19 @@ def train_layer(
527533 }
528534 )
529535
530- # Save best model
531- if eval_metrics ["f1" ] > best_f1 :
532- best_f1 = eval_metrics ["f1" ]
533- # Save the best model immediately
534- if save_dir :
536+ # Save best model among all lora models for the given layer
537+ if save_dir :
538+ f1_path = os .path .join (save_dir , f"f1_store.pt" )
539+ f1_store = torch .load (f1_path )
540+
541+ best_f1 = f1_store [self .layer_idx ] if self .layer_idx in f1_store else 0.0
542+
543+ if eval_metrics ["f1" ] > best_f1 :
544+ best_f1 = eval_metrics ["f1" ]
545+ f1_store [self .layer_idx ] = best_f1
546+ torch .save (f1_store , f1_path )
547+
548+ # Save best model
535549 best_model_name = f"best_predictor_layer_{ self .layer_idx } "
536550 self .save_predictor (save_dir , name = best_model_name )
537551 logger .info (f"Saved new best model: { best_model_name } " )
@@ -548,6 +562,14 @@ def train_layer(
548562 # Close progress bar
549563 progress_bar .close ()
550564
565+ # Save final predictor for this layer and LoRA size
566+ if save_dir :
567+ model_name = (
568+ f"final_predictor_layer_{ self .layer_idx } _lora_{ self .lora_pct :.1f} pct"
569+ )
570+ self .save_predictor (save_dir , name = model_name )
571+ logger .info (f"Saved final predictor: { model_name } " )
572+
551573 return self .predictor # type: ignore
552574
553575 def save_predictor (self , save_dir : str , name : str = "predictor" ):
@@ -566,7 +588,6 @@ def save_checkpoint(
566588 epoch : int ,
567589 optimizer : torch .optim .Optimizer ,
568590 scheduler ,
569- best_f1 : float ,
570591 loss : float ,
571592 ):
572593 """Save training checkpoint with full state."""
@@ -578,9 +599,9 @@ def save_checkpoint(
578599 "predictor_state_dict" : self .predictor .state_dict (),
579600 "optimizer_state_dict" : optimizer .state_dict (),
580601 "scheduler_state_dict" : scheduler .state_dict (),
581- "best_f1" : best_f1 ,
582602 "loss" : loss ,
583603 "layer_idx" : self .layer_idx ,
604+ "lora_pct" : self .lora_pct ,
584605 "hidden_size" : self .hidden_size ,
585606 "intermediate_size" : self .intermediate_size ,
586607 }
@@ -608,6 +629,9 @@ def load_checkpoint(
608629 logger .info (f"Loading checkpoint from { checkpoint_path } " )
609630 checkpoint = torch .load (checkpoint_path , map_location = self .device )
610631
632+ if checkpoint ["lora_pct" ] != self .lora_pct :
633+ raise AssertionError (f"Mismatched LoRA size found: expected { self .lora_pct } % but found { checkpoint ['lora_pct' ]} %." )
634+
611635 # Load predictor state
612636 self .predictor .load_state_dict (checkpoint ["predictor_state_dict" ])
613637
@@ -619,7 +643,6 @@ def load_checkpoint(
619643 return {
620644 "global_step" : checkpoint ["global_step" ],
621645 "epoch" : checkpoint ["epoch" ],
622- "best_f1" : checkpoint ["best_f1" ],
623646 "loss" : checkpoint ["loss" ],
624647 }
625648
@@ -771,6 +794,7 @@ def train_all_layers(
771794 save_interval : int = 1000 ,
772795 resume_from_checkpoint : bool = False ,
773796 checkpoint_path : Optional [str ] = None ,
797+ load_best_only : bool = False ,
774798 seed : int = 42 ,
775799 ):
776800 """Train predictors for all specified layers."""
@@ -795,60 +819,32 @@ def train_all_layers(
795819 [train_size , val_size ],
796820 generator = torch .Generator ().manual_seed (seed ),
797821 )
798- # Train each layer with each LoRA size (hyperparameter grid)
799- for lora_size , lora_pct in zip (self .lora_sizes , self .lora_size_percentages ):
800- logger .info (f"Training with LoRA size { lora_size } ({ lora_pct :.1f} %)" )
801-
802- for layer_idx in self .layer_indices :
803- logger .info (
804- f"Starting training for layer { layer_idx } with LoRA size { lora_size } "
805- )
806-
807- # Get or create trainer for this layer and LoRA size
808- trainer_key = (layer_idx , lora_size )
809- if trainer_key not in self .layer_trainers :
810- self .layer_trainers [trainer_key ] = LayerwisePredictorTrainer (
811- layer_idx = layer_idx ,
812- hidden_size = self .hidden_size ,
813- intermediate_size = self .intermediate_size ,
814- lora_size = lora_size ,
815- device = self .device ,
816- )
817-
818- trainer = self .layer_trainers [trainer_key ]
819-
820- # Switch shared dataset to current layer
821- logger .info (f"Switching shared dataset to layer { layer_idx } " )
822- self .shared_dataset .set_layer (layer_idx )
823-
824- logger .info (
825- f"Layer { layer_idx } , LoRA { lora_pct :.1f} %: Using { len (train_dataset )} training samples, { len (val_dataset )} validation samples"
826- )
827822
828- # Determine checkpoint path for this layer if resuming
829- layer_checkpoint_path = None
830- if resume_from_checkpoint :
831- if checkpoint_path :
832- # If specific checkpoint path provided, use it only for the matching layer
833- if f"layer_{ layer_idx } " in checkpoint_path :
834- layer_checkpoint_path = checkpoint_path
835- else :
836- # Look for latest checkpoint for this layer
837- layer_checkpoint_path = (
838- None # Let trainer find latest automatically
839- )
840-
841- # Update wandb to include LoRA size if using wandb
842- if use_wandb :
843- wandb .log (
844- {
845- f"layer_{ layer_idx } _lora_{ lora_pct :.1f} %/lora_size" : lora_size ,
846- f"layer_{ layer_idx } _lora_{ lora_pct :.1f} %/lora_pct" : lora_pct ,
847- }
848- )
849-
850- # Train predictor for this layer
851- trainer .train_layer (
823+ f1_path = os .path .join (save_dir , "f1_store.pt" ) # This should be replaced by something more sophisticated like an LMDB
824+ if not os .path .exists (f1_path ):
825+ torch .save ({
826+ layer_idx : 0.0 for layer_idx in self .layer_indices
827+ }, f1_path )
828+
829+ for layer_idx in self .layer_indices :
830+ logger .info (f"Training layer { layer_idx } ..." )
831+
832+ # Train only from best lora predictor
833+ if load_best_only :
834+ best_checkpoint_path = "best_predictor_layer_{self.layer_idx}.pt"
835+ if not os .path .exists (best_checkpoint_path ):
836+ logger .warning (f"Best checkpoint for layer { layer_idx } not found. Skipping to next layer." )
837+ continue
838+ best_checkpoint = torch .load (best_checkpoint_path )
839+ lora_pct = best_checkpoint ["lora_pct" ]
840+ del best_checkpoint
841+
842+ lora_size = self .intermediate_size * lora_pct / 100
843+
844+ self ._train_layer (
845+ layer_idx = layer_idx ,
846+ lora_pct = lora_pct ,
847+ lora_size = lora_size ,
852848 train_dataset = train_dataset ,
853849 val_dataset = val_dataset ,
854850 num_epochs = num_epochs ,
@@ -858,23 +854,122 @@ def train_all_layers(
858854 save_dir = save_dir ,
859855 save_interval = save_interval ,
860856 resume_from_checkpoint = resume_from_checkpoint ,
861- checkpoint_path = layer_checkpoint_path ,
857+ checkpoint_path = checkpoint_path
862858 )
859+ else :
860+ # Train each layer with each LoRA size (hyperparameter grid)
861+ for lora_size , lora_pct in zip (self .lora_sizes , self .lora_size_percentages ):
862+ logger .info (f"Training with LoRA size { lora_size } ({ lora_pct :.1f} %)" )
863863
864- # Save final predictor for this layer and LoRA size
865- if save_dir :
866- model_name = (
867- f"final_predictor_layer_{ layer_idx } _lora_{ lora_pct :.1f} pct"
864+ self ._train_layer (
865+ layer_idx = layer_idx ,
866+ lora_pct = lora_pct ,
867+ lora_size = lora_size ,
868+ train_dataset = train_dataset ,
869+ val_dataset = val_dataset ,
870+ num_epochs = num_epochs ,
871+ batch_size = batch_size ,
872+ learning_rate = learning_rate ,
873+ use_wandb = use_wandb ,
874+ save_dir = save_dir ,
875+ save_interval = save_interval ,
876+ resume_from_checkpoint = resume_from_checkpoint ,
877+ checkpoint_path = checkpoint_path
868878 )
869- trainer .save_predictor (save_dir , name = model_name )
870- logger .info (f"Saved final predictor: { model_name } " )
871879
872- logger .info (
873- f"Completed training for layer { layer_idx } with LoRA size { lora_size } "
880+ logger .info (
881+ f"Completed all training - { len (self .layer_indices )} layers × { len (self .lora_sizes )} LoRA sizes = { len (self .layer_indices ) * len (self .lora_sizes )} total experiments"
882+ )
883+
884+ def _train_layer (
885+ self ,
886+ layer_idx : int ,
887+ lora_pct : float ,
888+ lora_size : int ,
889+ train_dataset ,
890+ val_dataset ,
891+ num_epochs : int ,
892+ batch_size : int ,
893+ learning_rate : float ,
894+ use_wandb : bool = False ,
895+ save_dir : Optional [str ] = None ,
896+ save_interval : int = 1000 ,
897+ resume_from_checkpoint : bool = False ,
898+ checkpoint_path : Optional [str ] = None
899+ ):
900+ final_checkpoint = (
901+ f"final_predictor_layer_{ layer_idx } _lora_{ lora_pct :.1f} pct"
902+ )
903+ if os .path .exists (final_checkpoint ):
904+ logger .info (
905+ f"Final checkpoint for layer { layer_idx } with LoRA size { lora_size } found. Skipping training..."
906+ )
907+ return
908+
909+ logger .info (
910+ f"Starting training for layer { layer_idx } with LoRA size { lora_size } "
911+ )
912+
913+ # Get or create trainer for this layer and LoRA size
914+ trainer_key = (layer_idx , lora_size )
915+ if trainer_key not in self .layer_trainers :
916+ self .layer_trainers [trainer_key ] = LayerwisePredictorTrainer (
917+ layer_idx = layer_idx ,
918+ hidden_size = self .hidden_size ,
919+ intermediate_size = self .intermediate_size ,
920+ lora_size = lora_size ,
921+ lora_pct = lora_pct ,
922+ device = self .device ,
923+ )
924+
925+ trainer = self .layer_trainers [trainer_key ]
926+
927+ # Switch shared dataset to current layer
928+ logger .info (f"Switching shared dataset to layer { layer_idx } " )
929+ self .shared_dataset .set_layer (layer_idx )
930+
931+ logger .info (
932+ f"Layer { layer_idx } , LoRA { lora_pct :.1f} %: Using { len (train_dataset )} training samples, { len (val_dataset )} validation samples"
933+ )
934+
935+ # Determine checkpoint path for this layer if resuming
936+ layer_checkpoint_path = None
937+ if resume_from_checkpoint :
938+ if checkpoint_path :
939+ # If specific checkpoint path provided, use it only for the matching layer
940+ if f"layer_{ layer_idx } " in checkpoint_path :
941+ layer_checkpoint_path = checkpoint_path
942+ else :
943+ # Look for latest checkpoint for this layer
944+ layer_checkpoint_path = (
945+ None # Let trainer find latest automatically
874946 )
875947
948+ # Update wandb to include LoRA size if using wandb
949+ if use_wandb :
950+ wandb .log (
951+ {
952+ f"layer_{ layer_idx } _lora_{ lora_pct :.1f} %/lora_size" : lora_size ,
953+ f"layer_{ layer_idx } _lora_{ lora_pct :.1f} %/lora_pct" : lora_pct ,
954+ }
955+ )
956+
957+ # Train predictor for this layer
958+ trainer .train_layer (
959+ train_dataset = train_dataset ,
960+ val_dataset = val_dataset ,
961+ num_epochs = num_epochs ,
962+ batch_size = batch_size ,
963+ learning_rate = learning_rate ,
964+ use_wandb = use_wandb ,
965+ save_dir = save_dir ,
966+ save_interval = save_interval ,
967+ resume_from_checkpoint = resume_from_checkpoint ,
968+ checkpoint_path = layer_checkpoint_path
969+ )
970+
876971 logger .info (
877- f"Completed all training - { len ( self . layer_indices ) } layers × { len ( self . lora_sizes ) } LoRA sizes = { len ( self . layer_indices ) * len ( self . lora_sizes ) } total experiments "
972+ f"Completed training for layer { layer_idx } with LoRA size { lora_size } "
878973 )
879974
880975 def save_all_predictors (self , save_dir : str ):
0 commit comments