This commit is contained in:
Anand 2025-01-29 22:53:15 +05:30 committed by GitHub
commit c2ae9bae78
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 121 additions and 55 deletions

View File

@ -3,13 +3,22 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import logging
from concurrent.futures import ThreadPoolExecutor
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def setup_logging():
logging.basicConfig(
filename="conversion.log",
level=logging.INFO,
format="%(asctime)s - %(levelname)s - %(message)s"
)
def main(fp8_path, bf16_path, use_cpu):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -20,6 +29,7 @@ def main(fp8_path, bf16_path):
Args: Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file. fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved. bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
use_cpu (bool): Whether to use CPU instead of GPU.
Raises: Raises:
KeyError: If a required scale_inv tensor is missing for a weight. KeyError: If a required scale_inv tensor is missing for a weight.
@ -29,6 +39,9 @@ def main(fp8_path, bf16_path):
- The function caches loaded safetensor files to optimize memory usage. - The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors. - The function updates the model index file to remove references to scale_inv tensors.
""" """
setup_logging()
device = "cpu" if use_cpu else "cuda"
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name] file_name = weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") loaded_files[file_name] = load_file(file_path, device=device)
return loaded_files[file_name][tensor_name] return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
def process_file(safetensor_file):
file_name = os.path.basename(safetensor_file) file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda") current_state_dict = load_file(safetensor_file, device=device)
loaded_files[file_name] = current_state_dict loaded_files[file_name] = current_state_dict
new_state_dict = {} new_state_dict = {}
for weight_name, weight in current_state_dict.items(): for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
if weight_name.endswith("_scale_inv"): if weight_name.endswith("_scale_inv"):
continue continue
elif weight.element_size() == 1: # FP8 weight elif weight.element_size() == 1: # FP8 weight
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
fp8_weight_names.append(weight_name) fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError: except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
else: else:
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
del loaded_files[oldest_file] del loaded_files[oldest_file]
torch.cuda.empty_cache() torch.cuda.empty_cache()
with ThreadPoolExecutor() as executor:
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names: for weight_name in fp8_weight_names:
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
with open(new_model_index_file, "w") as f: with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
logging.info("Conversion completed successfully.")
if __name__ == "__main__": if __name__ == "__main__":
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)

View File

@ -1,5 +1,6 @@
import os import os
import json import json
import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List
@ -10,6 +11,9 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def sample(logits, temperature: float = 1.0): def sample(logits, temperature: float = 1.0):
""" """
@ -49,32 +53,37 @@ def generate(
List[List[int]]: A list of lists containing the generated tokens for each sequence. List[List[int]]: A list of lists containing the generated tokens for each sequence.
""" """
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0: next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos prev_pos = cur_pos
if finished.all(): if finished.all():
break break
completion_tokens = [] completion_tokens = []
for i, toks in enumerate(tokens.tolist()): for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens] toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks: if eos_id in toks:
toks = toks[:toks.index(eos_id)] toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks) completion_tokens.append(toks)
return completion_tokens return completion_tokens
@ -100,37 +109,56 @@ def main(
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1: if world_size > 1:
dist.init_process_group("nccl") dist.init_process_group("nccl")
global print
if rank != 0: if rank != 0:
print = lambda *_, **__: None logger.setLevel(logging.WARNING)
torch.cuda.set_device(local_rank) torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8) torch.set_num_threads(8)
torch.manual_seed(965) torch.manual_seed(965)
# Load model args
try:
with open(config) as f: with open(config) as f:
args = ModelArgs(**json.load(f)) args = ModelArgs(**json.load(f))
print(args) except FileNotFoundError as e:
logger.error(f"Config file not found: {e}")
return
except json.JSONDecodeError as e:
logger.error(f"Error parsing config file: {e}")
return
logger.info(f"Model args: {args}")
# Load the model on GPU
with torch.device("cuda"): with torch.device("cuda"):
model = Transformer(args) model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Generate a test sequence to verify everything is working
test_prompt = "DeepSeek"
test_tokens = tokenizer.encode(test_prompt)
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
# Load model weights
try:
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
except Exception as e:
logger.error(f"Error loading model: {e}")
return
# Interactive mode or batch processing
if interactive: if interactive:
messages = [] messages = []
while True: while True:
if world_size == 1: if world_size == 1 or rank == 0:
prompt = input(">>> ") prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit": if prompt == "/exit":
break break
elif prompt == "/clear": elif prompt == "/clear":
@ -140,15 +168,32 @@ def main(
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion) logger.info(f"Generated completion: {completion}")
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
elif rank != 0:
# Synchronize input across multiple nodes
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
else: else:
# Batch processing mode
if not input_file:
logger.error("Input file is required for batch processing mode")
return
try:
with open(input_file) as f: with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()] prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size except FileNotFoundError as e:
logger.error(f"Input file not found: {e}")
return
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions): for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt) print("Prompt:", prompt)
print("Completion:", completion) print("Completion:", completion)
@ -174,12 +219,14 @@ if __name__ == "__main__":
AssertionError: If neither input-file nor interactive mode is specified. AssertionError: If neither input-file nor interactive mode is specified.
""" """
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
parser.add_argument("--config", type=str, required=True) parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
parser.add_argument("--input-file", type=str, default="") parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
parser.add_argument("--interactive", action="store_true") parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)