add functionality

- applied asyncio to more files
 - added Parser class
 - made small changes
This commit is contained in:
CodingParadigm1 2025-02-03 10:06:38 -07:00
parent 07de76f5ee
commit 267e7ba685
4 changed files with 76 additions and 72 deletions

View File

@ -70,10 +70,8 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
None None
""" """
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)]
state_dicts = [{} for _ in range(mp)] tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*")))
tensor_dir = glob(os.path.join(hf_ckpt_path, "*.safetensors"))
token_dir = glob(os.path.join(hf_ckpt_path, "*token*"))
for file_path in tqdm(tensor_dir): for file_path in tqdm(tensor_dir):
cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu") cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu")
async with cm as f: async with cm as f:

View File

@ -3,13 +3,42 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
from asyncio import gather, to_thread, run
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def inner_tensor_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
async def main(fp8_path, bf16_path):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -37,8 +66,7 @@ def main(fp8_path, bf16_path):
weight_map = model_index["weight_map"] weight_map = model_index["weight_map"]
# Cache for loaded safetensor files # Cache for loaded safetensor files
loaded_files = {} loaded_files, fp8_weight_names = {}, []
fp8_weight_names = []
# Helper function to get tensor from the correct file # Helper function to get tensor from the correct file
def get_tensor(tensor_name): def get_tensor(tensor_name):
@ -62,45 +90,15 @@ def main(fp8_path, bf16_path):
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files): gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files)))
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names: for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv" scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map: if scale_inv_name in weight_map: weight_map.pop(scale_inv_name)
weight_map.pop(scale_inv_name) with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
@ -108,5 +106,5 @@ if __name__ == "__main__":
parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) run(main(args.input_fp8_hf_path, args.output_bf16_hf_path))

View File

@ -1,13 +1,13 @@
import os import os
import json import json
from parser import Parser
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers import AutoTokenizer from transformers import AutoTokenizer
from safetensors.torch import load_model from safetensors.torch import load_model
from asyncio import gather, to_thread, run
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
@ -36,6 +36,7 @@ def generate(
temperature: float = 1.0 temperature: float = 1.0
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Generates new tokens based on the given prompt tokens using the specified model. Generates new tokens based on the given prompt tokens using the specified model.
Args: Args:
@ -47,38 +48,35 @@ def generate(
Returns: Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence. List[List[int]]: A list of lists containing the generated tokens for each sequence.
""" """
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): def inner_cur_pos():
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: if temperature > 0: next_token = sample(logits, temperature)
next_token = sample(logits, temperature) else: next_token = logits.argmax(dim=-1)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos prev_pos = cur_pos
if finished.all(): if finished.all(): return
break gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len)))
completion_tokens = [] completion_tokens = []
for i, toks in enumerate(tokens.tolist()): for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks: if eos_id in toks: toks = toks[:toks.index(eos_id)]
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks) completion_tokens.append(toks)
return completion_tokens return completion_tokens
def main( async def main(
ckpt_path: str, ckpt_path: str,
config: str, config: str,
input_file: str = "", input_file: str = "",
@ -131,8 +129,7 @@ def main(
objects = [None] objects = [None]
dist.broadcast_object_list(objects, 0) dist.broadcast_object_list(objects, 0)
prompt = objects[0] prompt = objects[0]
if prompt == "/exit": if prompt == "/exit": break
break
elif prompt == "/clear": elif prompt == "/clear":
messages.clear() messages.clear()
continue continue
@ -143,8 +140,7 @@ def main(
print(completion) print(completion)
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()]
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
@ -154,8 +150,7 @@ def main(
print("Completion:", completion) print("Completion:", completion)
print() print()
if world_size > 1: if world_size > 1: dist.destroy_process_group()
dist.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
@ -173,13 +168,14 @@ if __name__ == "__main__":
Raises: Raises:
AssertionError: If neither input-file nor interactive mode is specified. AssertionError: If neither input-file nor interactive mode is specified.
""" """
parser = ArgumentParser() arg_variables = [
parser.add_argument("--ckpt-path", type=str, required=True) ("--ckpt-path", type:=str, required:=True),
parser.add_argument("--config", type=str, required=True) ("--config", type:=str, required:=True),
parser.add_argument("--input-file", type=str, default="") ("--input-file", type:=str, default:=""),
parser.add_argument("--interactive", action="store_true") ("--interactive", action:="store_true"),
parser.add_argument("--max-new-tokens", type=int, default=200) ("--max-new-tokens", type:=int, default:=200),
parser.add_argument("--temperature", type=float, default=0.2) ("--temperature", type:=float, default:=0.2)
args = parser.parse_args() ]
args = Parser(arg_list=arg_variables).apply_args().return_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))

12
inference/parser.py Normal file
View File

@ -0,0 +1,12 @@
from argparse import ArgumentParser
class Parser():
def __init__(self, parser = ArgumentParser(), arg_list = []):
self.parser = parser
self.arg_list = arg_list
def apply_args(self):
for arg in self.arg_list: self.parser.add_argument(*arg)
return self
def return_args(self):
return self.parser.parse_args()