mask标记声明

This commit is contained in:
wanglei 2025-01-14 15:29:06 +08:00
parent ee4c4ea32b
commit 7117f260e9

View File

@ -115,6 +115,7 @@ class ParallelEmbedding(nn.Module):
Raises: Raises:
ValueError: If `world_size` is not defined. ValueError: If `world_size` is not defined.
""" """
mask = torch.empty()
if world_size > 1: if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx x = x - self.vocab_start_idx