This commit is contained in:
CodingParadigm1 2025-02-03 18:49:12 +00:00 committed by GitHub
commit a7d0553e80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 129 additions and 117 deletions

View File

@ -1,9 +1,9 @@
import os import os
import shutil import shutil
from argparse import ArgumentParser from parser import Parser
from glob import glob from glob import glob
from tqdm import tqdm, trange from tqdm import tqdm, trange
import asyncio as sync
import torch import torch
from safetensors.torch import safe_open, save_file from safetensors.torch import safe_open, save_file
@ -29,6 +29,32 @@ mapping = {
"scale": ("scale", None), "scale": ("scale", None),
} }
async def set_param(param, name, i, n_local_experts, mp, state_dicts, dim):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
return
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
async def inner_safe_open(name, f, state_dicts, mp, n_local_experts):
if "model.layers.61" not in name:
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
await sync.gather(*(set_param(param, name, i, n_local_experts, mp, state_dicts, dim) for i in range(mp)))
def main(hf_ckpt_path, save_path, n_experts, mp): def main(hf_ckpt_path, save_path, n_experts, mp):
""" """
@ -44,53 +70,30 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
None None
""" """
torch.set_num_threads(8) torch.set_num_threads(8)
n_local_experts = n_experts // mp n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)]
state_dicts = [{} for _ in range(mp)] tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*")))
for file_path in tqdm(tensor_dir):
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))): cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu")
with safe_open(file_path, framework="pt", device="cpu") as f: async with cm as f:
for name in f.keys(): await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys()))
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i in trange(mp): await sync.gather(*(sync.to_thread(save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))) for i in trange(mp)))
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
async def set_file_path(file_path):
await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path)))
await sync.gather(*(set_file_path(file_path) for file_path in token_dir))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() arg_list = [
parser.add_argument("--hf-ckpt-path", type=str, required=True) ("--hf-ckpt-path", type:=str, required:=True),
parser.add_argument("--save-path", type=str, required=True) ("--save-path", type:=str, required:=True),
parser.add_argument("--n-experts", type=int, required=True) ("--n-experts", type:=int, required:=True),
parser.add_argument("--model-parallel", type=int, required=True) ("--model-parallel", type:=int, required:=True)
args = parser.parse_args() ]
assert args.n_experts % args.model_parallel == 0 args = Parser(arg_list).apply_args().assert_model_parallel().return_args()
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))

View File

@ -3,13 +3,41 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
from asyncio import gather, to_thread, run
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def inner_tensor_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
async def main(fp8_path, bf16_path):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -32,13 +60,11 @@ def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f: with open(model_index_file, "r") as f: model_index = json.load(f)
model_index = json.load(f)
weight_map = model_index["weight_map"] weight_map = model_index["weight_map"]
# Cache for loaded safetensor files # Cache for loaded safetensor files
loaded_files = {} loaded_files, fp8_weight_names = {}, []
fp8_weight_names = []
# Helper function to get tensor from the correct file # Helper function to get tensor from the correct file
def get_tensor(tensor_name): def get_tensor(tensor_name):
@ -62,45 +88,15 @@ def main(fp8_path, bf16_path):
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files): gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files)))
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names: for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv" scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map: if scale_inv_name in weight_map: weight_map.pop(scale_inv_name)
weight_map.pop(scale_inv_name) with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__": if __name__ == "__main__":
@ -108,5 +104,5 @@ if __name__ == "__main__":
parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) run(main(args.input_fp8_hf_path, args.output_bf16_hf_path))

View File

@ -1,13 +1,13 @@
import os import os
import json import json
from parser import Parser
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from transformers import AutoTokenizer from transformers import AutoTokenizer
from safetensors.torch import load_model from safetensors.torch import load_model
from asyncio import gather, to_thread, run
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
@ -36,6 +36,7 @@ def generate(
temperature: float = 1.0 temperature: float = 1.0
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Generates new tokens based on the given prompt tokens using the specified model. Generates new tokens based on the given prompt tokens using the specified model.
Args: Args:
@ -47,38 +48,35 @@ def generate(
Returns: Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence. List[List[int]]: A list of lists containing the generated tokens for each sequence.
""" """
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): def inner_cur_pos():
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: if temperature > 0: next_token = sample(logits, temperature)
next_token = sample(logits, temperature) else: next_token = logits.argmax(dim=-1)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos prev_pos = cur_pos
if finished.all(): if finished.all(): return
break gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len)))
completion_tokens = [] completion_tokens = []
for i, toks in enumerate(tokens.tolist()): for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks: if eos_id in toks: toks = toks[:toks.index(eos_id)]
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks) completion_tokens.append(toks)
return completion_tokens return completion_tokens
def main( async def main(
ckpt_path: str, ckpt_path: str,
config: str, config: str,
input_file: str = "", input_file: str = "",
@ -131,8 +129,7 @@ def main(
objects = [None] objects = [None]
dist.broadcast_object_list(objects, 0) dist.broadcast_object_list(objects, 0)
prompt = objects[0] prompt = objects[0]
if prompt == "/exit": if prompt == "/exit": break
break
elif prompt == "/clear": elif prompt == "/clear":
messages.clear() messages.clear()
continue continue
@ -143,8 +140,7 @@ def main(
print(completion) print(completion)
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()]
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
@ -154,8 +150,7 @@ def main(
print("Completion:", completion) print("Completion:", completion)
print() print()
if world_size > 1: if world_size > 1: dist.destroy_process_group()
dist.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
@ -173,13 +168,13 @@ if __name__ == "__main__":
Raises: Raises:
AssertionError: If neither input-file nor interactive mode is specified. AssertionError: If neither input-file nor interactive mode is specified.
""" """
parser = ArgumentParser() arg_list = [
parser.add_argument("--ckpt-path", type=str, required=True) ("--ckpt-path", type:=str, required:=True),
parser.add_argument("--config", type=str, required=True) ("--config", type:=str, required:=True),
parser.add_argument("--input-file", type=str, default="") ("--input-file", type:=str, default:=""),
parser.add_argument("--interactive", action="store_true") ("--interactive", action:="store_true"),
parser.add_argument("--max-new-tokens", type=int, default=200) ("--max-new-tokens", type:=int, default:=200),
parser.add_argument("--temperature", type=float, default=0.2) ("--temperature", type:=float, default:=0.2)
args = parser.parse_args() ]
assert args.input_file or args.interactive args = Parser(arg_list).apply_args().assert_interactive().return_args()
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))

18
inference/parser.py Normal file
View File

@ -0,0 +1,18 @@
from argparse import ArgumentParser
class Parser():
def __init__(self, parser = ArgumentParser(), arg_list = []):
self.parser = parser
self.arg_list = arg_list
def apply_args(self):
for arg in self.arg_list: self.parser.add_argument(*arg)
return self
def assert_model_parallel(self):
assert self.return_args.n_experts % self.return_args().model_parallel == 0
return self
def assert_interactive():
assert self.return_args().input_file or self.return_args().interactive
return self
def return_args(self):
return self.parser.parse_args()