Merge pull request #2 from wowrakibul/fix/convert-py-improvements

Improve convert.py with error handling and code optimization
This commit is contained in:
Wow Rakibul 2025-02-09 02:02:17 +06:00 committed by GitHub
commit 361d0bcc1c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -3,7 +3,6 @@ import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
@ -30,7 +29,7 @@ mapping = {
} }
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path: str, save_path: str, n_experts: int, mp: int) -> None:
""" """
Converts and saves model checkpoint files into a specified format. Converts and saves model checkpoint files into a specified format.
@ -43,6 +42,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns: Returns:
None None
""" """
try:
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)] state_dicts = [{} for _ in range(mp)]
@ -84,6 +84,9 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
new_file_path = os.path.join(save_path, os.path.basename(file_path)) new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path) shutil.copyfile(file_path, new_file_path)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()