From e75ce46245798822a4042e59860dccf4e41bf3b9 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Thu, 30 Jan 2025 00:06:55 +0000 Subject: [PATCH 1/4] feat: Enhance device compatibility and update PyTorch version --- inference/fp8_cast_bf16.py | 10 ++++++++-- inference/generate.py | 24 +++++++++++++++++------- inference/model.py | 8 +++++++- inference/requirements.txt | 2 +- 4 files changed, 33 insertions(+), 11 deletions(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..c25cea1 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -30,6 +30,12 @@ def main(fp8_path, bf16_path): - The function updates the model index file to remove references to scale_inv tensors. """ torch.set_default_dtype(torch.bfloat16) + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") with open(model_index_file, "r") as f: @@ -57,14 +63,14 @@ def main(fp8_path, bf16_path): file_name = weight_map[tensor_name] if file_name not in loaded_files: file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cuda") + loaded_files[file_name] = load_file(file_path, device=default_device) return loaded_files[file_name][tensor_name] safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() for safetensor_file in tqdm(safetensor_files): file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") + current_state_dict = load_file(safetensor_file, device=default_device) loaded_files[file_name] = current_state_dict new_state_dict = {} diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..bdfd828 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0): @torch.inference_mode() def generate( model: Transformer, + device: str, prompt_tokens: List[List[int]], max_new_tokens: int, eos_id: int, - temperature: float = 1.0 + temperature: float = 1.0, ) -> List[List[int]]: """ Generates new tokens based on the given prompt tokens using the specified model. @@ -51,11 +52,11 @@ def generate( prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) - tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device) for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device) prev_pos = 0 - finished = torch.tensor([False] * len(prompt_tokens), device="cuda") + finished = torch.tensor([False] * len(prompt_tokens), device=device) prompt_mask = tokens != -1 for cur_pos in range(min(prompt_lens), total_len): logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) @@ -97,11 +98,20 @@ def main( max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. temperature (float, optional): Temperature for sampling. Defaults to 1.0. """ + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) if world_size > 1: - dist.init_process_group("nccl") + if torch.cuda.is_available(): + dist.init_process_group("nccl") + else: + dist.init_process_group("gloo") global print if rank != 0: print = lambda *_, **__: None @@ -112,10 +122,10 @@ def main( with open(config) as f: args = ModelArgs(**json.load(f)) print(args) - with torch.device("cuda"): + with torch.device(default_device): model = Transformer(args) tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) + tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) if interactive: diff --git a/inference/model.py b/inference/model.py index 9ea60c9..3d8bf2c 100644 --- a/inference/model.py +++ b/inference/model.py @@ -796,7 +796,13 @@ class Transformer(nn.Module): if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) - torch.set_default_device("cuda") + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" + torch.set_default_device("default_device") torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128)) diff --git a/inference/requirements.txt b/inference/requirements.txt index 68a60d0..c82cd4a 100644 --- a/inference/requirements.txt +++ b/inference/requirements.txt @@ -1,4 +1,4 @@ -torch==2.4.1 +torch==2.6.0 triton==3.0.0 transformers==4.46.3 safetensors==0.4.5 \ No newline at end of file From b6e3910fd0f6835ef759e6bf3a96f50f58a19da2 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan <86844847+NripeshN@users.noreply.github.com> Date: Thu, 30 Jan 2025 16:04:00 +0000 Subject: [PATCH 2/4] Fix small error --- inference/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/model.py b/inference/model.py index 3d8bf2c..18586da 100644 --- a/inference/model.py +++ b/inference/model.py @@ -802,7 +802,7 @@ if __name__ == "__main__": default_device = "mps" else: default_device = "cpu" - torch.set_default_device("default_device") + torch.set_default_device(default_device) torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128)) From 73efe7c63175f677f111f1df8aa84128b805fb02 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Sun, 2 Feb 2025 10:41:20 +0000 Subject: [PATCH 3/4] Memory management update --- inference/fp8_cast_bf16.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index c25cea1..bce5784 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -3,6 +3,7 @@ import json from argparse import ArgumentParser from glob import glob from tqdm import tqdm +import gc import torch from safetensors.torch import load_file, save_file @@ -97,7 +98,12 @@ def main(fp8_path, bf16_path): if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] - torch.cuda.empty_cache() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + elif torch.mps.is_available(): + torch.mps.empty_cache() + else: + gc.collect() # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") From a5336884cfa952634f355549643dcb781eaec872 Mon Sep 17 00:00:00 2001 From: Nripesh Niketan Date: Mon, 3 Feb 2025 21:50:02 +0000 Subject: [PATCH 4/4] fix: Update triton dependency to use the latest version from GitHub --- inference/requirements.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/requirements.txt b/inference/requirements.txt index c82cd4a..76c961f 100644 --- a/inference/requirements.txt +++ b/inference/requirements.txt @@ -1,4 +1,4 @@ torch==2.6.0 -triton==3.0.0 +git+https://github.com/NripeshN/triton.git@main#subdirectory=python transformers==4.46.3 safetensors==0.4.5 \ No newline at end of file