small patch

applied further Parser class
This commit is contained in:
CodingParadigm1 2025-02-03 10:18:58 -07:00
parent 267e7ba685
commit c8146ec360
3 changed files with 11 additions and 12 deletions

View File

@ -1,6 +1,6 @@
import os
import shutil
from argparse import ArgumentParser
from parser import Parser
from glob import glob
from tqdm import tqdm, trange
import asyncio as sync
@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
arg_list = [
("--hf-ckpt-path", type:=str, required:=True),
("--save-path", type:=str, required:=True),
("--n-experts", type:=int, required:=True),
("--model-parallel", type:=int, required:=True)
]
args = Parser(arg_list).apply_args().return_args()
assert args.n_experts % args.model_parallel == 0
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))

View File

@ -6,7 +6,6 @@ from tqdm import tqdm
from asyncio import gather, to_thread, run
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def inner_tensor_file(safetensor_file):
@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
with open(model_index_file, "r") as f: model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files

View File

@ -168,7 +168,7 @@ if __name__ == "__main__":
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
arg_variables = [
arg_list = [
("--ckpt-path", type:=str, required:=True),
("--config", type:=str, required:=True),
("--input-file", type:=str, default:=""),
@ -176,6 +176,6 @@ if __name__ == "__main__":
("--max-new-tokens", type:=int, default:=200),
("--temperature", type:=float, default:=0.2)
]
args = Parser(arg_list=arg_variables).apply_args().return_args()
args = Parser(arg_list).apply_args().return_args()
assert args.input_file or args.interactive
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))