Fix: safer and cleaner forward() in distributed embedding layer

This commit is contained in:
saro1993 2025-04-05 21:33:46 +02:00
parent a878eada08
commit abaadd9b3e

View File

@ -115,17 +115,24 @@ class ParallelEmbedding(nn.Module):
Raises: Raises:
ValueError: If `world_size` is not defined. ValueError: If `world_size` is not defined.
""" """
if world_size > 1: def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.world_size < 1:
raise ValueError("world_size must be >= 1")
if self.world_size == 1:
return F.embedding(x, self.weight)
# For world_size > 1
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx x = x - self.vocab_start_idx
x[mask] = 0 x[mask] = 0
y = F.embedding(x, self.weight) y = F.embedding(x, self.weight)
if world_size > 1:
y[mask] = 0 y[mask] = 0
dist.all_reduce(y) dist.all_reduce(y)
return y return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
""" """
Applies a linear transformation to the incoming data: y = xA^T + b. Applies a linear transformation to the incoming data: y = xA^T + b.