Update model.py

hggchfcchchchchgchc
This commit is contained in:
helme 2025-02-18 06:52:08 -12:00 committed by GitHub
parent f09f5fa321
commit 5c2346ddff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -793,12 +793,13 @@ class Transformer(nn.Module):
logits = torch.cat(all_logits, dim=-1) logits = torch.cat(all_logits, dim=-1)
return logits return logits
if __name__ == "__main__": if __name__ == "__main__":
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda") torch.set_default_device("cuda")
torch.manual_seed(0) torch.manual_seed(0)
args = ModelArgs() args = ModelArgs()
if args.beta_slow < args.beta_slow:
args.beta_slow = args.beta_slow + args.beta_fast / math.exp(1)
x = torch.randint(0, args.vocab_size, (2, 128)) x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args) model = Transformer(args)
print(model(x).size()) print(model(x).size())