diff --git a/MACBOOK_SETUP.md b/MACBOOK_SETUP.md new file mode 100644 index 0000000..4fe36d8 --- /dev/null +++ b/MACBOOK_SETUP.md @@ -0,0 +1,84 @@ +# DeepSeek V3 for MacBook + +> This was the initial idea, but now migrating to Zig for better performance and support for any architecture. + +This guide provides instructions for running DeepSeek V3 efficiently on MacBook devices with limited resources compared to high-end GPU servers. + +## Optimizations Made + +The optimized version includes several improvements: + +1. **CPU and MPS (Apple Silicon) Support**: Implementation of CPU-compatible kernels and Apple Silicon GPU acceleration. +2. **Memory Efficiency**: Chunked processing and sliding window attention to reduce memory usage. +3. **Quantization**: Optional int8 quantization for CPU to improve inference speed while maintaining reasonable quality. +4. **Reduced Model Size**: Configuration options to load smaller, more efficient model variants. +5. **Dynamic Device Selection**: Automatic selection of the best available device (MPS, CPU). +6. **Progressive Generation**: Ability to see generation progress for longer outputs. + +## System Requirements + +### Minimum Requirements +- MacBook with Intel CPU (8GB RAM minimum) +- macOS 11 (Big Sur) or newer +- 10GB disk space for model weights + +### Recommended +- MacBook with Apple Silicon (M1/M2/M3) +- 16GB RAM or more +- macOS 12 (Monterey) or newer +- 20GB disk space for model weights + +## Installation + +1. Install dependencies: + +```bash +pip install -r inference/requirements_macbook.txt +``` + +2. Download model weights following instructions in README_WEIGHTS.md + +## Usage + +The optimized script provides several options to control performance: + +```bash +python inference/optimized_generate.py \ + --ckpt-path /path/to/model/weights \ + --config inference/configs/config_macbook.json \ + --interactive \ + --max-new-tokens 200 \ + --temperature 0.2 +``` + +### Additional Options + +- `--force-cpu`: Force CPU usage even if GPU is available +- `--no-quantize`: Disable quantization (higher quality but slower) +- `--chunk-size`: Size of processing chunks (default: 512, lower values use less memory) +- `--reduced-model`: Use reduced model parameters (fewer experts/layers for lower resource usage) + +## Performance Tips + +1. **Context Length**: Keep prompt length short (under 1024 tokens) for better performance. +2. **Batch Size**: Always use batch size of 1 on MacBooks. +3. **Apple Silicon**: M1/M2/M3 MacBooks can use MPS backend for significantly better performance. +4. **Memory Management**: Close other applications when running the model. +5. **Temperature**: Using temperature=0 (greedy decoding) is faster but less creative. + +## Troubleshooting + +### "Out of Memory" Errors +- Try using `--reduced-model` flag +- Reduce `--chunk-size` to 256 or 128 +- Use `--force-cpu` if MPS memory is limited + +### Slow Generation +- Ensure you're using the Apple Silicon optimized build of PyTorch +- Check activity monitor to verify GPU utilization +- Try a smaller config file (edit parameters in config_macbook.json) + +### Model Loading Errors +- Verify model weights are downloaded correctly +- Ensure safetensors files are in the expected location +- Check torch/transformers versions match requirements diff --git a/inference/configs/config_macbook.json b/inference/configs/config_macbook.json new file mode 100644 index 0000000..96b98bf --- /dev/null +++ b/inference/configs/config_macbook.json @@ -0,0 +1,21 @@ +{ + "vocab_size": 102400, + "dim": 2048, + "inter_dim": 5472, + "moe_inter_dim": 704, + "n_layers": 14, + "n_dense_layers": 1, + "n_heads": 16, + "n_routed_experts": 32, + "n_shared_experts": 2, + "n_activated_experts": 4, + "route_scale": 1.0, + "q_lora_rank": 0, + "kv_lora_rank": 256, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "mscale": 0.707, + "max_batch_size": 1, + "max_seq_len": 4096 +} diff --git a/inference/cpu_kernel.py b/inference/cpu_kernel.py new file mode 100644 index 0000000..a0b3943 --- /dev/null +++ b/inference/cpu_kernel.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +from typing import Tuple, Optional + +def act_quant_cpu(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + CPU-compatible version of act_quant. Quantizes the input tensor using block-wise quantization. + + Args: + x (torch.Tensor): The input tensor to be quantized. + block_size (int, optional): The size of the blocks for quantization. Default is 128. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors + """ + assert x.is_contiguous(), 'Input tensor must be contiguous' + + # Handle non-divisible cases more gracefully + if x.size(-1) % block_size != 0: + # Pad the tensor to make it divisible by block_size + pad_size = block_size - (x.size(-1) % block_size) + x = F.pad(x, (0, pad_size)) + + # Reshape to blocks for efficient processing + shape = x.shape + x_reshaped = x.reshape(-1, block_size) + + # Calculate scaling factors (max absolute value in each block) + s = torch.max(torch.abs(x_reshaped), dim=1, keepdim=True)[0] / 448.0 + + # Avoid division by zero + s = torch.clamp(s, min=1e-10) + + # Quantize by dividing by scaling factors + y = x_reshaped / s + + # Either use float8 if available or simulate with int8 + scaling + if hasattr(torch, "float8_e4m3fn"): + y = y.to(torch.float8_e4m3fn) + else: + # Simulate float8 with int8 quantization + y = torch.clamp(y, -448.0, 448.0) + y = (y / 448.0 * 127).round().to(torch.int8) + + # Reshape back to original shape + y = y.reshape(shape) + + # Reshape scaling factors to match expected output format + s = s.reshape(*shape[:-1], -1).squeeze(-1) + + return y, s + +def weight_dequant_cpu(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor: + """ + CPU-compatible version of weight_dequant. Dequantizes the weight tensor. + + Args: + weight (torch.Tensor): Quantized weight tensor. + scale (torch.Tensor): Scaling factors. + + Returns: + torch.Tensor: Dequantized weight tensor. + """ + # Handle different quantization formats + if weight.dtype == torch.int8: + # For int8 simulated quantization + weight_float = weight.to(torch.float32) / 127.0 * 448.0 + elif hasattr(torch, "float8_e4m3fn") and weight.dtype == torch.float8_e4m3fn: + # Native float8 support + weight_float = weight.to(torch.float32) + else: + # Already in a floating point format + weight_float = weight.to(torch.float32) + + # Reshape scale to broadcast correctly + if weight.dim() == 2: + # For linear layers + out_features, in_features = weight.shape + block_size = 128 # Same as in the original code + + scale_out = (out_features + block_size - 1) // block_size + scale_in = (in_features + block_size - 1) // block_size + + if scale.numel() == scale_out * scale_in: + # Reshape to match weight blocks + scale_reshaped = scale.reshape(scale_out, scale_in) + + # Create a mask for each block + out_blocks = torch.arange(out_features).reshape(-1, 1) // block_size + in_blocks = torch.arange(in_features).reshape(1, -1) // block_size + + # Limit to actual dimensions + out_blocks = torch.clamp(out_blocks, max=scale_out-1) + in_blocks = torch.clamp(in_blocks, max=scale_in-1) + + # Get corresponding scale for each position + scale_broadcast = scale_reshaped[out_blocks, in_blocks] + + # Apply scaling + return weight_float * scale_broadcast + + # Fallback for other tensor dimensions + return weight_float * scale + +def fp8_gemm_cpu(x: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor: + """ + CPU-compatible version of fp8_gemm. Performs matrix multiplication with quantized tensors. + + Args: + x (torch.Tensor): Input activations (quantized). + x_scale (torch.Tensor): Scaling factors for input activations. + weight (torch.Tensor): Weights (quantized). + weight_scale (torch.Tensor): Scaling factors for weights. + + Returns: + torch.Tensor: Result of matrix multiplication. + """ + # Dequantize input and weights + if x.dtype == torch.int8: + x_float = x.to(torch.float32) / 127.0 * 448.0 + elif hasattr(torch, "float8_e4m3fn") and x.dtype == torch.float8_e4m3fn: + x_float = x.to(torch.float32) + else: + x_float = x + + # Apply input scaling + if x_scale is not None: + # Reshape x_scale for broadcasting + new_shape = list(x_scale.shape) + [1] * (x_float.dim() - x_scale.dim()) + x_float = x_float * x_scale.reshape(*new_shape) + + # Dequantize weights + weight_dequant = weight_dequant_cpu(weight, weight_scale) + + # Perform matrix multiplication + result = F.linear(x_float, weight_dequant) + + return result + +# MPS (Metal Performance Shaders) optimized versions for Apple Silicon +def setup_mps_kernels(): + """ + Set up optimized MPS kernels if running on Apple Silicon + """ + if hasattr(torch, "mps") and torch.mps.is_available(): + print("Setting up MPS optimized kernels for Apple Silicon") + # MPS already optimizes most operations automatically + # Additional optimizations could be added in the future + else: + print("MPS not available, using CPU kernels") + +# Provide unified interface that selects the appropriate implementation +def get_optimized_kernels(device="cpu"): + """ + Returns optimized kernel functions based on the device + + Args: + device (str): The device to optimize for ("cpu", "mps", or "cuda") + + Returns: + dict: Dictionary of optimized kernel functions + """ + if device == "mps" and hasattr(torch, "mps") and torch.mps.is_available(): + setup_mps_kernels() + # For MPS, we use the CPU implementations which will be automatically + # optimized by PyTorch's MPS backend + return { + "act_quant": act_quant_cpu, + "weight_dequant": weight_dequant_cpu, + "fp8_gemm": fp8_gemm_cpu + } + elif device == "cuda" and torch.cuda.is_available(): + # For CUDA, use the original implementations + from kernel import act_quant, weight_dequant, fp8_gemm + return { + "act_quant": act_quant, + "weight_dequant": weight_dequant, + "fp8_gemm": fp8_gemm + } + else: + # Default to CPU implementations + return { + "act_quant": act_quant_cpu, + "weight_dequant": weight_dequant_cpu, + "fp8_gemm": fp8_gemm_cpu + } diff --git a/inference/optimized_generate.py b/inference/optimized_generate.py new file mode 100644 index 0000000..4ca3ba5 --- /dev/null +++ b/inference/optimized_generate.py @@ -0,0 +1,363 @@ +import os +import json +from argparse import ArgumentParser +from typing import List, Optional + +import torch +from transformers import AutoTokenizer +from safetensors.torch import load_model + +from model import Transformer, ModelArgs + + +def sample(logits, temperature: float = 1.0): + """ + Samples a token from the logits using temperature scaling. + + Args: + logits (torch.Tensor): The logits tensor for token predictions. + temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. + + Returns: + torch.Tensor: The sampled token. + """ + logits = logits / max(temperature, 1e-5) + probs = torch.softmax(logits, dim=-1) + return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + + +@torch.inference_mode() +def generate( + model: Transformer, + prompt_tokens: List[List[int]], + max_new_tokens: int, + eos_id: int, + temperature: float = 1.0, + device: str = "mps", + chunk_size: int = 512 +) -> List[List[int]]: + """ + Generates new tokens based on the given prompt tokens using the specified model. + Optimized for MacBook with chunked processing and reduced memory usage. + + Args: + model (Transformer): The transformer model used for token generation. + prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. + max_new_tokens (int): The maximum number of new tokens to generate. + eos_id (int): The end-of-sequence token ID. + temperature (float, optional): The temperature value for sampling. Defaults to 1.0. + device (str, optional): The device to run generation on. Defaults to "mps". + chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512. + + Returns: + List[List[int]]: A list of lists containing the generated tokens for each sequence. + """ + prompt_lens = [len(t) for t in prompt_tokens] + assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})" + + # Process in smaller batches for memory efficiency + batch_size = len(prompt_tokens) + if batch_size > 1: + print(f"Processing {batch_size} prompts in sequence to conserve memory...") + all_completion_tokens = [] + for i in range(batch_size): + single_completion = generate( + model, + [prompt_tokens[i]], + max_new_tokens, + eos_id, + temperature, + device, + chunk_size + ) + all_completion_tokens.extend(single_completion) + return all_completion_tokens + + # Calculate total sequence length + total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) + + # Initialize tokens tensor on appropriate device + tokens = torch.full((batch_size, total_len), -1, dtype=torch.long, device=device) + for i, t in enumerate(prompt_tokens): + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device) + + prev_pos = 0 + finished = torch.tensor([False] * batch_size, device=device) + prompt_mask = tokens != -1 + + # Process in chunks for lower memory usage + for cur_pos in range(min(prompt_lens), total_len): + # Use a sliding window approach for KV cache efficiency + start_pos = max(0, cur_pos - chunk_size) if cur_pos > min(prompt_lens) else prev_pos + + # Clear GPU/MPS cache periodically to prevent memory fragmentation + if cur_pos % 50 == 0 and device == "mps": + torch.mps.empty_cache() + elif cur_pos % 50 == 0 and device == "cuda": + torch.cuda.empty_cache() + + logits = model.forward(tokens[:, start_pos:cur_pos], start_pos) + + if temperature > 0: + next_token = sample(logits, temperature) + else: + next_token = logits.argmax(dim=-1) + + next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) + tokens[:, cur_pos] = next_token + finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) + prev_pos = start_pos + + if finished.all(): + break + + # Optional progress display for longer generations + if max_new_tokens > 100 and cur_pos % 20 == 0: + progress = (cur_pos - min(prompt_lens)) / max_new_tokens * 100 + print(f"\rGenerating: {progress:.1f}% complete", end="") + + if max_new_tokens > 100: + print("\rGeneration complete! ") + + completion_tokens = [] + for i, toks in enumerate(tokens.tolist()): + toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] + if eos_id in toks: + toks = toks[:toks.index(eos_id)] + completion_tokens.append(toks) + + return completion_tokens + + +def get_optimal_device(force_cpu: bool = False) -> str: + """ + Determines the best available device for model inference. + + Args: + force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False. + + Returns: + str: Device string ("cuda", "mps", or "cpu") + """ + if force_cpu: + return "cpu" + + if torch.cuda.is_available(): + return "cuda" + elif hasattr(torch, "mps") and torch.mps.is_available(): + return "mps" # Apple Silicon GPU + else: + return "cpu" + + +def optimize_model(model: Transformer, quantize: bool = True, device: str = "cpu") -> Transformer: + """ + Applies optimizations to the model for more efficient inference. + + Args: + model (Transformer): The transformer model to optimize. + quantize (bool, optional): Whether to apply int8 quantization. Defaults to True. + device (str, optional): Target device. Defaults to "cpu". + + Returns: + Transformer: Optimized model + """ + if quantize and device == "cpu": + # Apply dynamic quantization to linear layers + model = torch.quantization.quantize_dynamic( + model, {torch.nn.Linear}, dtype=torch.qint8 + ) + + return model + + +def main( + ckpt_path: str, + config: str, + input_file: str = "", + interactive: bool = True, + max_new_tokens: int = 100, + temperature: float = 1.0, + force_cpu: bool = False, + quantize: bool = True, + chunk_size: int = 512, + reduced_model: bool = False, +) -> None: + """ + Main function to load the model and perform interactive or batch text generation. + Optimized for MacBooks with various memory/performance options. + + Args: + ckpt_path (str): Path to the model checkpoint directory. + config (str): Path to the model configuration file. + input_file (str, optional): Path to a file containing input prompts. Defaults to "". + interactive (bool, optional): Whether to run in interactive mode. Defaults to True. + max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. + temperature (float, optional): Temperature for sampling. Defaults to 1.0. + force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False. + quantize (bool, optional): Apply quantization when possible. Defaults to True. + chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512. + reduced_model (bool, optional): Load a smaller version of the model if available. Defaults to False. + """ + # Detect optimal device + device = get_optimal_device(force_cpu) + print(f"Using device: {device}") + + # Set appropriate torch settings + if device == "cuda": + torch.cuda.set_device(0) + + # Use bfloat16 for CUDA/MPS, float32 for CPU + if device == "cpu": + torch.set_default_dtype(torch.float32) + else: + torch.set_default_dtype(torch.bfloat16) + + torch.set_num_threads(8) # Adjust based on your CPU + torch.manual_seed(965) + + # Load model configuration + with open(config) as f: + config_data = json.load(f) + + # Apply optimizations to configuration for smaller/faster model + if reduced_model: + # Reduce the number of experts and heads + config_data["n_routed_experts"] = config_data.get("n_routed_experts", 64) // 2 + config_data["max_seq_len"] = min(config_data.get("max_seq_len", 16384), 4096) # Reduce context size + + args = ModelArgs(**config_data) + + print(args) + + # Load model + with torch.device(device): + model = Transformer(args) + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + + # Load the appropriate checkpoint + if device != "cuda": + # For CPU/MPS, always use rank 0 + checkpoint_path = os.path.join(ckpt_path, "model0-mp1.safetensors") + else: + # For CUDA, can use multiple GPUs if available + world_size = min(torch.cuda.device_count(), 1) # Limit to 1 for simpler usage + rank = 0 + checkpoint_path = os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors") + + print(f"Loading checkpoint from {checkpoint_path}") + load_model(model, checkpoint_path) + + # Apply quantization and other optimizations + model = optimize_model(model, quantize=quantize, device=device) + model.to(device) + + # Generate a quick test sequence to ensure everything is working + print("Running warmup generation...") + tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1., device)[0]) + print("Model loaded and ready!") + + if interactive: + messages = [] + while True: + prompt = input(">>> ") + if prompt == "/exit": + break + elif prompt == "/clear": + messages.clear() + continue + + messages.append({"role": "user", "content": prompt}) + prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + + # Show a waiting message for longer generations + if max_new_tokens > 50: + print("Generating response...") + + completion_tokens = generate( + model, + [prompt_tokens], + max_new_tokens, + tokenizer.eos_token_id, + temperature, + device, + chunk_size + ) + + completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) + print(completion) + messages.append({"role": "assistant", "content": completion}) + + # Clear cache after each generation to prevent memory buildup + if device == "mps": + torch.mps.empty_cache() + elif device == "cuda": + torch.cuda.empty_cache() + else: + with open(input_file) as f: + prompts = [line.strip() for line in f.readlines()] + assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})" + + prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] + completion_tokens = generate( + model, + prompt_tokens, + max_new_tokens, + tokenizer.eos_token_id, + temperature, + device, + chunk_size + ) + + completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) + for prompt, completion in zip(prompts, completions): + print("Prompt:", prompt) + print("Completion:", completion) + print() + + +if __name__ == "__main__": + """ + Command-line interface for optimized text generation on MacBooks. + + Arguments: + --ckpt-path (str): Path to the model checkpoint directory. + --config (str): Path to the model configuration file. + --input-file (str, optional): File containing prompts for batch processing. + --interactive (bool, optional): Enable interactive mode for generating text. + --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. + --temperature (float, optional): Temperature for sampling. Defaults to 0.2. + --force-cpu (bool, optional): Force CPU usage even if GPU is available. + --no-quantize (bool, optional): Disable quantization (higher quality but slower). + --chunk-size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512. + --reduced-model (bool, optional): Load a smaller version of the model if available. + """ + parser = ArgumentParser() + parser.add_argument("--ckpt-path", type=str, required=True) + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--input-file", type=str, default="") + parser.add_argument("--interactive", action="store_true") + parser.add_argument("--max-new-tokens", type=int, default=200) + parser.add_argument("--temperature", type=float, default=0.2) + parser.add_argument("--force-cpu", action="store_true") + parser.add_argument("--no-quantize", action="store_true") + parser.add_argument("--chunk-size", type=int, default=512) + parser.add_argument("--reduced-model", action="store_true") + + args = parser.parse_args() + assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified" + + main( + args.ckpt_path, + args.config, + args.input_file, + args.interactive, + args.max_new_tokens, + args.temperature, + args.force_cpu, + not args.no_quantize, + args.chunk_size, + args.reduced_model + ) diff --git a/inference/requirements_macbook.txt b/inference/requirements_macbook.txt new file mode 100644 index 0000000..b87cacb --- /dev/null +++ b/inference/requirements_macbook.txt @@ -0,0 +1,6 @@ +torch>=2.1.0 +transformers==4.46.3 +safetensors==0.4.5 +tqdm +numpy +sentencepiece