Skip to content

Commit af94e97

Browse files
committed
fix: use legacy=False
1 parent ef09c42 commit af94e97

File tree

1 file changed

+6
-17
lines changed

1 file changed

+6
-17
lines changed

bigcodebench/model.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
warn("GoogleGenAI decoder will not work. Fix by `pip install google-generativeai`")
2727

2828
import torch
29-
from stop_sequencer import StopSequencer
3029
from transformers import AutoModelForCausalLM, AutoTokenizer
3130

3231
try:
@@ -130,10 +129,11 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
130129
"trust_remote_code": self.trust_remote_code,
131130
}
132131

133-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
132+
self.tokenizer = AutoTokenizer.from_pretrained(self.name, legacy=False, **kwargs)
134133
if self.tokenizer.chat_template is None:
135134
self.eos += extra_eos_for_direct_completion(dataset)
136135
self.llm = LLM(model=name, max_model_len=2048, **kwargs)
136+
self.llm.set_tokenizer(tokenizer=self.tokenizer)
137137

138138
def is_direct_completion(self) -> bool:
139139
return self.tokenizer.chat_template is None
@@ -179,15 +179,15 @@ def __init__(self, name: str, dataset: str, **kwargs):
179179
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
180180

181181
kwargs = {}
182-
kwargs["device_map"] = "auto"
182+
kwargs["device_map"] = "cuda:0"
183183
kwargs["trust_remote_code"] = self.trust_remote_code
184184
# string to torch dtype
185185
kwargs["torch_dtype"] = getattr(torch, self.dtype)
186186
self.skip_special_tokens = True
187187

188188
print(f"{kwargs = }")
189189

190-
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
190+
self.tokenizer = AutoTokenizer.from_pretrained(name, legacy=False, **kwargs)
191191
if self.tokenizer.chat_template is None:
192192
self.eos += extra_eos_for_direct_completion(dataset)
193193

@@ -213,18 +213,7 @@ def codegen(
213213
kwargs["top_p"] = 0.95
214214
kwargs["temperature"] = self.temperature
215215

216-
stop_sequencer = StopSequencer(
217-
self.model,
218-
model_type="causal", # or seq2seq
219-
tokenizer=self.tokenizer,
220-
)
221-
222-
model = stop_sequencer.register_stop_texts(
223-
stop_texts=self.eos,
224-
input_length=input_tokens.size(-1),
225-
)
226-
227-
outputs = model.generate(
216+
outputs = self.model.generate(
228217
input_tokens,
229218
max_new_tokens=self.max_new_tokens,
230219
do_sample=do_sample,
@@ -253,7 +242,7 @@ def __init__(self, name: str, **kwargs):
253242
super().__init__(name=name, **kwargs)
254243
self.eos += ["\n```\n"]
255244
print(f"EOS strings: {self.eos}")
256-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
245+
self.tokenizer = AutoTokenizer.from_pretrained(self.name, legacy=False, **kwargs)
257246

258247
def codegen(
259248
self, prompt: str, do_sample: bool = True, num_samples: int = 200

0 commit comments

Comments
 (0)