This commit is contained in:
Nripesh Niketan 2025-04-09 19:02:29 +08:00 committed by GitHub
commit 487f2395aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 41 additions and 13 deletions

View File

@ -3,6 +3,7 @@ import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from tqdm import tqdm from tqdm import tqdm
import gc
import torch import torch
from safetensors.torch import load_file, save_file from safetensors.torch import load_file, save_file
@ -30,6 +31,12 @@ def main(fp8_path, bf16_path):
- The function updates the model index file to remove references to scale_inv tensors. - The function updates the model index file to remove references to scale_inv tensors.
""" """
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
os.makedirs(bf16_path, exist_ok=True) os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f: with open(model_index_file, "r") as f:
@ -57,14 +64,14 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name] file_name = weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") loaded_files[file_name] = load_file(file_path, device=default_device)
return loaded_files[file_name][tensor_name] return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort() safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files): for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file) file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda") current_state_dict = load_file(safetensor_file, device=default_device)
loaded_files[file_name] = current_state_dict loaded_files[file_name] = current_state_dict
new_state_dict = {} new_state_dict = {}
@ -91,7 +98,12 @@ def main(fp8_path, bf16_path):
if len(loaded_files) > 2: if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files)) oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file] del loaded_files[oldest_file]
torch.cuda.empty_cache() if torch.cuda.is_available():
torch.cuda.empty_cache()
elif torch.mps.is_available():
torch.mps.empty_cache()
else:
gc.collect()
# Update model index # Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")

View File

@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0):
@torch.inference_mode() @torch.inference_mode()
def generate( def generate(
model: Transformer, model: Transformer,
device: str,
prompt_tokens: List[List[int]], prompt_tokens: List[List[int]],
max_new_tokens: int, max_new_tokens: int,
eos_id: int, eos_id: int,
temperature: float = 1.0 temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Generates new tokens based on the given prompt tokens using the specified model. Generates new tokens based on the given prompt tokens using the specified model.
@ -51,11 +52,11 @@ def generate(
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device=device)
prompt_mask = tokens != -1 prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len): for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
@ -97,11 +98,20 @@ def main(
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0. temperature (float, optional): Temperature for sampling. Defaults to 1.0.
""" """
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
world_size = int(os.getenv("WORLD_SIZE", "1")) world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0")) rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1: if world_size > 1:
dist.init_process_group("nccl") if torch.cuda.is_available():
dist.init_process_group("nccl")
else:
dist.init_process_group("gloo")
global print global print
if rank != 0: if rank != 0:
print = lambda *_, **__: None print = lambda *_, **__: None
@ -112,10 +122,10 @@ def main(
with open(config) as f: with open(config) as f:
args = ModelArgs(**json.load(f)) args = ModelArgs(**json.load(f))
print(args) print(args)
with torch.device("cuda"): with torch.device(default_device):
model = Transformer(args) model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device)
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive: if interactive:

View File

@ -796,7 +796,13 @@ class Transformer(nn.Module):
if __name__ == "__main__": if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda") if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
torch.set_default_device(default_device)
torch.manual_seed(0) torch.manual_seed(0)
args = ModelArgs() args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128)) x = torch.randint(0, args.vocab_size, (2, 128))

View File

@ -1,4 +1,4 @@
torch==2.4.1 torch==2.6.0
triton==3.0.0 git+https://github.com/NripeshN/triton.git@main#subdirectory=python
transformers==4.46.3 transformers==4.46.3
safetensors==0.4.5 safetensors==0.4.5