@@ -92,6 +92,7 @@ def __init__(
9292 max_new_tokens : int = 1280 ,
9393 dtype : str = "bfloat16" , # default
9494 trust_remote_code : bool = False ,
95+ tokenizer_name : str = None ,
9596 ) -> None :
9697 print ("Initializing a decoder model: {} ..." .format (name ))
9798 self .name = name
@@ -102,6 +103,7 @@ def __init__(
102103 self .max_new_tokens = max_new_tokens
103104 self .dtype = dtype
104105 self .trust_remote_code = trust_remote_code
106+ self .tokenizer_name = tokenizer_name
105107
106108 @abstractmethod
107109 def codegen (
@@ -129,11 +131,13 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
129131 "dtype" : self .dtype ,
130132 "trust_remote_code" : self .trust_remote_code ,
131133 }
132-
133- self .tokenizer = AutoTokenizer .from_pretrained (self .name , ** kwargs )
134+ if self .tokenizer_name is None :
135+ self .tokenizer_name = self .name
136+
137+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_name , ** kwargs )
134138 if self .tokenizer .chat_template is None :
135139 self .eos += extra_eos_for_direct_completion (dataset )
136- self .llm = LLM (model = name , max_model_len = 2048 , ** kwargs )
140+ self .llm = LLM (model = name , max_model_len = 2048 , tokenizer = self . tokenizer_name , ** kwargs )
137141
138142 def is_direct_completion (self ) -> bool :
139143 return self .tokenizer .chat_template is None
@@ -185,9 +189,12 @@ def __init__(self, name: str, dataset: str, **kwargs):
185189 kwargs ["torch_dtype" ] = getattr (torch , self .dtype )
186190 self .skip_special_tokens = True
187191
188- print (f"{ kwargs = } " )
192+ print (f"{ kwargs = } " , self .tokenizer_name )
193+
194+ if self .tokenizer_name is None :
195+ self .tokenizer_name = self .name
189196
190- self .tokenizer = AutoTokenizer .from_pretrained (name , ** kwargs )
197+ self .tokenizer = AutoTokenizer .from_pretrained (self . tokenizer_name , ** kwargs )
191198 if self .tokenizer .chat_template is None :
192199 self .eos += extra_eos_for_direct_completion (dataset )
193200
@@ -253,7 +260,7 @@ def __init__(self, name: str, **kwargs):
253260 super ().__init__ (name = name , ** kwargs )
254261 self .eos += ["\n ```\n " ]
255262 print (f"EOS strings: { self .eos } " )
256- self .tokenizer = AutoTokenizer .from_pretrained (self .name , ** kwargs )
263+ self .tokenizer = AutoTokenizer .from_pretrained (self .tokenizer_name if self . tokenizer_name else self . name , ** kwargs )
257264
258265 def codegen (
259266 self , prompt : str , do_sample : bool = True , num_samples : int = 200
@@ -486,6 +493,7 @@ def make_model(
486493 tp = 1 ,
487494 base_url = None ,
488495 trust_remote_code = False ,
496+ tokenizer_name = None ,
489497):
490498 if backend == "vllm" :
491499 return GeneralVllmDecoder (
@@ -495,6 +503,7 @@ def make_model(
495503 dataset = dataset ,
496504 tp = tp ,
497505 trust_remote_code = trust_remote_code ,
506+ tokenizer_name = tokenizer_name ,
498507 )
499508 elif backend == "hf" :
500509 return GenenralHfTorchDecoder (
@@ -503,6 +512,7 @@ def make_model(
503512 temperature = temperature ,
504513 dataset = dataset ,
505514 trust_remote_code = trust_remote_code ,
515+ tokenizer_name = tokenizer_name ,
506516 )
507517 elif backend == "openai" :
508518 return OpenAIChatDecoder (
0 commit comments