Merge pull request #720 from xiaokongkong/main

modify the explanation of MLA
This commit is contained in:
Xingkai Yu 2025-04-08 17:20:37 +08:00 committed by GitHub
commit 741b06ebca
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -392,7 +392,7 @@ def apply_rotary_emb(x: torch.Tensor, freqs_cis: torch.Tensor) -> torch.Tensor:
class MLA(nn.Module): class MLA(nn.Module):
""" """
Multi-Headed Attention Layer (MLA). Multi-Head Latent Attention (MLA) Layer.
Attributes: Attributes:
dim (int): Dimensionality of the input features. dim (int): Dimensionality of the input features.
@ -442,7 +442,7 @@ class MLA(nn.Module):
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
""" """
Forward pass for the Multi-Headed Attention Layer (MLA). Forward pass for the Multi-Head Latent Attention (MLA) Layer.
Args: Args:
x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim). x (torch.Tensor): Input tensor of shape (batch_size, seq_len, dim).