DeepSeek-Prover-V1.5/prover/workers/search.py

104 lines
4.3 KiB
Python
Raw Normal View History

2024-08-15 23:33:21 -04:00
import os
import time
import copy
import json
import pickle
from pathlib import Path
import torch
import torch.multiprocessing as mp
import numpy as np
from prover.utils import AttrDict, get_datetime
class SearchProcess(mp.Process):
def __init__(self, idx, log_dir, tokenizer_path, scheduler, data_loader, cfg):
self.idx = idx
self.log_dir = Path(log_dir)
self.scheduler = scheduler
self.data_loader = data_loader
super().__init__()
self._current_prob_idx = None
sampler_cls = cfg.sampler['algorithm']
self.sampler = sampler_cls(
scheduler=self.scheduler,
tokenizer_path=tokenizer_path,
process_print=self.process_print,
cfg=AttrDict({
**cfg.sampler,
'mode': cfg.model_args.mode,
'max_tokens': cfg.model_args.max_tokens,
})
)
def _post_process(self, data: dict, proof_code: str):
header = data.get('header', str())
tailer = data.get('tailer', str())
formal_statement = data['formal_statement']
return dict(
statement_proposal=f'{header}{formal_statement}{proof_code}{tailer}',
proof_code=proof_code,
)
def process_print(self, logs, **kwargs):
print('Process ID: {:3d} Problem ID: {} {}'.format(self.idx, self._current_prob, logs), **kwargs)
def run(self):
while True:
prob_idx, prob_runname, data = self.data_loader.get()
if prob_idx is None: break
sample_start_time = time.time()
# build a yield-iterator object to generate samples
self._current_prob = f'{prob_idx}_{prob_runname}'
prob_log_dir = self.log_dir / self._current_prob
os.makedirs(prob_log_dir, exist_ok=True)
sample_generator = self.sampler.sample(
data=data,
prob_log_dir=prob_log_dir,
)
# submit requests to the verification server when receiving from the generator
candidate_list, info_list, request_id_list = [], [], []
for sample, info in sample_generator:
candidate = self._post_process(data, sample)
candidate_list.append(candidate)
info_list.append(copy.deepcopy(info))
request_id = self.scheduler.verifier_submit_request(candidate['statement_proposal'])
request_id_list.append(request_id)
sample_timecost = time.time() - sample_start_time
verification_start_wait_time = time.time()
result_list = self.scheduler.verifier_get_all_request_outputs(request_id_list)
verification_timecost = time.time() - verification_start_wait_time
success_count = sum([int(result['complete']) for result in result_list])
self.process_print('Success: {} / {} Generation: {:.2f} secs Verfication: {:.2f} secs'.format(
success_count, len(candidate_list), sample_timecost, verification_timecost,
))
summary_dict = dict(success=[], failure=[])
for _idx, (candidate, result, info) in enumerate(zip(candidate_list, result_list, info_list)):
success_flag = 'success' if result['complete'] else 'failure'
summary_dict[success_flag].append(dict(
problem_name=data['name'],
sample_info=info,
formal_statement=data['formal_statement'],
proof_code=candidate['proof_code'],
result=result,
))
prob_name, run_id = prob_runname.split('/')
prob_log_basedir = self.log_dir / f'{prob_idx}_{data["name"]}'
log_tag = f'{self.sampler.algorithm_name}-{run_id}'
# separately save success and failure results
for success_flag, summary_list in summary_dict.items():
if len(summary_list) > 0:
with open(prob_log_basedir / f'{success_flag}-{log_tag}-{get_datetime()}.pkl', 'wb') as pkl_f:
pickle.dump(summary_list, pkl_f)
# create a 'finished' placeholder
with open(prob_log_dir / self.data_loader.finished_flag_filename, 'w') as f:
print('finished', file=f)