mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
style: Ruff format for pep guidelines
This commit is contained in:
parent
cb7d4d7e62
commit
43367a81e1
@ -39,7 +39,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
||||
n_experts (int): Total number of experts in the model.
|
||||
mp (int): Model parallelism factor.
|
||||
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
@ -54,7 +54,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
continue
|
||||
param: torch.Tensor = f.get_tensor(name)
|
||||
if name.startswith("model."):
|
||||
name = name[len("model."):]
|
||||
name = name[len("model.") :]
|
||||
name = name.replace("self_attn", "attn")
|
||||
name = name.replace("mlp", "ffn")
|
||||
name = name.replace("weight_scale_inv", "scale")
|
||||
@ -67,18 +67,25 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
||||
new_param = param
|
||||
if "experts" in name and "shared_experts" not in name:
|
||||
idx = int(name.split(".")[-3])
|
||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||
if (
|
||||
idx < i * n_local_experts
|
||||
or idx >= (i + 1) * n_local_experts
|
||||
):
|
||||
continue
|
||||
elif dim is not None:
|
||||
assert param.size(dim) % mp == 0
|
||||
shard_size = param.size(dim) // mp
|
||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||
new_param = param.narrow(
|
||||
dim, i * shard_size, shard_size
|
||||
).contiguous()
|
||||
state_dicts[i][name] = new_param
|
||||
|
||||
os.makedirs(save_path, exist_ok=True)
|
||||
|
||||
for i in trange(mp):
|
||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
||||
save_file(
|
||||
state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
|
||||
)
|
||||
|
||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
||||
|
@ -9,6 +9,7 @@ from safetensors.torch import load_file, save_file
|
||||
|
||||
from kernel import weight_dequant
|
||||
|
||||
|
||||
def main(fp8_path, bf16_path):
|
||||
"""
|
||||
Converts FP8 weights to BF16 and saves the converted weights.
|
||||
@ -35,7 +36,7 @@ def main(fp8_path, bf16_path):
|
||||
with open(model_index_file, "r") as f:
|
||||
model_index = json.load(f)
|
||||
weight_map = model_index["weight_map"]
|
||||
|
||||
|
||||
# Cache for loaded safetensor files
|
||||
loaded_files = {}
|
||||
fp8_weight_names = []
|
||||
@ -66,7 +67,7 @@ def main(fp8_path, bf16_path):
|
||||
file_name = os.path.basename(safetensor_file)
|
||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||
loaded_files[file_name] = current_state_dict
|
||||
|
||||
|
||||
new_state_dict = {}
|
||||
for weight_name, weight in current_state_dict.items():
|
||||
if weight_name.endswith("_scale_inv"):
|
||||
@ -79,20 +80,22 @@ def main(fp8_path, bf16_path):
|
||||
fp8_weight_names.append(weight_name)
|
||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
||||
except KeyError:
|
||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||
print(
|
||||
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
|
||||
)
|
||||
new_state_dict[weight_name] = weight
|
||||
else:
|
||||
new_state_dict[weight_name] = weight
|
||||
|
||||
|
||||
new_safetensor_file = os.path.join(bf16_path, file_name)
|
||||
save_file(new_state_dict, new_safetensor_file)
|
||||
|
||||
|
||||
# Memory management: keep only the 2 most recently used files
|
||||
if len(loaded_files) > 2:
|
||||
oldest_file = next(iter(loaded_files))
|
||||
del loaded_files[oldest_file]
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
# Update model index
|
||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
||||
for weight_name in fp8_weight_names:
|
||||
@ -101,7 +104,7 @@ def main(fp8_path, bf16_path):
|
||||
weight_map.pop(scale_inv_name)
|
||||
with open(new_model_index_file, "w") as f:
|
||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = ArgumentParser()
|
||||
@ -109,4 +112,3 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
||||
|
||||
|
@ -33,7 +33,7 @@ def generate(
|
||||
prompt_tokens: List[List[int]],
|
||||
max_new_tokens: int,
|
||||
eos_id: int,
|
||||
temperature: float = 1.0
|
||||
temperature: float = 1.0,
|
||||
) -> List[List[int]]:
|
||||
"""
|
||||
Generates new tokens based on the given prompt tokens using the specified model.
|
||||
@ -51,9 +51,11 @@ def generate(
|
||||
prompt_lens = [len(t) for t in prompt_tokens]
|
||||
assert max(prompt_lens) <= model.max_seq_len
|
||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||
tokens = torch.full(
|
||||
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
|
||||
)
|
||||
for i, t in enumerate(prompt_tokens):
|
||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
tokens[i, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||
prev_pos = 0
|
||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
||||
prompt_mask = tokens != -1
|
||||
@ -63,7 +65,9 @@ def generate(
|
||||
next_token = sample(logits, temperature)
|
||||
else:
|
||||
next_token = logits.argmax(dim=-1)
|
||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
next_token = torch.where(
|
||||
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||
)
|
||||
tokens[:, cur_pos] = next_token
|
||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||
prev_pos = cur_pos
|
||||
@ -71,9 +75,9 @@ def generate(
|
||||
break
|
||||
completion_tokens = []
|
||||
for i, toks in enumerate(tokens.tolist()):
|
||||
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
||||
toks = toks[prompt_lens[i] : prompt_lens[i] + max_new_tokens]
|
||||
if eos_id in toks:
|
||||
toks = toks[:toks.index(eos_id)]
|
||||
toks = toks[: toks.index(eos_id)]
|
||||
completion_tokens.append(toks)
|
||||
return completion_tokens
|
||||
|
||||
@ -115,8 +119,10 @@ def main(
|
||||
with torch.device("cuda"):
|
||||
model = Transformer(args)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.0)[0])
|
||||
load_model(
|
||||
model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
|
||||
)
|
||||
|
||||
if interactive:
|
||||
messages = []
|
||||
@ -137,18 +143,37 @@ def main(
|
||||
messages.clear()
|
||||
continue
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||
prompt_tokens = tokenizer.apply_chat_template(
|
||||
messages, add_generation_prompt=True
|
||||
)
|
||||
completion_tokens = generate(
|
||||
model,
|
||||
[prompt_tokens],
|
||||
max_new_tokens,
|
||||
tokenizer.eos_token_id,
|
||||
temperature,
|
||||
)
|
||||
completion = tokenizer.decode(
|
||||
completion_tokens[0], skip_special_tokens=True
|
||||
)
|
||||
print(completion)
|
||||
messages.append({"role": "assistant", "content": completion})
|
||||
else:
|
||||
with open(input_file) as f:
|
||||
prompts = [line.strip() for line in f.readlines()]
|
||||
assert len(prompts) <= args.max_batch_size
|
||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||
prompt_tokens = [
|
||||
tokenizer.apply_chat_template(
|
||||
[{"role": "user", "content": prompt}], add_generation_prompt=True
|
||||
)
|
||||
for prompt in prompts
|
||||
]
|
||||
completion_tokens = generate(
|
||||
model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature
|
||||
)
|
||||
completions = tokenizer.batch_decode(
|
||||
completion_tokens, skip_special_tokens=True
|
||||
)
|
||||
for prompt, completion in zip(prompts, completions):
|
||||
print("Prompt:", prompt)
|
||||
print("Completion:", completion)
|
||||
@ -182,4 +207,11 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--temperature", type=float, default=0.2)
|
||||
args = parser.parse_args()
|
||||
assert args.input_file or args.interactive
|
||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||
main(
|
||||
args.ckpt_path,
|
||||
args.config,
|
||||
args.input_file,
|
||||
args.interactive,
|
||||
args.max_new_tokens,
|
||||
args.temperature,
|
||||
)
|
||||
|
@ -23,14 +23,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||
pid = tl.program_id(axis=0)
|
||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
x = tl.load(x_ptr + offs).to(tl.float32)
|
||||
s = tl.max(tl.abs(x)) / 448.
|
||||
s = tl.max(tl.abs(x)) / 448.0
|
||||
y = x / s
|
||||
y = y.to(y_ptr.dtype.element_ty)
|
||||
tl.store(y_ptr + offs, y)
|
||||
tl.store(s_ptr + pid, s)
|
||||
|
||||
|
||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
def act_quant(
|
||||
x: torch.Tensor, block_size: int = 128
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Quantizes the input tensor `x` using block-wise quantization.
|
||||
|
||||
@ -47,7 +49,7 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
||||
assert x.size(-1) % block_size == 0
|
||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
|
||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||
return y, s
|
||||
|
||||
@ -81,7 +83,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||
tl.store(y_ptr + offs, y, mask=mask)
|
||||
|
||||
|
||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||
def weight_dequant(
|
||||
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Dequantizes the given weight tensor using the provided scale tensor.
|
||||
|
||||
@ -100,24 +104,41 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
||||
assert x.dim() == 2 and s.dim() == 2
|
||||
M, N = x.size()
|
||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||
grid = lambda meta: (
|
||||
triton.cdiv(M, meta["BLOCK_SIZE"]),
|
||||
triton.cdiv(N, meta["BLOCK_SIZE"]),
|
||||
)
|
||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||
return y
|
||||
|
||||
|
||||
fp8_gemm_configs = [
|
||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
||||
Config(
|
||||
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
|
||||
num_stages=num_stages,
|
||||
num_warps=8,
|
||||
)
|
||||
for block_m in [16, 32, 64]
|
||||
for block_n in [32, 64, 128]
|
||||
for num_stages in [3, 4, 5, 6]
|
||||
]
|
||||
|
||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
||||
|
||||
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
|
||||
@triton.jit
|
||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
||||
a_s_ptr, b_s_ptr,
|
||||
M, N: tl.constexpr, K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr):
|
||||
def fp8_gemm_kernel(
|
||||
a_ptr,
|
||||
b_ptr,
|
||||
c_ptr,
|
||||
a_s_ptr,
|
||||
b_s_ptr,
|
||||
M,
|
||||
N: tl.constexpr,
|
||||
K: tl.constexpr,
|
||||
BLOCK_SIZE_M: tl.constexpr,
|
||||
BLOCK_SIZE_N: tl.constexpr,
|
||||
BLOCK_SIZE_K: tl.constexpr,
|
||||
):
|
||||
"""
|
||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
||||
|
||||
@ -186,6 +207,9 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
||||
M = a.numel() // K
|
||||
N = b.size(0)
|
||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
||||
grid = lambda META: (
|
||||
triton.cdiv(M, META["BLOCK_SIZE_M"]),
|
||||
triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
||||
)
|
||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||
return c
|
||||
|
@ -16,6 +16,7 @@ block_size = 128
|
||||
gemm_impl: Literal["bf16", "fp8"] = "bf16"
|
||||
attn_impl: Literal["naive", "absorb"] = "absorb"
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelArgs:
|
||||
"""
|
||||
@ -51,6 +52,7 @@ class ModelArgs:
|
||||
beta_slow (int): Slow beta correction factor.
|
||||
mscale (float): Scaling factor for extended attention.
|
||||
"""
|
||||
|
||||
max_batch_size: int = 8
|
||||
max_seq_len: int = 4096 * 4
|
||||
dtype: Literal["bf16", "fp8"] = "bf16"
|
||||
@ -68,7 +70,7 @@ class ModelArgs:
|
||||
n_expert_groups: int = 1
|
||||
n_limited_groups: int = 1
|
||||
score_func: Literal["softmax", "sigmoid"] = "softmax"
|
||||
route_scale: float = 1.
|
||||
route_scale: float = 1.0
|
||||
# mla
|
||||
q_lora_rank: int = 0
|
||||
kv_lora_rank: int = 512
|
||||
@ -81,7 +83,7 @@ class ModelArgs:
|
||||
rope_factor: float = 40
|
||||
beta_fast: int = 32
|
||||
beta_slow: int = 1
|
||||
mscale: float = 1.
|
||||
mscale: float = 1.0
|
||||
|
||||
|
||||
class ParallelEmbedding(nn.Module):
|
||||
@ -92,12 +94,13 @@ class ParallelEmbedding(nn.Module):
|
||||
vocab_size (int): Vocabulary size.
|
||||
dim (int): Embedding dimension.
|
||||
"""
|
||||
|
||||
def __init__(self, vocab_size: int, dim: int):
|
||||
super().__init__()
|
||||
self.vocab_size = vocab_size
|
||||
self.dim = dim
|
||||
assert vocab_size % world_size == 0
|
||||
self.part_vocab_size = (vocab_size // world_size)
|
||||
self.part_vocab_size = vocab_size // world_size
|
||||
self.vocab_start_idx = rank * self.part_vocab_size
|
||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
|
||||
@ -126,7 +129,9 @@ class ParallelEmbedding(nn.Module):
|
||||
return y
|
||||
|
||||
|
||||
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
def linear(
|
||||
x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies a linear transformation to the incoming data: y = xA^T + b.
|
||||
This function supports specialized implementations based on quantization
|
||||
@ -134,16 +139,16 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): The input tensor.
|
||||
weight (torch.Tensor): The weight tensor. It may be quantized and
|
||||
weight (torch.Tensor): The weight tensor. It may be quantized and
|
||||
requires dequantization for certain cases.
|
||||
bias (Optional[torch.Tensor]): The bias tensor to be added. Default is None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The result of the linear transformation, which may involve
|
||||
torch.Tensor: The result of the linear transformation, which may involve
|
||||
quantization-aware computations depending on the input parameters.
|
||||
|
||||
Notes:
|
||||
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
||||
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
||||
is used for computation.
|
||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||
@ -171,17 +176,24 @@ class Linear(nn.Module):
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
|
||||
dtype = torch.bfloat16
|
||||
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)
|
||||
)
|
||||
if self.weight.element_size() == 1:
|
||||
scale_out_features = (out_features + block_size - 1) // block_size
|
||||
scale_in_features = (in_features + block_size - 1) // block_size
|
||||
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
|
||||
self.weight.scale = self.scale = nn.Parameter(
|
||||
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
|
||||
)
|
||||
else:
|
||||
self.register_parameter("scale", None)
|
||||
if bias:
|
||||
@ -212,7 +224,10 @@ class ColumnParallelLinear(Linear):
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||
):
|
||||
assert out_features % world_size == 0
|
||||
self.part_out_features = out_features // world_size
|
||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||
@ -241,7 +256,10 @@ class RowParallelLinear(Linear):
|
||||
bias (bool): Whether to include a bias term. Defaults to False.
|
||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||
"""
|
||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = False, dtype=None
|
||||
):
|
||||
assert in_features % world_size == 0
|
||||
self.part_in_features = in_features // world_size
|
||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||
@ -272,6 +290,7 @@ class RMSNorm(nn.Module):
|
||||
dim (int): Dimension of the input tensor.
|
||||
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
@ -321,7 +340,11 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
Returns:
|
||||
float: The correction dimension based on the input parameters.
|
||||
"""
|
||||
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
|
||||
return (
|
||||
dim
|
||||
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
|
||||
/ (2 * math.log(base))
|
||||
)
|
||||
|
||||
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
|
||||
"""
|
||||
@ -339,7 +362,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
"""
|
||||
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
|
||||
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
|
||||
return max(low, 0), min(high, dim-1)
|
||||
return max(low, 0), min(high, dim - 1)
|
||||
|
||||
def linear_ramp_factor(min, max, dim):
|
||||
"""
|
||||
@ -362,7 +385,9 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
|
||||
|
||||
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
|
||||
if seqlen > args.original_seq_len:
|
||||
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
|
||||
low, high = find_correction_range(
|
||||
beta_fast, beta_slow, dim, base, args.original_seq_len
|
||||
)
|
||||
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
|
||||
freqs = freqs / factor * (1 - smooth) + freqs * smooth
|
||||
|
||||
@ -406,6 +431,7 @@ class MLA(nn.Module):
|
||||
v_head_dim (int): Dimensionality of value projections.
|
||||
softmax_scale (float): Scaling factor for softmax in attention computation.
|
||||
"""
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.dim = args.dim
|
||||
@ -423,24 +449,62 @@ class MLA(nn.Module):
|
||||
else:
|
||||
self.wq_a = Linear(self.dim, self.q_lora_rank)
|
||||
self.q_norm = RMSNorm(self.q_lora_rank)
|
||||
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
|
||||
self.wq_b = ColumnParallelLinear(
|
||||
self.q_lora_rank, self.n_heads * self.qk_head_dim
|
||||
)
|
||||
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
|
||||
self.kv_norm = RMSNorm(self.kv_lora_rank)
|
||||
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
|
||||
self.wkv_b = ColumnParallelLinear(
|
||||
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
|
||||
)
|
||||
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
|
||||
self.softmax_scale = self.qk_head_dim ** -0.5
|
||||
self.softmax_scale = self.qk_head_dim**-0.5
|
||||
if args.max_seq_len > args.original_seq_len:
|
||||
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
|
||||
self.softmax_scale = self.softmax_scale * mscale * mscale
|
||||
|
||||
if attn_impl == "naive":
|
||||
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
|
||||
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
|
||||
self.register_buffer(
|
||||
"k_cache",
|
||||
torch.zeros(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_heads,
|
||||
self.qk_head_dim,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"v_cache",
|
||||
torch.zeros(
|
||||
args.max_batch_size,
|
||||
args.max_seq_len,
|
||||
self.n_local_heads,
|
||||
self.v_head_dim,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
else:
|
||||
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
|
||||
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
|
||||
self.register_buffer(
|
||||
"kv_cache",
|
||||
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
|
||||
persistent=False,
|
||||
)
|
||||
self.register_buffer(
|
||||
"pe_cache",
|
||||
torch.zeros(
|
||||
args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
):
|
||||
"""
|
||||
Forward pass for the Multi-Headed Attention Layer (MLA).
|
||||
|
||||
@ -460,7 +524,9 @@ class MLA(nn.Module):
|
||||
else:
|
||||
q = self.wq_b(self.q_norm(self.wq_a(x)))
|
||||
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
|
||||
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
||||
q_nope, q_pe = torch.split(
|
||||
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
||||
)
|
||||
q_pe = apply_rotary_emb(q_pe, freqs_cis)
|
||||
kv = self.wkv_a(x)
|
||||
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
||||
@ -468,20 +534,35 @@ class MLA(nn.Module):
|
||||
if attn_impl == "naive":
|
||||
q = torch.cat([q_nope, q_pe], dim=-1)
|
||||
kv = self.wkv_b(self.kv_norm(kv))
|
||||
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
||||
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
|
||||
kv = kv.view(
|
||||
bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim
|
||||
)
|
||||
k_nope, v = torch.split(
|
||||
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
|
||||
)
|
||||
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
|
||||
self.k_cache[:bsz, start_pos:end_pos] = k
|
||||
self.v_cache[:bsz, start_pos:end_pos] = v
|
||||
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
|
||||
scores = (
|
||||
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos])
|
||||
* self.softmax_scale
|
||||
)
|
||||
else:
|
||||
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||
wkv_b = (
|
||||
self.wkv_b.weight
|
||||
if self.wkv_b.scale is None
|
||||
else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
|
||||
)
|
||||
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
|
||||
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
|
||||
q_nope = torch.einsum(
|
||||
"bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]
|
||||
)
|
||||
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
|
||||
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
|
||||
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
|
||||
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
|
||||
scores = (
|
||||
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
|
||||
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
|
||||
) * self.softmax_scale
|
||||
if mask is not None:
|
||||
scores += mask.unsqueeze(1)
|
||||
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
|
||||
@ -489,7 +570,7 @@ class MLA(nn.Module):
|
||||
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
|
||||
else:
|
||||
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
|
||||
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
|
||||
x = self.wo(x.flatten(2))
|
||||
return x
|
||||
|
||||
@ -503,6 +584,7 @@ class MLP(nn.Module):
|
||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, inter_dim: int):
|
||||
"""
|
||||
Initializes the MLP layer.
|
||||
@ -543,6 +625,7 @@ class Gate(nn.Module):
|
||||
weight (torch.nn.Parameter): Learnable weights for the gate.
|
||||
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
|
||||
"""
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Gate module.
|
||||
@ -558,7 +641,11 @@ class Gate(nn.Module):
|
||||
self.score_func = args.score_func
|
||||
self.route_scale = args.route_scale
|
||||
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
|
||||
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
|
||||
self.bias = (
|
||||
nn.Parameter(torch.empty(args.n_routed_experts))
|
||||
if self.dim == 7168
|
||||
else None
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
@ -604,6 +691,7 @@ class Expert(nn.Module):
|
||||
w2 (nn.Module): Linear layer for hidden-to-output transformation.
|
||||
w3 (nn.Module): Additional linear layer for feature transformation.
|
||||
"""
|
||||
|
||||
def __init__(self, dim: int, inter_dim: int):
|
||||
"""
|
||||
Initializes the Expert layer.
|
||||
@ -643,6 +731,7 @@ class MoE(nn.Module):
|
||||
experts (nn.ModuleList): List of expert modules.
|
||||
shared_experts (nn.Module): Shared experts applied to all inputs.
|
||||
"""
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the MoE module.
|
||||
@ -659,8 +748,14 @@ class MoE(nn.Module):
|
||||
self.experts_start_idx = rank * self.n_local_experts
|
||||
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
|
||||
self.gate = Gate(args)
|
||||
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
|
||||
for i in range(self.n_routed_experts)])
|
||||
self.experts = nn.ModuleList(
|
||||
[
|
||||
Expert(args.dim, args.moe_inter_dim)
|
||||
if self.experts_start_idx <= i < self.experts_end_idx
|
||||
else None
|
||||
for i in range(self.n_routed_experts)
|
||||
]
|
||||
)
|
||||
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
@ -677,7 +772,9 @@ class MoE(nn.Module):
|
||||
x = x.view(-1, self.dim)
|
||||
weights, indices = self.gate(x)
|
||||
y = torch.zeros_like(x)
|
||||
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
|
||||
counts = torch.bincount(
|
||||
indices.flatten(), minlength=self.n_routed_experts
|
||||
).tolist()
|
||||
for i in range(self.experts_start_idx, self.experts_end_idx):
|
||||
if counts[i] == 0:
|
||||
continue
|
||||
@ -700,6 +797,7 @@ class Block(nn.Module):
|
||||
attn_norm (nn.Module): Layer normalization for attention.
|
||||
ffn_norm (nn.Module): Layer normalization for feed-forward network.
|
||||
"""
|
||||
|
||||
def __init__(self, layer_id: int, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Transformer block.
|
||||
@ -710,11 +808,21 @@ class Block(nn.Module):
|
||||
"""
|
||||
super().__init__()
|
||||
self.attn = MLA(args)
|
||||
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
|
||||
self.ffn = (
|
||||
MLP(args.dim, args.inter_dim)
|
||||
if layer_id < args.n_dense_layers
|
||||
else MoE(args)
|
||||
)
|
||||
self.attn_norm = RMSNorm(args.dim)
|
||||
self.ffn_norm = RMSNorm(args.dim)
|
||||
|
||||
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
start_pos: int,
|
||||
freqs_cis: torch.Tensor,
|
||||
mask: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass for the Transformer block.
|
||||
|
||||
@ -744,6 +852,7 @@ class Transformer(nn.Module):
|
||||
head (nn.Module): Output projection layer mapping to vocabulary size.
|
||||
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, args: ModelArgs):
|
||||
"""
|
||||
Initializes the Transformer model.
|
||||
@ -762,7 +871,9 @@ class Transformer(nn.Module):
|
||||
for layer_id in range(args.n_layers):
|
||||
self.layers.append(Block(layer_id, args))
|
||||
self.norm = RMSNorm(args.dim)
|
||||
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
|
||||
self.head = ColumnParallelLinear(
|
||||
args.dim, args.vocab_size, dtype=torch.get_default_dtype()
|
||||
)
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
|
||||
|
||||
@torch.inference_mode()
|
||||
@ -779,10 +890,12 @@ class Transformer(nn.Module):
|
||||
"""
|
||||
seqlen = tokens.size(1)
|
||||
h = self.embed(tokens)
|
||||
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
|
||||
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
|
||||
mask = None
|
||||
if seqlen > 1:
|
||||
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
|
||||
mask = torch.full(
|
||||
(seqlen, seqlen), float("-inf"), device=tokens.device
|
||||
).triu_(1)
|
||||
for layer in self.layers:
|
||||
h = layer(h, start_pos, freqs_cis, mask)
|
||||
h = self.norm(h)[:, -1]
|
||||
|
Loading…
Reference in New Issue
Block a user