diff --git a/inference/model.py b/inference/model.py index 40bbf4d..8f1ab81 100644 --- a/inference/model.py +++ b/inference/model.py @@ -585,8 +585,8 @@ class Gate(nn.Module): else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] - mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) - scores = (scores * mask.unsqueeze(-1)).flatten(1) + mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func == "sigmoid":