mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-20 02:28:57 -04:00
style: Ruff format for pep guidelines
This commit is contained in:
parent
cb7d4d7e62
commit
43367a81e1
@ -54,7 +54,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
|||||||
continue
|
continue
|
||||||
param: torch.Tensor = f.get_tensor(name)
|
param: torch.Tensor = f.get_tensor(name)
|
||||||
if name.startswith("model."):
|
if name.startswith("model."):
|
||||||
name = name[len("model."):]
|
name = name[len("model.") :]
|
||||||
name = name.replace("self_attn", "attn")
|
name = name.replace("self_attn", "attn")
|
||||||
name = name.replace("mlp", "ffn")
|
name = name.replace("mlp", "ffn")
|
||||||
name = name.replace("weight_scale_inv", "scale")
|
name = name.replace("weight_scale_inv", "scale")
|
||||||
@ -67,18 +67,25 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
|||||||
new_param = param
|
new_param = param
|
||||||
if "experts" in name and "shared_experts" not in name:
|
if "experts" in name and "shared_experts" not in name:
|
||||||
idx = int(name.split(".")[-3])
|
idx = int(name.split(".")[-3])
|
||||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
if (
|
||||||
|
idx < i * n_local_experts
|
||||||
|
or idx >= (i + 1) * n_local_experts
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
elif dim is not None:
|
elif dim is not None:
|
||||||
assert param.size(dim) % mp == 0
|
assert param.size(dim) % mp == 0
|
||||||
shard_size = param.size(dim) // mp
|
shard_size = param.size(dim) // mp
|
||||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
new_param = param.narrow(
|
||||||
|
dim, i * shard_size, shard_size
|
||||||
|
).contiguous()
|
||||||
state_dicts[i][name] = new_param
|
state_dicts[i][name] = new_param
|
||||||
|
|
||||||
os.makedirs(save_path, exist_ok=True)
|
os.makedirs(save_path, exist_ok=True)
|
||||||
|
|
||||||
for i in trange(mp):
|
for i in trange(mp):
|
||||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
save_file(
|
||||||
|
state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
|
||||||
|
)
|
||||||
|
|
||||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||||
|
@ -9,6 +9,7 @@ from safetensors.torch import load_file, save_file
|
|||||||
|
|
||||||
from kernel import weight_dequant
|
from kernel import weight_dequant
|
||||||
|
|
||||||
|
|
||||||
def main(fp8_path, bf16_path):
|
def main(fp8_path, bf16_path):
|
||||||
"""
|
"""
|
||||||
Converts FP8 weights to BF16 and saves the converted weights.
|
Converts FP8 weights to BF16 and saves the converted weights.
|
||||||
@ -79,7 +80,9 @@ def main(fp8_path, bf16_path):
|
|||||||
fp8_weight_names.append(weight_name)
|
fp8_weight_names.append(weight_name)
|
||||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
print(
|
||||||
|
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
|
||||||
|
)
|
||||||
new_state_dict[weight_name] = weight
|
new_state_dict[weight_name] = weight
|
||||||
else:
|
else:
|
||||||
new_state_dict[weight_name] = weight
|
new_state_dict[weight_name] = weight
|
||||||
@ -109,4 +112,3 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
||||||
|
|
||||||
|
@ -33,7 +33,7 @@ def generate(
|
|||||||
prompt_tokens: List[List[int]],
|
prompt_tokens: List[List[int]],
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
eos_id: int,
|
eos_id: int,
|
||||||
temperature: float = 1.0
|
temperature: float = 1.0,
|
||||||
) -> List[List[int]]:
|
) -> List[List[int]]:
|
||||||
"""
|
"""
|
||||||
Generates new tokens based on the given prompt tokens using the specified model.
|
Generates new tokens based on the given prompt tokens using the specified model.
|
||||||
@ -51,9 +51,11 @@ def generate(
|
|||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= model.max_seq_len
|
assert max(prompt_lens) <= model.max_seq_len
|
||||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
tokens = torch.full(
|
||||||
|
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
|
||||||
|
)
|
||||||
for i, t in enumerate(prompt_tokens):
|
for i, t in enumerate(prompt_tokens):
|
||||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
tokens[i, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||||
prompt_mask = tokens != -1
|
prompt_mask = tokens != -1
|
||||||
@ -63,7 +65,9 @@ def generate(
|
|||||||
next_token = sample(logits, temperature)
|
next_token = sample(logits, temperature)
|
||||||
else:
|
else:
|
||||||
next_token = logits.argmax(dim=-1)
|
next_token = logits.argmax(dim=-1)
|
||||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
next_token = torch.where(
|
||||||
|
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||||
|
)
|
||||||
tokens[:, cur_pos] = next_token
|
tokens[:, cur_pos] = next_token
|
||||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
@ -71,9 +75,9 @@ def generate(
|
|||||||
break
|
break
|
||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
for i, toks in enumerate(tokens.tolist()):
|
for i, toks in enumerate(tokens.tolist()):
|
||||||
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
toks = toks[prompt_lens[i] : prompt_lens[i] + max_new_tokens]
|
||||||
if eos_id in toks:
|
if eos_id in toks:
|
||||||
toks = toks[:toks.index(eos_id)]
|
toks = toks[: toks.index(eos_id)]
|
||||||
completion_tokens.append(toks)
|
completion_tokens.append(toks)
|
||||||
return completion_tokens
|
return completion_tokens
|
||||||
|
|
||||||
@ -115,8 +119,10 @@ def main(
|
|||||||
with torch.device("cuda"):
|
with torch.device("cuda"):
|
||||||
model = Transformer(args)
|
model = Transformer(args)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.0)[0])
|
||||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
load_model(
|
||||||
|
model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
|
||||||
|
)
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
messages = []
|
messages = []
|
||||||
@ -137,18 +143,37 @@ def main(
|
|||||||
messages.clear()
|
messages.clear()
|
||||||
continue
|
continue
|
||||||
messages.append({"role": "user", "content": prompt})
|
messages.append({"role": "user", "content": prompt})
|
||||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
prompt_tokens = tokenizer.apply_chat_template(
|
||||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
messages, add_generation_prompt=True
|
||||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
)
|
||||||
|
completion_tokens = generate(
|
||||||
|
model,
|
||||||
|
[prompt_tokens],
|
||||||
|
max_new_tokens,
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
|
completion = tokenizer.decode(
|
||||||
|
completion_tokens[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
print(completion)
|
print(completion)
|
||||||
messages.append({"role": "assistant", "content": completion})
|
messages.append({"role": "assistant", "content": completion})
|
||||||
else:
|
else:
|
||||||
with open(input_file) as f:
|
with open(input_file) as f:
|
||||||
prompts = [line.strip() for line in f.readlines()]
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
assert len(prompts) <= args.max_batch_size
|
assert len(prompts) <= args.max_batch_size
|
||||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
prompt_tokens = [
|
||||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
tokenizer.apply_chat_template(
|
||||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
[{"role": "user", "content": prompt}], add_generation_prompt=True
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
completion_tokens = generate(
|
||||||
|
model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature
|
||||||
|
)
|
||||||
|
completions = tokenizer.batch_decode(
|
||||||
|
completion_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
for prompt, completion in zip(prompts, completions):
|
for prompt, completion in zip(prompts, completions):
|
||||||
print("Prompt:", prompt)
|
print("Prompt:", prompt)
|
||||||
print("Completion:", completion)
|
print("Completion:", completion)
|
||||||
@ -182,4 +207,11 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.input_file or args.interactive
|
assert args.input_file or args.interactive
|
||||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
main(
|
||||||
|
args.ckpt_path,
|
||||||
|
args.config,
|
||||||
|
args.input_file,
|
||||||
|
args.interactive,
|
||||||
|
args.max_new_tokens,
|
||||||
|
args.temperature,
|
||||||
|
)
|
||||||
|
@ -23,14 +23,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||||
s = tl.max(tl.abs(x)) / 448.
|
s = tl.max(tl.abs(x)) / 448.0
|
||||||
y = x / s
|
y = x / s
|
||||||
y = y.to(y_ptr.dtype.element_ty)
|
y = y.to(y_ptr.dtype.element_ty)
|
||||||
tl.store(y_ptr + offs, y)
|
tl.store(y_ptr + offs, y)
|
||||||
tl.store(s_ptr + pid, s)
|
tl.store(s_ptr + pid, s)
|
||||||
|
|
||||||
|
|
||||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
def act_quant(
|
||||||
|
x: torch.Tensor, block_size: int = 128
|
||||||
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x` using block-wise quantization.
|
Quantizes the input tensor `x` using block-wise quantization.
|
||||||
|
|
||||||
@ -47,7 +49,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|||||||
assert x.size(-1) % block_size == 0
|
assert x.size(-1) % block_size == 0
|
||||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||||
return y, s
|
return y, s
|
||||||
|
|
||||||
@ -81,7 +83,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|||||||
tl.store(y_ptr + offs, y, mask=mask)
|
tl.store(y_ptr + offs, y, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
def weight_dequant(
|
||||||
|
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Dequantizes the given weight tensor using the provided scale tensor.
|
Dequantizes the given weight tensor using the provided scale tensor.
|
||||||
|
|
||||||
@ -100,24 +104,41 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
assert x.dim() == 2 and s.dim() == 2
|
assert x.dim() == 2 and s.dim() == 2
|
||||||
M, N = x.size()
|
M, N = x.size()
|
||||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
grid = lambda meta: (
|
||||||
|
triton.cdiv(M, meta["BLOCK_SIZE"]),
|
||||||
|
triton.cdiv(N, meta["BLOCK_SIZE"]),
|
||||||
|
)
|
||||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
fp8_gemm_configs = [
|
fp8_gemm_configs = [
|
||||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
Config(
|
||||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
|
||||||
|
num_stages=num_stages,
|
||||||
|
num_warps=8,
|
||||||
|
)
|
||||||
|
for block_m in [16, 32, 64]
|
||||||
|
for block_n in [32, 64, 128]
|
||||||
|
for num_stages in [3, 4, 5, 6]
|
||||||
]
|
]
|
||||||
|
|
||||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
|
||||||
|
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
def fp8_gemm_kernel(
|
||||||
a_s_ptr, b_s_ptr,
|
a_ptr,
|
||||||
M, N: tl.constexpr, K: tl.constexpr,
|
b_ptr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
c_ptr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
a_s_ptr,
|
||||||
BLOCK_SIZE_K: tl.constexpr):
|
b_s_ptr,
|
||||||
|
M,
|
||||||
|
N: tl.constexpr,
|
||||||
|
K: tl.constexpr,
|
||||||
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
|
BLOCK_SIZE_K: tl.constexpr,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||||
|
|
||||||
@ -186,6 +207,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
|||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(0)
|
||||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
grid = lambda META: (
|
||||||
|
triton.cdiv(M, META["BLOCK_SIZE_M"]),
|
||||||
|
triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||||
|
)
|
||||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||||
return c
|
return c
|
||||||
|
@ -16,6 +16,7 @@ block_size = 128
|
|||||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||||
attn_impl: Literal["naive", "absorb"] = "absorb"
|
attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ModelArgs:
|
class ModelArgs:
|
||||||
"""
|
"""
|
||||||
@ -51,6 +52,7 @@ class ModelArgs:
|
|||||||
beta_slow (int): Slow beta correction factor.
|
beta_slow (int): Slow beta correction factor.
|
||||||
mscale (float): Scaling factor for extended attention.
|
mscale (float): Scaling factor for extended attention.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_batch_size: int = 8
|
max_batch_size: int = 8
|
||||||
max_seq_len: int = 4096 * 4
|
max_seq_len: int = 4096 * 4
|
||||||
dtype: Literal["bf16", "fp8"] = "bf16"
|
dtype: Literal["bf16", "fp8"] = "bf16"
|
||||||
@ -68,7 +70,7 @@ class ModelArgs:
|
|||||||
n_expert_groups: int = 1
|
n_expert_groups: int = 1
|
||||||
n_limited_groups: int = 1
|
n_limited_groups: int = 1
|
||||||
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
||||||
route_scale: float = 1.
|
route_scale: float = 1.0
|
||||||
# mla
|
# mla
|
||||||
q_lora_rank: int = 0
|
q_lora_rank: int = 0
|
||||||
kv_lora_rank: int = 512
|
kv_lora_rank: int = 512
|
||||||
@ -81,7 +83,7 @@ class ModelArgs:
|
|||||||
rope_factor: float = 40
|
rope_factor: float = 40
|
||||||
beta_fast: int = 32
|
beta_fast: int = 32
|
||||||
beta_slow: int = 1
|
beta_slow: int = 1
|
||||||
mscale: float = 1.
|
mscale: float = 1.0
|
||||||
|
|
||||||
|
|
||||||
class ParallelEmbedding(nn.Module):
|
class ParallelEmbedding(nn.Module):
|
||||||
@ -92,12 +94,13 @@ class ParallelEmbedding(nn.Module):
|
|||||||
vocab_size (int): Vocabulary size.
|
vocab_size (int): Vocabulary size.
|
||||||
dim (int): Embedding dimension.
|
dim (int): Embedding dimension.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, vocab_size: int, dim: int):
|
def __init__(self, vocab_size: int, dim: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
assert vocab_size % world_size == 0
|
assert vocab_size % world_size == 0
|
||||||
self.part_vocab_size = (vocab_size // world_size)
|
self.part_vocab_size = vocab_size // world_size
|
||||||
self.vocab_start_idx = rank * self.part_vocab_size
|
self.vocab_start_idx = rank * self.part_vocab_size
|
||||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||||
@ -126,7 +129,9 @@ class ParallelEmbedding(nn.Module):
|
|||||||
return y
|
return y
|
||||||
|
|
||||||
|
|
||||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def linear(
|
||||||
|
x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||||
This function supports specialized implementations based on quantization
|
This function supports specialized implementations based on quantization
|
||||||
@ -171,17 +176,24 @@ class Linear(nn.Module):
|
|||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
bias (bool): Whether to include a bias term. Defaults to False.
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
dtype = torch.bfloat16
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(
|
||||||
|
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.in_features = in_features
|
self.in_features = in_features
|
||||||
self.out_features = out_features
|
self.out_features = out_features
|
||||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
|
self.weight = nn.Parameter(
|
||||||
|
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)
|
||||||
|
)
|
||||||
if self.weight.element_size() == 1:
|
if self.weight.element_size() == 1:
|
||||||
scale_out_features = (out_features + block_size - 1) // block_size
|
scale_out_features = (out_features + block_size - 1) // block_size
|
||||||
scale_in_features = (in_features + block_size - 1) // block_size
|
scale_in_features = (in_features + block_size - 1) // block_size
|
||||||
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
|
self.weight.scale = self.scale = nn.Parameter(
|
||||||
|
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.register_parameter("scale", None)
|
self.register_parameter("scale", None)
|
||||||
if bias:
|
if bias:
|
||||||
@ -212,7 +224,10 @@ class ColumnParallelLinear(Linear):
|
|||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
bias (bool): Whether to include a bias term. Defaults to False.
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||||
|
):
|
||||||
assert out_features % world_size == 0
|
assert out_features % world_size == 0
|
||||||
self.part_out_features = out_features // world_size
|
self.part_out_features = out_features // world_size
|
||||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||||
@ -241,7 +256,10 @@ class RowParallelLinear(Linear):
|
|||||||
bias (bool): Whether to include a bias term. Defaults to False.
|
bias (bool): Whether to include a bias term. Defaults to False.
|
||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
|
||||||
|
def __init__(
|
||||||
|
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||||
|
):
|
||||||
assert in_features % world_size == 0
|
assert in_features % world_size == 0
|
||||||
self.part_in_features = in_features // world_size
|
self.part_in_features = in_features // world_size
|
||||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||||
@ -272,6 +290,7 @@ class RMSNorm(nn.Module):
|
|||||||
dim (int): Dimension of the input tensor.
|
dim (int): Dimension of the input tensor.
|
||||||
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
@ -321,7 +340,11 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|||||||
Returns:
|
Returns:
|
||||||
float: The correction dimension based on the input parameters.
|
float: The correction dimension based on the input parameters.
|
||||||
"""
|
"""
|
||||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
return (
|
||||||
|
dim
|
||||||
|
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
|
||||||
|
/ (2 * math.log(base))
|
||||||
|
)
|
||||||
|
|
||||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
||||||
"""
|
"""
|
||||||
@ -339,7 +362,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|||||||
"""
|
"""
|
||||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
||||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
||||||
return max(low, 0), min(high, dim-1)
|
return max(low, 0), min(high, dim - 1)
|
||||||
|
|
||||||
def linear_ramp_factor(min, max, dim):
|
def linear_ramp_factor(min, max, dim):
|
||||||
"""
|
"""
|
||||||
@ -362,7 +385,9 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
|||||||
|
|
||||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||||
if seqlen > args.original_seq_len:
|
if seqlen > args.original_seq_len:
|
||||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
low, high = find_correction_range(
|
||||||
|
beta_fast, beta_slow, dim, base, args.original_seq_len
|
||||||
|
)
|
||||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
||||||
|
|
||||||
@ -406,6 +431,7 @@ class MLA(nn.Module):
|
|||||||
v_head_dim (int): Dimensionality of value projections.
|
v_head_dim (int): Dimensionality of value projections.
|
||||||
softmax_scale (float): Scaling factor for softmax in attention computation.
|
softmax_scale (float): Scaling factor for softmax in attention computation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
@ -423,24 +449,62 @@ class MLA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
||||||
self.q_norm = RMSNorm(self.q_lora_rank)
|
self.q_norm = RMSNorm(self.q_lora_rank)
|
||||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
self.wq_b = ColumnParallelLinear(
|
||||||
|
self.q_lora_rank, self.n_heads * self.qk_head_dim
|
||||||
|
)
|
||||||
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||||
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
||||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
self.wkv_b = ColumnParallelLinear(
|
||||||
|
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
|
||||||
|
)
|
||||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
self.softmax_scale = self.qk_head_dim**-0.5
|
||||||
if args.max_seq_len > args.original_seq_len:
|
if args.max_seq_len > args.original_seq_len:
|
||||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||||
|
|
||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
|
self.register_buffer(
|
||||||
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
|
"k_cache",
|
||||||
|
torch.zeros(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_heads,
|
||||||
|
self.qk_head_dim,
|
||||||
|
),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"v_cache",
|
||||||
|
torch.zeros(
|
||||||
|
args.max_batch_size,
|
||||||
|
args.max_seq_len,
|
||||||
|
self.n_local_heads,
|
||||||
|
self.v_head_dim,
|
||||||
|
),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
self.register_buffer(
|
||||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
"kv_cache",
|
||||||
|
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
self.register_buffer(
|
||||||
|
"pe_cache",
|
||||||
|
torch.zeros(
|
||||||
|
args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim
|
||||||
|
),
|
||||||
|
persistent=False,
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
Forward pass for the Multi-Headed Attention Layer (MLA).
|
||||||
|
|
||||||
@ -460,7 +524,9 @@ class MLA(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
||||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
q_nope, q_pe = torch.split(
|
||||||
|
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||||
|
)
|
||||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||||
kv = self.wkv_a(x)
|
kv = self.wkv_a(x)
|
||||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||||
@ -468,20 +534,35 @@ class MLA(nn.Module):
|
|||||||
if attn_impl == "naive":
|
if attn_impl == "naive":
|
||||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||||
kv = self.wkv_b(self.kv_norm(kv))
|
kv = self.wkv_b(self.kv_norm(kv))
|
||||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
kv = kv.view(
|
||||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||||
|
)
|
||||||
|
k_nope, v = torch.split(
|
||||||
|
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||||
|
)
|
||||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
self.k_cache[:bsz, start_pos:end_pos] = k
|
||||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
scores = (
|
||||||
|
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos])
|
||||||
|
* self.softmax_scale
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
wkv_b = (
|
||||||
|
self.wkv_b.weight
|
||||||
|
if self.wkv_b.scale is None
|
||||||
|
else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||||
|
)
|
||||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
q_nope = torch.einsum(
|
||||||
|
"bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]
|
||||||
|
)
|
||||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
scores = (
|
||||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
|
||||||
|
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
|
||||||
|
) * self.softmax_scale
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
scores += mask.unsqueeze(1)
|
scores += mask.unsqueeze(1)
|
||||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||||
@ -489,7 +570,7 @@ class MLA(nn.Module):
|
|||||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||||
else:
|
else:
|
||||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
|
||||||
x = self.wo(x.flatten(2))
|
x = self.wo(x.flatten(2))
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -503,6 +584,7 @@ class MLP(nn.Module):
|
|||||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, inter_dim: int):
|
def __init__(self, dim: int, inter_dim: int):
|
||||||
"""
|
"""
|
||||||
Initializes the MLP layer.
|
Initializes the MLP layer.
|
||||||
@ -543,6 +625,7 @@ class Gate(nn.Module):
|
|||||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||||
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the Gate module.
|
Initializes the Gate module.
|
||||||
@ -558,7 +641,11 @@ class Gate(nn.Module):
|
|||||||
self.score_func = args.score_func
|
self.score_func = args.score_func
|
||||||
self.route_scale = args.route_scale
|
self.route_scale = args.route_scale
|
||||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
self.bias = (
|
||||||
|
nn.Parameter(torch.empty(args.n_routed_experts))
|
||||||
|
if self.dim == 7168
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
@ -604,6 +691,7 @@ class Expert(nn.Module):
|
|||||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, dim: int, inter_dim: int):
|
def __init__(self, dim: int, inter_dim: int):
|
||||||
"""
|
"""
|
||||||
Initializes the Expert layer.
|
Initializes the Expert layer.
|
||||||
@ -643,6 +731,7 @@ class MoE(nn.Module):
|
|||||||
experts (nn.ModuleList): List of expert modules.
|
experts (nn.ModuleList): List of expert modules.
|
||||||
shared_experts (nn.Module): Shared experts applied to all inputs.
|
shared_experts (nn.Module): Shared experts applied to all inputs.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the MoE module.
|
Initializes the MoE module.
|
||||||
@ -659,8 +748,14 @@ class MoE(nn.Module):
|
|||||||
self.experts_start_idx = rank * self.n_local_experts
|
self.experts_start_idx = rank * self.n_local_experts
|
||||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||||
self.gate = Gate(args)
|
self.gate = Gate(args)
|
||||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
self.experts = nn.ModuleList(
|
||||||
for i in range(self.n_routed_experts)])
|
[
|
||||||
|
Expert(args.dim, args.moe_inter_dim)
|
||||||
|
if self.experts_start_idx <= i < self.experts_end_idx
|
||||||
|
else None
|
||||||
|
for i in range(self.n_routed_experts)
|
||||||
|
]
|
||||||
|
)
|
||||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
@ -677,7 +772,9 @@ class MoE(nn.Module):
|
|||||||
x = x.view(-1, self.dim)
|
x = x.view(-1, self.dim)
|
||||||
weights, indices = self.gate(x)
|
weights, indices = self.gate(x)
|
||||||
y = torch.zeros_like(x)
|
y = torch.zeros_like(x)
|
||||||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
counts = torch.bincount(
|
||||||
|
indices.flatten(), minlength=self.n_routed_experts
|
||||||
|
).tolist()
|
||||||
for i in range(self.experts_start_idx, self.experts_end_idx):
|
for i in range(self.experts_start_idx, self.experts_end_idx):
|
||||||
if counts[i] == 0:
|
if counts[i] == 0:
|
||||||
continue
|
continue
|
||||||
@ -700,6 +797,7 @@ class Block(nn.Module):
|
|||||||
attn_norm (nn.Module): Layer normalization for attention.
|
attn_norm (nn.Module): Layer normalization for attention.
|
||||||
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, layer_id: int, args: ModelArgs):
|
def __init__(self, layer_id: int, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the Transformer block.
|
Initializes the Transformer block.
|
||||||
@ -710,11 +808,21 @@ class Block(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.attn = MLA(args)
|
self.attn = MLA(args)
|
||||||
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
self.ffn = (
|
||||||
|
MLP(args.dim, args.inter_dim)
|
||||||
|
if layer_id < args.n_dense_layers
|
||||||
|
else MoE(args)
|
||||||
|
)
|
||||||
self.attn_norm = RMSNorm(args.dim)
|
self.attn_norm = RMSNorm(args.dim)
|
||||||
self.ffn_norm = RMSNorm(args.dim)
|
self.ffn_norm = RMSNorm(args.dim)
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
def forward(
|
||||||
|
self,
|
||||||
|
x: torch.Tensor,
|
||||||
|
start_pos: int,
|
||||||
|
freqs_cis: torch.Tensor,
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Forward pass for the Transformer block.
|
Forward pass for the Transformer block.
|
||||||
|
|
||||||
@ -744,6 +852,7 @@ class Transformer(nn.Module):
|
|||||||
head (nn.Module): Output projection layer mapping to vocabulary size.
|
head (nn.Module): Output projection layer mapping to vocabulary size.
|
||||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
"""
|
"""
|
||||||
Initializes the Transformer model.
|
Initializes the Transformer model.
|
||||||
@ -762,7 +871,9 @@ class Transformer(nn.Module):
|
|||||||
for layer_id in range(args.n_layers):
|
for layer_id in range(args.n_layers):
|
||||||
self.layers.append(Block(layer_id, args))
|
self.layers.append(Block(layer_id, args))
|
||||||
self.norm = RMSNorm(args.dim)
|
self.norm = RMSNorm(args.dim)
|
||||||
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
|
self.head = ColumnParallelLinear(
|
||||||
|
args.dim, args.vocab_size, dtype=torch.get_default_dtype()
|
||||||
|
)
|
||||||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
@ -779,10 +890,12 @@ class Transformer(nn.Module):
|
|||||||
"""
|
"""
|
||||||
seqlen = tokens.size(1)
|
seqlen = tokens.size(1)
|
||||||
h = self.embed(tokens)
|
h = self.embed(tokens)
|
||||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||||
mask = None
|
mask = None
|
||||||
if seqlen > 1:
|
if seqlen > 1:
|
||||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
mask = torch.full(
|
||||||
|
(seqlen, seqlen), float("-inf"), device=tokens.device
|
||||||
|
).triu_(1)
|
||||||
for layer in self.layers:
|
for layer in self.layers:
|
||||||
h = layer(h, start_pos, freqs_cis, mask)
|
h = layer(h, start_pos, freqs_cis, mask)
|
||||||
h = self.norm(h)[:, -1]
|
h = self.norm(h)[:, -1]
|
||||||
|
Loading…
Reference in New Issue
Block a user