diff --git a/inference/convert.py b/inference/convert.py index c606ce8..77655af 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -39,7 +39,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): save_path (str): Path to the directory where the converted checkpoint files will be saved. n_experts (int): Total number of experts in the model. mp (int): Model parallelism factor. - + Returns: None """ @@ -54,7 +54,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): continue param: torch.Tensor = f.get_tensor(name) if name.startswith("model."): - name = name[len("model."):] + name = name[len("model.") :] name = name.replace("self_attn", "attn") name = name.replace("mlp", "ffn") name = name.replace("weight_scale_inv", "scale") @@ -67,18 +67,25 @@ def main(hf_ckpt_path, save_path, n_experts, mp): new_param = param if "experts" in name and "shared_experts" not in name: idx = int(name.split(".")[-3]) - if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: + if ( + idx < i * n_local_experts + or idx >= (i + 1) * n_local_experts + ): continue elif dim is not None: assert param.size(dim) % mp == 0 shard_size = param.size(dim) // mp - new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() + new_param = param.narrow( + dim, i * shard_size, shard_size + ).contiguous() state_dicts[i][name] = new_param os.makedirs(save_path, exist_ok=True) for i in trange(mp): - save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) + save_file( + state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors") + ) for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): new_file_path = os.path.join(save_path, os.path.basename(file_path)) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..575c1ee 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -9,6 +9,7 @@ from safetensors.torch import load_file, save_file from kernel import weight_dequant + def main(fp8_path, bf16_path): """ Converts FP8 weights to BF16 and saves the converted weights. @@ -35,7 +36,7 @@ def main(fp8_path, bf16_path): with open(model_index_file, "r") as f: model_index = json.load(f) weight_map = model_index["weight_map"] - + # Cache for loaded safetensor files loaded_files = {} fp8_weight_names = [] @@ -66,7 +67,7 @@ def main(fp8_path, bf16_path): file_name = os.path.basename(safetensor_file) current_state_dict = load_file(safetensor_file, device="cuda") loaded_files[file_name] = current_state_dict - + new_state_dict = {} for weight_name, weight in current_state_dict.items(): if weight_name.endswith("_scale_inv"): @@ -79,20 +80,22 @@ def main(fp8_path, bf16_path): fp8_weight_names.append(weight_name) new_state_dict[weight_name] = weight_dequant(weight, scale_inv) except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") + print( + f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion" + ) new_state_dict[weight_name] = weight else: new_state_dict[weight_name] = weight - + new_safetensor_file = os.path.join(bf16_path, file_name) save_file(new_state_dict, new_safetensor_file) - + # Memory management: keep only the 2 most recently used files if len(loaded_files) > 2: oldest_file = next(iter(loaded_files)) del loaded_files[oldest_file] torch.cuda.empty_cache() - + # Update model index new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") for weight_name in fp8_weight_names: @@ -101,7 +104,7 @@ def main(fp8_path, bf16_path): weight_map.pop(scale_inv_name) with open(new_model_index_file, "w") as f: json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) - + if __name__ == "__main__": parser = ArgumentParser() @@ -109,4 +112,3 @@ if __name__ == "__main__": parser.add_argument("--output-bf16-hf-path", type=str, required=True) args = parser.parse_args() main(args.input_fp8_hf_path, args.output_bf16_hf_path) - diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..fb516f4 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -33,7 +33,7 @@ def generate( prompt_tokens: List[List[int]], max_new_tokens: int, eos_id: int, - temperature: float = 1.0 + temperature: float = 1.0, ) -> List[List[int]]: """ Generates new tokens based on the given prompt tokens using the specified model. @@ -51,9 +51,11 @@ def generate( prompt_lens = [len(t) for t in prompt_tokens] assert max(prompt_lens) <= model.max_seq_len total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) - tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") + tokens = torch.full( + (len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda" + ) for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + tokens[i, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") prev_pos = 0 finished = torch.tensor([False] * len(prompt_tokens), device="cuda") prompt_mask = tokens != -1 @@ -63,7 +65,9 @@ def generate( next_token = sample(logits, temperature) else: next_token = logits.argmax(dim=-1) - next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) + next_token = torch.where( + prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) tokens[:, cur_pos] = next_token finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) prev_pos = cur_pos @@ -71,9 +75,9 @@ def generate( break completion_tokens = [] for i, toks in enumerate(tokens.tolist()): - toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] + toks = toks[prompt_lens[i] : prompt_lens[i] + max_new_tokens] if eos_id in toks: - toks = toks[:toks.index(eos_id)] + toks = toks[: toks.index(eos_id)] completion_tokens.append(toks) return completion_tokens @@ -115,8 +119,10 @@ def main( with torch.device("cuda"): model = Transformer(args) tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) - load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.0)[0]) + load_model( + model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors") + ) if interactive: messages = [] @@ -137,18 +143,37 @@ def main( messages.clear() continue messages.append({"role": "user", "content": prompt}) - prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) - completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) + prompt_tokens = tokenizer.apply_chat_template( + messages, add_generation_prompt=True + ) + completion_tokens = generate( + model, + [prompt_tokens], + max_new_tokens, + tokenizer.eos_token_id, + temperature, + ) + completion = tokenizer.decode( + completion_tokens[0], skip_special_tokens=True + ) print(completion) messages.append({"role": "assistant", "content": completion}) else: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] assert len(prompts) <= args.max_batch_size - prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] - completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) - completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + prompt_tokens = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], add_generation_prompt=True + ) + for prompt in prompts + ] + completion_tokens = generate( + model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature + ) + completions = tokenizer.batch_decode( + completion_tokens, skip_special_tokens=True + ) for prompt, completion in zip(prompts, completions): print("Prompt:", prompt) print("Completion:", completion) @@ -182,4 +207,11 @@ if __name__ == "__main__": parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() assert args.input_file or args.interactive - main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) + main( + args.ckpt_path, + args.config, + args.input_file, + args.interactive, + args.max_new_tokens, + args.temperature, + ) diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..a175ba8 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -23,14 +23,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): pid = tl.program_id(axis=0) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. + s = tl.max(tl.abs(x)) / 448.0 y = x / s y = y.to(y_ptr.dtype.element_ty) tl.store(y_ptr + offs, y) tl.store(s_ptr + pid, s) -def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: +def act_quant( + x: torch.Tensor, block_size: int = 128 +) -> Tuple[torch.Tensor, torch.Tensor]: """ Quantizes the input tensor `x` using block-wise quantization. @@ -47,7 +49,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor assert x.size(-1) % block_size == 0 y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) - grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) + grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),) act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) return y, s @@ -81,7 +83,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): tl.store(y_ptr + offs, y, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: +def weight_dequant( + x: torch.Tensor, s: torch.Tensor, block_size: int = 128 +) -> torch.Tensor: """ Dequantizes the given weight tensor using the provided scale tensor. @@ -100,24 +104,41 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t assert x.dim() == 2 and s.dim() == 2 M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) + grid = lambda meta: ( + triton.cdiv(M, meta["BLOCK_SIZE"]), + triton.cdiv(N, meta["BLOCK_SIZE"]), + ) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) return y fp8_gemm_configs = [ - Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) - for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] + Config( + {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128}, + num_stages=num_stages, + num_warps=8, + ) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] ] -@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) + +@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"]) @triton.jit -def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, - a_s_ptr, b_s_ptr, - M, N: tl.constexpr, K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): +def fp8_gemm_kernel( + a_ptr, + b_ptr, + c_ptr, + a_s_ptr, + b_s_ptr, + M, + N: tl.constexpr, + K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr, +): """ Performs a matrix multiplication operation on FP8 matrices with scaling factors. @@ -186,6 +207,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten M = a.numel() // K N = b.size(0) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) + grid = lambda META: ( + triton.cdiv(M, META["BLOCK_SIZE_M"]), + triton.cdiv(N, META["BLOCK_SIZE_N"]), + ) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) return c diff --git a/inference/model.py b/inference/model.py index 9ea60c9..1d3f2ff 100644 --- a/inference/model.py +++ b/inference/model.py @@ -16,6 +16,7 @@ block_size = 128 gemm_impl: Literal["bf16", "fp8"] = "bf16" attn_impl: Literal["naive", "absorb"] = "absorb" + @dataclass class ModelArgs: """ @@ -51,6 +52,7 @@ class ModelArgs: beta_slow (int): Slow beta correction factor. mscale (float): Scaling factor for extended attention. """ + max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" @@ -68,7 +70,7 @@ class ModelArgs: n_expert_groups: int = 1 n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" - route_scale: float = 1. + route_scale: float = 1.0 # mla q_lora_rank: int = 0 kv_lora_rank: int = 512 @@ -81,7 +83,7 @@ class ModelArgs: rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 - mscale: float = 1. + mscale: float = 1.0 class ParallelEmbedding(nn.Module): @@ -92,12 +94,13 @@ class ParallelEmbedding(nn.Module): vocab_size (int): Vocabulary size. dim (int): Embedding dimension. """ + def __init__(self, vocab_size: int, dim: int): super().__init__() self.vocab_size = vocab_size self.dim = dim assert vocab_size % world_size == 0 - self.part_vocab_size = (vocab_size // world_size) + self.part_vocab_size = vocab_size // world_size self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) @@ -126,7 +129,9 @@ class ParallelEmbedding(nn.Module): return y -def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: +def linear( + x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None +) -> torch.Tensor: """ Applies a linear transformation to the incoming data: y = xA^T + b. This function supports specialized implementations based on quantization @@ -134,16 +139,16 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = Args: x (torch.Tensor): The input tensor. - weight (torch.Tensor): The weight tensor. It may be quantized and + weight (torch.Tensor): The weight tensor. It may be quantized and requires dequantization for certain cases. bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None. Returns: - torch.Tensor: The result of the linear transformation, which may involve + torch.Tensor: The result of the linear transformation, which may involve quantization-aware computations depending on the input parameters. Notes: - - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version + - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version is used for computation. - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. @@ -171,17 +176,24 @@ class Linear(nn.Module): bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ + dtype = torch.bfloat16 - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): super().__init__() self.in_features = in_features self.out_features = out_features - self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) + self.weight = nn.Parameter( + torch.empty(out_features, in_features, dtype=dtype or Linear.dtype) + ) if self.weight.element_size() == 1: scale_out_features = (out_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size - self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) + self.weight.scale = self.scale = nn.Parameter( + torch.empty(scale_out_features, scale_in_features, dtype=torch.float32) + ) else: self.register_parameter("scale", None) if bias: @@ -212,7 +224,10 @@ class ColumnParallelLinear(Linear): bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): assert out_features % world_size == 0 self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) @@ -241,7 +256,10 @@ class RowParallelLinear(Linear): bias (bool): Whether to include a bias term. Defaults to False. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ - def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): + + def __init__( + self, in_features: int, out_features: int, bias: bool = False, dtype=None + ): assert in_features % world_size == 0 self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) @@ -272,6 +290,7 @@ class RMSNorm(nn.Module): dim (int): Dimension of the input tensor. eps (float): Epsilon value for numerical stability. Defaults to 1e-6. """ + def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.dim = dim @@ -321,7 +340,11 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: Returns: float: The correction dimension based on the input parameters. """ - return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) + return ( + dim + * math.log(max_seq_len / (num_rotations * 2 * math.pi)) + / (2 * math.log(base)) + ) def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): """ @@ -339,7 +362,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: """ low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) - return max(low, 0), min(high, dim-1) + return max(low, 0), min(high, dim - 1) def linear_ramp_factor(min, max, dim): """ @@ -362,7 +385,9 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor: freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) if seqlen > args.original_seq_len: - low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) + low, high = find_correction_range( + beta_fast, beta_slow, dim, base, args.original_seq_len + ) smooth = 1 - linear_ramp_factor(low, high, dim // 2) freqs = freqs / factor * (1 - smooth) + freqs * smooth @@ -406,6 +431,7 @@ class MLA(nn.Module): v_head_dim (int): Dimensionality of value projections. softmax_scale (float): Scaling factor for softmax in attention computation. """ + def __init__(self, args: ModelArgs): super().__init__() self.dim = args.dim @@ -423,24 +449,62 @@ class MLA(nn.Module): else: self.wq_a = Linear(self.dim, self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank) - self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) + self.wq_b = ColumnParallelLinear( + self.q_lora_rank, self.n_heads * self.qk_head_dim + ) self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) self.kv_norm = RMSNorm(self.kv_lora_rank) - self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) + self.wkv_b = ColumnParallelLinear( + self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim) + ) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) - self.softmax_scale = self.qk_head_dim ** -0.5 + self.softmax_scale = self.qk_head_dim**-0.5 if args.max_seq_len > args.original_seq_len: mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 self.softmax_scale = self.softmax_scale * mscale * mscale if attn_impl == "naive": - self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) - self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) + self.register_buffer( + "k_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.n_local_heads, + self.qk_head_dim, + ), + persistent=False, + ) + self.register_buffer( + "v_cache", + torch.zeros( + args.max_batch_size, + args.max_seq_len, + self.n_local_heads, + self.v_head_dim, + ), + persistent=False, + ) else: - self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) - self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) + self.register_buffer( + "kv_cache", + torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), + persistent=False, + ) + self.register_buffer( + "pe_cache", + torch.zeros( + args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim + ), + persistent=False, + ) - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ): """ Forward pass for the Multi-Headed Attention Layer (MLA). @@ -460,7 +524,9 @@ class MLA(nn.Module): else: q = self.wq_b(self.q_norm(self.wq_a(x))) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) - q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_nope, q_pe = torch.split( + q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1 + ) q_pe = apply_rotary_emb(q_pe, freqs_cis) kv = self.wkv_a(x) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) @@ -468,20 +534,35 @@ class MLA(nn.Module): if attn_impl == "naive": q = torch.cat([q_nope, q_pe], dim=-1) kv = self.wkv_b(self.kv_norm(kv)) - kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) - k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + kv = kv.view( + bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim + ) + k_nope, v = torch.split( + kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1 + ) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) self.k_cache[:bsz, start_pos:end_pos] = k self.v_cache[:bsz, start_pos:end_pos] = v - scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale + scores = ( + torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) + * self.softmax_scale + ) else: - wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) + wkv_b = ( + self.wkv_b.weight + if self.wkv_b.scale is None + else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) + ) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) - q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) + q_nope = torch.einsum( + "bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim] + ) self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) - scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + - torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale + scores = ( + torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + + torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos]) + ) * self.softmax_scale if mask is not None: scores += mask.unsqueeze(1) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) @@ -489,7 +570,7 @@ class MLA(nn.Module): x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) else: x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) - x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) + x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :]) x = self.wo(x.flatten(2)) return x @@ -503,6 +584,7 @@ class MLP(nn.Module): w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ + def __init__(self, dim: int, inter_dim: int): """ Initializes the MLP layer. @@ -543,6 +625,7 @@ class Gate(nn.Module): weight (torch.nn.Parameter): Learnable weights for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. """ + def __init__(self, args: ModelArgs): """ Initializes the Gate module. @@ -558,7 +641,11 @@ class Gate(nn.Module): self.score_func = args.score_func self.route_scale = args.route_scale self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) - self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None + self.bias = ( + nn.Parameter(torch.empty(args.n_routed_experts)) + if self.dim == 7168 + else None + ) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: """ @@ -604,6 +691,7 @@ class Expert(nn.Module): w2 (nn.Module): Linear layer for hidden-to-output transformation. w3 (nn.Module): Additional linear layer for feature transformation. """ + def __init__(self, dim: int, inter_dim: int): """ Initializes the Expert layer. @@ -643,6 +731,7 @@ class MoE(nn.Module): experts (nn.ModuleList): List of expert modules. shared_experts (nn.Module): Shared experts applied to all inputs. """ + def __init__(self, args: ModelArgs): """ Initializes the MoE module. @@ -659,8 +748,14 @@ class MoE(nn.Module): self.experts_start_idx = rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.gate = Gate(args) - self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None - for i in range(self.n_routed_experts)]) + self.experts = nn.ModuleList( + [ + Expert(args.dim, args.moe_inter_dim) + if self.experts_start_idx <= i < self.experts_end_idx + else None + for i in range(self.n_routed_experts) + ] + ) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) def forward(self, x: torch.Tensor) -> torch.Tensor: @@ -677,7 +772,9 @@ class MoE(nn.Module): x = x.view(-1, self.dim) weights, indices = self.gate(x) y = torch.zeros_like(x) - counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() + counts = torch.bincount( + indices.flatten(), minlength=self.n_routed_experts + ).tolist() for i in range(self.experts_start_idx, self.experts_end_idx): if counts[i] == 0: continue @@ -700,6 +797,7 @@ class Block(nn.Module): attn_norm (nn.Module): Layer normalization for attention. ffn_norm (nn.Module): Layer normalization for feed-forward network. """ + def __init__(self, layer_id: int, args: ModelArgs): """ Initializes the Transformer block. @@ -710,11 +808,21 @@ class Block(nn.Module): """ super().__init__() self.attn = MLA(args) - self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) + self.ffn = ( + MLP(args.dim, args.inter_dim) + if layer_id < args.n_dense_layers + else MoE(args) + ) self.attn_norm = RMSNorm(args.dim) self.ffn_norm = RMSNorm(args.dim) - def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + def forward( + self, + x: torch.Tensor, + start_pos: int, + freqs_cis: torch.Tensor, + mask: Optional[torch.Tensor], + ) -> torch.Tensor: """ Forward pass for the Transformer block. @@ -744,6 +852,7 @@ class Transformer(nn.Module): head (nn.Module): Output projection layer mapping to vocabulary size. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. """ + def __init__(self, args: ModelArgs): """ Initializes the Transformer model. @@ -762,7 +871,9 @@ class Transformer(nn.Module): for layer_id in range(args.n_layers): self.layers.append(Block(layer_id, args)) self.norm = RMSNorm(args.dim) - self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) + self.head = ColumnParallelLinear( + args.dim, args.vocab_size, dtype=torch.get_default_dtype() + ) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) @torch.inference_mode() @@ -779,10 +890,12 @@ class Transformer(nn.Module): """ seqlen = tokens.size(1) h = self.embed(tokens) - freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] + freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen] mask = None if seqlen > 1: - mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) + mask = torch.full( + (seqlen, seqlen), float("-inf"), device=tokens.device + ).triu_(1) for layer in self.layers: h = layer(h, start_pos, freqs_cis, mask) h = self.norm(h)[:, -1]