Changes:

init_distributed function: Extracted the distributed setup logic into a separate function.
sample function: Modified it to use torch.multinomial instead of an exponentiation-based approach for sampling.
Argument Validation: Replaced the assert with a more user-friendly validation in main to ensure that at least one of the parameters (input-file or interactive) is provided.
Interactive Code Refactoring: The user interaction logic was kept, but the init_distributed function is now called separately at the beginning of main.
This commit is contained in:
Gabriel Caetano 2025-01-30 22:47:39 -03:00
parent b5d872ead0
commit 89882a94f6

View File

@ -24,7 +24,27 @@ def sample(logits, temperature: float = 1.0):
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
return torch.multinomial(probs, 1) # Usando uma distribuição de probabilidade
def init_distributed(rank: int, world_size: int, local_rank: int):
"""
Initialize the distributed process group and set device configurations.
Args:
rank (int): The rank of the current process.
world_size (int): Total number of processes.
local_rank (int): The local rank for multi-GPU configurations.
"""
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
@torch.inference_mode()
@ -100,20 +120,17 @@ def main(
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
# Initialize distributed configuration
init_distributed(rank, world_size, local_rank)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
@ -181,5 +198,10 @@ if __name__ == "__main__":
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive
# Validate input
if not (args.input_file or args.interactive):
print("Erro: É necessário especificar --input-file ou ativar --interactive.")
exit(1)
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)