Merge pull request #2 from wowrakibul/fix/convert-py-improvements

Improve convert.py with error handling and code optimization
This commit is contained in:
Wow Rakibul 2025-02-09 02:02:17 +06:00 committed by GitHub
commit 361d0bcc1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@ import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
import torch
from safetensors.torch import safe_open, save_file
@ -30,7 +29,7 @@ mapping = {
}
def main(hf_ckpt_path, save_path, n_experts, mp):
def main(hf_ckpt_path: str, save_path: str, n_experts: int, mp: int) -> None:
"""
Converts and saves model checkpoint files into a specified format.
@ -43,6 +42,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns:
None
"""
try:
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
@ -84,6 +84,9 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__":
parser = ArgumentParser()