diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..c25cea1 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -30,6 +30,12 @@ def main(fp8_path, bf16_path): - The function updates the model index file to remove references to scale_inv tensors. """ torch.set_default_dtype(torch.bfloat16) + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" os.makedirs(bf16_path, exist_ok=True) model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") with open(model_index_file, "r") as f: @@ -57,14 +63,14 @@ def main(fp8_path, bf16_path): file_name = weight_map[tensor_name] if file_name not in loaded_files: file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cuda") + loaded_files[file_name] = load_file(file_path, device=default_device) return loaded_files[file_name][tensor_name] safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) safetensor_files.sort() for safetensor_file in tqdm(safetensor_files): file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") + current_state_dict = load_file(safetensor_file, device=default_device) loaded_files[file_name] = current_state_dict new_state_dict = {} diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..bdfd828 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -30,10 +30,11 @@ def sample(logits, temperature: float = 1.0): @torch.inference_mode() def generate( model: Transformer, + device: str, prompt_tokens: List[List[int]], max_new_tokens: int, eos_id: int, - temperature: float = 1.0 + temperature: float = 1.0, ) -> List[List[int]]: """ Generates new tokens based on the given prompt tokens using the specified model. @@ -51,11 +52,11 @@ def generate( prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) - tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device=device) for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device) prev_pos = 0 - finished = torch.tensor([False] * len(prompt_tokens), device="cuda") + finished = torch.tensor([False] * len(prompt_tokens), device=device) prompt_mask = tokens != -1 for cur_pos in range(min(prompt_lens), total_len): logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) @@ -97,11 +98,20 @@ def main( max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. temperature (float, optional): Temperature for sampling. Defaults to 1.0. """ + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) if world_size > 1: - dist.init_process_group("nccl") + if torch.cuda.is_available(): + dist.init_process_group("nccl") + else: + dist.init_process_group("gloo") global print if rank != 0: print = lambda *_, **__: None @@ -112,10 +122,10 @@ def main( with open(config) as f: args = ModelArgs(**json.load(f)) print(args) - with torch.device("cuda"): + with torch.device(default_device): model = Transformer(args) tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) + tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0], default_device) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) if interactive: diff --git a/inference/model.py b/inference/model.py index 9ea60c9..3d8bf2c 100644 --- a/inference/model.py +++ b/inference/model.py @@ -796,7 +796,13 @@ class Transformer(nn.Module): if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) - torch.set_default_device("cuda") + if torch.cuda.is_available(): + default_device = "cuda" + elif torch.mps.is_available(): + default_device = "mps" + else: + default_device = "cpu" + torch.set_default_device("default_device") torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128)) diff --git a/inference/requirements.txt b/inference/requirements.txt index 68a60d0..c82cd4a 100644 --- a/inference/requirements.txt +++ b/inference/requirements.txt @@ -1,4 +1,4 @@ -torch==2.4.1 +torch==2.6.0 triton==3.0.0 transformers==4.46.3 safetensors==0.4.5 \ No newline at end of file