From c8146ec360514b893b1608643aa30351b5cbcb96 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:18:58 -0700 Subject: [PATCH] small patch applied further Parser class --- inference/convert.py | 15 ++++++++------- inference/fp8_cast_bf16.py | 4 +--- inference/generate.py | 4 ++-- 3 files changed, 11 insertions(+), 12 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index 1073ffb..e83ba9a 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -1,6 +1,6 @@ import os import shutil -from argparse import ArgumentParser +from parser import Parser from glob import glob from tqdm import tqdm, trange import asyncio as sync @@ -89,11 +89,12 @@ def main(hf_ckpt_path, save_path, n_experts, mp): if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--hf-ckpt-path", type=str, required=True) - parser.add_argument("--save-path", type=str, required=True) - parser.add_argument("--n-experts", type=int, required=True) - parser.add_argument("--model-parallel", type=int, required=True) - args = parser.parse_args() + arg_list = [ + ("--hf-ckpt-path", type:=str, required:=True), + ("--save-path", type:=str, required:=True), + ("--n-experts", type:=int, required:=True), + ("--model-parallel", type:=int, required:=True) + ] + args = Parser(arg_list).apply_args().return_args() assert args.n_experts % args.model_parallel == 0 sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 3047edc..2dec503 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -6,7 +6,6 @@ from tqdm import tqdm from asyncio import gather, to_thread, run import torch from safetensors.torch import load_file, save_file - from kernel import weight_dequant def inner_tensor_file(safetensor_file): @@ -61,8 +60,7 @@ async def main(fp8_path, bf16_path): torch.set_default_dtype(torch.bfloat16) os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") - with open(model_index_file, "r") as f: - model_index = json.load(f) + with open(model_index_file, "r") as f: model_index = json.load(f) weight_map = model_index["weight_map"] # Cache for loaded safetensor files diff --git a/inference/generate.py b/inference/generate.py index 01f2cb3..c67599c 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -168,7 +168,7 @@ if __name__ == "__main__": Raises: AssertionError: If neither input-file nor interactive mode is specified. """ - arg_variables = [ + arg_list = [ ("--ckpt-path", type:=str, required:=True), ("--config", type:=str, required:=True), ("--input-file", type:=str, default:=""), @@ -176,6 +176,6 @@ if __name__ == "__main__": ("--max-new-tokens", type:=int, default:=200), ("--temperature", type:=float, default:=0.2) ] - args = Parser(arg_list=arg_variables).apply_args().return_args() + args = Parser(arg_list).apply_args().return_args() assert args.input_file or args.interactive run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))