@@ -91,6 +91,7 @@ def __init__(
9191 max_new_tokens : int = 1280 ,
9292 dtype : str = "bfloat16" , # default
9393 trust_remote_code : bool = False ,
94+ tokenizer_name : str = None ,
9495 ) -> None :
9596 print ("Initializing a decoder model: {} ..." .format (name ))
9697 self .name = name
@@ -101,6 +102,7 @@ def __init__(
101102 self .max_new_tokens = max_new_tokens
102103 self .dtype = dtype
103104 self .trust_remote_code = trust_remote_code
105+ self .tokenizer_name = tokenizer_name
104106
105107 @abstractmethod
106108 def codegen (
@@ -185,7 +187,10 @@ def __init__(self, name: str, dataset: str, **kwargs):
185187 kwargs ["torch_dtype" ] = getattr (torch , self .dtype )
186188 self .skip_special_tokens = True
187189
188- print (f"{ kwargs = } " )
190+ print (f"{ kwargs = } " , self .tokenizer_name )
191+
192+ if self .tokenizer_name is None :
193+ self .tokenizer_name = self .name
189194
190195 self .tokenizer = AutoTokenizer .from_pretrained (name , legacy = False , ** kwargs )
191196 if self .tokenizer .chat_template is None :
@@ -475,6 +480,7 @@ def make_model(
475480 tp = 1 ,
476481 base_url = None ,
477482 trust_remote_code = False ,
483+ tokenizer_name = None ,
478484):
479485 if backend == "vllm" :
480486 return GeneralVllmDecoder (
@@ -484,6 +490,7 @@ def make_model(
484490 dataset = dataset ,
485491 tp = tp ,
486492 trust_remote_code = trust_remote_code ,
493+ tokenizer_name = tokenizer_name ,
487494 )
488495 elif backend == "hf" :
489496 return GenenralHfTorchDecoder (
@@ -492,6 +499,7 @@ def make_model(
492499 temperature = temperature ,
493500 dataset = dataset ,
494501 trust_remote_code = trust_remote_code ,
502+ tokenizer_name = tokenizer_name ,
495503 )
496504 elif backend == "openai" :
497505 return OpenAIChatDecoder (
0 commit comments