diff --git a/inference/model.py b/inference/model.py index 9ea60c9..cc74d26 100644 --- a/inference/model.py +++ b/inference/model.py @@ -782,7 +782,7 @@ class Transformer(nn.Module): freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] mask = None if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + mask = torch.triu(torch.full((seqlen, seqlen), float("-inf"), device=tokens.device), diagonal=1) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h)[:, -1]