feat: Enhance device compatibility and update PyTorch version

This commit is contained in:
Nripesh Niketan 2025-01-30 00:06:55 +00:00
parent b5d872ead0
commit e75ce46245
4 changed files with 33 additions and 11 deletions

View File

@ -30,6 +30,12 @@ def main(fp8_path, bf16_path):
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
@ -57,14 +63,14 @@ def main(fp8_path, bf16_path):
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
loaded_files[file_name] = load_file(file_path, device=default_device)
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
current_state_dict = load_file(safetensor_file, device=default_device)
loaded_files[file_name] = current_state_dict
new_state_dict = {}

View File

@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0):
@torch.inference_mode()
def generate(
model: Transformer,
device: str,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
temperature: float = 1.0,
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
@ -51,11 +52,11 @@ def generate(
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
finished = torch.tensor([False] * len(prompt_tokens), device=device)
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
@ -97,11 +98,20 @@ def main(
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
if torch.cuda.is_available():
dist.init_process_group("nccl")
else:
dist.init_process_group("gloo")
global print
if rank != 0:
print = lambda *_, **__: None
@ -112,10 +122,10 @@ def main(
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
with torch.device(default_device):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device)
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive:

View File

@ -796,7 +796,13 @@ class Transformer(nn.Module):
if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
if torch.cuda.is_available():
default_device = "cuda"
elif torch.mps.is_available():
default_device = "mps"
else:
default_device = "cpu"
torch.set_default_device("default_device")
torch.manual_seed(0)
args = ModelArgs()
x = torch.randint(0, args.vocab_size, (2, 128))

View File

@ -1,4 +1,4 @@
torch==2.4.1
torch==2.6.0
triton==3.0.0
transformers==4.46.3
safetensors==0.4.5