From c47ecaa800fe3c7283dea10e747402b16606442c Mon Sep 17 00:00:00 2001 From: Pedro Dessanti <92925579+Dessantii@users.noreply.github.com> Date: Mon, 27 Jan 2025 10:36:57 -0300 Subject: [PATCH] Update model.py Fixed the indentation of the suggestion. --- inference/model.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/inference/model.py b/inference/model.py index 472d740..e6acf1a 100644 --- a/inference/model.py +++ b/inference/model.py @@ -779,20 +779,21 @@ class Transformer(nn.Module): torch.Tensor: Logits tensor of shape (batch_size, vocab_size). """ with autocast(): - seqlen = tokens.size(1) - h = self.embed(tokens) - freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] - mask = None - if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) - for layer in self.layers: - h = layer(h, start_pos, freqs_cis, mask) - h = self.norm(h)[:, -1] - logits = self.head(h) - if world_size > 1: - all_logits = [torch.empty_like(logits) for _ in range(world_size)] - dist.all_gather(all_logits, logits) - logits = torch.cat(all_logits, dim=-1) + seqlen = tokens.size(1) + h = self.embed(tokens) + freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] + mask = None + if seqlen > 1: + mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + for layer in self.layers: + h = layer(h, start_pos, freqs_cis, mask) + h = self.norm(h)[:, -1] + logits = self.head(h) + if world_size > 1: + all_logits = [torch.empty_like(logits) for _ in range(world_size)] + dist.all_gather(all_logits, logits) + logits = torch.cat(all_logits, dim=-1) + return logits