mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-02-22 21:58:58 -05:00
Merge branch 'main' into refactor/codebase
This commit is contained in:
commit
6bb22e0c15
20
README.md
20
README.md
@ -44,6 +44,18 @@
|
|||||||
<a href="DeepSeek_V3.pdf"><b>Paper Link</b>👁️</a>
|
<a href="DeepSeek_V3.pdf"><b>Paper Link</b>👁️</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
1. [Introduction](#1-introduction)
|
||||||
|
2. [Model Summary](#2-model-summary)
|
||||||
|
3. [Model Downloads](#3-model-downloads)
|
||||||
|
4. [Evaluation Results](#4-evaluation-results)
|
||||||
|
5. [Chat Website & API Platform](#5-chat-website--api-platform)
|
||||||
|
6. [How to Run Locally](#6-how-to-run-locally)
|
||||||
|
7. [License](#7-license)
|
||||||
|
8. [Citation](#8-citation)
|
||||||
|
9. [Contact](#9-contact)
|
||||||
|
|
||||||
|
|
||||||
## 1. Introduction
|
## 1. Introduction
|
||||||
|
|
||||||
@ -99,7 +111,7 @@ Throughout the entire training process, we did not experience any irrecoverable
|
|||||||
</div>
|
</div>
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.**
|
> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.
|
||||||
|
|
||||||
To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally).
|
To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally).
|
||||||
|
|
||||||
@ -130,7 +142,7 @@ For developers looking to dive deeper, we recommend exploring [README_WEIGHTS.md
|
|||||||
| | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 |
|
| | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 |
|
||||||
| | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 |
|
| | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 |
|
||||||
| | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 |
|
| | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 |
|
||||||
| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | **82.7** | **82.9** |
|
| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | 82.7 | **82.9** |
|
||||||
| | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 |
|
| | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 |
|
||||||
| | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** |
|
| | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** |
|
||||||
| Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** |
|
| Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** |
|
||||||
@ -249,7 +261,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h
|
|||||||
```
|
```
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> Hugging Face's Transformers has not been directly supported yet.**
|
> Hugging Face's Transformers has not been directly supported yet.
|
||||||
|
|
||||||
### 6.1 Inference with DeepSeek-Infer Demo (example only)
|
### 6.1 Inference with DeepSeek-Infer Demo (example only)
|
||||||
|
|
||||||
@ -259,7 +271,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h
|
|||||||
> Linux with Python 3.10 only. Mac and Windows are not supported.
|
> Linux with Python 3.10 only. Mac and Windows are not supported.
|
||||||
|
|
||||||
Dependencies:
|
Dependencies:
|
||||||
```
|
```pip-requirements
|
||||||
torch==2.4.1
|
torch==2.4.1
|
||||||
triton==3.0.0
|
triton==3.0.0
|
||||||
transformers==4.46.3
|
transformers==4.46.3
|
||||||
|
@ -59,7 +59,8 @@ class TextGenerator:
|
|||||||
List[List[int]]: Generated tokens for each sequence.
|
List[List[int]]: Generated tokens for each sequence.
|
||||||
"""
|
"""
|
||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= self.model.max_seq_len
|
if max(prompt_lens) > self.model.max_seq_len:
|
||||||
|
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
|
||||||
|
|
||||||
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
|
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
|
||||||
tokens = self._initialize_tokens(prompt_tokens, total_len)
|
tokens = self._initialize_tokens(prompt_tokens, total_len)
|
||||||
@ -193,7 +194,9 @@ class ChatSession:
|
|||||||
def run_batch(self, input_file: str):
|
def run_batch(self, input_file: str):
|
||||||
with open(input_file) as f:
|
with open(input_file) as f:
|
||||||
prompts = [line.strip() for line in f.readlines()]
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
assert len(prompts) <= self.generator.model.args.max_batch_size
|
|
||||||
|
if len(prompts) > self.generator.model.args.max_batch_size:
|
||||||
|
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
|
||||||
|
|
||||||
completions = self._process_batch(prompts)
|
completions = self._process_batch(prompts)
|
||||||
for prompt, completion in zip(prompts, completions):
|
for prompt, completion in zip(prompts, completions):
|
||||||
@ -302,7 +305,9 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
assert args.input_file or args.interactive
|
if not args.input_file and not args.interactive:
|
||||||
|
raise ValueError("Either input-file or interactive mode must be specified")
|
||||||
|
|
||||||
main(
|
main(
|
||||||
args.ckpt_path,
|
args.ckpt_path,
|
||||||
args.config,
|
args.config,
|
||||||
|
@ -96,7 +96,7 @@ class ParallelEmbedding(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.vocab_size = vocab_size
|
self.vocab_size = vocab_size
|
||||||
self.dim = dim
|
self.dim = dim
|
||||||
assert vocab_size % world_size == 0
|
assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})"
|
||||||
self.part_vocab_size = (vocab_size // world_size)
|
self.part_vocab_size = (vocab_size // world_size)
|
||||||
self.vocab_start_idx = rank * self.part_vocab_size
|
self.vocab_start_idx = rank * self.part_vocab_size
|
||||||
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size
|
||||||
@ -143,7 +143,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] =
|
|||||||
quantization-aware computations depending on the input parameters.
|
quantization-aware computations depending on the input parameters.
|
||||||
|
|
||||||
Notes:
|
Notes:
|
||||||
- If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version
|
- If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version
|
||||||
is used for computation.
|
is used for computation.
|
||||||
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
- If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied.
|
||||||
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
- For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation.
|
||||||
@ -185,7 +185,7 @@ class Linear(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.register_parameter("scale", None)
|
self.register_parameter("scale", None)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = nn.Parameter(torch.empty(self.part_out_features))
|
self.bias = nn.Parameter(torch.empty(out_features))
|
||||||
else:
|
else:
|
||||||
self.register_parameter("bias", None)
|
self.register_parameter("bias", None)
|
||||||
|
|
||||||
@ -213,7 +213,7 @@ class ColumnParallelLinear(Linear):
|
|||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||||
assert out_features % world_size == 0
|
assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_out_features = out_features // world_size
|
self.part_out_features = out_features // world_size
|
||||||
super().__init__(in_features, self.part_out_features, bias, dtype)
|
super().__init__(in_features, self.part_out_features, bias, dtype)
|
||||||
|
|
||||||
@ -242,7 +242,7 @@ class RowParallelLinear(Linear):
|
|||||||
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None):
|
||||||
assert in_features % world_size == 0
|
assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})"
|
||||||
self.part_in_features = in_features // world_size
|
self.part_in_features = in_features // world_size
|
||||||
super().__init__(self.part_in_features, out_features, bias, dtype)
|
super().__init__(self.part_in_features, out_features, bias, dtype)
|
||||||
|
|
||||||
@ -652,7 +652,7 @@ class MoE(nn.Module):
|
|||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
assert args.n_routed_experts % world_size == 0
|
assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})"
|
||||||
self.n_routed_experts = args.n_routed_experts
|
self.n_routed_experts = args.n_routed_experts
|
||||||
self.n_local_experts = args.n_routed_experts // world_size
|
self.n_local_experts = args.n_routed_experts // world_size
|
||||||
self.n_activated_experts = args.n_activated_experts
|
self.n_activated_experts = args.n_activated_experts
|
||||||
|
Loading…
Reference in New Issue
Block a user