Update model.py

Fixed the indentation of the suggestion.
This commit is contained in:
Pedro Dessanti 2025-01-27 10:36:57 -03:00 committed by GitHub
parent 9bf00671cf
commit c47ecaa800
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -779,20 +779,21 @@ class Transformer(nn.Module):
torch.Tensor: Logits tensor of shape (batch_size, vocab_size). torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
""" """
with autocast(): with autocast():
seqlen = tokens.size(1) seqlen = tokens.size(1)
h = self.embed(tokens) h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
mask = None mask = None
if seqlen > 1: if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
for layer in self.layers: for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask) h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)[:, -1] h = self.norm(h)[:, -1]
logits = self.head(h) logits = self.head(h)
if world_size > 1: if world_size > 1:
all_logits = [torch.empty_like(logits) for _ in range(world_size)] all_logits = [torch.empty_like(logits) for _ in range(world_size)]
dist.all_gather(all_logits, logits) dist.all_gather(all_logits, logits)
logits = torch.cat(all_logits, dim=-1) logits = torch.cat(all_logits, dim=-1)
return logits return logits