diff --git a/inference/model.py b/inference/model.py index 40bbf4d..7000f79 100644 --- a/inference/model.py +++ b/inference/model.py @@ -802,3 +802,5 @@ if __name__ == "__main__": x = torch.randint(0, args.vocab_size, (2, 128)) model = Transformer(args) print(model(x).size()) + +# Automated edit: [Edited] Fix minor bug in the main function