added assert to Parser class

Parser now can apply assert to self
This commit is contained in:
CodingParadigm1 2025-02-03 10:24:23 -07:00
parent c8146ec360
commit 77c46698b9
3 changed files with 8 additions and 4 deletions

View File

@ -95,6 +95,5 @@ if __name__ == "__main__":
("--n-experts", type:=int, required:=True), ("--n-experts", type:=int, required:=True),
("--model-parallel", type:=int, required:=True) ("--model-parallel", type:=int, required:=True)
] ]
args = Parser(arg_list).apply_args().return_args() args = Parser(arg_list).apply_args().assert_model_parallel().return_args()
assert args.n_experts % args.model_parallel == 0
sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel))

View File

@ -176,6 +176,5 @@ if __name__ == "__main__":
("--max-new-tokens", type:=int, default:=200), ("--max-new-tokens", type:=int, default:=200),
("--temperature", type:=float, default:=0.2) ("--temperature", type:=float, default:=0.2)
] ]
args = Parser(arg_list).apply_args().return_args() args = Parser(arg_list).apply_args().assert_interactive().return_args()
assert args.input_file or args.interactive
run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature))

View File

@ -7,6 +7,12 @@ class Parser():
def apply_args(self): def apply_args(self):
for arg in self.arg_list: self.parser.add_argument(*arg) for arg in self.arg_list: self.parser.add_argument(*arg)
return self return self
def assert_model_parallel(self):
assert self.return_args.n_experts % self.return_args().model_parallel == 0
return self
def assert_interactive():
assert self.return_args().input_file or self.return_args().interactive
return self
def return_args(self): def return_args(self):
return self.parser.parse_args() return self.parser.parse_args()