diff --git a/inference/model.py b/inference/model.py index 8f1ab81..39edad0 100644 --- a/inference/model.py +++ b/inference/model.py @@ -793,12 +793,13 @@ class Transformer(nn.Module): logits = torch.cat(all_logits, dim=-1) return logits - if __name__ == "__main__": torch.set_default_dtype(torch.bfloat16) torch.set_default_device("cuda") torch.manual_seed(0) args = ModelArgs() + if args.beta_slow < args.beta_slow: + args.beta_slow = args.beta_slow + args.beta_fast / math.exp(1) x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size())