mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-07-06 00:11:56 -04:00
feat: Initial MacBook optimisation draft for DeepSeek V3 inference > moving to Zig instead
This commit is contained in:
parent
4cc6253d5c
commit
a1895012dd
84
MACBOOK_SETUP.md
Normal file
84
MACBOOK_SETUP.md
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# DeepSeek V3 for MacBook
|
||||||
|
|
||||||
|
> This was the initial idea, but now migrating to Zig for better performance and support for any architecture.
|
||||||
|
|
||||||
|
This guide provides instructions for running DeepSeek V3 efficiently on MacBook devices with limited resources compared to high-end GPU servers.
|
||||||
|
|
||||||
|
## Optimizations Made
|
||||||
|
|
||||||
|
The optimized version includes several improvements:
|
||||||
|
|
||||||
|
1. **CPU and MPS (Apple Silicon) Support**: Implementation of CPU-compatible kernels and Apple Silicon GPU acceleration.
|
||||||
|
2. **Memory Efficiency**: Chunked processing and sliding window attention to reduce memory usage.
|
||||||
|
3. **Quantization**: Optional int8 quantization for CPU to improve inference speed while maintaining reasonable quality.
|
||||||
|
4. **Reduced Model Size**: Configuration options to load smaller, more efficient model variants.
|
||||||
|
5. **Dynamic Device Selection**: Automatic selection of the best available device (MPS, CPU).
|
||||||
|
6. **Progressive Generation**: Ability to see generation progress for longer outputs.
|
||||||
|
|
||||||
|
## System Requirements
|
||||||
|
|
||||||
|
### Minimum Requirements
|
||||||
|
- MacBook with Intel CPU (8GB RAM minimum)
|
||||||
|
- macOS 11 (Big Sur) or newer
|
||||||
|
- 10GB disk space for model weights
|
||||||
|
|
||||||
|
### Recommended
|
||||||
|
- MacBook with Apple Silicon (M1/M2/M3)
|
||||||
|
- 16GB RAM or more
|
||||||
|
- macOS 12 (Monterey) or newer
|
||||||
|
- 20GB disk space for model weights
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
1. Install dependencies:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -r inference/requirements_macbook.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Download model weights following instructions in README_WEIGHTS.md
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
The optimized script provides several options to control performance:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python inference/optimized_generate.py \
|
||||||
|
--ckpt-path /path/to/model/weights \
|
||||||
|
--config inference/configs/config_macbook.json \
|
||||||
|
--interactive \
|
||||||
|
--max-new-tokens 200 \
|
||||||
|
--temperature 0.2
|
||||||
|
```
|
||||||
|
|
||||||
|
### Additional Options
|
||||||
|
|
||||||
|
- `--force-cpu`: Force CPU usage even if GPU is available
|
||||||
|
- `--no-quantize`: Disable quantization (higher quality but slower)
|
||||||
|
- `--chunk-size`: Size of processing chunks (default: 512, lower values use less memory)
|
||||||
|
- `--reduced-model`: Use reduced model parameters (fewer experts/layers for lower resource usage)
|
||||||
|
|
||||||
|
## Performance Tips
|
||||||
|
|
||||||
|
1. **Context Length**: Keep prompt length short (under 1024 tokens) for better performance.
|
||||||
|
2. **Batch Size**: Always use batch size of 1 on MacBooks.
|
||||||
|
3. **Apple Silicon**: M1/M2/M3 MacBooks can use MPS backend for significantly better performance.
|
||||||
|
4. **Memory Management**: Close other applications when running the model.
|
||||||
|
5. **Temperature**: Using temperature=0 (greedy decoding) is faster but less creative.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "Out of Memory" Errors
|
||||||
|
- Try using `--reduced-model` flag
|
||||||
|
- Reduce `--chunk-size` to 256 or 128
|
||||||
|
- Use `--force-cpu` if MPS memory is limited
|
||||||
|
|
||||||
|
### Slow Generation
|
||||||
|
- Ensure you're using the Apple Silicon optimized build of PyTorch
|
||||||
|
- Check activity monitor to verify GPU utilization
|
||||||
|
- Try a smaller config file (edit parameters in config_macbook.json)
|
||||||
|
|
||||||
|
### Model Loading Errors
|
||||||
|
- Verify model weights are downloaded correctly
|
||||||
|
- Ensure safetensors files are in the expected location
|
||||||
|
- Check torch/transformers versions match requirements
|
21
inference/configs/config_macbook.json
Normal file
21
inference/configs/config_macbook.json
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
{
|
||||||
|
"vocab_size": 102400,
|
||||||
|
"dim": 2048,
|
||||||
|
"inter_dim": 5472,
|
||||||
|
"moe_inter_dim": 704,
|
||||||
|
"n_layers": 14,
|
||||||
|
"n_dense_layers": 1,
|
||||||
|
"n_heads": 16,
|
||||||
|
"n_routed_experts": 32,
|
||||||
|
"n_shared_experts": 2,
|
||||||
|
"n_activated_experts": 4,
|
||||||
|
"route_scale": 1.0,
|
||||||
|
"q_lora_rank": 0,
|
||||||
|
"kv_lora_rank": 256,
|
||||||
|
"qk_nope_head_dim": 128,
|
||||||
|
"qk_rope_head_dim": 64,
|
||||||
|
"v_head_dim": 128,
|
||||||
|
"mscale": 0.707,
|
||||||
|
"max_batch_size": 1,
|
||||||
|
"max_seq_len": 4096
|
||||||
|
}
|
186
inference/cpu_kernel.py
Normal file
186
inference/cpu_kernel.py
Normal file
@ -0,0 +1,186 @@
|
|||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from typing import Tuple, Optional
|
||||||
|
|
||||||
|
def act_quant_cpu(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
CPU-compatible version of act_quant. Quantizes the input tensor using block-wise quantization.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): The input tensor to be quantized.
|
||||||
|
block_size (int, optional): The size of the blocks for quantization. Default is 128.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[torch.Tensor, torch.Tensor]: Quantized tensor and scaling factors
|
||||||
|
"""
|
||||||
|
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
||||||
|
|
||||||
|
# Handle non-divisible cases more gracefully
|
||||||
|
if x.size(-1) % block_size != 0:
|
||||||
|
# Pad the tensor to make it divisible by block_size
|
||||||
|
pad_size = block_size - (x.size(-1) % block_size)
|
||||||
|
x = F.pad(x, (0, pad_size))
|
||||||
|
|
||||||
|
# Reshape to blocks for efficient processing
|
||||||
|
shape = x.shape
|
||||||
|
x_reshaped = x.reshape(-1, block_size)
|
||||||
|
|
||||||
|
# Calculate scaling factors (max absolute value in each block)
|
||||||
|
s = torch.max(torch.abs(x_reshaped), dim=1, keepdim=True)[0] / 448.0
|
||||||
|
|
||||||
|
# Avoid division by zero
|
||||||
|
s = torch.clamp(s, min=1e-10)
|
||||||
|
|
||||||
|
# Quantize by dividing by scaling factors
|
||||||
|
y = x_reshaped / s
|
||||||
|
|
||||||
|
# Either use float8 if available or simulate with int8 + scaling
|
||||||
|
if hasattr(torch, "float8_e4m3fn"):
|
||||||
|
y = y.to(torch.float8_e4m3fn)
|
||||||
|
else:
|
||||||
|
# Simulate float8 with int8 quantization
|
||||||
|
y = torch.clamp(y, -448.0, 448.0)
|
||||||
|
y = (y / 448.0 * 127).round().to(torch.int8)
|
||||||
|
|
||||||
|
# Reshape back to original shape
|
||||||
|
y = y.reshape(shape)
|
||||||
|
|
||||||
|
# Reshape scaling factors to match expected output format
|
||||||
|
s = s.reshape(*shape[:-1], -1).squeeze(-1)
|
||||||
|
|
||||||
|
return y, s
|
||||||
|
|
||||||
|
def weight_dequant_cpu(weight: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
CPU-compatible version of weight_dequant. Dequantizes the weight tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight (torch.Tensor): Quantized weight tensor.
|
||||||
|
scale (torch.Tensor): Scaling factors.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Dequantized weight tensor.
|
||||||
|
"""
|
||||||
|
# Handle different quantization formats
|
||||||
|
if weight.dtype == torch.int8:
|
||||||
|
# For int8 simulated quantization
|
||||||
|
weight_float = weight.to(torch.float32) / 127.0 * 448.0
|
||||||
|
elif hasattr(torch, "float8_e4m3fn") and weight.dtype == torch.float8_e4m3fn:
|
||||||
|
# Native float8 support
|
||||||
|
weight_float = weight.to(torch.float32)
|
||||||
|
else:
|
||||||
|
# Already in a floating point format
|
||||||
|
weight_float = weight.to(torch.float32)
|
||||||
|
|
||||||
|
# Reshape scale to broadcast correctly
|
||||||
|
if weight.dim() == 2:
|
||||||
|
# For linear layers
|
||||||
|
out_features, in_features = weight.shape
|
||||||
|
block_size = 128 # Same as in the original code
|
||||||
|
|
||||||
|
scale_out = (out_features + block_size - 1) // block_size
|
||||||
|
scale_in = (in_features + block_size - 1) // block_size
|
||||||
|
|
||||||
|
if scale.numel() == scale_out * scale_in:
|
||||||
|
# Reshape to match weight blocks
|
||||||
|
scale_reshaped = scale.reshape(scale_out, scale_in)
|
||||||
|
|
||||||
|
# Create a mask for each block
|
||||||
|
out_blocks = torch.arange(out_features).reshape(-1, 1) // block_size
|
||||||
|
in_blocks = torch.arange(in_features).reshape(1, -1) // block_size
|
||||||
|
|
||||||
|
# Limit to actual dimensions
|
||||||
|
out_blocks = torch.clamp(out_blocks, max=scale_out-1)
|
||||||
|
in_blocks = torch.clamp(in_blocks, max=scale_in-1)
|
||||||
|
|
||||||
|
# Get corresponding scale for each position
|
||||||
|
scale_broadcast = scale_reshaped[out_blocks, in_blocks]
|
||||||
|
|
||||||
|
# Apply scaling
|
||||||
|
return weight_float * scale_broadcast
|
||||||
|
|
||||||
|
# Fallback for other tensor dimensions
|
||||||
|
return weight_float * scale
|
||||||
|
|
||||||
|
def fp8_gemm_cpu(x: torch.Tensor, x_scale: torch.Tensor, weight: torch.Tensor, weight_scale: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
CPU-compatible version of fp8_gemm. Performs matrix multiplication with quantized tensors.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x (torch.Tensor): Input activations (quantized).
|
||||||
|
x_scale (torch.Tensor): Scaling factors for input activations.
|
||||||
|
weight (torch.Tensor): Weights (quantized).
|
||||||
|
weight_scale (torch.Tensor): Scaling factors for weights.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Result of matrix multiplication.
|
||||||
|
"""
|
||||||
|
# Dequantize input and weights
|
||||||
|
if x.dtype == torch.int8:
|
||||||
|
x_float = x.to(torch.float32) / 127.0 * 448.0
|
||||||
|
elif hasattr(torch, "float8_e4m3fn") and x.dtype == torch.float8_e4m3fn:
|
||||||
|
x_float = x.to(torch.float32)
|
||||||
|
else:
|
||||||
|
x_float = x
|
||||||
|
|
||||||
|
# Apply input scaling
|
||||||
|
if x_scale is not None:
|
||||||
|
# Reshape x_scale for broadcasting
|
||||||
|
new_shape = list(x_scale.shape) + [1] * (x_float.dim() - x_scale.dim())
|
||||||
|
x_float = x_float * x_scale.reshape(*new_shape)
|
||||||
|
|
||||||
|
# Dequantize weights
|
||||||
|
weight_dequant = weight_dequant_cpu(weight, weight_scale)
|
||||||
|
|
||||||
|
# Perform matrix multiplication
|
||||||
|
result = F.linear(x_float, weight_dequant)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
# MPS (Metal Performance Shaders) optimized versions for Apple Silicon
|
||||||
|
def setup_mps_kernels():
|
||||||
|
"""
|
||||||
|
Set up optimized MPS kernels if running on Apple Silicon
|
||||||
|
"""
|
||||||
|
if hasattr(torch, "mps") and torch.mps.is_available():
|
||||||
|
print("Setting up MPS optimized kernels for Apple Silicon")
|
||||||
|
# MPS already optimizes most operations automatically
|
||||||
|
# Additional optimizations could be added in the future
|
||||||
|
else:
|
||||||
|
print("MPS not available, using CPU kernels")
|
||||||
|
|
||||||
|
# Provide unified interface that selects the appropriate implementation
|
||||||
|
def get_optimized_kernels(device="cpu"):
|
||||||
|
"""
|
||||||
|
Returns optimized kernel functions based on the device
|
||||||
|
|
||||||
|
Args:
|
||||||
|
device (str): The device to optimize for ("cpu", "mps", or "cuda")
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: Dictionary of optimized kernel functions
|
||||||
|
"""
|
||||||
|
if device == "mps" and hasattr(torch, "mps") and torch.mps.is_available():
|
||||||
|
setup_mps_kernels()
|
||||||
|
# For MPS, we use the CPU implementations which will be automatically
|
||||||
|
# optimized by PyTorch's MPS backend
|
||||||
|
return {
|
||||||
|
"act_quant": act_quant_cpu,
|
||||||
|
"weight_dequant": weight_dequant_cpu,
|
||||||
|
"fp8_gemm": fp8_gemm_cpu
|
||||||
|
}
|
||||||
|
elif device == "cuda" and torch.cuda.is_available():
|
||||||
|
# For CUDA, use the original implementations
|
||||||
|
from kernel import act_quant, weight_dequant, fp8_gemm
|
||||||
|
return {
|
||||||
|
"act_quant": act_quant,
|
||||||
|
"weight_dequant": weight_dequant,
|
||||||
|
"fp8_gemm": fp8_gemm
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Default to CPU implementations
|
||||||
|
return {
|
||||||
|
"act_quant": act_quant_cpu,
|
||||||
|
"weight_dequant": weight_dequant_cpu,
|
||||||
|
"fp8_gemm": fp8_gemm_cpu
|
||||||
|
}
|
363
inference/optimized_generate.py
Normal file
363
inference/optimized_generate.py
Normal file
@ -0,0 +1,363 @@
|
|||||||
|
import os
|
||||||
|
import json
|
||||||
|
from argparse import ArgumentParser
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
from safetensors.torch import load_model
|
||||||
|
|
||||||
|
from model import Transformer, ModelArgs
|
||||||
|
|
||||||
|
|
||||||
|
def sample(logits, temperature: float = 1.0):
|
||||||
|
"""
|
||||||
|
Samples a token from the logits using temperature scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
logits (torch.Tensor): The logits tensor for token predictions.
|
||||||
|
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The sampled token.
|
||||||
|
"""
|
||||||
|
logits = logits / max(temperature, 1e-5)
|
||||||
|
probs = torch.softmax(logits, dim=-1)
|
||||||
|
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
model: Transformer,
|
||||||
|
prompt_tokens: List[List[int]],
|
||||||
|
max_new_tokens: int,
|
||||||
|
eos_id: int,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
device: str = "mps",
|
||||||
|
chunk_size: int = 512
|
||||||
|
) -> List[List[int]]:
|
||||||
|
"""
|
||||||
|
Generates new tokens based on the given prompt tokens using the specified model.
|
||||||
|
Optimized for MacBook with chunked processing and reduced memory usage.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Transformer): The transformer model used for token generation.
|
||||||
|
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
||||||
|
max_new_tokens (int): The maximum number of new tokens to generate.
|
||||||
|
eos_id (int): The end-of-sequence token ID.
|
||||||
|
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
||||||
|
device (str, optional): The device to run generation on. Defaults to "mps".
|
||||||
|
chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
||||||
|
"""
|
||||||
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
|
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
||||||
|
|
||||||
|
# Process in smaller batches for memory efficiency
|
||||||
|
batch_size = len(prompt_tokens)
|
||||||
|
if batch_size > 1:
|
||||||
|
print(f"Processing {batch_size} prompts in sequence to conserve memory...")
|
||||||
|
all_completion_tokens = []
|
||||||
|
for i in range(batch_size):
|
||||||
|
single_completion = generate(
|
||||||
|
model,
|
||||||
|
[prompt_tokens[i]],
|
||||||
|
max_new_tokens,
|
||||||
|
eos_id,
|
||||||
|
temperature,
|
||||||
|
device,
|
||||||
|
chunk_size
|
||||||
|
)
|
||||||
|
all_completion_tokens.extend(single_completion)
|
||||||
|
return all_completion_tokens
|
||||||
|
|
||||||
|
# Calculate total sequence length
|
||||||
|
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
||||||
|
|
||||||
|
# Initialize tokens tensor on appropriate device
|
||||||
|
tokens = torch.full((batch_size, total_len), -1, dtype=torch.long, device=device)
|
||||||
|
for i, t in enumerate(prompt_tokens):
|
||||||
|
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device=device)
|
||||||
|
|
||||||
|
prev_pos = 0
|
||||||
|
finished = torch.tensor([False] * batch_size, device=device)
|
||||||
|
prompt_mask = tokens != -1
|
||||||
|
|
||||||
|
# Process in chunks for lower memory usage
|
||||||
|
for cur_pos in range(min(prompt_lens), total_len):
|
||||||
|
# Use a sliding window approach for KV cache efficiency
|
||||||
|
start_pos = max(0, cur_pos - chunk_size) if cur_pos > min(prompt_lens) else prev_pos
|
||||||
|
|
||||||
|
# Clear GPU/MPS cache periodically to prevent memory fragmentation
|
||||||
|
if cur_pos % 50 == 0 and device == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
elif cur_pos % 50 == 0 and device == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
logits = model.forward(tokens[:, start_pos:cur_pos], start_pos)
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
next_token = sample(logits, temperature)
|
||||||
|
else:
|
||||||
|
next_token = logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
||||||
|
prev_pos = start_pos
|
||||||
|
|
||||||
|
if finished.all():
|
||||||
|
break
|
||||||
|
|
||||||
|
# Optional progress display for longer generations
|
||||||
|
if max_new_tokens > 100 and cur_pos % 20 == 0:
|
||||||
|
progress = (cur_pos - min(prompt_lens)) / max_new_tokens * 100
|
||||||
|
print(f"\rGenerating: {progress:.1f}% complete", end="")
|
||||||
|
|
||||||
|
if max_new_tokens > 100:
|
||||||
|
print("\rGeneration complete! ")
|
||||||
|
|
||||||
|
completion_tokens = []
|
||||||
|
for i, toks in enumerate(tokens.tolist()):
|
||||||
|
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
||||||
|
if eos_id in toks:
|
||||||
|
toks = toks[:toks.index(eos_id)]
|
||||||
|
completion_tokens.append(toks)
|
||||||
|
|
||||||
|
return completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
def get_optimal_device(force_cpu: bool = False) -> str:
|
||||||
|
"""
|
||||||
|
Determines the best available device for model inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: Device string ("cuda", "mps", or "cpu")
|
||||||
|
"""
|
||||||
|
if force_cpu:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
return "cuda"
|
||||||
|
elif hasattr(torch, "mps") and torch.mps.is_available():
|
||||||
|
return "mps" # Apple Silicon GPU
|
||||||
|
else:
|
||||||
|
return "cpu"
|
||||||
|
|
||||||
|
|
||||||
|
def optimize_model(model: Transformer, quantize: bool = True, device: str = "cpu") -> Transformer:
|
||||||
|
"""
|
||||||
|
Applies optimizations to the model for more efficient inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Transformer): The transformer model to optimize.
|
||||||
|
quantize (bool, optional): Whether to apply int8 quantization. Defaults to True.
|
||||||
|
device (str, optional): Target device. Defaults to "cpu".
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Transformer: Optimized model
|
||||||
|
"""
|
||||||
|
if quantize and device == "cpu":
|
||||||
|
# Apply dynamic quantization to linear layers
|
||||||
|
model = torch.quantization.quantize_dynamic(
|
||||||
|
model, {torch.nn.Linear}, dtype=torch.qint8
|
||||||
|
)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def main(
|
||||||
|
ckpt_path: str,
|
||||||
|
config: str,
|
||||||
|
input_file: str = "",
|
||||||
|
interactive: bool = True,
|
||||||
|
max_new_tokens: int = 100,
|
||||||
|
temperature: float = 1.0,
|
||||||
|
force_cpu: bool = False,
|
||||||
|
quantize: bool = True,
|
||||||
|
chunk_size: int = 512,
|
||||||
|
reduced_model: bool = False,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Main function to load the model and perform interactive or batch text generation.
|
||||||
|
Optimized for MacBooks with various memory/performance options.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
ckpt_path (str): Path to the model checkpoint directory.
|
||||||
|
config (str): Path to the model configuration file.
|
||||||
|
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
||||||
|
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
||||||
|
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
||||||
|
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
||||||
|
force_cpu (bool, optional): Force CPU usage even if GPU is available. Defaults to False.
|
||||||
|
quantize (bool, optional): Apply quantization when possible. Defaults to True.
|
||||||
|
chunk_size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
|
||||||
|
reduced_model (bool, optional): Load a smaller version of the model if available. Defaults to False.
|
||||||
|
"""
|
||||||
|
# Detect optimal device
|
||||||
|
device = get_optimal_device(force_cpu)
|
||||||
|
print(f"Using device: {device}")
|
||||||
|
|
||||||
|
# Set appropriate torch settings
|
||||||
|
if device == "cuda":
|
||||||
|
torch.cuda.set_device(0)
|
||||||
|
|
||||||
|
# Use bfloat16 for CUDA/MPS, float32 for CPU
|
||||||
|
if device == "cpu":
|
||||||
|
torch.set_default_dtype(torch.float32)
|
||||||
|
else:
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
|
||||||
|
torch.set_num_threads(8) # Adjust based on your CPU
|
||||||
|
torch.manual_seed(965)
|
||||||
|
|
||||||
|
# Load model configuration
|
||||||
|
with open(config) as f:
|
||||||
|
config_data = json.load(f)
|
||||||
|
|
||||||
|
# Apply optimizations to configuration for smaller/faster model
|
||||||
|
if reduced_model:
|
||||||
|
# Reduce the number of experts and heads
|
||||||
|
config_data["n_routed_experts"] = config_data.get("n_routed_experts", 64) // 2
|
||||||
|
config_data["max_seq_len"] = min(config_data.get("max_seq_len", 16384), 4096) # Reduce context size
|
||||||
|
|
||||||
|
args = ModelArgs(**config_data)
|
||||||
|
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
# Load model
|
||||||
|
with torch.device(device):
|
||||||
|
model = Transformer(args)
|
||||||
|
|
||||||
|
# Load tokenizer
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
|
|
||||||
|
# Load the appropriate checkpoint
|
||||||
|
if device != "cuda":
|
||||||
|
# For CPU/MPS, always use rank 0
|
||||||
|
checkpoint_path = os.path.join(ckpt_path, "model0-mp1.safetensors")
|
||||||
|
else:
|
||||||
|
# For CUDA, can use multiple GPUs if available
|
||||||
|
world_size = min(torch.cuda.device_count(), 1) # Limit to 1 for simpler usage
|
||||||
|
rank = 0
|
||||||
|
checkpoint_path = os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")
|
||||||
|
|
||||||
|
print(f"Loading checkpoint from {checkpoint_path}")
|
||||||
|
load_model(model, checkpoint_path)
|
||||||
|
|
||||||
|
# Apply quantization and other optimizations
|
||||||
|
model = optimize_model(model, quantize=quantize, device=device)
|
||||||
|
model.to(device)
|
||||||
|
|
||||||
|
# Generate a quick test sequence to ensure everything is working
|
||||||
|
print("Running warmup generation...")
|
||||||
|
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1., device)[0])
|
||||||
|
print("Model loaded and ready!")
|
||||||
|
|
||||||
|
if interactive:
|
||||||
|
messages = []
|
||||||
|
while True:
|
||||||
|
prompt = input(">>> ")
|
||||||
|
if prompt == "/exit":
|
||||||
|
break
|
||||||
|
elif prompt == "/clear":
|
||||||
|
messages.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": prompt})
|
||||||
|
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
# Show a waiting message for longer generations
|
||||||
|
if max_new_tokens > 50:
|
||||||
|
print("Generating response...")
|
||||||
|
|
||||||
|
completion_tokens = generate(
|
||||||
|
model,
|
||||||
|
[prompt_tokens],
|
||||||
|
max_new_tokens,
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
temperature,
|
||||||
|
device,
|
||||||
|
chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
||||||
|
print(completion)
|
||||||
|
messages.append({"role": "assistant", "content": completion})
|
||||||
|
|
||||||
|
# Clear cache after each generation to prevent memory buildup
|
||||||
|
if device == "mps":
|
||||||
|
torch.mps.empty_cache()
|
||||||
|
elif device == "cuda":
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
else:
|
||||||
|
with open(input_file) as f:
|
||||||
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
|
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
||||||
|
|
||||||
|
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
||||||
|
completion_tokens = generate(
|
||||||
|
model,
|
||||||
|
prompt_tokens,
|
||||||
|
max_new_tokens,
|
||||||
|
tokenizer.eos_token_id,
|
||||||
|
temperature,
|
||||||
|
device,
|
||||||
|
chunk_size
|
||||||
|
)
|
||||||
|
|
||||||
|
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
||||||
|
for prompt, completion in zip(prompts, completions):
|
||||||
|
print("Prompt:", prompt)
|
||||||
|
print("Completion:", completion)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
"""
|
||||||
|
Command-line interface for optimized text generation on MacBooks.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
--ckpt-path (str): Path to the model checkpoint directory.
|
||||||
|
--config (str): Path to the model configuration file.
|
||||||
|
--input-file (str, optional): File containing prompts for batch processing.
|
||||||
|
--interactive (bool, optional): Enable interactive mode for generating text.
|
||||||
|
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
||||||
|
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
||||||
|
--force-cpu (bool, optional): Force CPU usage even if GPU is available.
|
||||||
|
--no-quantize (bool, optional): Disable quantization (higher quality but slower).
|
||||||
|
--chunk-size (int, optional): Size of processing chunks for memory efficiency. Defaults to 512.
|
||||||
|
--reduced-model (bool, optional): Load a smaller version of the model if available.
|
||||||
|
"""
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||||
|
parser.add_argument("--config", type=str, required=True)
|
||||||
|
parser.add_argument("--input-file", type=str, default="")
|
||||||
|
parser.add_argument("--interactive", action="store_true")
|
||||||
|
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||||
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
|
parser.add_argument("--force-cpu", action="store_true")
|
||||||
|
parser.add_argument("--no-quantize", action="store_true")
|
||||||
|
parser.add_argument("--chunk-size", type=int, default=512)
|
||||||
|
parser.add_argument("--reduced-model", action="store_true")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
||||||
|
|
||||||
|
main(
|
||||||
|
args.ckpt_path,
|
||||||
|
args.config,
|
||||||
|
args.input_file,
|
||||||
|
args.interactive,
|
||||||
|
args.max_new_tokens,
|
||||||
|
args.temperature,
|
||||||
|
args.force_cpu,
|
||||||
|
not args.no_quantize,
|
||||||
|
args.chunk_size,
|
||||||
|
args.reduced_model
|
||||||
|
)
|
6
inference/requirements_macbook.txt
Normal file
6
inference/requirements_macbook.txt
Normal file
@ -0,0 +1,6 @@
|
|||||||
|
torch>=2.1.0
|
||||||
|
transformers==4.46.3
|
||||||
|
safetensors==0.4.5
|
||||||
|
tqdm
|
||||||
|
numpy
|
||||||
|
sentencepiece
|
Loading…
Reference in New Issue
Block a user