Skip to content

Commit 96aafc0

Browse files
committed
add custom tokenzier
1 parent b2a14b6 commit 96aafc0

File tree

1 file changed

+16
-6
lines changed

1 file changed

+16
-6
lines changed

bigcodebench/model.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def __init__(
9292
max_new_tokens: int = 1280,
9393
dtype: str = "bfloat16", # default
9494
trust_remote_code: bool = False,
95+
tokenizer_name: str = None,
9596
) -> None:
9697
print("Initializing a decoder model: {} ...".format(name))
9798
self.name = name
@@ -102,6 +103,7 @@ def __init__(
102103
self.max_new_tokens = max_new_tokens
103104
self.dtype = dtype
104105
self.trust_remote_code = trust_remote_code
106+
self.tokenizer_name = tokenizer_name
105107

106108
@abstractmethod
107109
def codegen(
@@ -129,11 +131,13 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
129131
"dtype": self.dtype,
130132
"trust_remote_code": self.trust_remote_code,
131133
}
132-
133-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
134+
if self.tokenizer_name is None:
135+
self.tokenizer_name = self.name
136+
137+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
134138
if self.tokenizer.chat_template is None:
135139
self.eos += extra_eos_for_direct_completion(dataset)
136-
self.llm = LLM(model=name, max_model_len=2048, **kwargs)
140+
self.llm = LLM(model=name, max_model_len=2048, tokenizer=self.tokenizer_name, **kwargs)
137141

138142
def is_direct_completion(self) -> bool:
139143
return self.tokenizer.chat_template is None
@@ -185,9 +189,12 @@ def __init__(self, name: str, dataset: str, **kwargs):
185189
kwargs["torch_dtype"] = getattr(torch, self.dtype)
186190
self.skip_special_tokens = True
187191

188-
print(f"{kwargs = }")
192+
print(f"{kwargs = }", self.tokenizer_name)
193+
194+
if self.tokenizer_name is None:
195+
self.tokenizer_name = self.name
189196

190-
self.tokenizer = AutoTokenizer.from_pretrained(name, **kwargs)
197+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name, **kwargs)
191198
if self.tokenizer.chat_template is None:
192199
self.eos += extra_eos_for_direct_completion(dataset)
193200

@@ -253,7 +260,7 @@ def __init__(self, name: str, **kwargs):
253260
super().__init__(name=name, **kwargs)
254261
self.eos += ["\n```\n"]
255262
print(f"EOS strings: {self.eos}")
256-
self.tokenizer = AutoTokenizer.from_pretrained(self.name, **kwargs)
263+
self.tokenizer = AutoTokenizer.from_pretrained(self.tokenizer_name if self.tokenizer_name else self.name, **kwargs)
257264

258265
def codegen(
259266
self, prompt: str, do_sample: bool = True, num_samples: int = 200
@@ -486,6 +493,7 @@ def make_model(
486493
tp=1,
487494
base_url=None,
488495
trust_remote_code=False,
496+
tokenizer_name=None,
489497
):
490498
if backend == "vllm":
491499
return GeneralVllmDecoder(
@@ -495,6 +503,7 @@ def make_model(
495503
dataset=dataset,
496504
tp=tp,
497505
trust_remote_code=trust_remote_code,
506+
tokenizer_name=tokenizer_name,
498507
)
499508
elif backend == "hf":
500509
return GenenralHfTorchDecoder(
@@ -503,6 +512,7 @@ def make_model(
503512
temperature=temperature,
504513
dataset=dataset,
505514
trust_remote_code=trust_remote_code,
515+
tokenizer_name=tokenizer_name,
506516
)
507517
elif backend == "openai":
508518
return OpenAIChatDecoder(

0 commit comments

Comments
 (0)