From 5c2346ddffa6b72476fcb15514e85bbe469fcc75 Mon Sep 17 00:00:00 2001 From: helme Date: Tue, 18 Feb 2025 06:52:08 -1200 Subject: [PATCH] Update model.py hggchfcchchchchgchc --- inference/model.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inference/model.py b/inference/model.py index 8f1ab81..39edad0 100644 --- a/inference/model.py +++ b/inference/model.py @@ -793,12 +793,13 @@ class Transformer(nn.Module): logits = torch.cat(all_logits, dim=-1) return logits - if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs() + if args.beta_slow < args.beta_slow: + args.beta_slow = args.beta_slow + args.beta_fast / math.exp(1) x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size())