diff --git a/inference/model.py b/inference/model.py index 3d8bf2c..18586da 100644 --- a/inference/model.py +++ b/inference/model.py @@ -802,7 +802,7 @@ if __name__ == "__main__": default_device = "mps" else: default_device = "cpu" - torch.set_default_device("default_device") + torch.set_default_device(default_device) torch.manual_seed(0) args = ModelArgs() x = torch.randint(0, args.vocab_size, (2, 128))