This commit is contained in:
Dylancer 2024-04-15 10:18:15 -05:00 committed by GitHub
commit c77291ddff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -90,11 +90,17 @@ def do_parallel_sampling(args, task, answer_extraction_fn, eval_fn, input_dir, o
local_pids = [global_pid for (global_pid, _, _) in procs] local_pids = [global_pid for (global_pid, _, _) in procs]
if global_n_procs == 1:
agg_preds = read_data(os.path.join(output_dir, "predictions.json"))
else:
agg_preds = [] agg_preds = []
for fname in glob(os.path.join(output_dir, "predictions.*.json")): for fname in glob(os.path.join(output_dir, "predictions.*.json")):
if any(str(pid) in fname for pid in local_pids): if any(str(pid) in fname for pid in local_pids):
agg_preds.extend(read_data(fname)) agg_preds.extend(read_data(fname))
if global_n_procs == 1:
metrics = read_data(os.path.join(output_dir, "metrics.json"))
result_msg = f"n samples = {metrics['n_samples']}"
else:
metrics = {} metrics = {}
n_samples = 0 n_samples = 0
for fname in glob(os.path.join(output_dir, "metrics.*.json")): for fname in glob(os.path.join(output_dir, "metrics.*.json")):