From 2bf4595d13207d1fb3506ac7c487f655a44545cf Mon Sep 17 00:00:00 2001 From: Pedro Dessanti <92925579+Dessantii@users.noreply.github.com> Date: Mon, 27 Jan 2025 08:58:59 -0300 Subject: [PATCH] Update model.py Enabling mixed precision training to reduce memory usage and potentially speed up training. --- inference/model.py | 1 + 1 file changed, 1 insertion(+) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..86cbc35 100644 --- a/inference/model.py +++ b/inference/model.py @@ -777,6 +777,7 @@ class Transformer(nn.Module): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ + with autocast(): seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]