This commit is contained in:
CodingParadigm1 2025-02-03 18:49:12 +00:00 committed by GitHub
commit a7d0553e80
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 129 additions and 117 deletions

View File

@ -1,9 +1,9 @@
import os
import shutil
from argparse import ArgumentParser
from parser import Parser
from glob import glob
from tqdm import tqdm, trange
import asyncio as sync
import torch
from safetensors.torch import safe_open, save_file
@ -29,6 +29,32 @@ mapping = {
"scale": ("scale", None),
}
async def set_param(param, name, i, n_local_experts, mp, state_dicts, dim):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
return
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
async def inner_safe_open(name, f, state_dicts, mp, n_local_experts):
if "model.layers.61" not in name:
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
await sync.gather(*(set_param(param, name, i, n_local_experts, mp, state_dicts, dim) for i in range(mp)))
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
@ -44,53 +70,30 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
None
"""
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
continue
elif dim is not None:
assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
state_dicts[i][name] = new_param
n_local_experts,state_dicts = n_experts // mp, [{} for _ in range(mp)]
tensor_dir, token_dir = list(glob(os.path.join(hf_ckpt_path, "*.safetensors"))),list(glob(os.path.join(hf_ckpt_path, "*token*")))
for file_path in tqdm(tensor_dir):
cm = await sync.to_thread(safe_open, file_path, framework="pt", device="cpu")
async with cm as f:
await sync.gather(*(inner_safe_open(name, f, state_dicts, mp, n_local_experts) for name in f.keys()))
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
await sync.gather(*(sync.to_thread(save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))) for i in trange(mp)))
async def set_file_path(file_path):
await sync.to_thread(shutil.copyfile, file_path, os.path.join(save_path, os.path.basename(file_path)))
await sync.gather(*(set_file_path(file_path) for file_path in token_dir))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
arg_list = [
("--hf-ckpt-path", type:=str, required:=True),
("--save-path", type:=str, required:=True),
("--n-experts", type:=int, required:=True),
("--model-parallel", type:=int, required:=True)
]
args = Parser(arg_list).apply_args().assert_model_parallel().return_args()
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))

View File

@ -3,13 +3,41 @@ import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
from asyncio import gather, to_thread, run
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
def inner_tensor_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
async def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
@ -32,13 +60,11 @@ def main(fp8_path, bf16_path):
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
with open(model_index_file, "r") as f: model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
loaded_files, fp8_weight_names = {}, []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
@ -62,45 +88,15 @@ def main(fp8_path, bf16_path):
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
gather(*(to_thread(inner_tensor_file, safetensor_file) for safetensor_file in tqdm(safetensor_files)))
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if scale_inv_name in weight_map: weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
if __name__ == "__main__":
@ -108,5 +104,5 @@ if __name__ == "__main__":
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
run(main(args.input_fp8_hf_path, args.output_bf16_hf_path))

View File

@ -1,13 +1,13 @@
import os
import json
from parser import Parser
from argparse import ArgumentParser
from typing import List
import torch
import torch.distributed as dist
from transformers import AutoTokenizer
from safetensors.torch import load_model
from asyncio import gather, to_thread, run
from model import Transformer, ModelArgs
@ -36,6 +36,7 @@ def generate(
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args:
@ -47,38 +48,35 @@ def generate(
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
def inner_cur_pos():
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
if temperature > 0: next_token = sample(logits, temperature)
else: next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
if finished.all(): return
gather(*(to_thread(cur_pos) for cur_pos in range(min(prompt_lens), total_len)))
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
if eos_id in toks: toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
def main(
async def main(
ckpt_path: str,
config: str,
input_file: str = "",
@ -131,8 +129,7 @@ def main(
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
if prompt == "/exit": break
elif prompt == "/clear":
messages.clear()
continue
@ -143,8 +140,7 @@ def main(
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
with open(input_file) as f: prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
@ -154,8 +150,7 @@ def main(
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
if world_size > 1: dist.destroy_process_group()
if __name__ == "__main__":
@ -173,13 +168,13 @@ if __name__ == "__main__":
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
arg_list = [
("--ckpt-path", type:=str, required:=True),
("--config", type:=str, required:=True),
("--input-file", type:=str, default:=""),
("--interactive", action:="store_true"),
("--max-new-tokens", type:=int, default:=200),
("--temperature", type:=float, default:=0.2)
]
args = Parser(arg_list).apply_args().assert_interactive().return_args()
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))

18
inference/parser.py Normal file
View File

@ -0,0 +1,18 @@
from argparse import ArgumentParser
class Parser():
def __init__(self, parser = ArgumentParser(), arg_list = []):
self.parser = parser
self.arg_list = arg_list
def apply_args(self):
for arg in self.arg_list: self.parser.add_argument(*arg)
return self
def assert_model_parallel(self):
assert self.return_args.n_experts % self.return_args().model_parallel == 0
return self
def assert_interactive():
assert self.return_args().input_file or self.return_args().interactive
return self
def return_args(self):
return self.parser.parse_args()