diff --git a/inference/model.py b/inference/model.py
index 8f1ab81..47cd0eb 100644
--- a/inference/model.py
+++ b/inference/model.py
@@ -115,16 +115,23 @@ class ParallelEmbedding(nn.Module):
         Raises:
             ValueError: If `world_size` is not defined.
         """
-        if world_size > 1:
-            mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
-            x = x - self.vocab_start_idx
-            x[mask] = 0
-        y = F.embedding(x, self.weight)
-        if world_size > 1:
-            y[mask] = 0
-            dist.all_reduce(y)
-        return y
+def forward(self, x: torch.Tensor) -> torch.Tensor:
+    if self.world_size < 1:
+        raise ValueError("world_size must be >= 1")
 
+    if self.world_size == 1:
+        return F.embedding(x, self.weight)
+
+    # For world_size > 1
+    mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
+    x = x - self.vocab_start_idx
+    x[mask] = 0
+
+    y = F.embedding(x, self.weight)
+    y[mask] = 0
+
+    dist.all_reduce(y)
+    return y
 
 def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
     """