Update model.py

Enabling mixed precision training to reduce memory usage and potentially speed up training.
This commit is contained in:
Pedro Dessanti 2025-01-27 08:58:59 -03:00 committed by GitHub
parent b5d872ead0
commit 2bf4595d13
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -777,6 +777,7 @@ class Transformer(nn.Module):
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
with autocast():
seqlen = tokens.size(1)
h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]