From d29a967601cc772ede6c475870e3b591f2f89c45 Mon Sep 17 00:00:00 2001 From: huxuedan Date: Wed, 26 Feb 2025 17:06:54 +0800 Subject: [PATCH] modify the explanation of MLA --- inference/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/model.py b/inference/model.py index 8f1ab81..c143e97 100644 --- a/inference/model.py +++ b/inference/model.py @@ -392,7 +392,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor: class MLA(nn.Module): """ - Multi-Headed Attention Layer (MLA). + Multi-Head Latent Attention (MLA) Layer. Attributes: dim (int): Dimensionality of the input features. @@ -442,7 +442,7 @@ class MLA(nn.Module): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): """ - Forward pass for the Multi-Headed Attention Layer (MLA). + Forward pass for the Multi-Head Latent Attention (MLA) Layer. Args: x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).