Added various error handlers and Issue templates.

This commit is contained in:
Ankur Halder 2025-01-29 12:57:39 +05:30
parent b5d872ead0
commit 9d5abb844f
17 changed files with 829 additions and 247 deletions

25
.github/ISSUE_TEMPLATE/Improvement.md vendored Normal file
View File

@ -0,0 +1,25 @@
# Improvement / Enhancement
name: Improvement / Enhancement
about: Suggest an improvement or optimization for an existing feature
title: "[IMPROVEMENT] Short Description"
labels: "enhancement"
assignees: ''
---
## 💡 Improvement Description
A clear and concise description of the improvement you're suggesting.
## ✅ Why this improvement is needed
Explain the benefit of this improvement, and how it will enhance the project.
## 🔄 Describe alternatives you've considered
If applicable, describe any alternative approaches or solutions.
## 📄 Additional Context
Any other relevant details or context for this improvement.

View File

@ -1,23 +1,38 @@
---
name: Bug report
about: Create a report to help us improve
title: "[BUG]"
labels: ''
# Bug Report
name: Bug Report
about: Report a bug to help us improve the project
title: "[BUG] Short Description"
labels: "bug"
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
## 🐞 Bug Description
**To Reproduce**
Steps to reproduce the behavior.
A clear and concise description of the issue.
## ✅ Steps to Reproduce
1. Go to '...'
2. Click on '...'
3. Observe the error
## 🎯 Expected Behavior
**Expected behavior**
A clear and concise description of what you expected to happen.
**Screenshots**
If applicable, add screenshots to help explain your problem.
## 📸 Screenshots (if applicable)
**Additional context**
Add any other context about the problem here.
Attach screenshots or screen recordings to illustrate the problem.
## 🛠️ Environment (please complete the following information)
- OS: [e.g., Windows, macOS, Linux]
- Browser: [e.g., Chrome, Firefox, Edge]
- Node.js Version: [e.g., 18.x]
- Other relevant environment details
## 📄 Additional Context
Add any other relevant details, logs, or error messages here.

21
.github/ISSUE_TEMPLATE/chore.md vendored Normal file
View File

@ -0,0 +1,21 @@
# Chore Template
name: Chore
about: Routine tasks like dependency updates or code cleanup
title: "[CHORE] Short Description"
labels: "chore"
assignees: ''
---
## 🧹 Task Description
Describe the task that needs to be done (e.g., update dependencies, clean up code).
## ✅ Expected Outcome
Describe the desired result of completing this task.
## 📄 Additional Context
Any other relevant details or instructions.

21
.github/ISSUE_TEMPLATE/discussion.md vendored Normal file
View File

@ -0,0 +1,21 @@
# Discussion Template
name: Discussion
about: Initiate a discussion or brainstorm ideas for the project
title: "[DISCUSSION] Short Description"
labels: "discussion"
assignees: ''
---
## 🧠 Discussion Topic
Describe the topic or idea you'd like to discuss or brainstorm.
## 🔄 Related Ideas
Provide any relevant context, ideas, or previous discussions related to the topic.
## 📄 Additional Context
Any other information that may help with the discussion.

View File

@ -0,0 +1,21 @@
# Documentation Update
name: Documentation Update
about: Suggest an update or correction to the project documentation
title: "[DOC] Short Description"
labels: "documentation"
assignees: ''
---
## 📚 Documentation Issue
Describe the part of the documentation that needs updating or correcting.
## ✅ Suggested Changes
Provide a clear description of what needs to be added, changed, or fixed.
## 📄 Additional Context
Any other relevant details or links to the documentation.

View File

@ -1,20 +1,34 @@
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: ''
# Feature Request
about: Suggest a new idea or improvement for this project
title: "[FEATURE] Short Description"
labels: "enhancement"
assignees: ''
---
**Is your feature request related to a problem? Please describe.**
A clear and concise description of what the problem is. Ex. I'm always frustrated when [...]
## 💡 Feature Request Description
**Describe the solution you'd like**
A clear and concise description of what you want to happen.
A clear and concise description of the feature or improvement you're suggesting.
**Describe alternatives you've considered**
A clear and concise description of any alternative solutions or features you've considered.
## 🚧 Is your feature request related to a problem? Please describe
**Additional context**
Add any other context or screenshots about the feature request here.
Describe the problem you're facing or the limitation that this feature will address.
Example: "I find it difficult to... because..."
## ✅ Describe the solution you'd like
A clear and concise description of what you'd like to see happen.
Example: "It would be great if we could..."
## 🔄 Describe alternatives you've considered
If applicable, describe any other solutions or features you've considered as alternatives to your request.
## 📸 Screenshots or Mockups (if applicable)
Provide any visuals to help explain your feature request.
## 📄 Additional Context
Add any other relevant details, context, or use cases for this feature.

View File

@ -0,0 +1,25 @@
# Performance Issue Template
name: Performance Issue
about: Report a performance-related issue or bottleneck
title: "[PERFORMANCE] Short Description"
labels: "performance"
assignees: ''
---
## 🚀 Performance Issue Description
Describe the performance problem you're facing (e.g., slow load times, high CPU usage).
## 🕹️ Steps to Reproduce
Steps to reproduce the performance issue, if applicable.
## 🎯 Expected Performance
Describe the expected performance or benchmark for comparison.
## 📄 Additional Context
Provide any logs, metrics, or other relevant data.

21
.github/ISSUE_TEMPLATE/question.md vendored Normal file
View File

@ -0,0 +1,21 @@
# Question Template
name: Question
about: Ask a question related to the project
title: "[QUESTION] Short Description"
labels: "question"
assignees: ''
---
## ❓ Question Description
Clearly state your question or request for clarification.
## 🔄 Related Information
Provide any relevant context, such as code snippets, error messages, or links.
## 📄 Additional Context
Any other information that may help answer your question.

View File

@ -0,0 +1,21 @@
# Security Issue Report
name: Security Issue
about: Report a potential security vulnerability
title: "[SECURITY] Short Description"
labels: "security"
assignees: ''
---
## 🔐 Security Issue Description
Describe the potential security vulnerability or concern.
## 🛠️ Steps to Reproduce
Steps to reproduce the issue, if applicable.
## 📄 Additional Context
Provide any logs, error messages, or other relevant data.

21
.github/ISSUE_TEMPLATE/test_failure.md vendored Normal file
View File

@ -0,0 +1,21 @@
# Test Failure Report
name: Test Failure
about: Report a failed test case or issue with automated tests
title: "[TEST FAILURE] Short Description"
labels: "test"
assignees: ''
---
## 🧪 Test Failure Description
Describe the failed test case and the issue you're encountering.
## 🛠️ Steps to Reproduce
Steps to reproduce the test failure, including any relevant test case details.
## 📄 Additional Context
Provide any logs, error messages, or relevant data.

View File

@ -1,4 +1,4 @@
DEEPSEEK LICENSE AGREEMENT
# DEEPSEEK LICENSE AGREEMENT
Version 1.0, 23 October 2023
@ -19,16 +19,16 @@ This License governs the use of the model (and its derivatives) and is informed
NOW THEREFORE, You and DeepSeek agree as follows:
1. Definitions
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
"License" means the terms and conditions for use, reproduction, and Distribution as defined in this document.
"Data" means a collection of information and/or content extracted from the dataset used with the Model, including to train, pretrain, or otherwise evaluate the Model. The Data is not licensed under this License.
"Output" means the results of operating a Model as embodied in informational content resulting therefrom.
"Model" means any accompanying machine-learning based assemblies (including checkpoints), consisting of learnt weights, parameters (including optimizer states), corresponding to the model architecture as embodied in the Complementary Material, that have been trained or tuned, in whole or in part on the Data, using the Complementary Material.
"Derivatives of the Model" means all modifications to the Model, works based on the Model, or any other model which is created or initialized by transfer of patterns of the weights, parameters, activations or output of the Model, to the other model, in order to cause the other model to perform similarly to the Model, including - but not limited to - distillation methods entailing the use of intermediate data representations or methods based on the generation of synthetic data by the Model for training the other model.
"Complementary Material" means the accompanying source code and scripts used to define, run, load, benchmark or evaluate the Model, and used to prepare data for training or evaluation, if any. This includes any accompanying documentation, tutorials, examples, etc, if any.
"Distribution" means any transmission, reproduction, publication or other sharing of the Model or Derivatives of the Model to a third party, including providing the Model as a hosted service made available by electronic or other remote means - e.g. API-based or web access.
"DeepSeek" (or "we") means Beijing DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd., Hangzhou DeepSeek Artificial Intelligence Fundamental Technology Research Co., Ltd. and/or any of their affiliates.
"You" (or "Your") means an individual or Legal Entity exercising permissions granted by this License and/or making use of the Model for whichever purpose and in any field of use, including usage of the Model in an end-use application - e.g. chatbot, translator, etc.
"Third Parties" means individuals or legal entities that are not under common control with DeepSeek or You.
Section II: INTELLECTUAL PROPERTY RIGHTS
@ -38,15 +38,14 @@ Both copyright and patent grants apply to the Model, Derivatives of the Model an
3. Grant of Patent License. Subject to the terms and conditions of this License and where and as applicable, DeepSeek hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this paragraph) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Model and the Complementary Material, where such license applies only to those patent claims licensable by DeepSeek that are necessarily infringed by its contribution(s). If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Model and/or Complementary Material constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for the Model and/or works shall terminate as of the date such litigation is asserted or filed.
Section III: CONDITIONS OF USAGE, DISTRIBUTION AND REDISTRIBUTION
4. Distribution and Redistribution. You may host for Third Party remote access purposes (e.g. software-as-a-service), reproduce and distribute copies of the Model or Derivatives of the Model thereof in any medium, with or without modifications, provided that You meet the following conditions:
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
c. You must cause any modified files to carry prominent notices stating that You changed the files;
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
a. Use-based restrictions as referenced in paragraph 5 MUST be included as an enforceable provision by You in any type of legal agreement (e.g. a license) governing the use and/or distribution of the Model or Derivatives of the Model, and You shall give notice to subsequent users You Distribute to, that the Model or Derivatives of the Model are subject to paragraph 5. This provision does not apply to the use of Complementary Material.
b. You must give any Third Party recipients of the Model or Derivatives of the Model a copy of this License;
c. You must cause any modified files to carry prominent notices stating that You changed the files;
d. You must retain all copyright, patent, trademark, and attribution notices excluding those notices that do not pertain to any part of the Model, Derivatives of the Model.
e. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions - respecting paragraph 4.a. for use, reproduction, or Distribution of Your modifications, or for any such Derivatives of the Model as a whole, provided Your use, reproduction, and Distribution of the Model otherwise complies with the conditions stated in this License.
5. Use-based restrictions. The restrictions set forth in Attachment A are considered Use-based restrictions. Therefore You cannot use the Model and the Derivatives of the Model for the specified restricted uses. You may use the Model subject to this License, including only for lawful purposes and in accordance with the License. Use may include creating any content with, finetuning, updating, running, training, evaluating and/or reparametrizing the Model. You shall require all of Your users who use the Model or a Derivative of the Model to comply with the terms of this paragraph (paragraph 5).

View File

@ -20,7 +20,7 @@ The DeepSeek-V3 weight file consists of two main components: **Main Model Weight
- Total parameters: **671B**
- Activation parameters: **36.7B** (including 0.9B for Embedding and 0.9B for the output Head).
#### Structural Details
#### Main Model Structural Details
- **Embedding Layer**:
- `model.embed_tokens.weight`

View File

@ -1,6 +1,6 @@
import os
import shutil
from argparse import ArgumentParser
from argparse import ArgumentParser, ArgumentTypeError
from glob import glob
from tqdm import tqdm, trange
@ -30,6 +30,19 @@ mapping = {
}
def validate_positive_integer(value):
"""
Helper function to validate that a value is a positive integer.
"""
try:
ivalue = int(value)
if ivalue <= 0:
raise ArgumentTypeError(f"{value} is not a positive integer")
return ivalue
except ValueError:
raise ArgumentTypeError(f"{value} is not a valid integer")
def main(hf_ckpt_path, save_path, n_experts, mp):
"""
Converts and saves model checkpoint files into a specified format.
@ -43,54 +56,109 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
Returns:
None
"""
try:
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
if not os.path.exists(hf_ckpt_path):
raise FileNotFoundError(f"Checkpoint path '{hf_ckpt_path}' does not exist.")
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
if not os.path.isfile(file_path):
continue
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name[len("model.") :]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
key = name.split(".")[-2]
assert key in mapping
if key not in mapping:
raise KeyError(
f"Unexpected key '{key}' in tensor name '{name}'."
)
new_key, dim = mapping[key]
name = name.replace(key, new_key)
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
if (
idx < i * n_local_experts
or idx >= (i + 1) * n_local_experts
):
continue
elif dim is not None:
assert param.size(dim) % mp == 0
if param.size(dim) % mp != 0:
raise ValueError(
f"Tensor dimension mismatch for '{name}' (size {param.size(dim)})."
)
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
new_param = param.narrow(
dim, i * shard_size, shard_size
).contiguous()
state_dicts[i][name] = new_param
except Exception as e:
print(f"Error processing file {file_path}: {e}")
continue
os.makedirs(save_path, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
try:
save_file(
state_dicts[i],
os.path.join(save_path, f"model{i}-mp{mp}.safetensors"),
)
except Exception as e:
print(f"Error saving file for model {i}: {e}")
continue
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
try:
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
except Exception as e:
print(f"Error copying token file {file_path}: {e}")
continue
except Exception as e:
print(f"An unexpected error occurred: {e}")
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
parser.add_argument(
"--hf-ckpt-path", type=str, required=True, help="Path to the checkpoint files."
)
parser.add_argument(
"--save-path", type=str, required=True, help="Path to save the converted files."
)
parser.add_argument(
"--n-experts",
type=validate_positive_integer,
required=True,
help="Total number of experts in the model.",
)
parser.add_argument(
"--model-parallel",
type=validate_positive_integer,
required=True,
help="Model parallelism factor.",
)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0
if args.n_experts % args.model_parallel != 0:
raise ValueError(
f"Number of experts ({args.n_experts}) must be divisible by model parallelism factor ({args.model_parallel})."
)
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)

View File

@ -9,6 +9,7 @@ from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
@ -30,17 +31,31 @@ def main(fp8_path, bf16_path):
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
try:
os.makedirs(bf16_path, exist_ok=True)
except OSError as e:
print(f"Error creating directory {bf16_path}: {e}")
return
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
if not os.path.isfile(model_index_file):
print(f"Error: Model index file '{model_index_file}' does not exist.")
return
try:
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
except (json.JSONDecodeError, OSError) as e:
print(f"Error reading model index file '{model_index_file}': {e}")
return
weight_map = model_index.get("weight_map", {})
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
@ -54,53 +69,78 @@ def main(fp8_path, bf16_path):
Raises:
KeyError: If the tensor does not exist in the safetensor file.
"""
try:
file_name = weight_map[tensor_name]
except KeyError:
raise KeyError(f"Tensor '{tensor_name}' not found in weight map.")
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
try:
loaded_files[file_name] = load_file(file_path, device="cuda")
except (FileNotFoundError, OSError) as e:
raise FileNotFoundError(f"Error loading file '{file_path}': {e}")
return loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
try:
current_state_dict = load_file(safetensor_file, device="cuda")
except (FileNotFoundError, OSError) as e:
print(f"Error loading safetensor file '{safetensor_file}': {e}")
continue
loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
elif weight.element_size() == 1:
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
print(
f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion"
)
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
try:
save_file(new_state_dict, new_safetensor_file)
except (OSError, RuntimeError) as e:
print(f"Error saving safetensor file '{new_safetensor_file}': {e}")
continue
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
try:
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
except (OSError, json.JSONDecodeError) as e:
print(f"Error writing new model index file '{new_model_index_file}': {e}")
return
if __name__ == "__main__":
@ -108,5 +148,14 @@ if __name__ == "__main__":
parser.add_argument("--input-fp8-hf-path", type=str, required=True)
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
if not os.path.isdir(args.input_fp8_hf_path):
print(
f"Error: Input FP8 path '{args.input_fp8_hf_path}' is not a valid directory."
)
elif not os.path.isdir(args.output_bf16_hf_path):
print(
f"Error: Output BF16 path '{args.output_bf16_hf_path}' is not a valid directory."
)
else:
main(args.input_fp8_hf_path, args.output_bf16_hf_path)

View File

@ -33,7 +33,7 @@ def generate(
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
temperature: float = 1.0,
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
@ -51,9 +51,11 @@ def generate(
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
tokens = torch.full(
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
tokens[i, : len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
prompt_mask = tokens != -1
@ -63,7 +65,9 @@ def generate(
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
next_token = torch.where(
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
@ -71,9 +75,9 @@ def generate(
break
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
toks = toks[prompt_lens[i] : prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
toks = toks[: toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
@ -97,26 +101,75 @@ def main(
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
try:
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
except ValueError as e:
raise ValueError(
"Environment variables WORLD_SIZE, RANK, or LOCAL_RANK are not set correctly."
) from e
if world_size > 1:
try:
dist.init_process_group("nccl")
except Exception as e:
raise RuntimeError(
"Failed to initialize the process group for distributed training."
) from e
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
if not os.path.isfile(config):
raise FileNotFoundError(f"Configuration file {config} not found.")
try:
with open(config) as f:
args = ModelArgs(**json.load(f))
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to parse JSON from the configuration file {config}."
) from e
print(args)
try:
with torch.device("cuda"):
model = Transformer(args)
except Exception as e:
raise RuntimeError("Failed to load the model.") from e
if not os.path.isdir(ckpt_path):
raise FileNotFoundError(f"Checkpoint directory {ckpt_path} not found.")
try:
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
except Exception as e:
raise RuntimeError(f"Failed to load tokenizer from {ckpt_path}.") from e
try:
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.0)[0])
except Exception as e:
raise RuntimeError(
"Failed to generate tokens using the model and tokenizer."
) from e
try:
load_model(
model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
)
except FileNotFoundError as e:
raise FileNotFoundError(
f"Model file not found at {os.path.join(ckpt_path, f'model{rank}-mp{world_size}.safetensors')}"
) from e
except Exception as e:
raise RuntimeError("Failed to load the model checkpoint.") from e
if interactive:
messages = []
@ -131,24 +184,52 @@ def main(
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
prompt_tokens = tokenizer.apply_chat_template(
messages, add_generation_prompt=True
)
completion_tokens = generate(
model,
[prompt_tokens],
max_new_tokens,
tokenizer.eos_token_id,
temperature,
)
completion = tokenizer.decode(
completion_tokens[0], skip_special_tokens=True
)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
if not input_file:
raise ValueError("You must specify an input file for batch processing.")
if not os.path.isfile(input_file):
raise FileNotFoundError(f"Input file {input_file} not found.")
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
prompt_tokens = [
tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}], add_generation_prompt=True
)
for prompt in prompts
]
completion_tokens = generate(
model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature
)
completions = tokenizer.batch_decode(
completion_tokens, skip_special_tokens=True
)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
@ -181,5 +262,17 @@ if __name__ == "__main__":
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
assert (
args.input_file or args.interactive
), "You must specify either an input file or enable interactive mode."
try:
main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature,
)
except Exception as e:
print(f"Error: {e}")

View File

@ -23,14 +23,16 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
s = tl.max(tl.abs(x)) / 448.0
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
def act_quant(
x: torch.Tensor, block_size: int = 128
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
@ -42,12 +44,20 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
Raises:
ValueError: If the tensor is not contiguous or the last dimension is not divisible by `block_size`.
"""
assert x.is_contiguous()
assert x.size(-1) % block_size == 0
if not x.is_contiguous():
raise ValueError("Input tensor must be contiguous in memory.")
if x.size(-1) % block_size != 0:
raise ValueError(
f"The last dimension of the tensor must be divisible by {block_size}."
)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
grid = lambda meta: (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@ -81,7 +91,9 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
def weight_dequant(
x: torch.Tensor, s: torch.Tensor, block_size: int = 128
) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
@ -94,30 +106,52 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
ValueError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
if not x.is_contiguous():
raise ValueError("Input tensor `x` must be contiguous.")
if not s.is_contiguous():
raise ValueError("Input tensor `s` must be contiguous.")
if x.dim() != 2 or s.dim() != 2:
raise ValueError("Both `x` and `s` must be 2D tensors.")
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE_M"]),
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
Config(
{"BLOCK_SIZE_M": block_m, "BLOCK_SIZE_N": block_n, "BLOCK_SIZE_K": 128},
num_stages=num_stages,
num_warps=8,
)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@triton.autotune(configs=fp8_gemm_configs, key=["N", "K"])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
def fp8_gemm_kernel(
a_ptr,
b_ptr,
c_ptr,
a_s_ptr,
b_s_ptr,
M,
N: tl.constexpr,
K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
BLOCK_SIZE_K: tl.constexpr,
):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
@ -173,19 +207,35 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factors for matrix `a`.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factors for matrix `b`.
Returns:
torch.Tensor: The result of the matrix multiplication.
torch.Tensor: The resulting matrix after multiplication.
Raises:
ValueError: If `a`, `b`, `a_s`, or `b_s` are not contiguous.
"""
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
if not a.is_contiguous():
raise ValueError("Matrix `a` must be contiguous.")
if not b.is_contiguous():
raise ValueError("Matrix `b` must be contiguous.")
if not a_s.is_contiguous():
raise ValueError("Scaling factors `a_s` must be contiguous.")
if not b_s.is_contiguous():
raise ValueError("Scaling factors `b_s` must be contiguous.")
M, K = a.size()
K_, N = b.size()
assert K == K_, "Matrix dimensions do not match for multiplication."
c = torch.empty((M, N), dtype=torch.float32)
grid = lambda meta: (
triton.cdiv(M, meta["BLOCK_SIZE_M"]),
triton.cdiv(N, meta["BLOCK_SIZE_N"]),
)
fp8_gemm_kernel[grid](
a, b, c, a_s, b_s, M, N, K, BLOCK_SIZE_M=128, BLOCK_SIZE_N=128, BLOCK_SIZE_K=128
)
return c

View File

@ -16,6 +16,7 @@ block_size = 128
gemm_impl: Literal["bf16", "fp8"] = "bf16"
attn_impl: Literal["naive", "absorb"] = "absorb"
@dataclass
class ModelArgs:
"""
@ -51,6 +52,7 @@ class ModelArgs:
beta_slow (int): Slow beta correction factor.
mscale (float): Scaling factor for extended attention.
"""
max_batch_size: int = 8
max_seq_len: int = 4096 * 4
dtype: Literal["bf16", "fp8"] = "bf16"
@ -68,7 +70,7 @@ class ModelArgs:
n_expert_groups: int = 1
n_limited_groups: int = 1
score_func: Literal["softmax", "sigmoid"] = "softmax"
route_scale: float = 1.
route_scale: float = 1.0
# mla
q_lora_rank: int = 0
kv_lora_rank: int = 512
@ -81,7 +83,7 @@ class ModelArgs:
rope_factor: float = 40
beta_fast: int = 32
beta_slow: int = 1
mscale: float = 1.
mscale: float = 1.0
class ParallelEmbedding(nn.Module):
@ -92,12 +94,13 @@ class ParallelEmbedding(nn.Module):
vocab_size (int): Vocabulary size.
dim (int): Embedding dimension.
"""
def __init__(self, vocab_size: int, dim: int):
super().__init__()
self.vocab_size = vocab_size
self.dim = dim
assert vocab_size % world_size == 0
self.part_vocab_size = (vocab_size // world_size)
self.part_vocab_size = vocab_size // world_size
self.vocab_start_idx = rank * self.part_vocab_size
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
self.weight = nn.Parameter(torch.empty(self.part_vocab_size, self.dim))
@ -126,7 +129,9 @@ class ParallelEmbedding(nn.Module):
return y
def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
def linear(
x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None
) -> torch.Tensor:
"""
Applies a linear transformation to the incoming data: y = xA^T + b.
This function supports specialized implementations based on quantization
@ -171,17 +176,24 @@ class Linear(nn.Module):
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
dtype = torch.bfloat16
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = nn.Parameter(torch.empty(out_features, in_features, dtype=dtype or Linear.dtype))
self.weight = nn.Parameter(
torch.empty(out_features, in_features, dtype=dtype or Linear.dtype)
)
if self.weight.element_size() == 1:
scale_out_features = (out_features + block_size - 1) // block_size
scale_in_features = (in_features + block_size - 1) // block_size
self.weight.scale = self.scale = nn.Parameter(torch.empty(scale_out_features, scale_in_features, dtype=torch.float32))
self.weight.scale = self.scale = nn.Parameter(
torch.empty(scale_out_features, scale_in_features, dtype=torch.float32)
)
else:
self.register_parameter("scale", None)
if bias:
@ -212,7 +224,10 @@ class ColumnParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
assert out_features % world_size == 0
self.part_out_features = out_features // world_size
super().__init__(in_features, self.part_out_features, bias, dtype)
@ -241,7 +256,10 @@ class RowParallelLinear(Linear):
bias (bool): Whether to include a bias term. Defaults to False.
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
"""
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
def __init__(
self, in_features: int, out_features: int, bias: bool = False, dtype=None
):
assert in_features % world_size == 0
self.part_in_features = in_features // world_size
super().__init__(self.part_in_features, out_features, bias, dtype)
@ -272,6 +290,7 @@ class RMSNorm(nn.Module):
dim (int): Dimension of the input tensor.
eps (float): Epsilon value for numerical stability. Defaults to 1e-6.
"""
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.dim = dim
@ -321,7 +340,11 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
Returns:
float: The correction dimension based on the input parameters.
"""
return dim * math.log(max_seq_len / (num_rotations * 2 * math.pi)) / (2 * math.log(base))
return (
dim
* math.log(max_seq_len / (num_rotations * 2 * math.pi))
/ (2 * math.log(base))
)
def find_correction_range(low_rot, high_rot, dim, base, max_seq_len):
"""
@ -339,7 +362,7 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
"""
low = math.floor(find_correction_dim(low_rot, dim, base, max_seq_len))
high = math.ceil(find_correction_dim(high_rot, dim, base, max_seq_len))
return max(low, 0), min(high, dim-1)
return max(low, 0), min(high, dim - 1)
def linear_ramp_factor(min, max, dim):
"""
@ -362,7 +385,9 @@ def precompute_freqs_cis(args: ModelArgs) -> torch.Tensor:
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
if seqlen > args.original_seq_len:
low, high = find_correction_range(beta_fast, beta_slow, dim, base, args.original_seq_len)
low, high = find_correction_range(
beta_fast, beta_slow, dim, base, args.original_seq_len
)
smooth = 1 - linear_ramp_factor(low, high, dim // 2)
freqs = freqs / factor * (1 - smooth) + freqs * smooth
@ -406,6 +431,7 @@ class MLA(nn.Module):
v_head_dim (int): Dimensionality of value projections.
softmax_scale (float): Scaling factor for softmax in attention computation.
"""
def __init__(self, args: ModelArgs):
super().__init__()
self.dim = args.dim
@ -423,24 +449,62 @@ class MLA(nn.Module):
else:
self.wq_a = Linear(self.dim, self.q_lora_rank)
self.q_norm = RMSNorm(self.q_lora_rank)
self.wq_b = ColumnParallelLinear(self.q_lora_rank, self.n_heads * self.qk_head_dim)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank, self.n_heads * self.qk_head_dim
)
self.wkv_a = Linear(self.dim, self.kv_lora_rank + self.qk_rope_head_dim)
self.kv_norm = RMSNorm(self.kv_lora_rank)
self.wkv_b = ColumnParallelLinear(self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim))
self.wkv_b = ColumnParallelLinear(
self.kv_lora_rank, self.n_heads * (self.qk_nope_head_dim + self.v_head_dim)
)
self.wo = RowParallelLinear(self.n_heads * self.v_head_dim, self.dim)
self.softmax_scale = self.qk_head_dim ** -0.5
self.softmax_scale = self.qk_head_dim**-0.5
if args.max_seq_len > args.original_seq_len:
mscale = 0.1 * args.mscale * math.log(args.rope_factor) + 1.0
self.softmax_scale = self.softmax_scale * mscale * mscale
if attn_impl == "naive":
self.register_buffer("k_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.qk_head_dim), persistent=False)
self.register_buffer("v_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.n_local_heads, self.v_head_dim), persistent=False)
self.register_buffer(
"k_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_heads,
self.qk_head_dim,
),
persistent=False,
)
self.register_buffer(
"v_cache",
torch.zeros(
args.max_batch_size,
args.max_seq_len,
self.n_local_heads,
self.v_head_dim,
),
persistent=False,
)
else:
self.register_buffer("kv_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank), persistent=False)
self.register_buffer("pe_cache", torch.zeros(args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim), persistent=False)
self.register_buffer(
"kv_cache",
torch.zeros(args.max_batch_size, args.max_seq_len, self.kv_lora_rank),
persistent=False,
)
self.register_buffer(
"pe_cache",
torch.zeros(
args.max_batch_size, args.max_seq_len, self.qk_rope_head_dim
),
persistent=False,
)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]):
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
):
"""
Forward pass for the Multi-Headed Attention Layer (MLA).
@ -460,7 +524,9 @@ class MLA(nn.Module):
else:
q = self.wq_b(self.q_norm(self.wq_a(x)))
q = q.view(bsz, seqlen, self.n_local_heads, self.qk_head_dim)
q_nope, q_pe = torch.split(q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
q_nope, q_pe = torch.split(
q, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
)
q_pe = apply_rotary_emb(q_pe, freqs_cis)
kv = self.wkv_a(x)
kv, k_pe = torch.split(kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
@ -468,20 +534,35 @@ class MLA(nn.Module):
if attn_impl == "naive":
q = torch.cat([q_nope, q_pe], dim=-1)
kv = self.wkv_b(self.kv_norm(kv))
kv = kv.view(bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = torch.split(kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
kv = kv.view(
bsz, seqlen, self.n_local_heads, self.qk_nope_head_dim + self.v_head_dim
)
k_nope, v = torch.split(
kv, [self.qk_nope_head_dim, self.v_head_dim], dim=-1
)
k = torch.cat([k_nope, k_pe.expand(-1, -1, self.n_local_heads, -1)], dim=-1)
self.k_cache[:bsz, start_pos:end_pos] = k
self.v_cache[:bsz, start_pos:end_pos] = v
scores = torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos]) * self.softmax_scale
scores = (
torch.einsum("bshd,bthd->bsht", q, self.k_cache[:bsz, :end_pos])
* self.softmax_scale
)
else:
wkv_b = self.wkv_b.weight if self.wkv_b.scale is None else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
wkv_b = (
self.wkv_b.weight
if self.wkv_b.scale is None
else weight_dequant(self.wkv_b.weight, self.wkv_b.scale, block_size)
)
wkv_b = wkv_b.view(self.n_local_heads, -1, self.kv_lora_rank)
q_nope = torch.einsum("bshd,hdc->bshc", q_nope, wkv_b[:, :self.qk_nope_head_dim])
q_nope = torch.einsum(
"bshd,hdc->bshc", q_nope, wkv_b[:, : self.qk_nope_head_dim]
)
self.kv_cache[:bsz, start_pos:end_pos] = self.kv_norm(kv)
self.pe_cache[:bsz, start_pos:end_pos] = k_pe.squeeze(2)
scores = (torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos]) +
torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])) * self.softmax_scale
scores = (
torch.einsum("bshc,btc->bsht", q_nope, self.kv_cache[:bsz, :end_pos])
+ torch.einsum("bshr,btr->bsht", q_pe, self.pe_cache[:bsz, :end_pos])
) * self.softmax_scale
if mask is not None:
scores += mask.unsqueeze(1)
scores = scores.softmax(dim=-1, dtype=torch.float32).type_as(x)
@ -489,7 +570,7 @@ class MLA(nn.Module):
x = torch.einsum("bsht,bthd->bshd", scores, self.v_cache[:bsz, :end_pos])
else:
x = torch.einsum("bsht,btc->bshc", scores, self.kv_cache[:bsz, :end_pos])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim:])
x = torch.einsum("bshc,hdc->bshd", x, wkv_b[:, -self.v_head_dim :])
x = self.wo(x.flatten(2))
return x
@ -503,6 +584,7 @@ class MLP(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
"""
Initializes the MLP layer.
@ -543,6 +625,7 @@ class Gate(nn.Module):
weight (torch.nn.Parameter): Learnable weights for the gate.
bias (Optional[torch.nn.Parameter]): Optional bias term for the gate.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Gate module.
@ -558,7 +641,11 @@ class Gate(nn.Module):
self.score_func = args.score_func
self.route_scale = args.route_scale
self.weight = nn.Parameter(torch.empty(args.n_routed_experts, args.dim))
self.bias = nn.Parameter(torch.empty(args.n_routed_experts)) if self.dim == 7168 else None
self.bias = (
nn.Parameter(torch.empty(args.n_routed_experts))
if self.dim == 7168
else None
)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""
@ -604,6 +691,7 @@ class Expert(nn.Module):
w2 (nn.Module): Linear layer for hidden-to-output transformation.
w3 (nn.Module): Additional linear layer for feature transformation.
"""
def __init__(self, dim: int, inter_dim: int):
"""
Initializes the Expert layer.
@ -643,6 +731,7 @@ class MoE(nn.Module):
experts (nn.ModuleList): List of expert modules.
shared_experts (nn.Module): Shared experts applied to all inputs.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the MoE module.
@ -659,8 +748,16 @@ class MoE(nn.Module):
self.experts_start_idx = rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.gate = Gate(args)
self.experts = nn.ModuleList([Expert(args.dim, args.moe_inter_dim) if self.experts_start_idx <= i < self.experts_end_idx else None
for i in range(self.n_routed_experts)])
self.experts = nn.ModuleList(
[
(
Expert(args.dim, args.moe_inter_dim)
if self.experts_start_idx <= i < self.experts_end_idx
else None
)
for i in range(self.n_routed_experts)
]
)
self.shared_experts = MLP(args.dim, args.n_shared_experts * args.moe_inter_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
@ -677,7 +774,9 @@ class MoE(nn.Module):
x = x.view(-1, self.dim)
weights, indices = self.gate(x)
y = torch.zeros_like(x)
counts = torch.bincount(indices.flatten(), minlength=self.n_routed_experts).tolist()
counts = torch.bincount(
indices.flatten(), minlength=self.n_routed_experts
).tolist()
for i in range(self.experts_start_idx, self.experts_end_idx):
if counts[i] == 0:
continue
@ -700,6 +799,7 @@ class Block(nn.Module):
attn_norm (nn.Module): Layer normalization for attention.
ffn_norm (nn.Module): Layer normalization for feed-forward network.
"""
def __init__(self, layer_id: int, args: ModelArgs):
"""
Initializes the Transformer block.
@ -710,11 +810,21 @@ class Block(nn.Module):
"""
super().__init__()
self.attn = MLA(args)
self.ffn = MLP(args.dim, args.inter_dim) if layer_id < args.n_dense_layers else MoE(args)
self.ffn = (
MLP(args.dim, args.inter_dim)
if layer_id < args.n_dense_layers
else MoE(args)
)
self.attn_norm = RMSNorm(args.dim)
self.ffn_norm = RMSNorm(args.dim)
def forward(self, x: torch.Tensor, start_pos: int, freqs_cis: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor:
def forward(
self,
x: torch.Tensor,
start_pos: int,
freqs_cis: torch.Tensor,
mask: Optional[torch.Tensor],
) -> torch.Tensor:
"""
Forward pass for the Transformer block.
@ -744,6 +854,7 @@ class Transformer(nn.Module):
head (nn.Module): Output projection layer mapping to vocabulary size.
freqs_cis (torch.Tensor): Precomputed complex exponential values for rotary embeddings.
"""
def __init__(self, args: ModelArgs):
"""
Initializes the Transformer model.
@ -762,7 +873,9 @@ class Transformer(nn.Module):
for layer_id in range(args.n_layers):
self.layers.append(Block(layer_id, args))
self.norm = RMSNorm(args.dim)
self.head = ColumnParallelLinear(args.dim, args.vocab_size, dtype=torch.get_default_dtype())
self.head = ColumnParallelLinear(
args.dim, args.vocab_size, dtype=torch.get_default_dtype()
)
self.register_buffer("freqs_cis", precompute_freqs_cis(args), persistent=False)
@torch.inference_mode()
@ -779,10 +892,12 @@ class Transformer(nn.Module):
"""
seqlen = tokens.size(1)
h = self.embed(tokens)
freqs_cis = self.freqs_cis[start_pos:start_pos+seqlen]
freqs_cis = self.freqs_cis[start_pos : start_pos + seqlen]
mask = None
if seqlen > 1:
mask = torch.full((seqlen, seqlen), float("-inf"), device=tokens.device).triu_(1)
mask = torch.full(
(seqlen, seqlen), float("-inf"), device=tokens.device
).triu_(1)
for layer in self.layers:
h = layer(h, start_pos, freqs_cis, mask)
h = self.norm(h)[:, -1]
@ -795,6 +910,7 @@ class Transformer(nn.Module):
if __name__ == "__main__":
try:
torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
torch.manual_seed(0)
@ -802,3 +918,5 @@ if __name__ == "__main__":
x = torch.randint(0, args.vocab_size, (2, 128))
model = Transformer(args)
print(model(x).size())
except Exception as e:
print(f"Error during model execution: {str(e)}")