diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..f8b630a 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -24,7 +24,27 @@ def sample(logits, temperature: float = 1.0): """ logits = logits / max(temperature, 1e-5) probs = torch.softmax(logits, dim=-1) - return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + return torch.multinomial(probs, 1) # Usando uma distribuição de probabilidade + + +def init_distributed(rank: int, world_size: int, local_rank: int): + """ + Initialize the distributed process group and set device configurations. + + Args: + rank (int): The rank of the current process. + world_size (int): Total number of processes. + local_rank (int): The local rank for multi-GPU configurations. + """ + if world_size > 1: + dist.init_process_group("nccl") + global print + if rank != 0: + print = lambda *_, **__: None + torch.cuda.set_device(local_rank) + torch.set_default_dtype(torch.bfloat16) + torch.set_num_threads(8) + torch.manual_seed(965) @torch.inference_mode() @@ -100,20 +120,17 @@ def main( world_size = int(os.getenv("WORLD_SIZE", "1")) rank = int(os.getenv("RANK", "0")) local_rank = int(os.getenv("LOCAL_RANK", "0")) - if world_size > 1: - dist.init_process_group("nccl") - global print - if rank != 0: - print = lambda *_, **__: None - torch.cuda.set_device(local_rank) - torch.set_default_dtype(torch.bfloat16) - torch.set_num_threads(8) - torch.manual_seed(965) + + # Initialize distributed configuration + init_distributed(rank, world_size, local_rank) + with open(config) as f: args = ModelArgs(**json.load(f)) print(args) + with torch.device("cuda"): model = Transformer(args) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) @@ -181,5 +198,10 @@ if __name__ == "__main__": parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() - assert args.input_file or args.interactive + + # Validate input + if not (args.input_file or args.interactive): + print("Erro: É necessário especificar --input-file ou ativar --interactive.") + exit(1) + main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)