mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 06:08:58 -05:00
small patch
applied further Parser class
This commit is contained in:
parent
267e7ba685
commit
c8146ec360
@ -1,6 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
from argparse import ArgumentParser
|
from parser import Parser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from tqdm import tqdm, trange
|
from tqdm import tqdm, trange
|
||||||
import asyncio as sync
|
import asyncio as sync
|
||||||
@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
arg_list = [
|
||||||
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
("--hf-ckpt-path", type:=str, required:=True),
|
||||||
parser.add_argument("--save-path", type=str, required=True)
|
("--save-path", type:=str, required:=True),
|
||||||
parser.add_argument("--n-experts", type=int, required=True)
|
("--n-experts", type:=int, required:=True),
|
||||||
parser.add_argument("--model-parallel", type=int, required=True)
|
("--model-parallel", type:=int, required:=True)
|
||||||
args = parser.parse_args()
|
]
|
||||||
|
args = Parser(arg_list).apply_args().return_args()
|
||||||
assert args.n_experts % args.model_parallel == 0
|
assert args.n_experts % args.model_parallel == 0
|
||||||
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))
|
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))
|
||||||
|
@ -6,7 +6,6 @@ from tqdm import tqdm
|
|||||||
from asyncio import gather, to_thread, run
|
from asyncio import gather, to_thread, run
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from kernel import weight_dequant
|
from kernel import weight_dequant
|
||||||
|
|
||||||
def inner_tensor_file(safetensor_file):
|
def inner_tensor_file(safetensor_file):
|
||||||
@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path):
|
|||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
os.makedirs(bf16_path, exist_ok=True)
|
os.makedirs(bf16_path, exist_ok=True)
|
||||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
with open(model_index_file, "r") as f:
|
with open(model_index_file, "r") as f: model_index = json.load(f)
|
||||||
model_index = json.load(f)
|
|
||||||
weight_map = model_index["weight_map"]
|
weight_map = model_index["weight_map"]
|
||||||
|
|
||||||
# Cache for loaded safetensor files
|
# Cache for loaded safetensor files
|
||||||
|
@ -168,7 +168,7 @@ if __name__ == "__main__":
|
|||||||
Raises:
|
Raises:
|
||||||
AssertionError: If neither input-file nor interactive mode is specified.
|
AssertionError: If neither input-file nor interactive mode is specified.
|
||||||
"""
|
"""
|
||||||
arg_variables = [
|
arg_list = [
|
||||||
("--ckpt-path", type:=str, required:=True),
|
("--ckpt-path", type:=str, required:=True),
|
||||||
("--config", type:=str, required:=True),
|
("--config", type:=str, required:=True),
|
||||||
("--input-file", type:=str, default:=""),
|
("--input-file", type:=str, default:=""),
|
||||||
@ -176,6 +176,6 @@ if __name__ == "__main__":
|
|||||||
("--max-new-tokens", type:=int, default:=200),
|
("--max-new-tokens", type:=int, default:=200),
|
||||||
("--temperature", type:=float, default:=0.2)
|
("--temperature", type:=float, default:=0.2)
|
||||||
]
|
]
|
||||||
args = Parser(arg_list=arg_variables).apply_args().return_args()
|
args = Parser(arg_list).apply_args().return_args()
|
||||||
assert args.input_file or args.interactive
|
assert args.input_file or args.interactive
|
||||||
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))
|
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))
|
||||||
|
Loading…
Reference in New Issue
Block a user