small patch

applied further Parser class
This commit is contained in:
CodingParadigm1 2025-02-03 10:18:58 -07:00
parent 267e7ba685
commit c8146ec360
3 changed files with 11 additions and 12 deletions

View File

@ -1,6 +1,6 @@
import os import os
import shutil import shutil
from argparse import ArgumentParser from parser import Parser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import asyncio as sync import asyncio as sync
@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() arg_list = [
parser.add_argument("--hf-ckpt-path", type=str, required=True) ("--hf-ckpt-path", type:=str, required:=True),
parser.add_argument("--save-path", type=str, required=True) ("--save-path", type:=str, required:=True),
parser.add_argument("--n-experts", type=int, required=True) ("--n-experts", type:=int, required:=True),
parser.add_argument("--model-parallel", type=int, required=True) ("--model-parallel", type:=int, required:=True)
args = parser.parse_args() ]
args = Parser(arg_list).apply_args().return_args()
assert args.n_experts % args.model_parallel == 0 assert args.n_experts % args.model_parallel == 0
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))

View File

@ -6,7 +6,6 @@ from tqdm import tqdm
from asyncio import gather, to_thread, run from asyncio import gather, to_thread, run
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def inner_tensor_file(safetensor_file): def inner_tensor_file(safetensor_file):
@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f: with open(model_index_file, "r") as f: model_index = json.load(f)
model_index = json.load(f)
weight_map = model_index["weight_map"] weight_map = model_index["weight_map"]
# Cache for loaded safetensor files # Cache for loaded safetensor files

View File

@ -168,7 +168,7 @@ if __name__ == "__main__":
Raises: Raises:
AssertionError: If neither input-file nor interactive mode is specified. AssertionError: If neither input-file nor interactive mode is specified.
""" """
arg_variables = [ arg_list = [
("--ckpt-path", type:=str, required:=True), ("--ckpt-path", type:=str, required:=True),
("--config", type:=str, required:=True), ("--config", type:=str, required:=True),
("--input-file", type:=str, default:=""), ("--input-file", type:=str, default:=""),
@ -176,6 +176,6 @@ if __name__ == "__main__":
("--max-new-tokens", type:=int, default:=200), ("--max-new-tokens", type:=int, default:=200),
("--temperature", type:=float, default:=0.2) ("--temperature", type:=float, default:=0.2)
] ]
args = Parser(arg_list=arg_variables).apply_args().return_args() args = Parser(arg_list).apply_args().return_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))