From abaadd9b3e746fc91fcd19dbb79ff4ed615580e3 Mon Sep 17 00:00:00 2001 From: saro1993 <116103507+saro1993@users.noreply.github.com> Date: Sat, 5 Apr 2025 21:33:46 +0200 Subject: [PATCH] Fix: safer and cleaner forward() in distributed embedding layer --- inference/model.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/inference/model.py b/inference/model.py index 8f1ab81..47cd0eb 100644 --- a/inference/model.py +++ b/inference/model.py @@ -115,16 +115,23 @@ class ParallelEmbedding(nn.Module): Raises: ValueError: If `world_size` is not defined. """ - if world_size > 1: - mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) - x = x - self.vocab_start_idx - x[mask] = 0 - y = F.embedding(x, self.weight) - if world_size > 1: - y[mask] = 0 - dist.all_reduce(y) - return y +def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.world_size < 1: + raise ValueError("world_size must be >= 1") + if self.world_size == 1: + return F.embedding(x, self.weight) + + # For world_size > 1 + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + x = x - self.vocab_start_idx + x[mask] = 0 + + y = F.embedding(x, self.weight) + y[mask] = 0 + + dist.all_reduce(y) + return y def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """