mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 05:38:59 -05:00
fix scores mask
This commit is contained in:
parent
2f7b80eece
commit
1398800ebf
@ -585,8 +585,8 @@ class Gate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
||||||
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
||||||
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
|
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
|
||||||
scores = (scores * mask.unsqueeze(-1)).flatten(1)
|
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
||||||
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
||||||
weights = original_scores.gather(1, indices)
|
weights = original_scores.gather(1, indices)
|
||||||
if self.score_func == "sigmoid":
|
if self.score_func == "sigmoid":
|
||||||
|
Loading…
Reference in New Issue
Block a user