mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-05-03 00:49:01 -04:00
Update model.py
Enabling mixed precision training to reduce memory usage and potentially speed up training.
This commit is contained in:
parent
b5d872ead0
commit
9bf00671cf
@ -4,6 +4,7 @@ from typing import Tuple, Optional, Literal
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
from torch.cuda.amp import autocast
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
@ -777,6 +778,7 @@ class Transformer(nn.Module):
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
torch.Tensor: Logits tensor of shape (batch_size, vocab_size).
|
||||||
"""
|
"""
|
||||||
|
with autocast():
|
||||||
seqlen = tokens.size(1)
|
seqlen = tokens.size(1)
|
||||||
h = self.embed(tokens)
|
h = self.embed(tokens)
|
||||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||||
|
Loading…
Reference in New Issue
Block a user