diff --git a/inference/model.py b/inference/model.py
index 9ea60c9..472d740 100644
--- a/inference/model.py
+++ b/inference/model.py
@@ -4,6 +4,7 @@ from typing import Tuple, Optional, Literal
 
 import torch
 from torch import nn
+from torch.cuda.amp import autocast
 import torch.nn.functional as F
 import torch.distributed as dist
 
@@ -777,6 +778,7 @@ class Transformer(nn.Module):
         Returns:
             torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
         """
+        with autocast():
         seqlen = tokens.size(1)
         h = self.embed(tokens)
         freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]