This commit is contained in:
Anand 2025-01-29 22:53:15 +05:30 committed by GitHub
commit c2ae9bae78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 55 deletions

View File

@ -3,13 +3,22 @@ import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
def setup_logging():
logging.basicConfig(
filename="conversion.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def main(fp8_path, bf16_path, use_cpu):
"""
Converts FP8 weights to BF16 and saves the converted weights.
@ -20,6 +29,7 @@ def main(fp8_path, bf16_path):
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
use_cpu (bool): Whether to use CPU instead of GPU.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
@ -29,6 +39,9 @@ def main(fp8_path, bf16_path):
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
setup_logging()
device = "cpu" if use_cpu else "cuda"
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
loaded_files[file_name] = load_file(file_path, device=device)
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
def process_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
current_state_dict = load_file(safetensor_file, device=device)
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
del loaded_files[oldest_file]
torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
logging.info("Conversion completed successfully.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)

View File

@ -1,5 +1,6 @@
import os
import json
import logging
from argparse import ArgumentParser
from typing import List
@ -10,6 +11,9 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def sample(logits, temperature: float = 1.0):
"""
@ -49,32 +53,37 @@ def generate(
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
@ -100,37 +109,56 @@ def main(
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
logger.setLevel(logging.WARNING)
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
# Load model args
try:
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
except FileNotFoundError as e:
logger.error(f"Config file not found: {e}")
return
except json.JSONDecodeError as e:
logger.error(f"Error parsing config file: {e}")
return
logger.info(f"Model args: {args}")
# Load the model on GPU
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Generate a test sequence to verify everything is working
test_prompt = "DeepSeek"
test_tokens = tokenizer.encode(test_prompt)
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
# Load model weights
try:
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
except Exception as e:
logger.error(f"Error loading model: {e}")
return
# Interactive mode or batch processing
if interactive:
messages = []
while True:
if world_size == 1:
if world_size == 1 or rank == 0:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
@ -140,15 +168,32 @@ def main(
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
logger.info(f"Generated completion: {completion}")
messages.append({"role": "assistant", "content": completion})
elif rank != 0:
# Synchronize input across multiple nodes
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
else:
# Batch processing mode
if not input_file:
logger.error("Input file is required for batch processing mode")
return
try:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
except FileNotFoundError as e:
logger.error(f"Input file not found: {e}")
return
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
@ -174,12 +219,14 @@ if __name__ == "__main__":
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
args = parser.parse_args()
assert args.input_file or args.interactive
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)