mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
small patch
applied further Parser class
This commit is contained in:
parent
267e7ba685
commit
c8146ec360
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import shutil
|
||||
from argparse import ArgumentParser
|
||||
from parser import Parser
|
||||
from glob import glob
|
||||
from tqdm import tqdm, trange
|
||||
import asyncio as sync
|
||||
@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--save-path", type=str, required=True)
|
||||
parser.add_argument("--n-experts", type=int, required=True)
|
||||
parser.add_argument("--model-parallel", type=int, required=True)
|
||||
args = parser.parse_args()
|
||||
arg_list = [
|
||||
("--hf-ckpt-path", type:=str, required:=True),
|
||||
("--save-path", type:=str, required:=True),
|
||||
("--n-experts", type:=int, required:=True),
|
||||
("--model-parallel", type:=int, required:=True)
|
||||
]
|
||||
args = Parser(arg_list).apply_args().return_args()
|
||||
assert args.n_experts % args.model_parallel == 0
|
||||
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))
|
||||
|
@ -6,7 +6,6 @@ from tqdm import tqdm
|
||||
from asyncio import gather, to_thread, run
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from kernel import weight_dequant
|
||||
|
||||
def inner_tensor_file(safetensor_file):
|
||||
@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path):
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
os.makedirs(bf16_path, exist_ok=True)
|
||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
with open(model_index_file, "r") as f:
|
||||
model_index = json.load(f)
|
||||
with open(model_index_file, "r") as f: model_index = json.load(f)
|
||||
weight_map = model_index["weight_map"]
|
||||
|
||||
# Cache for loaded safetensor files
|
||||
|
@ -168,7 +168,7 @@ if __name__ == "__main__":
|
||||
Raises:
|
||||
AssertionError: If neither input-file nor interactive mode is specified.
|
||||
"""
|
||||
arg_variables = [
|
||||
arg_list = [
|
||||
("--ckpt-path", type:=str, required:=True),
|
||||
("--config", type:=str, required:=True),
|
||||
("--input-file", type:=str, default:=""),
|
||||
@ -176,6 +176,6 @@ if __name__ == "__main__":
|
||||
("--max-new-tokens", type:=int, default:=200),
|
||||
("--temperature", type:=float, default:=0.2)
|
||||
]
|
||||
args = Parser(arg_list=arg_variables).apply_args().return_args()
|
||||
args = Parser(arg_list).apply_args().return_args()
|
||||
assert args.input_file or args.interactive
|
||||
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))
|
||||
|
Loading…
Reference in New Issue
Block a user