diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml new file mode 100644 index 0000000..22706e8 --- /dev/null +++ b/.github/workflows/stale.yml @@ -0,0 +1,30 @@ +name: "Mark and close stale issues" +on: + workflow_dispatch: + schedule: + - cron: "0 0 * * *" + +jobs: + stale: + if: ${{ github.repository == 'deepseek-ai/DeepSeek-V3' }} + runs-on: ubuntu-latest + steps: + - name: "Mark and close stale issues" + uses: actions/stale@v9 + with: + days-before-issue-stale: 30 + days-before-issue-close: 14 + stale-issue-label: "stale" + close-issue-label: "closed-as-stale" + exempt-issue-labels: | + pinned + security + stale-issue-message: > + This issue has been automatically marked as stale because it has not had + recent activity. It will be closed if no further activity occurs. If you + believe this issue is still relevant, please leave a comment to keep it open. + Thank you for your contributions! + close-issue-message: false + days-before-pr-stale: -1 + days-before-pr-close: -1 + repo-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/CITATION.cff b/CITATION.cff deleted file mode 100644 index c5fbc0e..0000000 --- a/CITATION.cff +++ /dev/null @@ -1,215 +0,0 @@ -cff-version: 1.2.0 -message: "If you use this work, please cite it using the following metadata." -title: "DeepSeek-V3 Technical Report" -authors: - - name: "DeepSeek-AI" - - name: "Aixin Liu" - - name: "Bei Feng" - - name: "Bing Xue" - - name: "Bingxuan Wang" - - name: "Bochao Wu" - - name: "Chengda Lu" - - name: "Chenggang Zhao" - - name: "Chengqi Deng" - - name: "Chenyu Zhang" - - name: "Chong Ruan" - - name: "Damai Dai" - - name: "Daya Guo" - - name: "Dejian Yang" - - name: "Deli Chen" - - name: "Dongjie Ji" - - name: "Erhang Li" - - name: "Fangyun Lin" - - name: "Fucong Dai" - - name: "Fuli Luo" - - name: "Guangbo Hao" - - name: "Guanting Chen" - - name: "Guowei Li" - - name: "H. Zhang" - - name: "Han Bao" - - name: "Hanwei Xu" - - name: "Haocheng Wang" - - name: "Haowei Zhang" - - name: "Honghui Ding" - - name: "Huajian Xin" - - name: "Huazuo Gao" - - name: "Hui Li" - - name: "Hui Qu" - - name: "J. L. Cai" - - name: "Jian Liang" - - name: "Jianzhong Guo" - - name: "Jiaqi Ni" - - name: "Jiashi Li" - - name: "Jiawei Wang" - - name: "Jin Chen" - - name: "Jingchang Chen" - - name: "Jingyang Yuan" - - name: "Junjie Qiu" - - name: "Junlong Li" - - name: "Junxiao Song" - - name: "Kai Dong" - - name: "Kai Hu" - - name: "Kaige Gao" - - name: "Kang Guan" - - name: "Kexin Huang" - - name: "Kuai Yu" - - name: "Lean Wang" - - name: "Lecong Zhang" - - name: "Lei Xu" - - name: "Leyi Xia" - - name: "Liang Zhao" - - name: "Litong Wang" - - name: "Liyue Zhang" - - name: "Meng Li" - - name: "Miaojun Wang" - - name: "Mingchuan Zhang" - - name: "Minghua Zhang" - - name: "Minghui Tang" - - name: "Mingming Li" - - name: "Ning Tian" - - name: "Panpan Huang" - - name: "Peiyi Wang" - - name: "Peng Zhang" - - name: "Qiancheng Wang" - - name: "Qihao Zhu" - - name: "Qinyu Chen" - - name: "Qiushi Du" - - name: "R. J. Chen" - - name: "R. L. Jin" - - name: "Ruiqi Ge" - - name: "Ruisong Zhang" - - name: "Ruizhe Pan" - - name: "Runji Wang" - - name: "Runxin Xu" - - name: "Ruoyu Zhang" - - name: "Ruyi Chen" - - name: "S. S. Li" - - name: "Shanghao Lu" - - name: "Shangyan Zhou" - - name: "Shanhuang Chen" - - name: "Shaoqing Wu" - - name: "Shengfeng Ye" - - name: "Shirong Ma" - - name: "Shiyu Wang" - - name: "Shuang Zhou" - - name: "Shuiping Yu" - - name: "Shunfeng Zhou" - - name: "Shuting Pan" - - name: "T. Wang" - - name: "Tao Yun" - - name: "Tian Pei" - - name: "Tianyu Sun" - - name: "W. L. Xiao" - - name: "Wangding Zeng" - - name: "Wanjia Zhao" - - name: "Wei An" - - name: "Wen Liu" - - name: "Wenfeng Liang" - - name: "Wenjun Gao" - - name: "Wenqin Yu" - - name: "Wentao Zhang" - - name: "X. Q. Li" - - name: "Xiangyue Jin" - - name: "Xianzu Wang" - - name: "Xiao Bi" - - name: "Xiaodong Liu" - - name: "Xiaohan Wang" - - name: "Xiaojin Shen" - - name: "Xiaokang Chen" - - name: "Xiaokang Zhang" - - name: "Xiaosha Chen" - - name: "Xiaotao Nie" - - name: "Xiaowen Sun" - - name: "Xiaoxiang Wang" - - name: "Xin Cheng" - - name: "Xin Liu" - - name: "Xin Xie" - - name: "Xingchao Liu" - - name: "Xingkai Yu" - - name: "Xinnan Song" - - name: "Xinxia Shan" - - name: "Xinyi Zhou" - - name: "Xinyu Yang" - - name: "Xinyuan Li" - - name: "Xuecheng Su" - - name: "Xuheng Lin" - - name: "Y. K. Li" - - name: "Y. Q. Wang" - - name: "Y. X. Wei" - - name: "Y. X. Zhu" - - name: "Yang Zhang" - - name: "Yanhong Xu" - - name: "Yanping Huang" - - name: "Yao Li" - - name: "Yao Zhao" - - name: "Yaofeng Sun" - - name: "Yaohui Li" - - name: "Yaohui Wang" - - name: "Yi Yu" - - name: "Yi Zheng" - - name: "Yichao Zhang" - - name: "Yifan Shi" - - name: "Yiliang Xiong" - - name: "Ying He" - - name: "Ying Tang" - - name: "Yishi Piao" - - name: "Yisong Wang" - - name: "Yixuan Tan" - - name: "Yiyang Ma" - - name: "Yiyuan Liu" - - name: "Yongqiang Guo" - - name: "Yu Wu" - - name: "Yuan Ou" - - name: "Yuchen Zhu" - - name: "Yuduan Wang" - - name: "Yue Gong" - - name: "Yuheng Zou" - - name: "Yujia He" - - name: "Yukun Zha" - - name: "Yunfan Xiong" - - name: "Yunxian Ma" - - name: "Yuting Yan" - - name: "Yuxiang Luo" - - name: "Yuxiang You" - - name: "Yuxuan Liu" - - name: "Yuyang Zhou" - - name: "Z. F. Wu" - - name: "Z. Z. Ren" - - name: "Zehui Ren" - - name: "Zhangli Sha" - - name: "Zhe Fu" - - name: "Zhean Xu" - - name: "Zhen Huang" - - name: "Zhen Zhang" - - name: "Zhenda Xie" - - name: "Zhengyan Zhang" - - name: "Zhewen Hao" - - name: "Zhibin Gou" - - name: "Zhicheng Ma" - - name: "Zhigang Yan" - - name: "Zhihong Shao" - - name: "Zhipeng Xu" - - name: "Zhiyu Wu" - - name: "Zhongyu Zhang" - - name: "Zhuoshu Li" - - name: "Zihui Gu" - - name: "Zijia Zhu" - - name: "Zijun Liu" - - name: "Zilin Li" - - name: "Ziwei Xie" - - name: "Ziyang Song" - - name: "Ziyi Gao" - - name: "Zizheng Pan" -year: 2024 -identifiers: - - type: doi - value: 10.48550/arXiv.2412.19437 - - type: arXiv - value: 2412.19437 -url: "https://arxiv.org/abs/2412.19437" -categories: - - "cs.CL" -repository-code: "https://github.com/deepseek-ai/DeepSeek-V3" -license: "MIT" -abstract: > - We present DeepSeek-V3, a strong Mixture-of-Experts (MoE) language model with 671B total parameters with 37B activated for each token. To achieve efficient inference and cost-effective training, DeepSeek-V3 adopts Multi-head Latent Attention (MLA) and DeepSeekMoE architectures, which were thoroughly validated in DeepSeek-V2. Furthermore, DeepSeek-V3 pioneers an auxiliary-loss-free strategy for load balancing and sets a multi-token prediction training objective for stronger performance. We pre-train DeepSeek-V3 on 14.8 trillion diverse and high-quality tokens, followed by Supervised Fine-Tuning and Reinforcement Learning stages to fully harness its capabilities. Comprehensive evaluations reveal that DeepSeek-V3 outperforms other open-source models and achieves performance comparable to leading closed-source models. Despite its excellent performance, DeepSeek-V3 requires only 2.788M H800 GPU hours for its full training. In addition, its training process is remarkably stable. Throughout the entire training process, we did not experience any irrecoverable loss spikes or perform any rollbacks. \ No newline at end of file diff --git a/README.md b/README.md index 7ecf87e..b1fdbef 100644 --- a/README.md +++ b/README.md @@ -7,42 +7,39 @@
- - Homepage - - - Chat - - - Hugging Face - -
- -
- - Discord - - - Wechat - - - Twitter Follow - -
- -
- - Code License - - - Model License - -
- - -

+ Homepage + Chat + Hugging Face +
+ Discord + Wechat + Twitter Follow +
+ Code License + Model License +
Paper Link👁️ -

+ + +## Table of Contents + +1. [Introduction](#1-introduction) +2. [Model Summary](#2-model-summary) +3. [Model Downloads](#3-model-downloads) +4. [Evaluation Results](#4-evaluation-results) +5. [Chat Website & API Platform](#5-chat-website--api-platform) +6. [How to Run Locally](#6-how-to-run-locally) +7. [License](#7-license) +8. [Citation](#8-citation) +9. [Contact](#9-contact) ## 1. Introduction @@ -99,7 +96,7 @@ Throughout the entire training process, we did not experience any irrecoverable > [!NOTE] -> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights.** +> The total size of DeepSeek-V3 models on Hugging Face is 685B, which includes 671B of the Main Model weights and 14B of the Multi-Token Prediction (MTP) Module weights. To ensure optimal performance and flexibility, we have partnered with open-source communities and hardware vendors to provide multiple ways to run the model locally. For step-by-step guidance, check out Section 6: [How_to Run_Locally](#6-how-to-run-locally). @@ -130,7 +127,7 @@ For developers looking to dive deeper, we recommend exploring [README_WEIGHTS.md | | WinoGrande (Acc.) | 5-shot | **86.3** | 82.3 | 85.2 | 84.9 | | | RACE-Middle (Acc.) | 5-shot | 73.1 | 68.1 | **74.2** | 67.1 | | | RACE-High (Acc.) | 5-shot | 52.6 | 50.3 | **56.8** | 51.3 | -| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | **82.7** | **82.9** | +| | TriviaQA (EM) | 5-shot | 80.0 | 71.9 | 82.7 | **82.9** | | | NaturalQuestions (EM) | 5-shot | 38.6 | 33.2 | **41.5** | 40.0 | | | AGIEval (Acc.) | 0-shot | 57.5 | 75.8 | 60.6 | **79.6** | | Code | HumanEval (Pass@1) | 0-shot | 43.3 | 53.0 | 54.9 | **65.2** | @@ -249,7 +246,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h ``` > [!NOTE] -> Hugging Face's Transformers has not been directly supported yet.** +> Hugging Face's Transformers has not been directly supported yet. ### 6.1 Inference with DeepSeek-Infer Demo (example only) @@ -259,7 +256,7 @@ python fp8_cast_bf16.py --input-fp8-hf-path /path/to/fp8_weights --output-bf16-h > Linux with Python 3.10 only. Mac and Windows are not supported. Dependencies: -``` +```pip-requirements torch==2.4.1 triton==3.0.0 transformers==4.46.3 @@ -346,7 +343,7 @@ This code repository is licensed under [the MIT License](LICENSE-CODE). The use ``` @misc{deepseekai2024deepseekv3technicalreport, title={DeepSeek-V3 Technical Report}, - author={DeepSeek-AI and Aixin Liu and Bei Feng and Bing Xue and Bingxuan Wang and Bochao Wu and Chengda Lu and Chenggang Zhao and Chengqi Deng and Chenyu Zhang and Chong Ruan and Damai Dai and Daya Guo and Dejian Yang and Deli Chen and Dongjie Ji and Erhang Li and Fangyun Lin and Fucong Dai and Fuli Luo and Guangbo Hao and Guanting Chen and Guowei Li and H. Zhang and Han Bao and Hanwei Xu and Haocheng Wang and Haowei Zhang and Honghui Ding and Huajian Xin and Huazuo Gao and Hui Li and Hui Qu and J. L. Cai and Jian Liang and Jianzhong Guo and Jiaqi Ni and Jiashi Li and Jiawei Wang and Jin Chen and Jingchang Chen and Jingyang Yuan and Junjie Qiu and Junlong Li and Junxiao Song and Kai Dong and Kai Hu and Kaige Gao and Kang Guan and Kexin Huang and Kuai Yu and Lean Wang and Lecong Zhang and Lei Xu and Leyi Xia and Liang Zhao and Litong Wang and Liyue Zhang and Meng Li and Miaojun Wang and Mingchuan Zhang and Minghua Zhang and Minghui Tang and Mingming Li and Ning Tian and Panpan Huang and Peiyi Wang and Peng Zhang and Qiancheng Wang and Qihao Zhu and Qinyu Chen and Qiushi Du and R. J. Chen and R. L. Jin and Ruiqi Ge and Ruisong Zhang and Ruizhe Pan and Runji Wang and Runxin Xu and Ruoyu Zhang and Ruyi Chen and S. S. Li and Shanghao Lu and Shangyan Zhou and Shanhuang Chen and Shaoqing Wu and Shengfeng Ye and Shengfeng Ye and Shirong Ma and Shiyu Wang and Shuang Zhou and Shuiping Yu and Shunfeng Zhou and Shuting Pan and T. Wang and Tao Yun and Tian Pei and Tianyu Sun and W. L. Xiao and Wangding Zeng and Wanjia Zhao and Wei An and Wen Liu and Wenfeng Liang and Wenjun Gao and Wenqin Yu and Wentao Zhang and X. Q. Li and Xiangyue Jin and Xianzu Wang and Xiao Bi and Xiaodong Liu and Xiaohan Wang and Xiaojin Shen and Xiaokang Chen and Xiaokang Zhang and Xiaosha Chen and Xiaotao Nie and Xiaowen Sun and Xiaoxiang Wang and Xin Cheng and Xin Liu and Xin Xie and Xingchao Liu and Xingkai Yu and Xinnan Song and Xinxia Shan and Xinyi Zhou and Xinyu Yang and Xinyuan Li and Xuecheng Su and Xuheng Lin and Y. K. Li and Y. Q. Wang and Y. X. Wei and Y. X. Zhu and Yang Zhang and Yanhong Xu and Yanhong Xu and Yanping Huang and Yao Li and Yao Zhao and Yaofeng Sun and Yaohui Li and Yaohui Wang and Yi Yu and Yi Zheng and Yichao Zhang and Yifan Shi and Yiliang Xiong and Ying He and Ying Tang and Yishi Piao and Yisong Wang and Yixuan Tan and Yiyang Ma and Yiyuan Liu and Yongqiang Guo and Yu Wu and Yuan Ou and Yuchen Zhu and Yuduan Wang and Yue Gong and Yuheng Zou and Yujia He and Yukun Zha and Yunfan Xiong and Yunxian Ma and Yuting Yan and Yuxiang Luo and Yuxiang You and Yuxuan Liu and Yuyang Zhou and Z. F. Wu and Z. Z. Ren and Zehui Ren and Zhangli Sha and Zhe Fu and Zhean Xu and Zhen Huang and Zhen Zhang and Zhenda Xie and Zhengyan Zhang and Zhewen Hao and Zhibin Gou and Zhicheng Ma and Zhigang Yan and Zhihong Shao and Zhipeng Xu and Zhiyu Wu and Zhongyu Zhang and Zhuoshu Li and Zihui Gu and Zijia Zhu and Zijun Liu and Zilin Li and Ziwei Xie and Ziyang Song and Ziyi Gao and Zizheng Pan}, + author={DeepSeek-AI}, year={2024}, eprint={2412.19437}, archivePrefix={arXiv}, diff --git a/inference/convert.py b/inference/convert.py index c606ce8..6d85ccc 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -60,7 +60,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): name = name.replace("weight_scale_inv", "scale") name = name.replace("e_score_correction_bias", "bias") key = name.split(".")[-2] - assert key in mapping + assert key in mapping, f"Key {key} not found in mapping" new_key, dim = mapping[key] name = name.replace(key, new_key) for i in range(mp): @@ -70,7 +70,7 @@ def main(hf_ckpt_path, save_path, n_experts, mp): if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts: continue elif dim is not None: - assert param.size(dim) % mp == 0 + assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}" shard_size = param.size(dim) // mp new_param = param.narrow(dim, i * shard_size, shard_size).contiguous() state_dicts[i][name] = new_param @@ -92,5 +92,5 @@ if __name__ == "__main__": parser.add_argument("--n-experts", type=int, required=True) parser.add_argument("--model-parallel", type=int, required=True) args = parser.parse_args() - assert args.n_experts % args.model_parallel == 0 + assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism" main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel) diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..7e9bffe 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -49,7 +49,7 @@ def generate( List[List[int]]: A list of lists containing the generated tokens for each sequence. """ prompt_lens = [len(t) for t in prompt_tokens] - assert max(prompt_lens) <= model.max_seq_len + assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") for i, t in enumerate(prompt_tokens): @@ -145,7 +145,7 @@ def main( else: with open(input_file) as f: prompts = [line.strip() for line in f.readlines()] - assert len(prompts) <= args.max_batch_size + assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) @@ -181,5 +181,5 @@ if __name__ == "__main__": parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() - assert args.input_file or args.interactive + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..ae907ad 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -43,8 +43,8 @@ def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, tor - The quantized tensor with dtype `torch.float8_e4m3fn`. - A tensor of scaling factors with dtype `torch.float32`. """ - assert x.is_contiguous() - assert x.size(-1) % block_size == 0 + assert x.is_contiguous(), 'Input tensor must be contiguous' + assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})' y = torch.empty_like(x, dtype=torch.float8_e4m3fn) s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) @@ -96,8 +96,8 @@ def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> t Raises: AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. """ - assert x.is_contiguous() and s.is_contiguous() - assert x.dim() == 2 and s.dim() == 2 + assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous' + assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions' M, N = x.size() y = torch.empty_like(x, dtype=torch.get_default_dtype()) grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) @@ -180,8 +180,8 @@ def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Ten Returns: torch.Tensor: The result of the matrix multiplication. """ - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() + assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous' + assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous' K = a.size(-1) M = a.numel() // K N = b.size(0) diff --git a/inference/model.py b/inference/model.py index 9ea60c9..8f1ab81 100644 --- a/inference/model.py +++ b/inference/model.py @@ -96,7 +96,7 @@ class ParallelEmbedding(nn.Module): super().__init__() self.vocab_size = vocab_size self.dim = dim - assert vocab_size % world_size == 0 + assert vocab_size % world_size == 0, f"Vocabulary size must be divisible by world size (world_size={world_size})" self.part_vocab_size = (vocab_size // world_size) self.vocab_start_idx = rank * self.part_vocab_size self.vocab_end_idx = self.vocab_start_idx + self.part_vocab_size @@ -143,7 +143,7 @@ def linear(x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = quantization-aware computations depending on the input parameters. Notes: - - If `weight` is quantized (e.g., `element_size() > 1`), a dequantized version + - If `weight` is quantized (e.g., `element_size() == 1`), a dequantized version is used for computation. - If `gemm_impl == "bf16"`, dequantization and a `bf16` GEMM operation are applied. - For other cases, the function applies quantization to `x` and uses `fp8_gemm` for computation. @@ -185,7 +185,7 @@ class Linear(nn.Module): else: self.register_parameter("scale", None) if bias: - self.bias = nn.Parameter(torch.empty(self.part_out_features)) + self.bias = nn.Parameter(torch.empty(out_features)) else: self.register_parameter("bias", None) @@ -213,7 +213,7 @@ class ColumnParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert out_features % world_size == 0 + assert out_features % world_size == 0, f"Output features must be divisible by world size (world_size={world_size})" self.part_out_features = out_features // world_size super().__init__(in_features, self.part_out_features, bias, dtype) @@ -242,7 +242,7 @@ class RowParallelLinear(Linear): dtype (optional): Data type for the layer. Defaults to `torch.bfloat16`. """ def __init__(self, in_features: int, out_features: int, bias: bool = False, dtype = None): - assert in_features % world_size == 0 + assert in_features % world_size == 0, f"Input features must be divisible by world size (world_size={world_size})" self.part_in_features = in_features // world_size super().__init__(self.part_in_features, out_features, bias, dtype) @@ -585,8 +585,8 @@ class Gate(nn.Module): else: group_scores = scores.topk(2, dim=-1)[0].sum(dim=-1) indices = group_scores.topk(self.topk_groups, dim=-1)[1] - mask = torch.zeros_like(scores[..., 0]).scatter_(1, indices, True) - scores = (scores * mask.unsqueeze(-1)).flatten(1) + mask = scores.new_ones(x.size(0), self.n_groups, dtype=bool).scatter_(1, indices, False) + scores = scores.masked_fill_(mask.unsqueeze(-1), float("-inf")).flatten(1) indices = torch.topk(scores, self.topk, dim=-1)[1] weights = original_scores.gather(1, indices) if self.score_func == "sigmoid": @@ -652,7 +652,7 @@ class MoE(nn.Module): """ super().__init__() self.dim = args.dim - assert args.n_routed_experts % world_size == 0 + assert args.n_routed_experts % world_size == 0, f"Number of experts must be divisible by world size (world_size={world_size})" self.n_routed_experts = args.n_routed_experts self.n_local_experts = args.n_routed_experts // world_size self.n_activated_experts = args.n_activated_experts