33import os
44import pathlib
55from typing import Dict , Generator , List , Optional , Set , Tuple
6+ from pqdm .processes import pqdm
67
78from tqdm import tqdm
89from tree_sitter import Node
@@ -178,8 +179,48 @@ def sanitize(code: str, entrypoint: Optional[str] = None) -> str:
178179 return sanitized_output
179180
180181
182+ def process_solution (
183+ sample_solution : Dict ,
184+ dataset : Dict ,
185+ entry_point : Dict ,
186+ debug_task : str = None ,
187+ calibrate : bool = False ,
188+ is_folder : bool = False ,
189+ target_path : str = None ,
190+ ):
191+
192+ task_id = sample_solution .get ("task_id" )
193+ if not task_id or task_id not in dataset :
194+ return None
195+
196+ dbg_identifier = sample_solution ["_identifier" ]
197+ if debug_task is not None and task_id != debug_task :
198+ return None
199+
200+ function_name = entry_point .get (task_id )
201+ old_code = sample_solution .get ("solution" )
202+
203+ if old_code is None :
204+ assert "completion" in sample_solution , sample_solution
205+ old_code = dataset [task_id ]["complete_prompt" ] + "\n " + sample_solution .get ("completion" )
206+ else :
207+ if calibrate :
208+ old_code = old_code .replace ("```python\n " , "```python\n " + dataset [task_id ]["complete_prompt" ]+ " " )
209+
210+ new_code = sanitize (code = old_code , entrypoint = function_name )
211+
212+ # if old code and new code are different, print msg
213+ if new_code != old_code :
214+ msg = "Sanitized: " + dbg_identifier
215+ if is_folder :
216+ msg += " -> " + dbg_identifier .replace (samples , target_path )
217+ print (msg )
218+
219+ return {"task_id" : task_id , "solution" : new_code }
220+
221+
181222def script (
182- samples : str , inplace : bool = False , debug_task : str = None , calibrate : bool = False
223+ samples : str , inplace : bool = False , debug_task : str = None , calibrate : bool = False , parallel : int = 32
183224):
184225 # task_id -> entry_point
185226 entry_point = {}
@@ -211,38 +252,26 @@ def script(
211252
212253 new_solutions = []
213254
214- for solution in tqdm (load_solutions (samples )):
215- task_id = solution ["task_id" ]
216- if task_id not in dataset :
217- print (
218- f"Skiping { task_id } as it does not existing in the latest EvalPlus dataset."
219- )
220- continue
221-
222- function_name = entry_point [task_id ] if task_id in entry_point else None
223- dbg_identifier = solution ["_identifier" ]
224- if debug_task is not None and task_id != debug_task :
225- continue
226-
227- ntotal += 1
228- if "solution" in solution :
229- old_code = solution ["solution" ]
230- if calibrate :
231- old_code = solution ["solution" ].replace ("```python\n " , "```python\n " + dataset [task_id ]["complete_prompt" ]+ " " )
232- else :
233- assert "completion" in solution
234- old_code = dataset [task_id ]["complete_prompt" ] + "\n " + solution ["completion" ]
235-
236- new_code = sanitize (code = old_code , entrypoint = function_name )
237- # if changed, print the message
238- if new_code != old_code :
239- msg = "Sanitized: " + dbg_identifier
240- if is_folder :
241- msg += " -> " + dbg_identifier .replace (samples , target_path )
242- print (msg )
255+ parallel_arg_list = [
256+ {
257+ "sample_solution" : sample_solution ,
258+ "dataset" : dataset ,
259+ "entry_point" : entry_point ,
260+ "debug_task" : debug_task ,
261+ "calibrate" : calibrate ,
262+ "is_folder" : is_folder ,
263+ "target_path" : target_path
264+ }
265+ for sample_solution in load_solutions (samples )
266+ ]
267+
268+ results = pqdm (parallel_arg_list , process_solution , n_jobs = min (parallel , os .cpu_count ()), argument_type = "kwargs" )
269+
270+ for result in results :
271+ if result is not None :
272+ new_solutions .append (result )
243273 nsan += 1
244-
245- new_solutions .append ({"task_id" : task_id , "solution" : new_code })
274+ ntotal += 1
246275
247276 if is_folder :
248277 write_directory (target_path , new_solutions )
@@ -263,4 +292,4 @@ def main():
263292
264293
265294if __name__ == "__main__" :
266- main ()
295+ main ()
0 commit comments