diff --git a/inference/model.py b/inference/model.py index c143e97..e92156d 100644 --- a/inference/model.py +++ b/inference/model.py @@ -115,16 +115,23 @@ class ParallelEmbedding(nn.Module): Raises: ValueError: If `world_size` is not defined. """ - if world_size > 1: - mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) - x = x - self.vocab_start_idx - x[mask] = 0 - y = F.embedding(x, self.weight) - if world_size > 1: - y[mask] = 0 - dist.all_reduce(y) - return y +def forward(self, x: torch.Tensor) -> torch.Tensor: + if self.world_size < 1: + raise ValueError("world_size must be >= 1") + if self.world_size == 1: + return F.embedding(x, self.weight) + + # For world_size > 1 + mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) + x = x - self.vocab_start_idx + x[mask] = 0 + + y = F.embedding(x, self.weight) + y[mask] = 0 + + dist.all_reduce(y) + return y def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: """