Skip to content

Commit 50692df

Browse files
authored
Downstream evaluation (#49)
* Initial config for evaluation. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Basic working implementation of downstream eval using lm_eval_harness. Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> * Remove unnecessary calls to reset cache Signed-off-by: Kira Selby <kaselby@uwaterloo.ca> --------- Signed-off-by: Kira Selby <kaselby@uwaterloo.ca>
1 parent 74586f7 commit 50692df

File tree

1 file changed

+86
-0
lines changed

1 file changed

+86
-0
lines changed

evaluate.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
import argparse
2+
import os
3+
import logging
4+
5+
import torch
6+
7+
from transformers.trainer_utils import set_seed
8+
from transformers import AutoConfig, AutoModelForCausalLM
9+
from lm_eval.models import HFLM
10+
from lm_eval import simple_evaluate
11+
from lm_eval.utils import make_table
12+
from lm_eval.models.huggingface import HFLM
13+
14+
15+
# Setup logging
16+
logging.basicConfig(level=logging.INFO)
17+
logger = logging.getLogger(__name__)
18+
19+
def parse_args():
20+
parser = argparse.ArgumentParser(description="Evaluate trained model on common LM datasets using LM Eval Harness.")
21+
parser.add_argument("--model_type", type=str, choices=["hf", "sparse"], default="hf")
22+
parser.add_argument("--model_name_or_config", type=str, required=True,
23+
help="Name or path of the base model (e.g., meta-llama/Llama-2-7b-hf)")
24+
parser.add_argument("--sp_dir", type=str, default="",
25+
help="Path to trained predictor dir for sparse model.")
26+
parser.add_argument("--tasks", nargs='+', default=["hellaswag"],
27+
help="Tasks on which to evaluate")
28+
parser.add_argument("--batch_size", type=int, default=4,
29+
help="Batch size for processing")
30+
parser.add_argument("--device", type=str, default="auto",
31+
help="Device to use (auto, cpu, cuda)")
32+
return parser
33+
34+
35+
def main():
36+
args = parse_args()
37+
38+
# Set seed
39+
set_seed(args.seed)
40+
41+
# Setup device
42+
if args.device == "auto":
43+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
44+
else:
45+
device = torch.device(args.device)
46+
47+
logger.info(f"Using device: {device}")
48+
49+
# Load pretrained model
50+
logging.info("Loading pretrained model for evaluation...")
51+
config = AutoConfig.from_pretrained(args.model_name_or_config)
52+
if args.model_type == "hf":
53+
model = AutoModelForCausalLM.from_pretrained(config)
54+
if args.model_type == "sparse":
55+
model = AutoModelForCausalLM.from_pretrained(config._name_or_path, config=config)
56+
for layer_idx, layer in enumerate(model.get_decoder().layers):
57+
layer_path = os.path.join(args.sp_dir, f"final_predictor_layer_{layer_idx}")
58+
if not os.path.exists(layer_path):
59+
logger.error(f"Pretrained weights for sparse predictor at layer {layer_idx} do not exist.")
60+
return
61+
pretrained_dict = torch.load(layer_path)
62+
layer.mlp_lora_proj.load_state_dict(pretrained_dict)
63+
model.tie_weights()
64+
model.reset_cache()
65+
66+
wrapped_model = HFLM(
67+
pretrained=model,
68+
batch_size=args.batch_size,
69+
device=device
70+
)
71+
72+
logging.info("Beginning evaluation...")
73+
results = simple_evaluate(
74+
wrapped_model,
75+
tasks=args.tasks,
76+
batch_size=args.batch_size,
77+
device=device
78+
)
79+
80+
if results is not None:
81+
print(make_table(results))
82+
if "groups" in results:
83+
print(make_table(results, "groups"))
84+
85+
if __name__ == '__main__':
86+
main()

0 commit comments

Comments
 (0)