@@ -91,6 +91,7 @@ def __init__(
9191 max_new_tokens : int = 1280 ,
9292 dtype : str = "bfloat16" , # default
9393 trust_remote_code : bool = False ,
94+ tokenizer_name : str = None ,
9495 ) -> None :
9596 print ("Initializing a decoder model: {} ..." .format (name ))
9697 self .name = name
@@ -101,6 +102,7 @@ def __init__(
101102 self .max_new_tokens = max_new_tokens
102103 self .dtype = dtype
103104 self .trust_remote_code = trust_remote_code
105+ self .tokenizer_name = tokenizer_name
104106
105107 @abstractmethod
106108 def codegen (
@@ -128,8 +130,10 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
128130 "dtype" : self .dtype ,
129131 "trust_remote_code" : self .trust_remote_code ,
130132 }
131-
132- self .tokenizer = AutoTokenizer .from_pretrained (self .name , legacy = False , ** kwargs )
133+ if self .tokenizer_name is None :
134+ self .tokenizer_name = self .name
135+
136+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_name , legacy = False , ** kwargs )
133137 if self .tokenizer .chat_template is None :
134138 self .eos += extra_eos_for_direct_completion (dataset )
135139 self .llm = LLM (model = name , max_model_len = 2048 , ** kwargs )
@@ -185,9 +189,11 @@ def __init__(self, name: str, dataset: str, **kwargs):
185189 kwargs ["torch_dtype" ] = getattr (torch , self .dtype )
186190 self .skip_special_tokens = True
187191
188- print (f"{ kwargs = } " )
189-
190- self .tokenizer = AutoTokenizer .from_pretrained (name , legacy = False , ** kwargs )
192+ print (f"{ kwargs = } " , self .tokenizer_name )
193+ if self .tokenizer_name is None :
194+ self .tokenizer_name = self .name
195+
196+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_name , legacy = False , ** kwargs )
191197 if self .tokenizer .chat_template is None :
192198 self .eos += extra_eos_for_direct_completion (dataset )
193199
@@ -242,7 +248,7 @@ def __init__(self, name: str, **kwargs):
242248 super ().__init__ (name = name , ** kwargs )
243249 self .eos += ["\n ```\n " ]
244250 print (f"EOS strings: { self .eos } " )
245- self .tokenizer = AutoTokenizer .from_pretrained (self .name , legacy = False , ** kwargs )
251+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_name if self . tokenizer_name else self . name , ** kwargs )
246252
247253 def codegen (
248254 self , prompt : str , do_sample : bool = True , num_samples : int = 200
@@ -475,6 +481,7 @@ def make_model(
475481 tp = 1 ,
476482 base_url = None ,
477483 trust_remote_code = False ,
484+ tokenizer_name = None ,
478485):
479486 if backend == "vllm" :
480487 return GeneralVllmDecoder (
@@ -484,6 +491,7 @@ def make_model(
484491 dataset = dataset ,
485492 tp = tp ,
486493 trust_remote_code = trust_remote_code ,
494+ tokenizer_name = tokenizer_name ,
487495 )
488496 elif backend == "hf" :
489497 return GenenralHfTorchDecoder (
@@ -492,6 +500,7 @@ def make_model(
492500 temperature = temperature ,
493501 dataset = dataset ,
494502 trust_remote_code = trust_remote_code ,
503+ tokenizer_name = tokenizer_name ,
495504 )
496505 elif backend == "openai" :
497506 return OpenAIChatDecoder (
0 commit comments