From 9bf00671cfe45264aed0b3ad23acef39e84afdcc Mon Sep 17 00:00:00 2001 From: Pedro Dessanti <92925579+Dessantii@users.noreply.github.com> Date: Mon, 27 Jan 2025 09:04:34 -0300 Subject: [PATCH] Update model.py Enabling mixed precision training to reduce memory usage and potentially speed up training. --- inference/model.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..472d740 100644 --- a/inference/model.py +++ b/inference/model.py @@ -4,6 +4,7 @@ from typing import Tuple, Optional, Literal import torch from torch import nn +from torch.cuda.amp import autocast import torch.nn.functional as F import torch.distributed as dist @@ -777,6 +778,7 @@ class Transformer(nn.Module): Returns: torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ + with autocast(): seqlen = tokens.size(1) h = self.embed(tokens) freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]