mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-20 02:28:57 -04:00
Merge e965eec9c0
into b5d872ead0
This commit is contained in:
commit
c2ae9bae78
@ -3,13 +3,22 @@ import json
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
import logging
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import load_file, save_file
|
from safetensors.torch import load_file, save_file
|
||||||
|
|
||||||
from kernel import weight_dequant
|
from kernel import weight_dequant
|
||||||
|
|
||||||
def main(fp8_path, bf16_path):
|
def setup_logging():
|
||||||
|
logging.basicConfig(
|
||||||
|
filename="conversion.log",
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(levelname)s - %(message)s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def main(fp8_path, bf16_path, use_cpu):
|
||||||
"""
|
"""
|
||||||
Converts FP8 weights to BF16 and saves the converted weights.
|
Converts FP8 weights to BF16 and saves the converted weights.
|
||||||
|
|
||||||
@ -20,6 +29,7 @@ def main(fp8_path, bf16_path):
|
|||||||
Args:
|
Args:
|
||||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
||||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
||||||
|
use_cpu (bool): Whether to use CPU instead of GPU.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
KeyError: If a required scale_inv tensor is missing for a weight.
|
||||||
@ -29,6 +39,9 @@ def main(fp8_path, bf16_path):
|
|||||||
- The function caches loaded safetensor files to optimize memory usage.
|
- The function caches loaded safetensor files to optimize memory usage.
|
||||||
- The function updates the model index file to remove references to scale_inv tensors.
|
- The function updates the model index file to remove references to scale_inv tensors.
|
||||||
"""
|
"""
|
||||||
|
setup_logging()
|
||||||
|
|
||||||
|
device = "cpu" if use_cpu else "cuda"
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
os.makedirs(bf16_path, exist_ok=True)
|
os.makedirs(bf16_path, exist_ok=True)
|
||||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
@ -57,18 +70,19 @@ def main(fp8_path, bf16_path):
|
|||||||
file_name = weight_map[tensor_name]
|
file_name = weight_map[tensor_name]
|
||||||
if file_name not in loaded_files:
|
if file_name not in loaded_files:
|
||||||
file_path = os.path.join(fp8_path, file_name)
|
file_path = os.path.join(fp8_path, file_name)
|
||||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
loaded_files[file_name] = load_file(file_path, device=device)
|
||||||
return loaded_files[file_name][tensor_name]
|
return loaded_files[file_name][tensor_name]
|
||||||
|
|
||||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||||
safetensor_files.sort()
|
safetensor_files.sort()
|
||||||
for safetensor_file in tqdm(safetensor_files):
|
|
||||||
|
def process_file(safetensor_file):
|
||||||
file_name = os.path.basename(safetensor_file)
|
file_name = os.path.basename(safetensor_file)
|
||||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
current_state_dict = load_file(safetensor_file, device=device)
|
||||||
loaded_files[file_name] = current_state_dict
|
loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for weight_name, weight in current_state_dict.items():
|
for weight_name, weight in tqdm(current_state_dict.items(), desc=f"Processing {file_name}"):
|
||||||
if weight_name.endswith("_scale_inv"):
|
if weight_name.endswith("_scale_inv"):
|
||||||
continue
|
continue
|
||||||
elif weight.element_size() == 1: # FP8 weight
|
elif weight.element_size() == 1: # FP8 weight
|
||||||
@ -79,7 +93,7 @@ def main(fp8_path, bf16_path):
|
|||||||
fp8_weight_names.append(weight_name)
|
fp8_weight_names.append(weight_name)
|
||||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
logging.warning(f"Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||||
new_state_dict[weight_name] = weight
|
new_state_dict[weight_name] = weight
|
||||||
else:
|
else:
|
||||||
new_state_dict[weight_name] = weight
|
new_state_dict[weight_name] = weight
|
||||||
@ -93,6 +107,9 @@ def main(fp8_path, bf16_path):
|
|||||||
del loaded_files[oldest_file]
|
del loaded_files[oldest_file]
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
with ThreadPoolExecutor() as executor:
|
||||||
|
list(tqdm(executor.map(process_file, safetensor_files), total=len(safetensor_files), desc="Converting files"))
|
||||||
|
|
||||||
# Update model index
|
# Update model index
|
||||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||||
for weight_name in fp8_weight_names:
|
for weight_name in fp8_weight_names:
|
||||||
@ -102,11 +119,13 @@ def main(fp8_path, bf16_path):
|
|||||||
with open(new_model_index_file, "w") as f:
|
with open(new_model_index_file, "w") as f:
|
||||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||||
|
|
||||||
|
logging.info("Conversion completed successfully.")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
|
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
|
||||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||||
|
parser.add_argument("--use-cpu", action="store_true", help="Use CPU for processing instead of GPU")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
|
||||||
|
|
||||||
|
main(args.input_fp8_hf_path, args.output_bf16_hf_path, args.use_cpu)
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
@ -10,6 +11,9 @@ from safetensors.torch import load_model
|
|||||||
|
|
||||||
from model import Transformer, ModelArgs
|
from model import Transformer, ModelArgs
|
||||||
|
|
||||||
|
# Configure logging
|
||||||
|
logging.basicConfig(level=logging.INFO)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0):
|
def sample(logits, temperature: float = 1.0):
|
||||||
"""
|
"""
|
||||||
@ -49,32 +53,37 @@ def generate(
|
|||||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||||
"""
|
"""
|
||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= model.max_seq_len
|
assert max(prompt_lens) <= model.max_seq_len, "Prompt length exceeds model max sequence length"
|
||||||
|
|
||||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
for i, t in enumerate(prompt_tokens):
|
for i, t in enumerate(prompt_tokens):
|
||||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||||
|
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||||
prompt_mask = tokens != -1
|
prompt_mask = tokens != -1
|
||||||
|
|
||||||
for cur_pos in range(min(prompt_lens), total_len):
|
for cur_pos in range(min(prompt_lens), total_len):
|
||||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
if temperature > 0:
|
next_token = sample(logits, temperature) if temperature > 0 else logits.argmax(dim=-1)
|
||||||
next_token = sample(logits, temperature)
|
|
||||||
else:
|
|
||||||
next_token = logits.argmax(dim=-1)
|
|
||||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||||
tokens[:, cur_pos] = next_token
|
tokens[:, cur_pos] = next_token
|
||||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||||
|
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
if finished.all():
|
if finished.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
for i, toks in enumerate(tokens.tolist()):
|
for i, toks in enumerate(tokens.tolist()):
|
||||||
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
||||||
if eos_id in toks:
|
if eos_id in toks:
|
||||||
toks = toks[:toks.index(eos_id)]
|
toks = toks[:toks.index(eos_id)]
|
||||||
completion_tokens.append(toks)
|
completion_tokens.append(toks)
|
||||||
|
|
||||||
return completion_tokens
|
return completion_tokens
|
||||||
|
|
||||||
|
|
||||||
@ -100,37 +109,56 @@ def main(
|
|||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
rank = int(os.getenv("RANK", "0"))
|
rank = int(os.getenv("RANK", "0"))
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
dist.init_process_group("nccl")
|
dist.init_process_group("nccl")
|
||||||
global print
|
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
print = lambda *_, **__: None
|
logger.setLevel(logging.WARNING)
|
||||||
|
|
||||||
torch.cuda.set_device(local_rank)
|
torch.cuda.set_device(local_rank)
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_num_threads(8)
|
torch.set_num_threads(8)
|
||||||
torch.manual_seed(965)
|
torch.manual_seed(965)
|
||||||
|
|
||||||
|
# Load model args
|
||||||
|
try:
|
||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
args = ModelArgs(**json.load(f))
|
args = ModelArgs(**json.load(f))
|
||||||
print(args)
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"Config file not found: {e}")
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError as e:
|
||||||
|
logger.error(f"Error parsing config file: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"Model args: {args}")
|
||||||
|
|
||||||
|
# Load the model on GPU
|
||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
|
||||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
|
||||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
|
||||||
|
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
|
|
||||||
|
# Generate a test sequence to verify everything is working
|
||||||
|
test_prompt = "DeepSeek"
|
||||||
|
test_tokens = tokenizer.encode(test_prompt)
|
||||||
|
generated_tokens = generate(model, [test_tokens], 2, tokenizer.eos_token_id, 1.0)
|
||||||
|
logger.info(f"Generated test output: {tokenizer.decode(generated_tokens[0])}")
|
||||||
|
|
||||||
|
# Load model weights
|
||||||
|
try:
|
||||||
|
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error loading model: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Interactive mode or batch processing
|
||||||
if interactive:
|
if interactive:
|
||||||
messages = []
|
messages = []
|
||||||
while True:
|
while True:
|
||||||
if world_size == 1:
|
if world_size == 1 or rank == 0:
|
||||||
prompt = input(">>> ")
|
prompt = input(">>> ")
|
||||||
elif rank == 0:
|
|
||||||
prompt = input(">>> ")
|
|
||||||
objects = [prompt]
|
|
||||||
dist.broadcast_object_list(objects, 0)
|
|
||||||
else:
|
|
||||||
objects = [None]
|
|
||||||
dist.broadcast_object_list(objects, 0)
|
|
||||||
prompt = objects[0]
|
|
||||||
if prompt == "/exit":
|
if prompt == "/exit":
|
||||||
break
|
break
|
||||||
elif prompt == "/clear":
|
elif prompt == "/clear":
|
||||||
@ -140,15 +168,32 @@ def main(
|
|||||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||||
print(completion)
|
logger.info(f"Generated completion: {completion}")
|
||||||
messages.append({"role": "assistant", "content": completion})
|
messages.append({"role": "assistant", "content": completion})
|
||||||
|
elif rank != 0:
|
||||||
|
# Synchronize input across multiple nodes
|
||||||
|
objects = [None]
|
||||||
|
dist.broadcast_object_list(objects, 0)
|
||||||
|
prompt = objects[0]
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
# Batch processing mode
|
||||||
|
if not input_file:
|
||||||
|
logger.error("Input file is required for batch processing mode")
|
||||||
|
return
|
||||||
|
try:
|
||||||
with open(input_file) as f:
|
with open(input_file) as f:
|
||||||
prompts = [line.strip() for line in f.readlines()]
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
assert len(prompts) <= args.max_batch_size
|
except FileNotFoundError as e:
|
||||||
|
logger.error(f"Input file not found: {e}")
|
||||||
|
return
|
||||||
|
|
||||||
|
assert len(prompts) <= args.max_batch_size, "Exceeds batch size limit"
|
||||||
|
|
||||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||||
|
|
||||||
for prompt, completion in zip(prompts, completions):
|
for prompt, completion in zip(prompts, completions):
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
print("Completion:", completion)
|
print("Completion:", completion)
|
||||||
@ -174,12 +219,14 @@ if __name__ == "__main__":
|
|||||||
AssertionError: If neither input-file nor interactive mode is specified.
|
AssertionError: If neither input-file nor interactive mode is specified.
|
||||||
"""
|
"""
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
parser.add_argument("--ckpt-path", type=str, required=True, help="Path to the model checkpoint directory.")
|
||||||
parser.add_argument("--config", type=str, required=True)
|
parser.add_argument("--config", type=str, required=True, help="Path to the model configuration file.")
|
||||||
parser.add_argument("--input-file", type=str, default="")
|
parser.add_argument("--input-file", type=str, default="", help="File containing prompts for batch processing.")
|
||||||
parser.add_argument("--interactive", action="store_true")
|
parser.add_argument("--interactive", action="store_true", help="Enable interactive mode.")
|
||||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
parser.add_argument("--max-new-tokens", type=int, default=200, help="Maximum number of new tokens to generate.")
|
||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2, help="Temperature for sampling.")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.input_file or args.interactive
|
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||||
|
|
||||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||||
|
Loading…
Reference in New Issue
Block a user