fix scores mask

This commit is contained in:
Xingkai Yu 2025-02-14 20:26:45 +08:00 committed by GitHub
parent 2f7b80eece
commit 1398800ebf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -585,8 +585,8 @@ class Gate(nn.Module):
else:
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
scores = (scores * mask.unsqueeze(-1)).flatten(1)
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
indices = torch.topk(scores, self.topk, dim=-1)[1]
weights = original_scores.gather(1, indices)
if self.score_func == "sigmoid":