Optimize Multi-head Latent Attention (MLA) for Short Sequences

This commit is contained in:
XxAlonexX 2025-02-19 10:31:28 +05:30
parent 6a30b43249
commit cc66d60c67

View File

@ -85,13 +85,6 @@ class ModelArgs:
class ParallelEmbedding(nn.Module): class ParallelEmbedding(nn.Module):
"""
Embedding layer with parallelism support across distributed processes.
Args:
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int): def __init__(self, vocab_size: int, dim: int):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
@ -103,18 +96,6 @@ class ParallelEmbedding(nn.Module):
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for parallel embedding layer.
Args:
x (torch.Tensor): Input tensor containing token indices.
Returns:
torch.Tensor: Embedded representations.
Raises:
ValueError: If `world_size` is not defined.
"""
if world_size > 1: if world_size > 1:
mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx)
x = x - self.vocab_start_idx x = x - self.vocab_start_idx
@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
class Linear(nn.Module): class Linear(nn.Module):
"""
Custom linear layer with support for quantized weights and optional bias.
Args:
in_features (int): Number of input features.
out_features (int): Number of output features.
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16 dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
@ -190,15 +162,6 @@ class Linear(nn.Module):
self.register_parameter("bias", None) self.register_parameter("bias", None)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass for the custom linear layer.
Args:
x (torch.Tensor): Input tensor.
Returns:
torch.Tensor: Transformed tensor after linear computation.
"""
return linear(x, self.weight, self.bias) return linear(x, self.weight, self.bias)
@ -440,7 +403,7 @@ class MLA(nn.Module):
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
""" """
Forward pass for the Multi-Headed Attention Layer (MLA). Forward pass for the Multi-Headed Attention Layer (MLA).
@ -453,45 +416,63 @@ class MLA(nn.Module):
Returns: Returns:
torch.Tensor: Output tensor with the same shape as the input. torch.Tensor: Output tensor with the same shape as the input.
""" """
bsz, seqlen, _ = x.size() bsz, seqlen, _ = x.shape
end_pos = start_pos + seqlen
# Optimization for small sequence lengths
use_efficient_attn = seqlen <= 256 and mask is None
if self.q_lora_rank == 0: if self.q_lora_rank == 0:
q = self.wq(x) q = self.wq(x)
else: else:
q = self.wq_b(self.q_norm(self.wq_a(x))) q = self.wq_b(self.q_norm(self.wq_a(x)))
kv_out = self.wkv_a(x)
kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:]
kv_in = self.wkv_b(self.kv_norm(kv_in))
k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:]
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim)
q_pe = apply_rotary_emb(q_pe, freqs_cis) v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:]
k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim)
if attn_impl == "naive": if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1) self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1)
kv = self.wkv_b(self.kv_norm(kv)) self.v_cache[: bsz, start_pos: start_pos + seqlen] = v
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) k = self.k_cache[: bsz, : start_pos + seqlen]
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) v = self.v_cache[: bsz, : start_pos + seqlen]
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
else: else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) k = torch.cat([k_rope, k_nope], dim=-1)
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) q = apply_rotary_emb(q_rope, freqs_cis)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + k = apply_rotary_emb(k_rope, freqs_cis)
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
if mask is not None: if use_efficient_attn:
scores += mask.unsqueeze(1) # Efficient attention for small sequences
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim)
if attn_impl == "naive": k = k.transpose(1, 2)
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
else: else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) # Regular attention computation
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) q = q.transpose(1, 2)
x = self.wo(x.flatten(2)) k = k.transpose(1, 2)
return x v = v.transpose(1, 2)
scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale
if mask is not None:
scores = scores + mask
scores = F.softmax(scores, dim=-1)
output = torch.matmul(scores, v)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
return self.wo(output)
class MLP(nn.Module): class MLP(nn.Module):
@ -757,7 +738,7 @@ class Transformer(nn.Module):
Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16
super().__init__() super().__init__()
self.max_seq_len = args.max_seq_len self.max_seq_len = args.max_seq_len
self.embed = ParallelEmbedding(args.vocab_size, args.dim) self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True)
self.layers = torch.nn.ModuleList() self.layers = torch.nn.ModuleList()
for layer_id in range(args.n_layers): for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args)) self.layers.append(Block(layer_id, args))