Masking: avoid modifying tensor in-place to improve performance

This commit is contained in:
Christopher Harrison 2025-01-28 01:35:50 +00:00
parent b5d872ead0
commit f3a55f92c2

View File

@ -782,7 +782,7 @@ class Transformer(nn.Module):
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
mask = torch.triu(torch.full((seqlen, seqlen), float("-inf"), device=tokens.device), diagonal=1)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)[:, -1]