mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 14:18:57 -05:00
feat: Enhance device compatibility and update PyTorch version
This commit is contained in:
parent
b5d872ead0
commit
e75ce46245
@ -30,6 +30,12 @@ def main(fp8_path, bf16_path):
|
||||
- The function updates the model index file to remove references to scale_inv tensors.
|
||||
"""
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.mps.is_available():
|
||||
default_device = "mps"
|
||||
else:
|
||||
default_device = "cpu"
|
||||
os.makedirs(bf16_path, exist_ok=True)
|
||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||
with open(model_index_file, "r") as f:
|
||||
@ -57,14 +63,14 @@ def main(fp8_path, bf16_path):
|
||||
file_name = weight_map[tensor_name]
|
||||
if file_name not in loaded_files:
|
||||
file_path = os.path.join(fp8_path, file_name)
|
||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||
loaded_files[file_name] = load_file(file_path, device=default_device)
|
||||
return loaded_files[file_name][tensor_name]
|
||||
|
||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||
safetensor_files.sort()
|
||||
for safetensor_file in tqdm(safetensor_files):
|
||||
file_name = os.path.basename(safetensor_file)
|
||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||
current_state_dict = load_file(safetensor_file, device=default_device)
|
||||
loaded_files[file_name] = current_state_dict
|
||||
|
||||
new_state_dict = {}
|
||||
|
@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0):
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
model: Transformer,
|
||||
device: str,
|
||||
prompt_tokens: List[List[int]],
|
||||
max_new_tokens: int,
|
||||
eos_id: int,
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Generates new tokens based on the given prompt tokens using the specified model.
|
||||
@ -51,11 +52,11 @@ def generate(
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
|
||||
for i, t in enumerate(prompt_tokens):
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
|
||||
prev_pos = 0
|
||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||
finished = torch.tensor([False] * len(prompt_tokens), device=device)
|
||||
prompt_mask = tokens != -1
|
||||
for cur_pos in range(min(prompt_lens), total_len):
|
||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
@ -97,11 +98,20 @@ def main(
|
||||
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
||||
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
||||
"""
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.mps.is_available():
|
||||
default_device = "mps"
|
||||
else:
|
||||
default_device = "cpu"
|
||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||
rank = int(os.getenv("RANK", "0"))
|
||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||
if world_size > 1:
|
||||
if torch.cuda.is_available():
|
||||
dist.init_process_group("nccl")
|
||||
else:
|
||||
dist.init_process_group("gloo")
|
||||
global print
|
||||
if rank != 0:
|
||||
print = lambda *_, **__: None
|
||||
@ -112,10 +122,10 @@ def main(
|
||||
with open(config) as f:
|
||||
args = ModelArgs(**json.load(f))
|
||||
print(args)
|
||||
with torch.device("cuda"):
|
||||
with torch.device(default_device):
|
||||
model = Transformer(args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device)
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||
|
||||
if interactive:
|
||||
|
@ -796,7 +796,13 @@ class Transformer(nn.Module):
|
||||
|
||||
if __name__ == "__main__":
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
torch.set_default_device("cuda")
|
||||
if torch.cuda.is_available():
|
||||
default_device = "cuda"
|
||||
elif torch.mps.is_available():
|
||||
default_device = "mps"
|
||||
else:
|
||||
default_device = "cpu"
|
||||
torch.set_default_device("default_device")
|
||||
torch.manual_seed(0)
|
||||
args = ModelArgs()
|
||||
x = torch.randint(0, args.vocab_size, (2, 128))
|
||||
|
@ -1,4 +1,4 @@
|
||||
torch==2.4.1
|
||||
torch==2.6.0
|
||||
triton==3.0.0
|
||||
transformers==4.46.3
|
||||
safetensors==0.4.5
|
Loading…
Reference in New Issue
Block a user