@@ -19,13 +19,13 @@ def codegen(
1919 model : DecoderBase ,
2020 target_path : str ,
2121 split : str ,
22- subset = "full" ,
23- greedy = False ,
24- strip_newlines = False ,
25- n_samples = 1 ,
26- id_range = None ,
27- resume = True ,
28- batch_size : int = - 1 ,
22+ subset : str ,
23+ greedy : bool = False ,
24+ strip_newlines : bool = False ,
25+ n_samples : int = 1 ,
26+ id_range : Tuple [ int , int ] = None ,
27+ resume : bool = True ,
28+ batch_size : int = - 1 ,
2929):
3030 with Progress (
3131 TextColumn (f"BigCodeBench--{ split .capitalize ()} ({ subset .capitalize ()} ) •" + "[progress.percentage]{task.percentage:>3.0f}%" ),
@@ -51,12 +51,12 @@ def codegen(
5151 batch_entry_points = []
5252
5353 # Read existing data once if resuming
54- existing_data = {}
54+ task2nexist = {}
5555 if resume and os .path .exists (target_path ):
5656 with open (target_path , "r" ) as f :
5757 for line in f :
5858 item = json .loads (line )
59- existing_data [item ["task_id" ]] = existing_data .get (item ["task_id" ], 0 ) + 1
59+ task2nexist [item ["task_id" ]] = task2nexist .get (item ["task_id" ], 0 ) + 1
6060
6161 for id_num , (task_id , task ) in enumerate (p .track (dataset .items ())):
6262 if id_range is not None :
@@ -69,7 +69,7 @@ def codegen(
6969
7070 p_name = task_id .replace ("/" , "_" )
7171
72- n_existing = existing_data .get (task_id , 0 )
72+ n_existing = task2nexist .get (task_id , 0 )
7373 nsamples = n_samples - n_existing
7474
7575 try :
@@ -91,7 +91,7 @@ def codegen(
9191 p .console .print (log )
9292
9393 if (batch_size and len (batch_prompts ) == batch_size ) or id_num == len (dataset ) - 1 or (id_range and id_num == id_range [1 ] - 1 ):
94- if not batch_prompts and id_num == len (dataset ) - 1 :
94+ if not batch_prompts and ( id_num == len (dataset ) - 1 or ( id_range and id_num == id_range [ 1 ] - 1 )) :
9595 break
9696 outputs = model .codegen (
9797 batch_prompts ,
@@ -130,6 +130,7 @@ def run_codegen(
130130 bs : Optional [int ] = None ,
131131 n_samples : int = 1 ,
132132 temperature : float = 0.0 ,
133+ max_new_tokens : int = 1280 ,
133134 greedy : bool = False ,
134135 strip_newlines : bool = False ,
135136 direct_completion : bool = False ,
@@ -147,7 +148,7 @@ def run_codegen(
147148 temperature = 0
148149 n_samples = 1
149150 greedy = True
150- print ("Greedy decoding ON (--greedy): setting bs=1, n_samples=1, temperature=0" )
151+ print ("Greedy decoding ON (--greedy): setting n_samples=1, temperature=0" )
151152
152153 if id_range is not None :
153154 assert len (id_range ) == 2 , "id_range must be a list of length 2"
@@ -167,6 +168,7 @@ def run_codegen(
167168 subset = subset ,
168169 split = split ,
169170 temperature = temperature ,
171+ max_new_tokens = max_new_tokens ,
170172 instruction_prefix = instruction_prefix ,
171173 response_prefix = response_prefix ,
172174 base_url = base_url ,
@@ -181,7 +183,10 @@ def run_codegen(
181183 identifier = model .replace ("/" , "--" ) + f"--bigcodebench{ extra } -{ split } --{ backend } -{ temperature } -{ n_samples } -sanitized_calibrated.jsonl"
182184
183185 target_path = os .path .join (root , identifier )
184-
186+
187+ if not resume :
188+ os .remove (target_path )
189+
185190 codegen (
186191 model = model_runner ,
187192 target_path = target_path ,
0 commit comments