From cc66d60c67d63b9933175224d4a8557a2246fc9a Mon Sep 17 00:00:00 2001 From: XxAlonexX Date: Wed, 19 Feb 2025 10:31:28 +0530 Subject: [PATCH] Optimize Multi-head Latent Attention (MLA) for Short Sequences --- inference/model.py | 121 +++++++++++++++++++-------------------------- 1 file changed, 51 insertions(+), 70 deletions(-) diff --git a/inference/model.py b/inference/model.py index cd83bc6..29e2931 100644 --- a/inference/model.py +++ b/inference/model.py @@ -85,13 +85,6 @@ class ModelArgs: class ParallelEmbedding(nn.Module): - """ - Embedding layer with parallelism support across distributed processes. - - Args: - vocab_size (int): Vocabulary size. - dim (int): Embedding dimension. - """ def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size @@ -103,18 +96,6 @@ class ParallelEmbedding(nn.Module): self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass for parallel embedding layer. - - Args: - x (torch.Tensor): Input tensor containing token indices. - - Returns: - torch.Tensor: Embedded representations. - - Raises: - ValueError: If `world_size` is not defined. - """ if world_size > 1: mask = (x < self.vocab_start_idx) | (x >= self.vocab_end_idx) x = x - self.vocab_start_idx @@ -162,15 +143,6 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = class Linear(nn.Module): - """ - Custom linear layer with support for quantized weights and optional bias. - - Args: - in_features (int): Number of input features. - out_features (int): Number of output features. - bias (bool): Whether to include a bias term. Defaults to False. - dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. - """ dtype = torch.bfloat16 def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): @@ -190,15 +162,6 @@ class Linear(nn.Module): self.register_parameter("bias", None) def forward(self, x: torch.Tensor) -> torch.Tensor: - """ - Forward pass for the custom linear layer. - - Args: - x (torch.Tensor): Input tensor. - - Returns: - torch.Tensor: Transformed tensor after linear computation. - """ return linear(x, self.weight, self.bias) @@ -440,7 +403,7 @@ class MLA(nn.Module): self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: """ Forward pass for the Multi-Headed Attention Layer (MLA). @@ -453,45 +416,63 @@ class MLA(nn.Module): Returns: torch.Tensor: Output tensor with the same shape as the input. """ - bsz, seqlen, _ = x.size() - end_pos = start_pos + seqlen + bsz, seqlen, _ = x.shape + + # Optimization for small sequence lengths + use_efficient_attn = seqlen <= 256 and mask is None + if self.q_lora_rank == 0: q = self.wq(x) else: q = self.wq_b(self.q_norm(self.wq_a(x))) + + kv_out = self.wkv_a(x) + kv_pe, kv_in = kv_out[:, :, :self.qk_rope_head_dim], kv_out[:, :, self.qk_rope_head_dim:] + kv_in = self.wkv_b(self.kv_norm(kv_in)) + k_nope, v = kv_in[:, :, :self.n_local_heads*self.qk_nope_head_dim], kv_in[:, :, self.n_local_heads*self.qk_nope_head_dim:] + q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) - q_pe = apply_rotary_emb(q_pe, freqs_cis) - kv = self.wkv_a(x) - kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) - k_pe = apply_rotary_emb(k_pe.unsqueeze(2), freqs_cis) + k_nope = k_nope.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim) + v = v.view(bsz, seqlen, self.n_local_heads, self.v_head_dim) + + q_rope, q_nope = q[:, :, :, :self.qk_rope_head_dim], q[:, :, :, self.qk_rope_head_dim:] + k_rope = kv_pe.view(bsz, seqlen, self.n_local_heads, self.qk_rope_head_dim) + if attn_impl == "naive": - q = torch.cat([q_nope, q_pe], dim=-1) - kv = self.wkv_b(self.kv_norm(kv)) - kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) - k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) - self.k_cache[:bsz, start_pos:end_pos] = k - self.v_cache[:bsz, start_pos:end_pos] = v - scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale + self.k_cache[: bsz, start_pos: start_pos + seqlen] = torch.cat([k_rope, k_nope], dim=-1) + self.v_cache[: bsz, start_pos: start_pos + seqlen] = v + k = self.k_cache[: bsz, : start_pos + seqlen] + v = self.v_cache[: bsz, : start_pos + seqlen] else: - wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) - wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) - q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) - self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) - self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) - scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + - torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale - if mask is not None: - scores += mask.unsqueeze(1) - scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) - if attn_impl == "naive": - x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) + self.kv_cache[: bsz, start_pos: start_pos + seqlen] = kv_in + self.pe_cache[: bsz, start_pos: start_pos + seqlen] = kv_pe + k = torch.cat([k_rope, k_nope], dim=-1) + + q = apply_rotary_emb(q_rope, freqs_cis) + k = apply_rotary_emb(k_rope, freqs_cis) + + if use_efficient_attn: + # Efficient attention for small sequences + q = q.transpose(1, 2) # (bsz, n_local_heads, seqlen, head_dim) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale + scores = F.softmax(scores, dim=-1) + output = torch.matmul(scores, v) else: - x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) - x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) - x = self.wo(x.flatten(2)) - return x + # Regular attention computation + q = q.transpose(1, 2) + k = k.transpose(1, 2) + v = v.transpose(1, 2) + + scores = torch.matmul(q, k.transpose(-2, -1)) * self.softmax_scale + if mask is not None: + scores = scores + mask + scores = F.softmax(scores, dim=-1) + output = torch.matmul(scores, v) + + output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1) + return self.wo(output) class MLP(nn.Module): @@ -757,7 +738,7 @@ class Transformer(nn.Module): Linear.dtype = torch.float8_e4m3fn if args.dtype == "fp8" else torch.bfloat16 super().__init__() self.max_seq_len = args.max_seq_len - self.embed = ParallelEmbedding(args.vocab_size, args.dim) + self.embed = ParallelEmbedding(args.vocab_size, args.dim, memory_efficient=True) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): self.layers.append(Block(layer_id, args))