mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 14:18:57 -05:00
feat: Enhance device compatibility and update PyTorch version
This commit is contained in:
parent
b5d872ead0
commit
e75ce46245
@ -30,6 +30,12 @@ def main(fp8_path, bf16_path):
|
|||||||
- The function updates the model index file to remove references to scale_inv tensors.
|
- The function updates the model index file to remove references to scale_inv tensors.
|
||||||
"""
|
"""
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
default_device = "cuda"
|
||||||
|
elif torch.mps.is_available():
|
||||||
|
default_device = "mps"
|
||||||
|
else:
|
||||||
|
default_device = "cpu"
|
||||||
os.makedirs(bf16_path, exist_ok=True)
|
os.makedirs(bf16_path, exist_ok=True)
|
||||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
||||||
with open(model_index_file, "r") as f:
|
with open(model_index_file, "r") as f:
|
||||||
@ -57,14 +63,14 @@ def main(fp8_path, bf16_path):
|
|||||||
file_name = weight_map[tensor_name]
|
file_name = weight_map[tensor_name]
|
||||||
if file_name not in loaded_files:
|
if file_name not in loaded_files:
|
||||||
file_path = os.path.join(fp8_path, file_name)
|
file_path = os.path.join(fp8_path, file_name)
|
||||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
loaded_files[file_name] = load_file(file_path, device=default_device)
|
||||||
return loaded_files[file_name][tensor_name]
|
return loaded_files[file_name][tensor_name]
|
||||||
|
|
||||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
||||||
safetensor_files.sort()
|
safetensor_files.sort()
|
||||||
for safetensor_file in tqdm(safetensor_files):
|
for safetensor_file in tqdm(safetensor_files):
|
||||||
file_name = os.path.basename(safetensor_file)
|
file_name = os.path.basename(safetensor_file)
|
||||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
current_state_dict = load_file(safetensor_file, device=default_device)
|
||||||
loaded_files[file_name] = current_state_dict
|
loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
|
@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0):
|
|||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
|
device: str,
|
||||||
prompt_tokens: List[List[int]],
|
prompt_tokens: List[List[int]],
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
Generates new tokens based on the given prompt tokens using the specified model.
|
Generates new tokens based on the given prompt tokens using the specified model.
|
||||||
@ -51,11 +52,11 @@ def generate(
|
|||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= model.max_seq_len
|
assert max(prompt_lens) <= model.max_seq_len
|
||||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
|
||||||
for i, t in enumerate(prompt_tokens):
|
for i, t in enumerate(prompt_tokens):
|
||||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
finished = torch.tensor([False] * len(prompt_tokens), device=device)
|
||||||
prompt_mask = tokens != -1
|
prompt_mask = tokens != -1
|
||||||
for cur_pos in range(min(prompt_lens), total_len):
|
for cur_pos in range(min(prompt_lens), total_len):
|
||||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
@ -97,11 +98,20 @@ def main(
|
|||||||
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
||||||
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
||||||
"""
|
"""
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
default_device = "cuda"
|
||||||
|
elif torch.mps.is_available():
|
||||||
|
default_device = "mps"
|
||||||
|
else:
|
||||||
|
default_device = "cpu"
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
rank = int(os.getenv("RANK", "0"))
|
rank = int(os.getenv("RANK", "0"))
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
dist.init_process_group("nccl")
|
if torch.cuda.is_available():
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
else:
|
||||||
|
dist.init_process_group("gloo")
|
||||||
global print
|
global print
|
||||||
if rank != 0:
|
if rank != 0:
|
||||||
print = lambda *_, **__: None
|
print = lambda *_, **__: None
|
||||||
@ -112,10 +122,10 @@ def main(
|
|||||||
with open(config) as f:
|
with open(config) as f:
|
||||||
args = ModelArgs(**json.load(f))
|
args = ModelArgs(**json.load(f))
|
||||||
print(args)
|
print(args)
|
||||||
with torch.device("cuda"):
|
with torch.device(default_device):
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device)
|
||||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
|
@ -796,7 +796,13 @@ class Transformer(nn.Module):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_default_device("cuda")
|
if torch.cuda.is_available():
|
||||||
|
default_device = "cuda"
|
||||||
|
elif torch.mps.is_available():
|
||||||
|
default_device = "mps"
|
||||||
|
else:
|
||||||
|
default_device = "cpu"
|
||||||
|
torch.set_default_device("default_device")
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
args = ModelArgs()
|
args = ModelArgs()
|
||||||
x = torch.randint(0, args.vocab_size, (2, 128))
|
x = torch.randint(0, args.vocab_size, (2, 128))
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
torch==2.4.1
|
torch==2.6.0
|
||||||
triton==3.0.0
|
triton==3.0.0
|
||||||
transformers==4.46.3
|
transformers==4.46.3
|
||||||
safetensors==0.4.5
|
safetensors==0.4.5
|
Loading…
Reference in New Issue
Block a user