11import os
22from typing import List
3- from tqdm import tqdm
43
54import openai
65
7- from bigcodebench .provider .base import DecoderBase
86from bigcodebench .gen .util .openai_request import make_auto_request
97from bigcodebench .provider .utility import make_raw_chat_prompt
8+ from bigcodebench .provider .base import DecoderBase
9+ from bigcodebench .provider .utility import concurrent_call
1010
1111class OpenAIChatDecoder (DecoderBase ):
1212 def __init__ (self , name : str , base_url = None , ** kwargs ) -> None :
@@ -15,34 +15,83 @@ def __init__(self, name: str, base_url=None, **kwargs) -> None:
1515 api_key = os .getenv ("OPENAI_API_KEY" , "none" ), base_url = base_url
1616 )
1717
18+ # def codegen(
19+ # self, prompts: List[str], do_sample: bool = True, num_samples: int = 200
20+ # ) -> List[str]:
21+ # if do_sample:
22+ # assert self.temperature > 0, "Temperature must be positive for sampling"
23+ # all_outputs = []
24+ # for prompt in tqdm(prompts):
25+ # outputs = []
26+ # message = make_raw_chat_prompt(
27+ # task_prompt=prompt,
28+ # subset=self.subset,
29+ # split=self.split,
30+ # instruction_prefix=self.instruction_prefix,
31+ # response_prefix=self.response_prefix,
32+ # tokenizer=None,
33+ # )
34+ # ret = make_auto_request(
35+ # self.client,
36+ # message=message,
37+ # model=self.name,
38+ # max_tokens=self.max_new_tokens,
39+ # temperature=self.temperature,
40+ # n=num_samples,
41+ # )
42+ # for item in ret.choices:
43+ # outputs.append(item.message.content)
44+ # all_outputs.append(outputs)
45+ # return all_outputs
46+
47+ # def is_direct_completion(self) -> bool:
48+ # return False
49+
1850 def codegen (
1951 self , prompts : List [str ], do_sample : bool = True , num_samples : int = 200
2052 ) -> List [str ]:
2153 if do_sample :
2254 assert self .temperature > 0 , "Temperature must be positive for sampling"
55+ messages = [make_raw_chat_prompt (
56+ task_prompt = prompt ,
57+ subset = self .subset ,
58+ split = self .split ,
59+ instruction_prefix = self .instruction_prefix ,
60+ response_prefix = self .response_prefix ,
61+ tokenizer = None ,
62+ ) for prompt in prompts ]
63+ # use concurrency based batching for o1 and deepseek models
64+ if self .name .startswith ("o1-" ) or self .name == "deepseek-chat" :
65+ return self ._codegen_batch_via_concurrency (messages , num_samples )
66+
67+ return self ._codegen_api_batch (messages , num_samples )
68+
69+ def _codegen_api_batch (self , messages : List [str ], num_samples : int ) -> List [str ]:
70+ client = openai .OpenAI (
71+ api_key = os .getenv ("OPENAI_API_KEY" , "none" ), base_url = self .base_url
72+ )
73+
2374 all_outputs = []
24- for prompt in tqdm (prompts ):
25- outputs = []
26- message = make_raw_chat_prompt (
27- task_prompt = prompt ,
28- subset = self .subset ,
29- split = self .split ,
30- instruction_prefix = self .instruction_prefix ,
31- response_prefix = self .response_prefix ,
32- tokenizer = None ,
33- )
75+ for message in messages :
3476 ret = make_auto_request (
35- self . client ,
77+ client ,
3678 message = message ,
3779 model = self .name ,
3880 max_tokens = self .max_new_tokens ,
3981 temperature = self .temperature ,
4082 n = num_samples ,
4183 )
84+ outputs = []
4285 for item in ret .choices :
4386 outputs .append (item .message .content )
4487 all_outputs .append (outputs )
4588 return all_outputs
4689
90+ def _codegen_batch_via_concurrency (self , messages : List [str ], num_samples : int ) -> List [str ]:
91+ batches = concurrent_call (
92+ num_samples , self ._codegen_api_batch , messages , num_samples = 1
93+ )
94+ return [b [0 ] for b in batches ]
95+
4796 def is_direct_completion (self ) -> bool :
4897 return False
0 commit comments