This commit is contained in:
Anand 2025-01-29 22:53:15 +05:30 committed by GitHub
commit c2ae9bae78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 55 deletions

View File

@ -3,32 +3,45 @@ import json
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import torch
from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
def setup_logging():
logging.basicConfig(
filename="conversion.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def main(fp8_path, bf16_path, use_cpu):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
use_cpu (bool): Whether to use CPU instead of GPU.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
setup_logging()
device = "cpu" if use_cpu else "cuda"
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@ -39,7 +52,7 @@ def main(fp8_path, bf16_path):
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
"""
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
loaded_files[file_name] = load_file(file_path, device=device)
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
def process_file(safetensor_file):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
current_state_dict = load_file(safetensor_file, device=device)
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
del loaded_files[oldest_file]
torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
logging.info("Conversion completed successfully.")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)

View File

@ -1,5 +1,6 @@
import os
import json
import logging
from argparse import ArgumentParser
from typing import List
@ -10,6 +11,9 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def sample(logits, temperature: float = 1.0):
"""
@ -49,32 +53,37 @@ def generate(
List[List[int]]: A list of lists containing the generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all():
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
@ -100,60 +109,96 @@ def main(
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
logger.setLevel(logging.WARNING)
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
# Load model args
try:
with open(config) as f:
args = ModelArgs(**json.load(f))
except FileNotFoundError as e:
logger.error(f"Config file not found: {e}")
return
except json.JSONDecodeError as e:
logger.error(f"Error parsing config file: {e}")
return
logger.info(f"Model args: {args}")
# Load the model on GPU
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
# Generate a test sequence to verify everything is working
test_prompt = "DeepSeek"
test_tokens = tokenizer.encode(test_prompt)
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
# Load model weights
try:
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
except Exception as e:
logger.error(f"Error loading model: {e}")
return
# Interactive mode or batch processing
if interactive:
messages = []
while True:
if world_size == 1:
if world_size == 1 or rank == 0:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
logger.info(f"Generated completion: {completion}")
messages.append({"role": "assistant", "content": completion})
elif rank != 0:
# Synchronize input across multiple nodes
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
# Batch processing mode
if not input_file:
logger.error("Input file is required for batch processing mode")
return
try:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
except FileNotFoundError as e:
logger.error(f"Input file not found: {e}")
return
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1:
dist.destroy_process_group()
@ -174,12 +219,14 @@ if __name__ == "__main__":
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
parser.add_argument("--interactive", action="store_true")
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
args = parser.parse_args()
assert args.input_file or args.interactive
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)