mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Update model.py
Fixed the indentation of the suggestion.
This commit is contained in:
parent
9bf00671cf
commit
c47ecaa800
@ -779,20 +779,21 @@ class Transformer(nn.Module):
|
|||||||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
||||||
"""
|
"""
|
||||||
with autocast():
|
with autocast():
|
||||||
seqlen = tokens.size(1)
|
seqlen = tokens.size(1)
|
||||||
h = self.embed(tokens)
|
h = self.embed(tokens)
|
||||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||||
mask = None
|
mask = None
|
||||||
if seqlen > 1:
|
if seqlen > 1:
|
||||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
h = layer(h, start_pos, freqs_cis, mask)
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
h = self.norm(h)[:, -1]
|
h = self.norm(h)[:, -1]
|
||||||
logits = self.head(h)
|
logits = self.head(h)
|
||||||
if world_size > 1:
|
if world_size > 1:
|
||||||
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
all_logits = [torch.empty_like(logits) for _ in range(world_size)]
|
||||||
dist.all_gather(all_logits, logits)
|
dist.all_gather(all_logits, logits)
|
||||||
logits = torch.cat(all_logits, dim=-1)
|
logits = torch.cat(all_logits, dim=-1)
|
||||||
|
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user