DeepSeek-V3/inference/model.py
felipjah 0b8ca63f78
Update model.py
Optimized Moe Transformer
2025-02-04 21:10:07 +08:00

103 lines
3.4 KiB
Python

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from torch.cuda.amp import autocast
from torch.utils.checkpoint import checkpoint
class MLP(nn.Module):
def __init__(self, dim, inter_dim):
super().__init__()
self.w1 = nn.Linear(dim, inter_dim, bias=False)
self.w2 = nn.Linear(inter_dim, dim, bias=False)
self.w3 = nn.Linear(dim, inter_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class Gate(nn.Module):
def __init__(self, args):
super().__init__()
self.dim = args.dim
self.topk = args.n_activated_experts
self.n_experts = args.n_routed_experts
self.weight = nn.Parameter(torch.empty(self.n_experts, self.dim))
nn.init.xavier_uniform_(self.weight)
def forward(self, x):
scores = F.softmax(F.linear(x, self.weight), dim=-1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = torch.gather(scores, 1, indices)
return weights, indices
class Expert(nn.Module):
def __init__(self, dim, inter_dim):
super().__init__()
self.w1 = nn.Linear(dim, inter_dim, bias=False)
self.w2 = nn.Linear(inter_dim, dim, bias=False)
self.w3 = nn.Linear(dim, inter_dim, bias=False)
def forward(self, x):
return self.w2(F.silu(self.w1(x)) * self.w3(x))
class MoE(nn.Module):
def __init__(self, args):
super().__init__()
self.dim = args.dim
self.n_experts = args.n_routed_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) for _ in range(self.n_experts)])
def forward(self, x):
weights, indices = self.gate(x)
y = torch.zeros_like(x)
for i in range(self.n_experts):
mask = (indices == i).float().unsqueeze(-1)
if mask.any():
y += self.experts[i](x * mask) * weights.unsqueeze(-1)
return y
class TransformerBlock(nn.Module):
def __init__(self, args):
super().__init__()
self.attn = nn.MultiheadAttention(args.dim, args.num_heads, batch_first=True)
self.ffn = MoE(args) if args.use_moe else MLP(args.dim, args.inter_dim)
self.norm1 = nn.LayerNorm(args.dim)
self.norm2 = nn.LayerNorm(args.dim)
def forward(self, x):
x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x), need_weights=False)[0]
x = x + self.ffn(self.norm2(x))
return x
class Transformer(nn.Module):
def __init__(self, args):
super().__init__()
self.embed = nn.Embedding(args.vocab_size, args.dim)
self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.n_layers)])
self.norm = nn.LayerNorm(args.dim)
self.head = nn.Linear(args.dim, args.vocab_size, bias=False)
def forward(self, x):
x = self.embed(x)
for layer in self.layers:
x = checkpoint(layer, x)
return self.head(self.norm(x))
if __name__ == "__main__":
class ModelArgs:
vocab_size = 32000
dim = 1024
inter_dim = 4096
n_layers = 12
num_heads = 8
use_moe = True
n_routed_experts = 4
n_activated_experts = 2
moe_inter_dim = 4096
args = ModelArgs()
model = Transformer(args).cuda()
x = torch.randint(0, args.vocab_size, (2, 128), device='cuda')
print(model(x).size())