From 77c46698b9bd6c73c6371dab8ada0c2ad95f6369 Mon Sep 17 00:00:00 2001 From: CodingParadigm1 <124928232+CodingParadigm1@users.noreply.github.com> Date: Mon, 3 Feb 2025 10:24:23 -0700 Subject: [PATCH] added assert to Parser class Parser now can apply assert to self --- inference/convert.py | 3 +-- inference/generate.py | 3 +-- inference/parser.py | 6 ++++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index e83ba9a..017d4cc 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -95,6 +95,5 @@ if __name__ == "__main__": ("--n-experts", type:=int, required:=True), ("--model-parallel", type:=int, required:=True) ] - args = Parser(arg_list).apply_args().return_args() - assert args.n_experts % args.model_parallel == 0 + args = Parser(arg_list).apply_args().assert_model_parallel().return_args() sync.run(main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)) diff --git a/inference/generate.py b/inference/generate.py index c67599c..80dbd30 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -176,6 +176,5 @@ if __name__ == "__main__": ("--max-new-tokens", type:=int, default:=200), ("--temperature", type:=float, default:=0.2) ] - args = Parser(arg_list).apply_args().return_args() - assert args.input_file or args.interactive + args = Parser(arg_list).apply_args().assert_interactive().return_args() run(main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)) diff --git a/inference/parser.py b/inference/parser.py index 847ce1a..12fd64e 100644 --- a/inference/parser.py +++ b/inference/parser.py @@ -7,6 +7,12 @@ class Parser(): def apply_args(self): for arg in self.arg_list: self.parser.add_argument(*arg) return self + def assert_model_parallel(self): + assert self.return_args.n_experts % self.return_args().model_parallel == 0 + return self + def assert_interactive(): + assert self.return_args().input_file or self.return_args().interactive + return self def return_args(self): return self.parser.parse_args() \ No newline at end of file