From ddc501b80e812f05a4237514c592f286554f2a11 Mon Sep 17 00:00:00 2001 From: Dhieu Date: Mon, 27 Jan 2025 14:18:17 +0300 Subject: [PATCH 01/12] Add table of contents to README --- README.md | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/README.md b/README.md index 7ecf87e..ccd62cc 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,17 @@ Paper Link👁️

+## Table of Contents + +1. [Introduction](## 1. Introduction) +2. [Model Summary](## 2. Model Summary) +3. [Model Downloads](## 3. Model Downloads) +4. [Evaluation Results](## 4. Evaluation Results) +5. [Chat Website & API Platform](## 5. Chat Website & API Platform) +6. [How to Run Locally](## 6. How to Run Locally) +7. [License](## 7. License) +8. [Citation](## 8. Citation) +9. [Contact](## 9. Contact) ## 1. Introduction From 2756e130c2430eedc916ca331f5e360b519ed7ab Mon Sep 17 00:00:00 2001 From: Roman Fitzjalen Date: Tue, 28 Jan 2025 13:16:54 +0100 Subject: [PATCH 02/12] clarify assertion error --- inference/convert.py | 6 +++--- inference/generate.py | 6 +++--- inference/kernel.py | 12 ++++++------ inference/model.py | 8 ++++---- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/inference/convert.py b/inference/convert.py index c606ce8..6d85ccc 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -60,7 +60,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): name = name.replace("weight_scale_inv", "scale") name = name.replace("e_score_correction_bias", "bias") key = name.split(".")[-2] - assert key in mapping + assert key in mapping, f"Key {key} not found in mapping" new_key, dim = mapping[key] name = name.replace(key, new_key) for i in range(mp): @@ -70,7 +70,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: continue elif dim is not None: - assert param.size(dim) % mp == 0 + assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" shard_size = param.size(dim) // mp new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() state_dicts[i][name] = new_param @@ -92,5 +92,5 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..7e9bffe 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -49,7 +49,7 @@ def generate( List[List[int]]: A list of lists containing the generated tokens for each sequence. """ prompt_lens = [len(t) for t in prompt_tokens] - assert max(prompt_lens) <= model.max_seq_len + assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") for i, t in enumerate(prompt_tokens): @@ -145,7 +145,7 @@ def main( else: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] - assert len(prompts) <= args.max_batch_size + assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) @@ -181,5 +181,5 @@ if __name__ == "__main__": parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() - assert args.input_file or args.interactive + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..ae907ad 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -43,8 +43,8 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. """ - assert x.is_contiguous() - assert x.size(-1) % block_size == 0 + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) @@ -96,8 +96,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous() - assert x.dim() == 2 and s.dim() == 2 + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) @@ -180,8 +180,8 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten Returns: torch.Tensor: The result of the matrix multiplication. """ - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K N = b.size(0) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..2afaba9 100644 --- a/inference/model.py +++ b/inference/model.py @@ -96,7 +96,7 @@ class ParallelEmbedding(nn.Module): super().__init__() self.vocab_size = vocab_size self.dim = dim - assert vocab_size % world_size == 0 + assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size @@ -213,7 +213,7 @@ class ColumnParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert out_features % world_size == 0 + assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) @@ -242,7 +242,7 @@ class RowParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert in_features % world_size == 0 + assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) @@ -652,7 +652,7 @@ class MoE(nn.Module): """ super().__init__() self.dim = args.dim - assert args.n_routed_experts % world_size == 0 + assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts From 6784e1976df5e5ebdff031523384807b992992f4 Mon Sep 17 00:00:00 2001 From: Dhieu Date: Tue, 28 Jan 2025 17:14:35 +0300 Subject: [PATCH 03/12] Fix TOC links to correctly link to headings in Markdown --- README.md | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index ccd62cc..ba4f5e4 100644 --- a/README.md +++ b/README.md @@ -46,15 +46,16 @@ ## Table of Contents -1. [Introduction](## 1. Introduction) -2. [Model Summary](## 2. Model Summary) -3. [Model Downloads](## 3. Model Downloads) -4. [Evaluation Results](## 4. Evaluation Results) -5. [Chat Website & API Platform](## 5. Chat Website & API Platform) -6. [How to Run Locally](## 6. How to Run Locally) -7. [License](## 7. License) -8. [Citation](## 8. Citation) -9. [Contact](## 9. Contact) +1. [Introduction](#1-introduction) +2. [Model Summary](#2-model-summary) +3. [Model Downloads](#3-model-downloads) +4. [Evaluation Results](#4-evaluation-results) +5. [Chat Website & API Platform](#5-chat-website--api-platform) +6. [How to Run Locally](#6-how-to-run-locally) +7. [License](#7-license) +8. [Citation](#8-citation) +9. [Contact](#9-contact) + ## 1. Introduction From 760d22821fa8d019cd63fb4a986dc4c48f4bee49 Mon Sep 17 00:00:00 2001 From: Spenser Black Date: Tue, 28 Jan 2025 18:07:15 -0500 Subject: [PATCH 04/12] Add syntax highlighting to requirements code block --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ecf87e..3a5a7a2 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h > Linux with Python 3.10 only. Mac and Windows are not supported. Dependencies: -``` +```pip-requirements torch==2.4.1 triton==3.0.0 transformers==4.46.3 From d5c08b384b4d4096e494d61ba7d329f593116872 Mon Sep 17 00:00:00 2001 From: wangsl <48207171+WSL0809@users.noreply.github.com> Date: Sun, 2 Feb 2025 02:34:59 +0800 Subject: [PATCH 05/12] Update README.md fix(table): correct bold formatting for TriviaQA EM comparison - Remove redundant bolding on LLaMA3.1 405B (82.7) - Retain single bold style for DeepSeek-V3's highest score (82.9) - Aligns with evaluation convention of highlighting only the best performance --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ecf87e..3353cbd 100644 --- a/README.md +++ b/README.md @@ -130,7 +130,7 @@ For developers looking to dive deeper, we recommend exploring [README_WEIGHTS.md | | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 | | | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 | | | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 | -| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | **82.7** | **82.9** | +| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | 82.7 | **82.9** | | | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 | | | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** | | Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** | From 97b35f1fcadf435b41a835d5f4e86c8d9dc4497b Mon Sep 17 00:00:00 2001 From: luislopez-developer Date: Mon, 3 Feb 2025 15:02:04 -0500 Subject: [PATCH 06/12] docs: remove redundant asterisks in note --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 7ecf87e..0b452f5 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ Throughout the entire training process, we did not experience any irrecoverable > [!NOTE] -> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.** +> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights. To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally). @@ -249,7 +249,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h ``` > [!NOTE] -> Hugging Face's Transformers has not been directly supported yet.** +> Hugging Face's Transformers has not been directly supported yet. ### 6.1 Inference with DeepSeek-Infer Demo (example only) From 5ee97a83f0457d0d805b862aeb387358e1801e6d Mon Sep 17 00:00:00 2001 From: Xingkai Yu <38156925+GeeeekExplorer@users.noreply.github.com> Date: Fri, 7 Feb 2025 16:42:55 +0800 Subject: [PATCH 07/12] fix comment --- inference/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inference/model.py b/inference/model.py index 2ec1b20..40bbf4d 100644 --- a/inference/model.py +++ b/inference/model.py @@ -143,7 +143,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = quantization-aware computations depending on the input parameters. Notes: - - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version + - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version is used for computation. - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. From 76d8d39560032fc85c67950e29550868e128b6c5 Mon Sep 17 00:00:00 2001 From: Konano Date: Sat, 8 Feb 2025 15:12:09 +0800 Subject: [PATCH 08/12] chore: add stale issue management configuration --- .github/workflows/stale.yml | 30 ++++++++++++++++++++++++++++++ 1 file changed, 30 insertions(+) create mode 100644 .github/workflows/stale.yml diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000..22706e8 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,30 @@ +name: "Mark and close stale issues" +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + stale: + if: ${{ github.repository == 'deepseek-ai/DeepSeek-V3' }} + runs-on: ubuntu-latest + steps: + - name: "Mark and close stale issues" + uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + close-issue-label: "closed-as-stale" + exempt-issue-labels: | + pinned + security + stale-issue-message: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. If you + believe this issue is still relevant, please leave a comment to keep it open. + Thank you for your contributions! + close-issue-message: false + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} From e15f67af1ce54a8545b0a98cc43ba8e9faba647f Mon Sep 17 00:00:00 2001 From: Konano Date: Sat, 8 Feb 2025 18:28:40 +0800 Subject: [PATCH 09/12] chore: update README.md to improve layout and image attributes --- README.md | 64 ++++++++++++++++++++++++++++++++++--------------------- 1 file changed, 40 insertions(+), 24 deletions(-) diff --git a/README.md b/README.md index 318a40c..632b628 100644 --- a/README.md +++ b/README.md @@ -7,36 +7,52 @@
- - Homepage - - - Chat - - - Hugging Face - + Homepage + Chat + Hugging Face
- - Discord - - - Wechat - - - Twitter Follow - + Discord + Wechat + Twitter Follow
- - Code License - - - Model License - + Code License + Model License
From 0866cab5f9b7e26c8ad57077c3fcd16b50d855e1 Mon Sep 17 00:00:00 2001 From: Konano Date: Fri, 14 Feb 2025 12:02:10 +0800 Subject: [PATCH 10/12] chore: update README.md to improve layout and image attributes --- README.md | 73 ++++++++++++++++--------------------------------------- 1 file changed, 21 insertions(+), 52 deletions(-) diff --git a/README.md b/README.md index 632b628..6746781 100644 --- a/README.md +++ b/README.md @@ -6,59 +6,28 @@ DeepSeek-V3
-
- Homepage - Chat - Hugging Face -
- -
- Discord - Wechat - Twitter Follow -
- -
- Code License - Model License -
- - -

+

+ Homepage + Chat + Hugging Face +
+ Discord + Wechat + Twitter Follow +
+ Code License + Model License +
Paper Link👁️ -

+
## Table of Contents From f07bccc49e02a2ea9b80214a65d181958a8d554e Mon Sep 17 00:00:00 2001 From: Konano Date: Fri, 14 Feb 2025 12:12:16 +0800 Subject: [PATCH 11/12] fix: resolve center alignment issue in preview --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 6746781..9ba2346 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ DeepSeek-V3
-
+
Homepage Chat Date: Fri, 14 Feb 2025 20:26:45 +0800 Subject: [PATCH 12/12] fix scores mask --- inference/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/inference/model.py b/inference/model.py index 40bbf4d..8f1ab81 100644 --- a/inference/model.py +++ b/inference/model.py @@ -585,8 +585,8 @@ class Gate(nn.Module): else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] - mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) - scores = (scores * mask.unsqueeze(-1)).flatten(1) + mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func == "sigmoid":