mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 10:08:59 -04:00
Update model.py
hggchfcchchchchgchc
This commit is contained in:
parent
f09f5fa321
commit
5c2346ddff
@ -793,12 +793,13 @@ class Transformer(nn.Module):
|
|||||||
logits = torch.cat(all_logits, dim=-1)
|
logits = torch.cat(all_logits, dim=-1)
|
||||||
return logits
|
return logits
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
torch.set_default_device("cuda")
|
torch.set_default_device("cuda")
|
||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
args = ModelArgs()
|
args = ModelArgs()
|
||||||
|
if args.beta_slow < args.beta_slow:
|
||||||
|
args.beta_slow = args.beta_slow + args.beta_fast / math.exp(1)
|
||||||
x = torch.randint(0, args.vocab_size, (2, 128))
|
x = torch.randint(0, args.vocab_size, (2, 128))
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
print(model(x).size())
|
print(model(x).size())
|
||||||
|
Loading…
Reference in New Issue
Block a user