From 2756e130c2430eedc916ca331f5e360b519ed7ab Mon Sep 17 00:00:00 2001 From: Roman Fitzjalen Date: Tue, 28 Jan 2025 13:16:54 +0100 Subject: [PATCH] clarify assertion error --- inference/convert.py | 6 +++--- inference/generate.py | 6 +++--- inference/kernel.py | 12 ++++++------ inference/model.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index c606ce8..6d85ccc 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -60,7 +60,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): name = name.replace("weight_scale_inv", "scale") name = name.replace("e_score_correction_bias", "bias") key = name.split(".")[-2] - assert key in mapping + assert key in mapping, f"Key {key} not found in mapping" new_key, dim = mapping[key] name = name.replace(key, new_key) for i in range(mp): @@ -70,7 +70,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: continue elif dim is not None: - assert param.size(dim) % mp == 0 + assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" shard_size = param.size(dim) // mp new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() state_dicts[i][name] = new_param @@ -92,5 +92,5 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..7e9bffe 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -49,7 +49,7 @@ def generate( List[List[int]]: A list of lists containing the generated tokens for each sequence. """ prompt_lens = [len(t) for t in prompt_tokens] - assert max(prompt_lens) <= model.max_seq_len + assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") for i, t in enumerate(prompt_tokens): @@ -145,7 +145,7 @@ def main( else: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] - assert len(prompts) <= args.max_batch_size + assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) @@ -181,5 +181,5 @@ if __name__ == "__main__": parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() - assert args.input_file or args.interactive + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..ae907ad 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -43,8 +43,8 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. """ - assert x.is_contiguous() - assert x.size(-1) % block_size == 0 + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) @@ -96,8 +96,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous() - assert x.dim() == 2 and s.dim() == 2 + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) @@ -180,8 +180,8 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten Returns: torch.Tensor: The result of the matrix multiplication. """ - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K N = b.size(0) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..2afaba9 100644 --- a/inference/model.py +++ b/inference/model.py @@ -96,7 +96,7 @@ class ParallelEmbedding(nn.Module): super().__init__() self.vocab_size = vocab_size self.dim = dim - assert vocab_size % world_size == 0 + assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size @@ -213,7 +213,7 @@ class ColumnParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert out_features % world_size == 0 + assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) @@ -242,7 +242,7 @@ class RowParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert in_features % world_size == 0 + assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) @@ -652,7 +652,7 @@ class MoE(nn.Module): """ super().__init__() self.dim = args.dim - assert args.n_routed_experts % world_size == 0 + assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts