mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Merge abaadd9b3e
into 4cc6253d5c
This commit is contained in:
commit
3b10d4539f
@ -115,16 +115,23 @@ class ParallelEmbedding(nn.Module):
|
||||
Raises:
|
||||
ValueError: If `world_size` is not defined.
|
||||
"""
|
||||
if world_size > 1:
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
x = x - self.vocab_start_idx
|
||||
x[mask] = 0
|
||||
y = F.embedding(x, self.weight)
|
||||
if world_size > 1:
|
||||
y[mask] = 0
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
if self.world_size < 1:
|
||||
raise ValueError("world_size must be >= 1")
|
||||
|
||||
if self.world_size == 1:
|
||||
return F.embedding(x, self.weight)
|
||||
|
||||
# For world_size > 1
|
||||
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
|
||||
x = x - self.vocab_start_idx
|
||||
x[mask] = 0
|
||||
|
||||
y = F.embedding(x, self.weight)
|
||||
y[mask] = 0
|
||||
|
||||
dist.all_reduce(y)
|
||||
return y
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user