From 7117f260e9c4098ccd128e869d53d278ac3216f2 Mon Sep 17 00:00:00 2001 From: wanglei <1105865632@qq.com> Date: Tue, 14 Jan 2025 15:29:06 +0800 Subject: [PATCH] =?UTF-8?q?mask=E6=A0=87=E8=AE=B0=E5=A3=B0=E6=98=8E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- inference/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..51bd544 100644 --- a/inference/model.py +++ b/inference/model.py @@ -115,6 +115,7 @@ class ParallelEmbedding(nn.Module): Raises: ValueError: If `world_size` is not defined. """ + mask = torch.empty() if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx