diff --git a/inference/model.py b/inference/model.py index 9ea60c9..51bd544 100644 --- a/inference/model.py +++ b/inference/model.py @@ -115,6 +115,7 @@ class ParallelEmbedding(nn.Module): Raises: ValueError: If `world_size` is not defined. """ + mask = torch.empty() if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx