style: Ruff format for pep guidelines

This commit is contained in:
Harlan D Heilman 2025-01-27 19:31:11 -08:00
parent cb7d4d7e62
commit 43367a81e1
5 changed files with 262 additions and 84 deletions

View File

@ -54,7 +54,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
continue continue
param: torch.Tensor = f.get_tensor(name) param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."): if name.startswith("model."):
name = name[len("model."):] name = name[len("model.") :]
name = name.replace("self_attn", "attn") name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn") name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale") name = name.replace("weight_scale_inv", "scale")
@ -67,18 +67,25 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
new_param = param new_param = param
if "experts" in name and "shared_experts" not in name: if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3]) idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: if (
idx < i * n_local_experts
or idx >= (i + 1) * n_local_experts
):
continue continue
elif dim is not None: elif dim is not None:
assert param.size(dim) % mp == 0 assert param.size(dim) % mp == 0
shard_size = param.size(dim) // mp shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() new_param = param.narrow(
dim, i * shard_size, shard_size
).contiguous()
state_dicts[i][name] = new_param state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
for i in trange(mp): for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")) save_file(
state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
)
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path)) new_file_path = os.path.join(save_path, os.path.basename(file_path))

View File

@ -9,6 +9,7 @@ from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path): def main(fp8_path, bf16_path):
""" """
Converts FP8 weights to BF16 and saves the converted weights. Converts FP8 weights to BF16 and saves the converted weights.
@ -79,7 +80,9 @@ def main(fp8_path, bf16_path):
fp8_weight_names.append(weight_name) fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv) new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError: except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") print(
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
)
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
else: else:
new_state_dict[weight_name] = weight new_state_dict[weight_name] = weight
@ -109,4 +112,3 @@ if __name__ == "__main__":
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) main(args.input_fp8_hf_path, args.output_bf16_hf_path)

View File

@ -33,7 +33,7 @@ def generate(
prompt_tokens: List[List[int]], prompt_tokens: List[List[int]],
max_new_tokens: int, max_new_tokens: int,
eos_id: int, eos_id: int,
temperature: float = 1.0 temperature: float = 1.0,
) -> List[List[int]]: ) -> List[List[int]]:
""" """
Generates new tokens based on the given prompt tokens using the specified model. Generates new tokens based on the given prompt tokens using the specified model.
@ -51,9 +51,11 @@ def generate(
prompt_lens = [len(t) for t in prompt_tokens] prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") tokens = torch.full(
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
)
for i, t in enumerate(prompt_tokens): for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") tokens[i, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0 prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1 prompt_mask = tokens != -1
@ -63,7 +65,9 @@ def generate(
next_token = sample(logits, temperature) next_token = sample(logits, temperature)
else: else:
next_token = logits.argmax(dim=-1) next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) next_token = torch.where(
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos prev_pos = cur_pos
@ -71,9 +75,9 @@ def generate(
break break
completion_tokens = [] completion_tokens = []
for i, toks in enumerate(tokens.tolist()): for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] toks = toks[prompt_lens[i] : prompt_lens[i] + max_new_tokens]
if eos_id in toks: if eos_id in toks:
toks = toks[:toks.index(eos_id)] toks = toks[: toks.index(eos_id)]
completion_tokens.append(toks) completion_tokens.append(toks)
return completion_tokens return completion_tokens
@ -115,8 +119,10 @@ def main(
with torch.device("cuda"): with torch.device("cuda"):
model = Transformer(args) model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path) tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.0)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) load_model(
model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
)
if interactive: if interactive:
messages = [] messages = []
@ -137,18 +143,37 @@ def main(
messages.clear() messages.clear()
continue continue
messages.append({"role": "user", "content": prompt}) messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt_tokens = tokenizer.apply_chat_template(
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) messages, add_generation_prompt=True
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) )
completion_tokens = generate(
model,
[prompt_tokens],
max_new_tokens,
tokenizer.eos_token_id,
temperature,
)
completion = tokenizer.decode(
completion_tokens[0], skip_special_tokens=True
)
print(completion) print(completion)
messages.append({"role": "assistant", "content": completion}) messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()] prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] prompt_tokens = [
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) tokenizer.apply_chat_template(
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) [{"role": "user", "content": prompt}], add_generation_prompt=True
)
for prompt in prompts
]
completion_tokens = generate(
model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature
)
completions = tokenizer.batch_decode(
completion_tokens, skip_special_tokens=True
)
for prompt, completion in zip(prompts, completions): for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt) print("Prompt:", prompt)
print("Completion:", completion) print("Completion:", completion)
@ -182,4 +207,11 @@ if __name__ == "__main__":
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature,
)

View File

@ -23,14 +23,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32) x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448. s = tl.max(tl.abs(x)) / 448.0
y = x / s y = x / s
y = y.to(y_ptr.dtype.element_ty) y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y) tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s) tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: def act_quant(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Quantizes the input tensor `x` using block-wise quantization. Quantizes the input tensor `x` using block-wise quantization.
@ -47,7 +49,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
assert x.size(-1) % block_size == 0 assert x.size(-1) % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn) y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s return y, s
@ -81,7 +83,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask) tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: def weight_dequant(
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
) -> torch.Tensor:
""" """
Dequantizes the given weight tensor using the provided scale tensor. Dequantizes the given weight tensor using the provided scale tensor.
@ -100,24 +104,41 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
assert x.dim() == 2 and s.dim() == 2 assert x.dim() == 2 and s.dim() == 2
M, N = x.size() M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype()) y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE"]),
triton.cdiv(N, meta["BLOCK_SIZE"]),
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y return y
fp8_gemm_configs = [ fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) Config(
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] {"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
] ]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit @triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, def fp8_gemm_kernel(
a_s_ptr, b_s_ptr, a_ptr,
M, N: tl.constexpr, K: tl.constexpr, b_ptr,
BLOCK_SIZE_M: tl.constexpr, c_ptr,
BLOCK_SIZE_N: tl.constexpr, a_s_ptr,
BLOCK_SIZE_K: tl.constexpr): b_s_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr,
):
""" """
Performs a matrix multiplication operation on FP8 matrices with scaling factors. Performs a matrix multiplication operation on FP8 matrices with scaling factors.
@ -186,6 +207,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
M = a.numel() // K M = a.numel() // K
N = b.size(0) N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) grid = lambda META: (
triton.cdiv(M, META["BLOCK_SIZE_M"]),
triton.cdiv(N, META["BLOCK_SIZE_N"]),
)
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c return c

View File

@ -16,6 +16,7 @@ block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16" gemm_impl: Literal["bf16", "fp8"] = "bf16"
attn_impl: Literal["naive", "absorb"] = "absorb" attn_impl: Literal["naive", "absorb"] = "absorb"
@dataclass @dataclass
class ModelArgs: class ModelArgs:
""" """
@ -51,6 +52,7 @@ class ModelArgs:
beta_slow (int): Slow beta correction factor. beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention. mscale (float): Scaling factor for extended attention.
""" """
max_batch_size: int = 8 max_batch_size: int = 8
max_seq_len: int = 4096 * 4 max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16" dtype: Literal["bf16", "fp8"] = "bf16"
@ -68,7 +70,7 @@ class ModelArgs:
n_expert_groups: int = 1 n_expert_groups: int = 1
n_limited_groups: int = 1 n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax" score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1. route_scale: float = 1.0
# mla # mla
q_lora_rank: int = 0 q_lora_rank: int = 0
kv_lora_rank: int = 512 kv_lora_rank: int = 512
@ -81,7 +83,7 @@ class ModelArgs:
rope_factor: float = 40 rope_factor: float = 40
beta_fast: int = 32 beta_fast: int = 32
beta_slow: int = 1 beta_slow: int = 1
mscale: float = 1. mscale: float = 1.0
class ParallelEmbedding(nn.Module): class ParallelEmbedding(nn.Module):
@ -92,12 +94,13 @@ class ParallelEmbedding(nn.Module):
vocab_size (int): Vocabulary size. vocab_size (int): Vocabulary size.
dim (int): Embedding dimension. dim (int): Embedding dimension.
""" """
def __init__(self, vocab_size: int, dim: int): def __init__(self, vocab_size: int, dim: int):
super().__init__() super().__init__()
self.vocab_size = vocab_size self.vocab_size = vocab_size
self.dim = dim self.dim = dim
assert vocab_size % world_size == 0 assert vocab_size % world_size == 0
self.part_vocab_size = (vocab_size // world_size) self.part_vocab_size = vocab_size // world_size
self.vocab_start_idx = rank * self.part_vocab_size self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim)) self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
@ -126,7 +129,9 @@ class ParallelEmbedding(nn.Module):
return y return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: def linear(
x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
""" """
Applies a linear transformation to the incoming data: y = xA^T + b. Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization This function supports specialized implementations based on quantization
@ -171,17 +176,24 @@ class Linear(nn.Module):
bias (bool): Whether to include a bias term. Defaults to False. bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
""" """
dtype = torch.bfloat16 dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
super().__init__() super().__init__()
self.in_features = in_features self.in_features = in_features
self.out_features = out_features self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)) self.weight = nn.Parameter(
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)
)
if self.weight.element_size() == 1: if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)) self.weight.scale = self.scale = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
)
else: else:
self.register_parameter("scale", None) self.register_parameter("scale", None)
if bias: if bias:
@ -212,7 +224,10 @@ class ColumnParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False. bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
""" """
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
assert out_features % world_size == 0 assert out_features % world_size == 0
self.part_out_features = out_features // world_size self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype) super().__init__(in_features, self.part_out_features, bias, dtype)
@ -241,7 +256,10 @@ class RowParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False. bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
""" """
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
assert in_features % world_size == 0 assert in_features % world_size == 0
self.part_in_features = in_features // world_size self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype) super().__init__(self.part_in_features, out_features, bias, dtype)
@ -272,6 +290,7 @@ class RMSNorm(nn.Module):
dim (int): Dimension of the input tensor. dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6. eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
""" """
def __init__(self, dim: int, eps: float = 1e-6): def __init__(self, dim: int, eps: float = 1e-6):
super().__init__() super().__init__()
self.dim = dim self.dim = dim
@ -321,7 +340,11 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
Returns: Returns:
float: The correction dimension based on the input parameters. float: The correction dimension based on the input parameters.
""" """
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base)) return (
dim
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
/ (2 * math.log(base))
)
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len): def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
""" """
@ -339,7 +362,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
""" """
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len)) low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len)) high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1) return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim): def linear_ramp_factor(min, max, dim):
""" """
@ -362,7 +385,9 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim)) freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len: if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len) low, high = find_correction_range(
beta_fast, beta_slow, dim, base, args.original_seq_len
)
smooth = 1 - linear_ramp_factor(low, high, dim // 2) smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth freqs = freqs / factor * (1 - smooth) + freqs * smooth
@ -406,6 +431,7 @@ class MLA(nn.Module):
v_head_dim (int): Dimensionality of value projections. v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation. softmax_scale (float): Scaling factor for softmax in attention computation.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
super().__init__() super().__init__()
self.dim = args.dim self.dim = args.dim
@ -423,24 +449,62 @@ class MLA(nn.Module):
else: else:
self.wq_a = Linear(self.dim, self.q_lora_rank) self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank) self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim) self.wq_b = ColumnParallelLinear(
self.q_lora_rank, self.n_heads * self.qk_head_dim
)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim) self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank) self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)) self.wkv_b = ColumnParallelLinear(
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
)
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim) self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5 self.softmax_scale = self.qk_head_dim**-0.5
if args.max_seq_len > args.original_seq_len: if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0 mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale self.softmax_scale = self.softmax_scale * mscale * mscale
if attn_impl == "naive": if attn_impl == "naive":
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False) self.register_buffer(
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False) "k_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_heads,
self.qk_head_dim,
),
persistent=False,
)
self.register_buffer(
"v_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_heads,
self.v_head_dim,
),
persistent=False,
)
else: else:
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False) self.register_buffer(
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False) "kv_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
persistent=False,
)
self.register_buffer(
"pe_cache",
torch.zeros(
args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim
),
persistent=False,
)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]): def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
""" """
Forward pass for the Multi-Headed Attention Layer (MLA). Forward pass for the Multi-Headed Attention Layer (MLA).
@ -460,7 +524,9 @@ class MLA(nn.Module):
else: else:
q = self.wq_b(self.q_norm(self.wq_a(x))) q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim) q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
q_pe = apply_rotary_emb(q_pe, freqs_cis) q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x) kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
@ -468,20 +534,35 @@ class MLA(nn.Module):
if attn_impl == "naive": if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1) q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv)) kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim) kv = kv.view(
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1) k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale scores = (
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos])
* self.softmax_scale
)
else: else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size) wkv_b = (
self.wkv_b.weight
if self.wkv_b.scale is None
else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank) wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim]) q_nope = torch.einsum(
"bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]
)
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv) self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2) self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) + scores = (
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
) * self.softmax_scale
if mask is not None: if mask is not None:
scores += mask.unsqueeze(1) scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x) scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
@ -489,7 +570,7 @@ class MLA(nn.Module):
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos]) x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else: else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos]) x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:]) x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
x = self.wo(x.flatten(2)) x = self.wo(x.flatten(2))
return x return x
@ -503,6 +584,7 @@ class MLP(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation. w3 (nn.Module): Additional linear layer for feature transformation.
""" """
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int):
""" """
Initializes the MLP layer. Initializes the MLP layer.
@ -543,6 +625,7 @@ class Gate(nn.Module):
weight (torch.nn.Parameter): Learnable weights for the gate. weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate. bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
""" """
Initializes the Gate module. Initializes the Gate module.
@ -558,7 +641,11 @@ class Gate(nn.Module):
self.score_func = args.score_func self.score_func = args.score_func
self.route_scale = args.route_scale self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim)) self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None self.bias = (
nn.Parameter(torch.empty(args.n_routed_experts))
if self.dim == 7168
else None
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
@ -604,6 +691,7 @@ class Expert(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation. w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation. w3 (nn.Module): Additional linear layer for feature transformation.
""" """
def __init__(self, dim: int, inter_dim: int): def __init__(self, dim: int, inter_dim: int):
""" """
Initializes the Expert layer. Initializes the Expert layer.
@ -643,6 +731,7 @@ class MoE(nn.Module):
experts (nn.ModuleList): List of expert modules. experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs. shared_experts (nn.Module): Shared experts applied to all inputs.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
""" """
Initializes the MoE module. Initializes the MoE module.
@ -659,8 +748,14 @@ class MoE(nn.Module):
self.experts_start_idx = rank * self.n_local_experts self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args) self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None self.experts = nn.ModuleList(
for i in range(self.n_routed_experts)]) [
Expert(args.dim, args.moe_inter_dim)
if self.experts_start_idx <= i < self.experts_end_idx
else None
for i in range(self.n_routed_experts)
]
)
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim) self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -677,7 +772,9 @@ class MoE(nn.Module):
x = x.view(-1, self.dim) x = x.view(-1, self.dim)
weights, indices = self.gate(x) weights, indices = self.gate(x)
y = torch.zeros_like(x) y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist() counts = torch.bincount(
indices.flatten(), minlength=self.n_routed_experts
).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx): for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0: if counts[i] == 0:
continue continue
@ -700,6 +797,7 @@ class Block(nn.Module):
attn_norm (nn.Module): Layer normalization for attention. attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network. ffn_norm (nn.Module): Layer normalization for feed-forward network.
""" """
def __init__(self, layer_id: int, args: ModelArgs): def __init__(self, layer_id: int, args: ModelArgs):
""" """
Initializes the Transformer block. Initializes the Transformer block.
@ -710,11 +808,21 @@ class Block(nn.Module):
""" """
super().__init__() super().__init__()
self.attn = MLA(args) self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args) self.ffn = (
MLP(args.dim, args.inter_dim)
if layer_id < args.n_dense_layers
else MoE(args)
)
self.attn_norm = RMSNorm(args.dim) self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim) self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
) -> torch.Tensor:
""" """
Forward pass for the Transformer block. Forward pass for the Transformer block.
@ -744,6 +852,7 @@ class Transformer(nn.Module):
head (nn.Module): Output projection layer mapping to vocabulary size. head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings. freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
""" """
def __init__(self, args: ModelArgs): def __init__(self, args: ModelArgs):
""" """
Initializes the Transformer model. Initializes the Transformer model.
@ -762,7 +871,9 @@ class Transformer(nn.Module):
for layer_id in range(args.n_layers): for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args)) self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim) self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype()) self.head = ColumnParallelLinear(
args.dim, args.vocab_size, dtype=torch.get_default_dtype()
)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False) self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode() @torch.inference_mode()
@ -779,10 +890,12 @@ class Transformer(nn.Module):
""" """
seqlen = tokens.size(1) seqlen = tokens.size(1)
h = self.embed(tokens) h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen] freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None mask = None
if seqlen > 1: if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1) mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
).triu_(1)
for layer in self.layers: for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask) h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)[:, -1] h = self.norm(h)[:, -1]