mask标记声明

This commit is contained in:
wanglei 2025-01-14 15:29:06 +08:00
parent ee4c4ea32b
commit 7117f260e9

View File

@ -115,6 +115,7 @@ class ParallelEmbedding(nn.Module):
Raises:
ValueError: If `world_size` is not defined.
"""
mask = torch.empty()
if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx