mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-23 14:18:57 -05:00
Merge branch 'main' of github.com:XxAlonexX/DeepSeek-V3
This commit is contained in:
commit
f8b7c3b6e7
30
.github/workflows/stale.yml
vendored
Normal file
30
.github/workflows/stale.yml
vendored
Normal file
@ -0,0 +1,30 @@
|
|||||||
|
name: "Mark and close stale issues"
|
||||||
|
on:
|
||||||
|
workflow_dispatch:
|
||||||
|
schedule:
|
||||||
|
- cron: "0 0 * * *"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
stale:
|
||||||
|
if: ${{ github.repository == 'deepseek-ai/DeepSeek-V3' }}
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: "Mark and close stale issues"
|
||||||
|
uses: actions/stale@v9
|
||||||
|
with:
|
||||||
|
days-before-issue-stale: 30
|
||||||
|
days-before-issue-close: 14
|
||||||
|
stale-issue-label: "stale"
|
||||||
|
close-issue-label: "closed-as-stale"
|
||||||
|
exempt-issue-labels: |
|
||||||
|
pinned
|
||||||
|
security
|
||||||
|
stale-issue-message: >
|
||||||
|
This issue has been automatically marked as stale because it has not had
|
||||||
|
recent activity. It will be closed if no further activity occurs. If you
|
||||||
|
believe this issue is still relevant, please leave a comment to keep it open.
|
||||||
|
Thank you for your contributions!
|
||||||
|
close-issue-message: false
|
||||||
|
days-before-pr-stale: -1
|
||||||
|
days-before-pr-close: -1
|
||||||
|
repo-token: ${{ secrets.GITHUB_TOKEN }}
|
75
README.md
75
README.md
@ -7,42 +7,39 @@
|
|||||||
</div>
|
</div>
|
||||||
<hr>
|
<hr>
|
||||||
<div align="center" style="line-height: 1;">
|
<div align="center" style="line-height: 1;">
|
||||||
<a href="https://www.deepseek.com/" target="_blank" style="margin: 2px;">
|
<a href="https://www.deepseek.com/"><img alt="Homepage"
|
||||||
<img alt="Homepage" src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true" style="display: inline-block; vertical-align: middle;"/>
|
src="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/badge.svg?raw=true"/></a>
|
||||||
</a>
|
<a href="https://chat.deepseek.com/"><img alt="Chat"
|
||||||
<a href="https://chat.deepseek.com/" target="_blank" style="margin: 2px;">
|
src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20V3-536af5?color=536af5&logoColor=white"/></a>
|
||||||
<img alt="Chat" src="https://img.shields.io/badge/🤖%20Chat-DeepSeek%20V3-536af5?color=536af5&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
<a href="https://huggingface.co/deepseek-ai"><img alt="Hugging Face"
|
||||||
</a>
|
src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white"/></a>
|
||||||
<a href="https://huggingface.co/deepseek-ai" target="_blank" style="margin: 2px;">
|
<br>
|
||||||
<img alt="Hugging Face" src="https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-DeepSeek%20AI-ffc107?color=ffc107&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
<a href="https://discord.gg/Tc7c45Zzu5"><img alt="Discord"
|
||||||
</a>
|
src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da"/></a>
|
||||||
</div>
|
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/qr.jpeg?raw=true"><img alt="Wechat"
|
||||||
|
src="https://img.shields.io/badge/WeChat-DeepSeek%20AI-brightgreen?logo=wechat&logoColor=white"/></a>
|
||||||
<div align="center" style="line-height: 1;">
|
<a href="https://twitter.com/deepseek_ai"><img alt="Twitter Follow"
|
||||||
<a href="https://discord.gg/Tc7c45Zzu5" target="_blank" style="margin: 2px;">
|
src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white"/></a>
|
||||||
<img alt="Discord" src="https://img.shields.io/badge/Discord-DeepSeek%20AI-7289da?logo=discord&logoColor=white&color=7289da" style="display: inline-block; vertical-align: middle;"/>
|
<br>
|
||||||
</a>
|
<a href="https://github.com/deepseek-ai/DeepSeek-V3/blob/main/LICENSE-CODE"><img alt="Code License"
|
||||||
<a href="https://github.com/deepseek-ai/DeepSeek-V2/blob/main/figures/qr.jpeg?raw=true" target="_blank" style="margin: 2px;">
|
src="https://img.shields.io/badge/Code_License-MIT-f5de53?&color=f5de53"/></a>
|
||||||
<img alt="Wechat" src="https://img.shields.io/badge/WeChat-DeepSeek%20AI-brightgreen?logo=wechat&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
<a href="https://github.com/deepseek-ai/DeepSeek-V3/blob/main/LICENSE-MODEL"><img alt="Model License"
|
||||||
</a>
|
src="https://img.shields.io/badge/Model_License-Model_Agreement-f5de53?&color=f5de53"/></a>
|
||||||
<a href="https://twitter.com/deepseek_ai" target="_blank" style="margin: 2px;">
|
<br>
|
||||||
<img alt="Twitter Follow" src="https://img.shields.io/badge/Twitter-deepseek_ai-white?logo=x&logoColor=white" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div align="center" style="line-height: 1;">
|
|
||||||
<a href="https://github.com/deepseek-ai/DeepSeek-V3/blob/main/LICENSE-CODE" style="margin: 2px;">
|
|
||||||
<img alt="Code License" src="https://img.shields.io/badge/Code_License-MIT-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
|
||||||
<a href="https://github.com/deepseek-ai/DeepSeek-V3/blob/main/LICENSE-MODEL" style="margin: 2px;">
|
|
||||||
<img alt="Model License" src="https://img.shields.io/badge/Model_License-Model_Agreement-f5de53?&color=f5de53" style="display: inline-block; vertical-align: middle;"/>
|
|
||||||
</a>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
|
|
||||||
<p align="center">
|
|
||||||
<a href="DeepSeek_V3.pdf"><b>Paper Link</b>👁️</a>
|
<a href="DeepSeek_V3.pdf"><b>Paper Link</b>👁️</a>
|
||||||
</p>
|
</div>
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [Introduction](#1-introduction)
|
||||||
|
2. [Model Summary](#2-model-summary)
|
||||||
|
3. [Model Downloads](#3-model-downloads)
|
||||||
|
4. [Evaluation Results](#4-evaluation-results)
|
||||||
|
5. [Chat Website & API Platform](#5-chat-website--api-platform)
|
||||||
|
6. [How to Run Locally](#6-how-to-run-locally)
|
||||||
|
7. [License](#7-license)
|
||||||
|
8. [Citation](#8-citation)
|
||||||
|
9. [Contact](#9-contact)
|
||||||
|
|
||||||
|
|
||||||
## 1. Introduction
|
## 1. Introduction
|
||||||
@ -99,7 +96,7 @@ Throughout the entire training process, we did not experience any irrecoverable
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.**
|
> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.
|
||||||
|
|
||||||
To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally).
|
To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally).
|
||||||
|
|
||||||
@ -130,7 +127,7 @@ For developers looking to dive deeper, we recommend exploring [README_WEIGHTS.md
|
|||||||
| | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 |
|
| | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 |
|
||||||
| | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 |
|
| | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 |
|
||||||
| | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 |
|
| | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 |
|
||||||
| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | **82.7** | **82.9** |
|
| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | 82.7 | **82.9** |
|
||||||
| | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 |
|
| | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 |
|
||||||
| | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** |
|
| | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** |
|
||||||
| Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** |
|
| Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** |
|
||||||
@ -249,7 +246,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Hugging Face's Transformers has not been directly supported yet.**
|
> Hugging Face's Transformers has not been directly supported yet.
|
||||||
|
|
||||||
### 6.1 Inference with DeepSeek-Infer Demo (example only)
|
### 6.1 Inference with DeepSeek-Infer Demo (example only)
|
||||||
|
|
||||||
@ -259,7 +256,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h
|
|||||||
> Linux with Python 3.10 only. Mac and Windows are not supported.
|
> Linux with Python 3.10 only. Mac and Windows are not supported.
|
||||||
|
|
||||||
Dependencies:
|
Dependencies:
|
||||||
```
|
```pip-requirements
|
||||||
torch==2.4.1
|
torch==2.4.1
|
||||||
triton==3.0.0
|
triton==3.0.0
|
||||||
transformers==4.46.3
|
transformers==4.46.3
|
||||||
|
@ -60,7 +60,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
|||||||
name = name.replace("weight_scale_inv", "scale")
|
name = name.replace("weight_scale_inv", "scale")
|
||||||
name = name.replace("e_score_correction_bias", "bias")
|
name = name.replace("e_score_correction_bias", "bias")
|
||||||
key = name.split(".")[-2]
|
key = name.split(".")[-2]
|
||||||
assert key in mapping
|
assert key in mapping, f"Key {key} not found in mapping"
|
||||||
new_key, dim = mapping[key]
|
new_key, dim = mapping[key]
|
||||||
name = name.replace(key, new_key)
|
name = name.replace(key, new_key)
|
||||||
for i in range(mp):
|
for i in range(mp):
|
||||||
@ -70,7 +70,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp):
|
|||||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
||||||
continue
|
continue
|
||||||
elif dim is not None:
|
elif dim is not None:
|
||||||
assert param.size(dim) % mp == 0
|
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
||||||
shard_size = param.size(dim) // mp
|
shard_size = param.size(dim) // mp
|
||||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
||||||
state_dicts[i][name] = new_param
|
state_dicts[i][name] = new_param
|
||||||
@ -92,5 +92,5 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--n-experts", type=int, required=True)
|
parser.add_argument("--n-experts", type=int, required=True)
|
||||||
parser.add_argument("--model-parallel", type=int, required=True)
|
parser.add_argument("--model-parallel", type=int, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.n_experts % args.model_parallel == 0
|
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
||||||
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
||||||
|
@ -49,7 +49,7 @@ def generate(
|
|||||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||||
"""
|
"""
|
||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= model.max_seq_len
|
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
||||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
||||||
for i, t in enumerate(prompt_tokens):
|
for i, t in enumerate(prompt_tokens):
|
||||||
@ -145,7 +145,7 @@ def main(
|
|||||||
else:
|
else:
|
||||||
with open(input_file) as f:
|
with open(input_file) as f:
|
||||||
prompts = [line.strip() for line in f.readlines()]
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
assert len(prompts) <= args.max_batch_size
|
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
||||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
||||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||||
@ -181,5 +181,5 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.input_file or args.interactive
|
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
||||||
|
@ -43,8 +43,8 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor
|
|||||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
||||||
- A tensor of scaling factors with dtype `torch.float32`.
|
- A tensor of scaling factors with dtype `torch.float32`.
|
||||||
"""
|
"""
|
||||||
assert x.is_contiguous()
|
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
||||||
assert x.size(-1) % block_size == 0
|
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
||||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
||||||
@ -96,8 +96,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t
|
|||||||
Raises:
|
Raises:
|
||||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
||||||
"""
|
"""
|
||||||
assert x.is_contiguous() and s.is_contiguous()
|
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
||||||
assert x.dim() == 2 and s.dim() == 2
|
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
||||||
M, N = x.size()
|
M, N = x.size()
|
||||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
||||||
@ -180,8 +180,8 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten
|
|||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The result of the matrix multiplication.
|
torch.Tensor: The result of the matrix multiplication.
|
||||||
"""
|
"""
|
||||||
assert a.is_contiguous() and b.is_contiguous()
|
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
||||||
assert a_s.is_contiguous() and b_s.is_contiguous()
|
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
||||||
K = a.size(-1)
|
K = a.size(-1)
|
||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(0)
|
||||||
|
@ -89,7 +89,7 @@ class ParallelEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
assert vocab_size % world_size == 0
|
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
||||||
self.part_vocab_size = (vocab_size // world_size)
|
self.part_vocab_size = (vocab_size // world_size)
|
||||||
self.vocab_start_idx = rank * self.part_vocab_size
|
self.vocab_start_idx = rank * self.part_vocab_size
|
||||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||||
@ -124,7 +124,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
|||||||
quantization-aware computations depending on the input parameters.
|
quantization-aware computations depending on the input parameters.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
||||||
is used for computation.
|
is used for computation.
|
||||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||||
@ -176,7 +176,7 @@ class ColumnParallelLinear(Linear):
|
|||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||||
assert out_features % world_size == 0
|
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_out_features = out_features // world_size
|
self.part_out_features = out_features // world_size
|
||||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||||
|
|
||||||
@ -205,7 +205,7 @@ class RowParallelLinear(Linear):
|
|||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||||
assert in_features % world_size == 0
|
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_in_features = in_features // world_size
|
self.part_in_features = in_features // world_size
|
||||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||||
|
|
||||||
@ -566,8 +566,8 @@ class Gate(nn.Module):
|
|||||||
else:
|
else:
|
||||||
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1)
|
||||||
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
indices = group_scores.topk(self.topk_groups, dim=-1)[1]
|
||||||
mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True)
|
mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False)
|
||||||
scores = (scores * mask.unsqueeze(-1)).flatten(1)
|
scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1)
|
||||||
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
indices = torch.topk(scores, self.topk, dim=-1)[1]
|
||||||
weights = original_scores.gather(1, indices)
|
weights = original_scores.gather(1, indices)
|
||||||
if self.score_func == "sigmoid":
|
if self.score_func == "sigmoid":
|
||||||
@ -633,7 +633,7 @@ class MoE(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
assert args.n_routed_experts % world_size == 0
|
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
||||||
self.n_routed_experts = args.n_routed_experts
|
self.n_routed_experts = args.n_routed_experts
|
||||||
self.n_local_experts = args.n_routed_experts // world_size
|
self.n_local_experts = args.n_routed_experts // world_size
|
||||||
self.n_activated_experts = args.n_activated_experts
|
self.n_activated_experts = args.n_activated_experts
|
||||||
|
Loading…
Reference in New Issue
Block a user