from dataclasses import dataclass from typing import Literal @dataclass class ModelArgs: max_batch_size: int = 8 max_seq_len: int = 4096 * 4 dtype: Literal["bf16", "fp8"] = "bf16" vocab_size: int = 102400 dim: int = 2048 inter_dim: int = 10944 moe_inter_dim: int = 1408 n_layers: int = 27 n_dense_layers: int = 1 n_heads: int = 16 # moe n_routed_experts: int = 64 n_shared_experts: int = 2 n_activated_experts: int = 6 n_expert_groups: int = 1 n_limited_groups: int = 1 score_func: Literal["softmax", "sigmoid"] = "softmax" route_scale: float = 1. # mla q_lora_rank: int = 0 kv_lora_rank: int = 512 qk_nope_head_dim: int = 128 qk_rope_head_dim: int = 64 v_head_dim: int = 128 # yarn original_seq_len: int = 4096 rope_theta: float = 10000.0 rope_factor: float = 40 beta_fast: int = 32 beta_slow: int = 1 mscale: float = 1.