diff --git a/inference/model.py b/inference/model.py index 9ea60c9..86cbc35 100644 --- a/inference/model.py +++ b/inference/model.py @@ -777,6 +777,7 @@ class Transformer(nn.Module): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ + with autocast(): seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]