This commit is contained in:
Wow Rakibul 2025-03-17 13:50:33 +11:00 committed by GitHub
commit 53a4b021d5
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 41 additions and 38 deletions

Binary file not shown.

Before

Width:  |  Height:  |  Size: 179 KiB

After

Width:  |  Height:  |  Size: 100 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 106 KiB

After

Width:  |  Height:  |  Size: 78 KiB

View File

@ -3,7 +3,6 @@ import shutil
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
@ -30,7 +29,7 @@ mapping = {
} }
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path: str, save_path: str, n_experts: int, mp: int) -> None:
""" """
Converts and saves model checkpoint files into a specified format. Converts and saves model checkpoint files into a specified format.
@ -43,6 +42,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns: Returns:
None None
""" """
try:
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)] state_dicts = [{} for _ in range(mp)]
@ -84,6 +84,9 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
new_file_path = os.path.join(save_path, os.path.basename(file_path)) new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path) shutil.copyfile(file_path, new_file_path)
except Exception as e:
print(f"An error occurred: {e}")
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()