mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Merge e965eec9c0
into b5d872ead0
This commit is contained in:
commit
c2ae9bae78
@ -3,32 +3,45 @@ import json
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from kernel import weight_dequant
|
||||
|
||||
def main(fp8_path, bf16_path):
|
||||
def setup_logging():
|
||||
logging.basicConfig(
|
||||
filename="conversion.log",
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||
)
|
||||
|
||||
def main(fp8_path, bf16_path, use_cpu):
|
||||
"""
|
||||
Converts FP8 weights to BF16 and saves the converted weights.
|
||||
|
||||
|
||||
This function reads FP8 weights from the specified directory, converts them to BF16,
|
||||
and saves the converted weights to another specified directory. It also updates the
|
||||
model index file to reflect the changes.
|
||||
|
||||
|
||||
Args:
|
||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
||||
|
||||
use_cpu (bool): Whether to use CPU instead of GPU.
|
||||
|
||||
Raises:
|
||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||
|
||||
|
||||
Notes:
|
||||
- The function assumes that the FP8 weights are stored in safetensor files.
|
||||
- The function caches loaded safetensor files to optimize memory usage.
|
||||
- The function updates the model index file to remove references to scale_inv tensors.
|
||||
"""
|
||||
setup_logging()
|
||||
|
||||
device = "cpu" if use_cpu else "cuda"
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
os.makedirs(bf16_path, exist_ok=True)
|
||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
@ -39,7 +52,7 @@ def main(fp8_path, bf16_path):
|
||||
# Cache for loaded safetensor files
|
||||
loaded_files = {}
|
||||
fp8_weight_names = []
|
||||
|
||||
|
||||
# Helper function to get tensor from the correct file
|
||||
def get_tensor(tensor_name):
|
||||
"""
|
||||
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
|
||||
file_name = weight_map[tensor_name]
|
||||
if file_name not in loaded_files:
|
||||
file_path = os.path.join(fp8_path, file_name)
|
||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||
loaded_files[file_name] = load_file(file_path, device=device)
|
||||
return loaded_files[file_name][tensor_name]
|
||||
|
||||
|
||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||
safetensor_files.sort()
|
||||
for safetensor_file in tqdm(safetensor_files):
|
||||
|
||||
def process_file(safetensor_file):
|
||||
file_name = os.path.basename(safetensor_file)
|
||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||
current_state_dict = load_file(safetensor_file, device=device)
|
||||
loaded_files[file_name] = current_state_dict
|
||||
|
||||
new_state_dict = {}
|
||||
for weight_name, weight in current_state_dict.items():
|
||||
for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
|
||||
if weight_name.endswith("_scale_inv"):
|
||||
continue
|
||||
elif weight.element_size() == 1: # FP8 weight
|
||||
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
|
||||
fp8_weight_names.append(weight_name)
|
||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
||||
except KeyError:
|
||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||
logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||
new_state_dict[weight_name] = weight
|
||||
else:
|
||||
new_state_dict[weight_name] = weight
|
||||
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
|
||||
del loaded_files[oldest_file]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
with ThreadPoolExecutor() as executor:
|
||||
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
|
||||
|
||||
# Update model index
|
||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||
for weight_name in fp8_weight_names:
|
||||
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
|
||||
with open(new_model_index_file, "w") as f:
|
||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||
|
||||
logging.info("Conversion completed successfully.")
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
|
||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
|
||||
args = parser.parse_args()
|
||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
||||
|
||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)
|
||||
|
@ -1,5 +1,6 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from argparse import ArgumentParser
|
||||
from typing import List
|
||||
|
||||
@ -10,6 +11,9 @@ from safetensors.torch import load_model
|
||||
|
||||
from model import Transformer, ModelArgs
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def sample(logits, temperature: float = 1.0):
|
||||
"""
|
||||
@ -49,32 +53,37 @@ def generate(
|
||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||
"""
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len
|
||||
assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
|
||||
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||
|
||||
for i, t in enumerate(prompt_tokens):
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||
prompt_mask = tokens != -1
|
||||
|
||||
for cur_pos in range(min(prompt_lens), total_len):
|
||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
if temperature > 0:
|
||||
next_token = sample(logits, temperature)
|
||||
else:
|
||||
next_token = logits.argmax(dim=-1)
|
||||
next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
|
||||
|
||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if finished.all():
|
||||
break
|
||||
|
||||
completion_tokens = []
|
||||
for i, toks in enumerate(tokens.tolist()):
|
||||
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
||||
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
||||
if eos_id in toks:
|
||||
toks = toks[:toks.index(eos_id)]
|
||||
completion_tokens.append(toks)
|
||||
|
||||
return completion_tokens
|
||||
|
||||
|
||||
@ -100,60 +109,96 @@ def main(
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
if world_size > 1:
|
||||
dist.init_process_group("nccl")
|
||||
global print
|
||||
|
||||
if rank != 0:
|
||||
print = lambda *_, **__: None
|
||||
logger.setLevel(logging.WARNING)
|
||||
|
||||
torch.cuda.set_device(local_rank)
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_num_threads(8)
|
||||
torch.manual_seed(965)
|
||||
with open(config) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
print(args)
|
||||
|
||||
# Load model args
|
||||
try:
|
||||
with open(config) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"Config file not found: {e}")
|
||||
return
|
||||
except json.JSONDecodeError as e:
|
||||
logger.error(f"Error parsing config file: {e}")
|
||||
return
|
||||
|
||||
logger.info(f"Model args: {args}")
|
||||
|
||||
# Load the model on GPU
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||
|
||||
|
||||
# Generate a test sequence to verify everything is working
|
||||
test_prompt = "DeepSeek"
|
||||
test_tokens = tokenizer.encode(test_prompt)
|
||||
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
|
||||
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
|
||||
|
||||
# Load model weights
|
||||
try:
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||
except Exception as e:
|
||||
logger.error(f"Error loading model: {e}")
|
||||
return
|
||||
|
||||
# Interactive mode or batch processing
|
||||
if interactive:
|
||||
messages = []
|
||||
while True:
|
||||
if world_size == 1:
|
||||
if world_size == 1 or rank == 0:
|
||||
prompt = input(">>> ")
|
||||
elif rank == 0:
|
||||
prompt = input(">>> ")
|
||||
objects = [prompt]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
else:
|
||||
if prompt == "/exit":
|
||||
break
|
||||
elif prompt == "/clear":
|
||||
messages.clear()
|
||||
continue
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||
logger.info(f"Generated completion: {completion}")
|
||||
messages.append({"role": "assistant", "content": completion})
|
||||
elif rank != 0:
|
||||
# Synchronize input across multiple nodes
|
||||
objects = [None]
|
||||
dist.broadcast_object_list(objects, 0)
|
||||
prompt = objects[0]
|
||||
if prompt == "/exit":
|
||||
break
|
||||
elif prompt == "/clear":
|
||||
messages.clear()
|
||||
continue
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||
print(completion)
|
||||
messages.append({"role": "assistant", "content": completion})
|
||||
|
||||
else:
|
||||
with open(input_file) as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
assert len(prompts) <= args.max_batch_size
|
||||
# Batch processing mode
|
||||
if not input_file:
|
||||
logger.error("Input file is required for batch processing mode")
|
||||
return
|
||||
try:
|
||||
with open(input_file) as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
except FileNotFoundError as e:
|
||||
logger.error(f"Input file not found: {e}")
|
||||
return
|
||||
|
||||
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
|
||||
|
||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||
|
||||
for prompt, completion in zip(prompts, completions):
|
||||
print("Prompt:", prompt)
|
||||
print("Completion:", completion)
|
||||
print()
|
||||
|
||||
|
||||
if world_size > 1:
|
||||
dist.destroy_process_group()
|
||||
|
||||
@ -174,12 +219,14 @@ if __name__ == "__main__":
|
||||
AssertionError: If neither input-file nor interactive mode is specified.
|
||||
"""
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||
parser.add_argument("--config", type=str, required=True)
|
||||
parser.add_argument("--input-file", type=str, default="")
|
||||
parser.add_argument("--interactive", action="store_true")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||
parser.add_argument("--temperature", type=float, default=0.2)
|
||||
parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
|
||||
parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
|
||||
parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
|
||||
parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
|
||||
parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
|
||||
parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
|
||||
|
||||
args = parser.parse_args()
|
||||
assert args.input_file or args.interactive
|
||||
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||
|
||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||
|
Loading…
Reference in New Issue
Block a user