2626 warn ("GoogleGenAI decoder will not work. Fix by `pip install google-generativeai`" )
2727
2828import torch
29- from stop_sequencer import StopSequencer
3029from transformers import AutoModelForCausalLM , AutoTokenizer
3130
3231try :
@@ -130,10 +129,11 @@ def __init__(self, name: str, dataset: str, tp: int, **kwargs) -> None:
130129 "trust_remote_code" : self .trust_remote_code ,
131130 }
132131
133- self .tokenizer = AutoTokenizer .from_pretrained (self .name , ** kwargs )
132+ self .tokenizer = AutoTokenizer .from_pretrained (self .name , legacy = False , ** kwargs )
134133 if self .tokenizer .chat_template is None :
135134 self .eos += extra_eos_for_direct_completion (dataset )
136135 self .llm = LLM (model = name , max_model_len = 2048 , ** kwargs )
136+ self .llm .set_tokenizer (tokenizer = self .tokenizer )
137137
138138 def is_direct_completion (self ) -> bool :
139139 return self .tokenizer .chat_template is None
@@ -179,15 +179,15 @@ def __init__(self, name: str, dataset: str, **kwargs):
179179 self .device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
180180
181181 kwargs = {}
182- kwargs ["device_map" ] = "auto "
182+ kwargs ["device_map" ] = "cuda:0 "
183183 kwargs ["trust_remote_code" ] = self .trust_remote_code
184184 # string to torch dtype
185185 kwargs ["torch_dtype" ] = getattr (torch , self .dtype )
186186 self .skip_special_tokens = True
187187
188188 print (f"{ kwargs = } " )
189189
190- self .tokenizer = AutoTokenizer .from_pretrained (name , ** kwargs )
190+ self .tokenizer = AutoTokenizer .from_pretrained (name , legacy = False , ** kwargs )
191191 if self .tokenizer .chat_template is None :
192192 self .eos += extra_eos_for_direct_completion (dataset )
193193
@@ -213,18 +213,7 @@ def codegen(
213213 kwargs ["top_p" ] = 0.95
214214 kwargs ["temperature" ] = self .temperature
215215
216- stop_sequencer = StopSequencer (
217- self .model ,
218- model_type = "causal" , # or seq2seq
219- tokenizer = self .tokenizer ,
220- )
221-
222- model = stop_sequencer .register_stop_texts (
223- stop_texts = self .eos ,
224- input_length = input_tokens .size (- 1 ),
225- )
226-
227- outputs = model .generate (
216+ outputs = self .model .generate (
228217 input_tokens ,
229218 max_new_tokens = self .max_new_tokens ,
230219 do_sample = do_sample ,
@@ -253,7 +242,7 @@ def __init__(self, name: str, **kwargs):
253242 super ().__init__ (name = name , ** kwargs )
254243 self .eos += ["\n ```\n " ]
255244 print (f"EOS strings: { self .eos } " )
256- self .tokenizer = AutoTokenizer .from_pretrained (self .name , ** kwargs )
245+ self .tokenizer = AutoTokenizer .from_pretrained (self .name , legacy = False , ** kwargs )
257246
258247 def codegen (
259248 self , prompt : str , do_sample : bool = True , num_samples : int = 200
0 commit comments