1+ import argparse
2+ import os
3+ import logging
4+
5+ import torch
6+
7+ from transformers .trainer_utils import set_seed
8+ from transformers import AutoConfig , AutoModelForCausalLM
9+ from lm_eval .models import HFLM
10+ from lm_eval import simple_evaluate
11+ from lm_eval .utils import make_table
12+ from lm_eval .models .huggingface import HFLM
13+
14+
15+ # Setup logging
16+ logging .basicConfig (level = logging .INFO )
17+ logger = logging .getLogger (__name__ )
18+
19+ def parse_args ():
20+ parser = argparse .ArgumentParser (description = "Evaluate trained model on common LM datasets using LM Eval Harness." )
21+ parser .add_argument ("--model_type" , type = str , choices = ["hf" , "sparse" ], default = "hf" )
22+ parser .add_argument ("--model_name_or_config" , type = str , required = True ,
23+ help = "Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)" )
24+ parser .add_argument ("--sp_dir" , type = str , default = "" ,
25+ help = "Path to trained predictor dir for sparse model." )
26+ parser .add_argument ("--tasks" , nargs = '+' , default = ["hellaswag" ],
27+ help = "Tasks on which to evaluate" )
28+ parser .add_argument ("--batch_size" , type = int , default = 4 ,
29+ help = "Batch size for processing" )
30+ parser .add_argument ("--device" , type = str , default = "auto" ,
31+ help = "Device to use (auto, cpu, cuda)" )
32+ return parser
33+
34+
35+ def main ():
36+ args = parse_args ()
37+
38+ # Set seed
39+ set_seed (args .seed )
40+
41+ # Setup device
42+ if args .device == "auto" :
43+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
44+ else :
45+ device = torch .device (args .device )
46+
47+ logger .info (f"Using device: { device } " )
48+
49+ # Load pretrained model
50+ logging .info ("Loading pretrained model for evaluation..." )
51+ config = AutoConfig .from_pretrained (args .model_name_or_config )
52+ if args .model_type == "hf" :
53+ model = AutoModelForCausalLM .from_pretrained (config )
54+ if args .model_type == "sparse" :
55+ model = AutoModelForCausalLM .from_pretrained (config ._name_or_path , config = config )
56+ for layer_idx , layer in enumerate (model .get_decoder ().layers ):
57+ layer_path = os .path .join (args .sp_dir , f"final_predictor_layer_{ layer_idx } " )
58+ if not os .path .exists (layer_path ):
59+ logger .error (f"Pretrained weights for sparse predictor at layer { layer_idx } do not exist." )
60+ return
61+ pretrained_dict = torch .load (layer_path )
62+ layer .mlp_lora_proj .load_state_dict (pretrained_dict )
63+ model .tie_weights ()
64+ model .reset_cache ()
65+
66+ wrapped_model = HFLM (
67+ pretrained = model ,
68+ batch_size = args .batch_size ,
69+ device = device
70+ )
71+
72+ logging .info ("Beginning evaluation..." )
73+ results = simple_evaluate (
74+ wrapped_model ,
75+ tasks = args .tasks ,
76+ batch_size = args .batch_size ,
77+ device = device
78+ )
79+
80+ if results is not None :
81+ print (make_table (results ))
82+ if "groups" in results :
83+ print (make_table (results , "groups" ))
84+
85+ if __name__ == '__main__' :
86+ main ()
0 commit comments