[200~feat: Added logging, parallel processing, and CPU processing option for FP8 to BF16 conversion

This commit is contained in:
Anand 2025-01-29 22:32:59 +05:30
parent b5d872ead0
commit e965eec9c0
2 changed files with 121 additions and 55 deletions

View File

@ -3,32 +3,45 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def setup_logging():
logging.basicConfig(
filename="conversion.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def main(fp8_path, bf16_path, use_cpu):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16, This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes. model index file to reflect the changes.
Args: Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file. fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved. bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
use_cpu (bool): Whether to use CPU instead of GPU.
Raises: Raises:
KeyError: If a required scale_inv tensor is missing for a weight. KeyError: If a required scale_inv tensor is missing for a weight.
Notes: Notes:
- The function assumes that the FP8 weights are stored in safetensor files. - The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage. - The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors. - The function updates the model index file to remove references to scale_inv tensors.
""" """
setup_logging()
device = "cpu" if use_cpu else "cuda"
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@ -39,7 +52,7 @@ def main(fp8_path, bf16_path):
# Cache for loaded safetensor files # Cache for loaded safetensor files
loaded_files = {} loaded_files = {}
fp8_weight_names = [] fp8_weight_names = []
# Helper function to get tensor from the correct file # Helper function to get tensor from the correct file
def get_tensor(tensor_name): def get_tensor(tensor_name):
""" """
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name] file_name = weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") loaded_files[file_name] = load_file(file_path, device=device)
return loaded_files[file_name][tensor_name] return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
def process_file(safetensor_file):
file_name = os.path.basename(safetensor_file) file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda") current_state_dict = load_file(safetensor_file, device=device)
loaded_files[file_name] = current_state_dict loaded_files[file_name] = current_state_dict
new_state_dict = {} new_state_dict = {}
for weight_name, weight in current_state_dict.items(): for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
if weight_name.endswith("_scale_inv"): if weight_name.endswith("_scale_inv"):
continue continue
elif weight.element_size() == 1: # FP8 weight elif weight.element_size() == 1: # FP8 weight
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names.append(weight_name) fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError: except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
else: else:
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
del loaded_files[oldest_file] del loaded_files[oldest_file]
torch.cuda.empty_cache() torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names: for weight_name in fp8_weight_names:
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
with open(new_model_index_file, "w") as f: with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
logging.info("Conversion completed successfully.")
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)

View File

@ -1,5 +1,6 @@
import os import os
import json import json
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List
@ -10,6 +11,9 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def sample(logits, temperature: float = 1.0): def sample(logits, temperature: float = 1.0):
""" """
@ -49,32 +53,37 @@ def generate(
List[List[int]]: A list of lists containing the generated tokens for each sequence. List[List[int]]: A list of lists containing the generated tokens for each sequence.
""" """
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos prev_pos = cur_pos
if finished.all(): if finished.all():
break break
completion_tokens = [] completion_tokens = []
for i, toks in enumerate(tokens.tolist()): for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks: if eos_id in toks:
toks = toks[:toks.index(eos_id)] toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks) completion_tokens.append(toks)
return completion_tokens return completion_tokens
@ -100,60 +109,96 @@ def main(
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1: if world_size > 1:
dist.init_process_group("nccl") dist.init_process_group("nccl")
global print
if rank != 0: if rank != 0:
print = lambda *_, **__: None logger.setLevel(logging.WARNING)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8) torch.set_num_threads(8)
torch.manual_seed(965) torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f)) # Load model args
print(args) try:
with open(config) as f:
args = ModelArgs(**json.load(f))
except FileNotFoundError as e:
logger.error(f"Config file not found: {e}")
return
except json.JSONDecodeError as e:
logger.error(f"Error parsing config file: {e}")
return
logger.info(f"Model args: {args}")
# Load the model on GPU
with torch.device("cuda"): with torch.device("cuda"):
model = Transformer(args) model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) # Generate a test sequence to verify everything is working
test_prompt = "DeepSeek"
test_tokens = tokenizer.encode(test_prompt)
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
# Load model weights
try:
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
except Exception as e:
logger.error(f"Error loading model: {e}")
return
# Interactive mode or batch processing
if interactive: if interactive:
messages = [] messages = []
while True: while True:
if world_size == 1: if world_size == 1 or rank == 0:
prompt = input(">>> ") prompt = input(">>> ")
elif rank == 0: if prompt == "/exit":
prompt = input(">>> ") break
objects = [prompt] elif prompt == "/clear":
dist.broadcast_object_list(objects, 0) messages.clear()
else: continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
logger.info(f"Generated completion: {completion}")
messages.append({"role": "assistant", "content": completion})
elif rank != 0:
# Synchronize input across multiple nodes
objects = [None] objects = [None]
dist.broadcast_object_list(objects, 0) dist.broadcast_object_list(objects, 0)
prompt = objects[0] prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: # Batch processing mode
prompts = [line.strip() for line in f.readlines()] if not input_file:
assert len(prompts) <= args.max_batch_size logger.error("Input file is required for batch processing mode")
return
try:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
except FileNotFoundError as e:
logger.error(f"Input file not found: {e}")
return
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions): for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt) print("Prompt:", prompt)
print("Completion:", completion) print("Completion:", completion)
print() print()
if world_size > 1: if world_size > 1:
dist.destroy_process_group() dist.destroy_process_group()
@ -174,12 +219,14 @@ if __name__ == "__main__":
AssertionError: If neither input-file nor interactive mode is specified. AssertionError: If neither input-file nor interactive mode is specified.
""" """
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
parser.add_argument("--config", type=str, required=True) parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
parser.add_argument("--input-file", type=str, default="") parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)