From f3a55f92c2e0f9b5c048980bbc8493e10869dbb2 Mon Sep 17 00:00:00 2001 From: Christopher Harrison Date: Tue, 28 Jan 2025 01:35:50 +0000 Subject: [PATCH] Masking: avoid modifying tensor in-place to improve performance --- inference/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..cc74d26 100644 --- a/inference/model.py +++ b/inference/model.py @@ -782,7 +782,7 @@ class Transformer(nn.Module): freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] mask = None if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + mask = torch.triu(torch.full((seqlen, seqlen), float("-inf"), device=tokens.device), diagonal=1) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h)[:, -1]