Update model.py

Enabling mixed precision training to reduce memory usage and potentially speed up training.
This commit is contained in:
Pedro Dessanti 2025-01-27 09:04:34 -03:00 committed by GitHub
parent b5d872ead0
commit 9bf00671cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,6 +4,7 @@ from typing import Tuple, Optional, Literal
import torch
from torch import nn
from torch.cuda.amp import autocast
import torch.nn.functional as F
import torch.distributed as dist
@ -777,6 +778,7 @@ class Transformer(nn.Module):
Returns:
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
"""
with autocast():
seqlen = tokens.size(1)
h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]