mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Merge 6bb22e0c15
into 88d6547df2
This commit is contained in:
commit
7c6911dce9
@ -2,13 +2,19 @@ import os
|
|||||||
import shutil
|
import shutil
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from tqdm import tqdm, trange
|
from pathlib import Path
|
||||||
|
from typing import Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from safetensors.torch import safe_open, save_file
|
from safetensors.torch import safe_open, save_file
|
||||||
|
from tqdm import tqdm, trange
|
||||||
|
|
||||||
|
# Constants and type definitions
|
||||||
|
TensorMapping = Dict[str, Tuple[str, Optional[int]]]
|
||||||
|
StateDict = Dict[str, torch.Tensor]
|
||||||
|
|
||||||
mapping = {
|
# Define mapping as a constant at module level
|
||||||
|
TENSOR_MAPPING: TensorMapping = {
|
||||||
"embed_tokens": ("embed", 0),
|
"embed_tokens": ("embed", 0),
|
||||||
"input_layernorm": ("attn_norm", None),
|
"input_layernorm": ("attn_norm", None),
|
||||||
"post_attention_layernorm": ("ffn_norm", None),
|
"post_attention_layernorm": ("ffn_norm", None),
|
||||||
@ -29,68 +35,144 @@ mapping = {
|
|||||||
"scale": ("scale", None),
|
"scale": ("scale", None),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def process_tensor_name(name: str) -> str:
|
||||||
def main(hf_ckpt_path, save_path, n_experts, mp):
|
|
||||||
"""
|
"""
|
||||||
Converts and saves model checkpoint files into a specified format.
|
Process tensor name by removing prefixes and replacing common patterns.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
|
name: Original tensor name
|
||||||
save_path (str): Path to the directory where the converted checkpoint files will be saved.
|
|
||||||
n_experts (int): Total number of experts in the model.
|
|
||||||
mp (int): Model parallelism factor.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
Processed tensor name
|
||||||
"""
|
"""
|
||||||
|
if name.startswith("model."):
|
||||||
|
name = name[len("model."):]
|
||||||
|
|
||||||
|
replacements = {
|
||||||
|
"self_attn": "attn",
|
||||||
|
"mlp": "ffn",
|
||||||
|
"weight_scale_inv": "scale",
|
||||||
|
"e_score_correction_bias": "bias"
|
||||||
|
}
|
||||||
|
|
||||||
|
for old, new in replacements.items():
|
||||||
|
name = name.replace(old, new)
|
||||||
|
|
||||||
|
return name
|
||||||
|
|
||||||
|
def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Shard a tensor along specified dimension for model parallelism.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
param: Input tensor to shard
|
||||||
|
mp_idx: Index of current model parallel rank
|
||||||
|
mp_count: Total number of model parallel ranks
|
||||||
|
dim: Dimension along which to shard
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Sharded tensor slice
|
||||||
|
"""
|
||||||
|
if param.size(dim) % mp_count != 0:
|
||||||
|
raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}")
|
||||||
|
|
||||||
|
shard_size = param.size(dim) // mp_count
|
||||||
|
return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous()
|
||||||
|
|
||||||
|
def convert_checkpoint(
|
||||||
|
hf_ckpt_path: Union[str, Path],
|
||||||
|
save_path: Union[str, Path],
|
||||||
|
n_experts: int,
|
||||||
|
mp: int
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Convert and save model checkpoint files into a specified format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
hf_ckpt_path: Path to input checkpoint directory
|
||||||
|
save_path: Path to output directory for converted checkpoints
|
||||||
|
n_experts: Total number of experts in model
|
||||||
|
mp: Model parallelism factor
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If n_experts is not divisible by mp
|
||||||
|
FileNotFoundError: If input path doesn't exist or contain safetensors
|
||||||
|
"""
|
||||||
|
if n_experts % mp != 0:
|
||||||
|
raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})")
|
||||||
|
|
||||||
|
hf_ckpt_path = Path(hf_ckpt_path)
|
||||||
|
save_path = Path(save_path)
|
||||||
|
|
||||||
|
if not hf_ckpt_path.exists():
|
||||||
|
raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist")
|
||||||
|
|
||||||
|
safetensor_files = list(hf_ckpt_path.glob("*.safetensors"))
|
||||||
|
if not safetensor_files:
|
||||||
|
raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}")
|
||||||
|
|
||||||
torch.set_num_threads(8)
|
torch.set_num_threads(8)
|
||||||
n_local_experts = n_experts // mp
|
n_local_experts = n_experts // mp
|
||||||
state_dicts = [{} for _ in range(mp)]
|
state_dicts: List[StateDict] = [{} for _ in range(mp)]
|
||||||
|
|
||||||
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
|
# Process each checkpoint file
|
||||||
|
for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"):
|
||||||
with safe_open(file_path, framework="pt", device="cpu") as f:
|
with safe_open(file_path, framework="pt", device="cpu") as f:
|
||||||
for name in f.keys():
|
for name in f.keys():
|
||||||
if "model.layers.61" in name:
|
if "model.layers.61" in name:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
param: torch.Tensor = f.get_tensor(name)
|
param: torch.Tensor = f.get_tensor(name)
|
||||||
if name.startswith("model."):
|
name = process_tensor_name(name)
|
||||||
name = name[len("model."):]
|
|
||||||
name = name.replace("self_attn", "attn")
|
|
||||||
name = name.replace("mlp", "ffn")
|
|
||||||
name = name.replace("weight_scale_inv", "scale")
|
|
||||||
name = name.replace("e_score_correction_bias", "bias")
|
|
||||||
key = name.split(".")[-2]
|
key = name.split(".")[-2]
|
||||||
assert key in mapping, f"Key {key} not found in mapping"
|
if key not in TENSOR_MAPPING:
|
||||||
new_key, dim = mapping[key]
|
raise ValueError(f"Unknown tensor key: {key}")
|
||||||
|
|
||||||
|
new_key, dim = TENSOR_MAPPING[key]
|
||||||
name = name.replace(key, new_key)
|
name = name.replace(key, new_key)
|
||||||
|
|
||||||
|
# Distribute tensors across model parallel ranks
|
||||||
for i in range(mp):
|
for i in range(mp):
|
||||||
new_param = param
|
new_param = param
|
||||||
if "experts" in name and "shared_experts" not in name:
|
if "experts" in name and "shared_experts" not in name:
|
||||||
idx = int(name.split(".")[-3])
|
idx = int(name.split(".")[-3])
|
||||||
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
|
if not (i * n_local_experts <= idx < (i + 1) * n_local_experts):
|
||||||
continue
|
continue
|
||||||
elif dim is not None:
|
elif dim is not None:
|
||||||
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
|
new_param = shard_tensor(param, i, mp, dim)
|
||||||
shard_size = param.size(dim) // mp
|
|
||||||
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
|
|
||||||
state_dicts[i][name] = new_param
|
state_dicts[i][name] = new_param
|
||||||
|
|
||||||
os.makedirs(save_path, exist_ok=True)
|
# Save converted checkpoints
|
||||||
|
save_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
for i in trange(mp):
|
for i in trange(mp, desc="Saving converted checkpoints"):
|
||||||
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
|
output_file = save_path / f"model{i}-mp{mp}.safetensors"
|
||||||
|
save_file(state_dicts[i], str(output_file))
|
||||||
|
|
||||||
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
|
# Copy tokenizer files
|
||||||
new_file_path = os.path.join(save_path, os.path.basename(file_path))
|
for file_path in hf_ckpt_path.glob("*token*"):
|
||||||
shutil.copyfile(file_path, new_file_path)
|
shutil.copyfile(file_path, save_path / file_path.name)
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""Parse command line arguments and run the conversion."""
|
||||||
|
parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format")
|
||||||
|
parser.add_argument("--hf-ckpt-path", type=str, required=True,
|
||||||
|
help="Path to input HuggingFace checkpoint directory")
|
||||||
|
parser.add_argument("--save-path", type=str, required=True,
|
||||||
|
help="Path to output directory for converted checkpoints")
|
||||||
|
parser.add_argument("--n-experts", type=int, required=True,
|
||||||
|
help="Total number of experts in the model")
|
||||||
|
parser.add_argument("--model-parallel", type=int, required=True,
|
||||||
|
help="Model parallelism factor")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error during conversion: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = ArgumentParser()
|
main()
|
||||||
parser.add_argument("--hf-ckpt-path", type=str, required=True)
|
|
||||||
parser.add_argument("--save-path", type=str, required=True)
|
|
||||||
parser.add_argument("--n-experts", type=int, required=True)
|
|
||||||
parser.add_argument("--model-parallel", type=int, required=True)
|
|
||||||
args = parser.parse_args()
|
|
||||||
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
|
|
||||||
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
|
|
@ -2,6 +2,7 @@ import os
|
|||||||
import json
|
import json
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
|
from typing import Dict, Any
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -9,98 +10,137 @@ from safetensors.torch import load_file, save_file
|
|||||||
|
|
||||||
from kernel import weight_dequant
|
from kernel import weight_dequant
|
||||||
|
|
||||||
def main(fp8_path, bf16_path):
|
|
||||||
"""
|
|
||||||
Converts FP8 weights to BF16 and saves the converted weights.
|
|
||||||
|
|
||||||
This function reads FP8 weights from the specified directory, converts them to BF16,
|
class WeightConverter:
|
||||||
and saves the converted weights to another specified directory. It also updates the
|
def __init__(self, fp8_path: str, bf16_path: str):
|
||||||
model index file to reflect the changes.
|
"""
|
||||||
|
Initialize the weight converter with input and output paths.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
|
fp8_path (str): Path to the directory containing FP8 weights
|
||||||
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
|
bf16_path (str): Path to save the converted BF16 weights
|
||||||
|
|
||||||
Raises:
|
|
||||||
KeyError: If a required scale_inv tensor is missing for a weight.
|
|
||||||
|
|
||||||
Notes:
|
|
||||||
- The function assumes that the FP8 weights are stored in safetensor files.
|
|
||||||
- The function caches loaded safetensor files to optimize memory usage.
|
|
||||||
- The function updates the model index file to remove references to scale_inv tensors.
|
|
||||||
"""
|
"""
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
self.fp8_path = fp8_path
|
||||||
os.makedirs(bf16_path, exist_ok=True)
|
self.bf16_path = bf16_path
|
||||||
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
|
self.loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
|
||||||
with open(model_index_file, "r") as f:
|
self.fp8_weight_names: list = []
|
||||||
model_index = json.load(f)
|
self.weight_map: Dict[str, str] = self._load_model_index()
|
||||||
weight_map = model_index["weight_map"]
|
|
||||||
|
|
||||||
# Cache for loaded safetensor files
|
def _load_model_index(self) -> Dict[str, str]:
|
||||||
loaded_files = {}
|
|
||||||
fp8_weight_names = []
|
|
||||||
|
|
||||||
# Helper function to get tensor from the correct file
|
|
||||||
def get_tensor(tensor_name):
|
|
||||||
"""
|
"""
|
||||||
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
|
Load the model index file.
|
||||||
|
|
||||||
Args:
|
|
||||||
tensor_name (str): The name of the tensor to retrieve.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The retrieved tensor.
|
Dict[str, str]: Weight mapping from the index file
|
||||||
|
"""
|
||||||
|
model_index_file = os.path.join(self.fp8_path, "model.safetensors.index.json")
|
||||||
|
with open(model_index_file, "r") as f:
|
||||||
|
return json.load(f)["weight_map"]
|
||||||
|
|
||||||
|
def _get_tensor(self, tensor_name: str) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Get a tensor from cache or load it from disk.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tensor_name (str): Name of the tensor to retrieve
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The requested tensor
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
KeyError: If the tensor does not exist in the safetensor file.
|
KeyError: If tensor doesn't exist in the safetensor file
|
||||||
"""
|
"""
|
||||||
file_name = weight_map[tensor_name]
|
file_name = self.weight_map[tensor_name]
|
||||||
if file_name not in loaded_files:
|
if file_name not in self.loaded_files:
|
||||||
file_path = os.path.join(fp8_path, file_name)
|
file_path = os.path.join(self.fp8_path, file_name)
|
||||||
loaded_files[file_name] = load_file(file_path, device="cuda")
|
self.loaded_files[file_name] = load_file(file_path, device="cuda")
|
||||||
return loaded_files[file_name][tensor_name]
|
return self.loaded_files[file_name][tensor_name]
|
||||||
|
|
||||||
|
def _manage_memory(self):
|
||||||
|
"""
|
||||||
|
Keep only the 2 most recently used files in memory.
|
||||||
|
"""
|
||||||
|
if len(self.loaded_files) > 2:
|
||||||
|
oldest_file = next(iter(self.loaded_files))
|
||||||
|
del self.loaded_files[oldest_file]
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
|
def _process_weight(self, weight_name: str, weight: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Process a single weight tensor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
weight_name (str): Name of the weight tensor
|
||||||
|
weight (torch.Tensor): The weight tensor to process
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Processed weight tensor
|
||||||
|
"""
|
||||||
|
if weight_name.endswith("_scale_inv"):
|
||||||
|
return None
|
||||||
|
|
||||||
|
if weight.element_size() == 1: # FP8 weight
|
||||||
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
|
try:
|
||||||
|
scale_inv = self._get_tensor(scale_inv_name)
|
||||||
|
self.fp8_weight_names.append(weight_name)
|
||||||
|
return weight_dequant(weight, scale_inv)
|
||||||
|
except KeyError:
|
||||||
|
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
||||||
|
return weight
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _save_model_index(self):
|
||||||
|
"""
|
||||||
|
Save the updated model index file.
|
||||||
|
"""
|
||||||
|
new_model_index_file = os.path.join(self.bf16_path, "model.safetensors.index.json")
|
||||||
|
for weight_name in self.fp8_weight_names:
|
||||||
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
|
if scale_inv_name in self.weight_map:
|
||||||
|
self.weight_map.pop(scale_inv_name)
|
||||||
|
|
||||||
|
with open(new_model_index_file, "w") as f:
|
||||||
|
json.dump({"metadata": {}, "weight_map": self.weight_map}, f, indent=2)
|
||||||
|
|
||||||
|
def convert(self):
|
||||||
|
"""
|
||||||
|
Convert FP8 weights to BF16 format.
|
||||||
|
"""
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
os.makedirs(self.bf16_path, exist_ok=True)
|
||||||
|
|
||||||
|
safetensor_files = sorted(glob(os.path.join(self.fp8_path, "*.safetensors")))
|
||||||
|
|
||||||
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
|
|
||||||
safetensor_files.sort()
|
|
||||||
for safetensor_file in tqdm(safetensor_files):
|
for safetensor_file in tqdm(safetensor_files):
|
||||||
file_name = os.path.basename(safetensor_file)
|
file_name = os.path.basename(safetensor_file)
|
||||||
current_state_dict = load_file(safetensor_file, device="cuda")
|
current_state_dict = load_file(safetensor_file, device="cuda")
|
||||||
loaded_files[file_name] = current_state_dict
|
self.loaded_files[file_name] = current_state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {}
|
||||||
for weight_name, weight in current_state_dict.items():
|
for weight_name, weight in current_state_dict.items():
|
||||||
if weight_name.endswith("_scale_inv"):
|
processed_weight = self._process_weight(weight_name, weight)
|
||||||
continue
|
if processed_weight is not None:
|
||||||
elif weight.element_size() == 1: # FP8 weight
|
new_state_dict[weight_name] = processed_weight
|
||||||
scale_inv_name = f"{weight_name}_scale_inv"
|
|
||||||
try:
|
|
||||||
# Get scale_inv from the correct file
|
|
||||||
scale_inv = get_tensor(scale_inv_name)
|
|
||||||
fp8_weight_names.append(weight_name)
|
|
||||||
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
|
|
||||||
except KeyError:
|
|
||||||
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
|
|
||||||
new_state_dict[weight_name] = weight
|
|
||||||
else:
|
|
||||||
new_state_dict[weight_name] = weight
|
|
||||||
|
|
||||||
new_safetensor_file = os.path.join(bf16_path, file_name)
|
new_safetensor_file = os.path.join(self.bf16_path, file_name)
|
||||||
save_file(new_state_dict, new_safetensor_file)
|
save_file(new_state_dict, new_safetensor_file)
|
||||||
|
|
||||||
# Memory management: keep only the 2 most recently used files
|
self._manage_memory()
|
||||||
if len(loaded_files) > 2:
|
|
||||||
oldest_file = next(iter(loaded_files))
|
|
||||||
del loaded_files[oldest_file]
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
# Update model index
|
self._save_model_index()
|
||||||
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
|
|
||||||
for weight_name in fp8_weight_names:
|
|
||||||
scale_inv_name = f"{weight_name}_scale_inv"
|
def main(fp8_path: str, bf16_path: str):
|
||||||
if scale_inv_name in weight_map:
|
"""
|
||||||
weight_map.pop(scale_inv_name)
|
Main function to convert FP8 weights to BF16.
|
||||||
with open(new_model_index_file, "w") as f:
|
|
||||||
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
|
Args:
|
||||||
|
fp8_path (str): Input directory containing FP8 weights
|
||||||
|
bf16_path (str): Output directory for BF16 weights
|
||||||
|
"""
|
||||||
|
converter = WeightConverter(fp8_path, bf16_path)
|
||||||
|
converter.convert()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
@ -109,4 +149,3 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
main(args.input_fp8_hf_path, args.output_bf16_hf_path)
|
||||||
|
|
||||||
|
@ -1,7 +1,8 @@
|
|||||||
import os
|
import os
|
||||||
import json
|
import json
|
||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from typing import List
|
from typing import List, Optional, Dict, Any, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -11,13 +12,22 @@ from safetensors.torch import load_model
|
|||||||
from model import Transformer, ModelArgs
|
from model import Transformer, ModelArgs
|
||||||
|
|
||||||
|
|
||||||
def sample(logits, temperature: float = 1.0):
|
@dataclass
|
||||||
|
class GenerationConfig:
|
||||||
|
max_new_tokens: int
|
||||||
|
temperature: float
|
||||||
|
eos_id: int
|
||||||
|
|
||||||
|
|
||||||
|
class TokenSampler:
|
||||||
|
@staticmethod
|
||||||
|
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Samples a token from the logits using temperature scaling.
|
Samples a token from the logits using temperature scaling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (torch.Tensor): The logits tensor for token predictions.
|
logits (torch.Tensor): The logits tensor for token predictions.
|
||||||
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
|
temperature (float): Temperature for scaling logits.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The sampled token.
|
torch.Tensor: The sampled token.
|
||||||
@ -27,57 +37,235 @@ def sample(logits, temperature: float = 1.0):
|
|||||||
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
|
||||||
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
class TextGenerator:
|
||||||
def generate(
|
def __init__(self, model: Transformer, tokenizer: Any):
|
||||||
model: Transformer,
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
prompt_tokens: List[List[int]],
|
prompt_tokens: List[List[int]],
|
||||||
max_new_tokens: int,
|
config: GenerationConfig
|
||||||
eos_id: int,
|
) -> List[List[int]]:
|
||||||
temperature: float = 1.0
|
|
||||||
) -> List[List[int]]:
|
|
||||||
"""
|
"""
|
||||||
Generates new tokens based on the given prompt tokens using the specified model.
|
Generates new tokens based on the given prompt tokens.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model (Transformer): The transformer model used for token generation.
|
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
|
||||||
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
|
config: Generation configuration parameters.
|
||||||
max_new_tokens (int): The maximum number of new tokens to generate.
|
|
||||||
eos_id (int): The end-of-sequence token ID.
|
|
||||||
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
List[List[int]]: A list of lists containing the generated tokens for each sequence.
|
List[List[int]]: Generated tokens for each sequence.
|
||||||
"""
|
"""
|
||||||
prompt_lens = [len(t) for t in prompt_tokens]
|
prompt_lens = [len(t) for t in prompt_tokens]
|
||||||
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
|
if max(prompt_lens) > self.model.max_seq_len:
|
||||||
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
|
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
|
||||||
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
|
|
||||||
|
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
|
||||||
|
tokens = self._initialize_tokens(prompt_tokens, total_len)
|
||||||
|
|
||||||
|
completion_tokens = self._generate_tokens(
|
||||||
|
tokens, prompt_lens, total_len, config
|
||||||
|
)
|
||||||
|
return completion_tokens
|
||||||
|
|
||||||
|
def _initialize_tokens(
|
||||||
|
self, prompt_tokens: List[List[int]], total_len: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
tokens = torch.full(
|
||||||
|
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
|
||||||
|
)
|
||||||
for i, t in enumerate(prompt_tokens):
|
for i, t in enumerate(prompt_tokens):
|
||||||
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
def _generate_tokens(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
prompt_lens: List[int],
|
||||||
|
total_len: int,
|
||||||
|
config: GenerationConfig
|
||||||
|
) -> List[List[int]]:
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
|
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
|
||||||
prompt_mask = tokens != -1
|
prompt_mask = tokens != -1
|
||||||
|
|
||||||
for cur_pos in range(min(prompt_lens), total_len):
|
for cur_pos in range(min(prompt_lens), total_len):
|
||||||
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
if temperature > 0:
|
next_token = self._get_next_token(logits, config.temperature)
|
||||||
next_token = sample(logits, temperature)
|
next_token = torch.where(
|
||||||
else:
|
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
|
||||||
next_token = logits.argmax(dim=-1)
|
)
|
||||||
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
|
||||||
tokens[:, cur_pos] = next_token
|
tokens[:, cur_pos] = next_token
|
||||||
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
|
finished |= torch.logical_and(
|
||||||
|
~prompt_mask[:, cur_pos], next_token == config.eos_id
|
||||||
|
)
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
|
|
||||||
if finished.all():
|
if finished.all():
|
||||||
break
|
break
|
||||||
|
|
||||||
|
return self._process_completion_tokens(
|
||||||
|
tokens, prompt_lens, config.max_new_tokens, config.eos_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_next_token(
|
||||||
|
self, logits: torch.Tensor, temperature: float
|
||||||
|
) -> torch.Tensor:
|
||||||
|
if temperature > 0:
|
||||||
|
return TokenSampler.sample(logits, temperature)
|
||||||
|
return logits.argmax(dim=-1)
|
||||||
|
|
||||||
|
def _process_completion_tokens(
|
||||||
|
self,
|
||||||
|
tokens: torch.Tensor,
|
||||||
|
prompt_lens: List[int],
|
||||||
|
max_new_tokens: int,
|
||||||
|
eos_id: int
|
||||||
|
) -> List[List[int]]:
|
||||||
completion_tokens = []
|
completion_tokens = []
|
||||||
for i, toks in enumerate(tokens.tolist()):
|
for i, toks in enumerate(tokens.tolist()):
|
||||||
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
|
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
|
||||||
if eos_id in toks:
|
if eos_id in toks:
|
||||||
toks = toks[:toks.index(eos_id)]
|
toks = toks[:toks.index(eos_id)]
|
||||||
completion_tokens.append(toks)
|
completion_tokens.append(toks)
|
||||||
return completion_tokens
|
return completion_tokens
|
||||||
|
|
||||||
|
|
||||||
|
class DistributedEnvironment:
|
||||||
|
def __init__(self):
|
||||||
|
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
|
||||||
|
self.rank = int(os.getenv("RANK", "0"))
|
||||||
|
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
||||||
|
|
||||||
|
def setup(self):
|
||||||
|
if self.world_size > 1:
|
||||||
|
dist.init_process_group("nccl")
|
||||||
|
if self.rank != 0:
|
||||||
|
global print
|
||||||
|
print = lambda *_, **__: None
|
||||||
|
torch.cuda.set_device(self.local_rank)
|
||||||
|
|
||||||
|
def cleanup(self):
|
||||||
|
if self.world_size > 1:
|
||||||
|
dist.destroy_process_group()
|
||||||
|
|
||||||
|
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
|
||||||
|
if self.world_size == 1:
|
||||||
|
return input(">>> ")
|
||||||
|
elif self.rank == 0:
|
||||||
|
prompt = input(">>> ")
|
||||||
|
objects = [prompt]
|
||||||
|
dist.broadcast_object_list(objects, 0)
|
||||||
|
return prompt
|
||||||
|
else:
|
||||||
|
objects = [None]
|
||||||
|
dist.broadcast_object_list(objects, 0)
|
||||||
|
return objects[0]
|
||||||
|
|
||||||
|
|
||||||
|
class ChatSession:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
generator: TextGenerator,
|
||||||
|
config: GenerationConfig,
|
||||||
|
dist_env: DistributedEnvironment
|
||||||
|
):
|
||||||
|
self.generator = generator
|
||||||
|
self.config = config
|
||||||
|
self.dist_env = dist_env
|
||||||
|
self.messages = []
|
||||||
|
|
||||||
|
def run_interactive(self):
|
||||||
|
while True:
|
||||||
|
prompt = self.dist_env.broadcast_prompt()
|
||||||
|
if prompt == "/exit":
|
||||||
|
break
|
||||||
|
elif prompt == "/clear":
|
||||||
|
self.messages.clear()
|
||||||
|
continue
|
||||||
|
|
||||||
|
completion = self._process_message(prompt)
|
||||||
|
print(completion)
|
||||||
|
self.messages.append({"role": "assistant", "content": completion})
|
||||||
|
|
||||||
|
def run_batch(self, input_file: str):
|
||||||
|
with open(input_file) as f:
|
||||||
|
prompts = [line.strip() for line in f.readlines()]
|
||||||
|
|
||||||
|
if len(prompts) > self.generator.model.args.max_batch_size:
|
||||||
|
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
|
||||||
|
|
||||||
|
completions = self._process_batch(prompts)
|
||||||
|
for prompt, completion in zip(prompts, completions):
|
||||||
|
print("Prompt:", prompt)
|
||||||
|
print("Completion:", completion)
|
||||||
|
print()
|
||||||
|
|
||||||
|
def _process_message(self, prompt: str) -> str:
|
||||||
|
self.messages.append({"role": "user", "content": prompt})
|
||||||
|
prompt_tokens = self.generator.tokenizer.apply_chat_template(
|
||||||
|
self.messages, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
completion_tokens = self.generator.generate(
|
||||||
|
[prompt_tokens], self.config
|
||||||
|
)
|
||||||
|
return self.generator.tokenizer.decode(
|
||||||
|
completion_tokens[0], skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
def _process_batch(self, prompts: List[str]) -> List[str]:
|
||||||
|
prompt_tokens = [
|
||||||
|
self.generator.tokenizer.apply_chat_template(
|
||||||
|
[{"role": "user", "content": prompt}],
|
||||||
|
add_generation_prompt=True
|
||||||
|
)
|
||||||
|
for prompt in prompts
|
||||||
|
]
|
||||||
|
completion_tokens = self.generator.generate(
|
||||||
|
prompt_tokens, self.config
|
||||||
|
)
|
||||||
|
return self.generator.tokenizer.batch_decode(
|
||||||
|
completion_tokens, skip_special_tokens=True
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def initialize_model(
|
||||||
|
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
|
||||||
|
) -> Tuple[Transformer, Any]:
|
||||||
|
"""Initialize the model and tokenizer."""
|
||||||
|
torch.set_default_dtype(torch.bfloat16)
|
||||||
|
torch.set_num_threads(8)
|
||||||
|
torch.manual_seed(965)
|
||||||
|
|
||||||
|
with open(config_path) as f:
|
||||||
|
args = ModelArgs(**json.load(f))
|
||||||
|
print(args)
|
||||||
|
|
||||||
|
with torch.device("cuda"):
|
||||||
|
model = Transformer(args)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
||||||
|
|
||||||
|
# Warmup
|
||||||
|
tokenizer.decode(
|
||||||
|
TextGenerator(model, tokenizer).generate(
|
||||||
|
[tokenizer.encode("DeepSeek")],
|
||||||
|
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
|
||||||
|
)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
load_model(
|
||||||
|
model,
|
||||||
|
os.path.join(
|
||||||
|
ckpt_path,
|
||||||
|
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
def main(
|
def main(
|
||||||
ckpt_path: str,
|
ckpt_path: str,
|
||||||
config: str,
|
config: str,
|
||||||
@ -86,94 +274,29 @@ def main(
|
|||||||
max_new_tokens: int = 100,
|
max_new_tokens: int = 100,
|
||||||
temperature: float = 1.0,
|
temperature: float = 1.0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
dist_env = DistributedEnvironment()
|
||||||
Main function to load the model and perform interactive or batch text generation.
|
dist_env.setup()
|
||||||
|
|
||||||
Args:
|
model, tokenizer = initialize_model(ckpt_path, config, dist_env)
|
||||||
ckpt_path (str): Path to the model checkpoint directory.
|
generator = TextGenerator(model, tokenizer)
|
||||||
config (str): Path to the model configuration file.
|
gen_config = GenerationConfig(
|
||||||
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
|
max_new_tokens=max_new_tokens,
|
||||||
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
|
temperature=temperature,
|
||||||
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
|
eos_id=tokenizer.eos_token_id
|
||||||
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
|
)
|
||||||
"""
|
|
||||||
world_size = int(os.getenv("WORLD_SIZE", "1"))
|
session = ChatSession(generator, gen_config, dist_env)
|
||||||
rank = int(os.getenv("RANK", "0"))
|
|
||||||
local_rank = int(os.getenv("LOCAL_RANK", "0"))
|
|
||||||
if world_size > 1:
|
|
||||||
dist.init_process_group("nccl")
|
|
||||||
global print
|
|
||||||
if rank != 0:
|
|
||||||
print = lambda *_, **__: None
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
torch.set_num_threads(8)
|
|
||||||
torch.manual_seed(965)
|
|
||||||
with open(config) as f:
|
|
||||||
args = ModelArgs(**json.load(f))
|
|
||||||
print(args)
|
|
||||||
with torch.device("cuda"):
|
|
||||||
model = Transformer(args)
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
|
|
||||||
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
|
|
||||||
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
|
|
||||||
|
|
||||||
if interactive:
|
if interactive:
|
||||||
messages = []
|
session.run_interactive()
|
||||||
while True:
|
|
||||||
if world_size == 1:
|
|
||||||
prompt = input(">>> ")
|
|
||||||
elif rank == 0:
|
|
||||||
prompt = input(">>> ")
|
|
||||||
objects = [prompt]
|
|
||||||
dist.broadcast_object_list(objects, 0)
|
|
||||||
else:
|
else:
|
||||||
objects = [None]
|
session.run_batch(input_file)
|
||||||
dist.broadcast_object_list(objects, 0)
|
|
||||||
prompt = objects[0]
|
|
||||||
if prompt == "/exit":
|
|
||||||
break
|
|
||||||
elif prompt == "/clear":
|
|
||||||
messages.clear()
|
|
||||||
continue
|
|
||||||
messages.append({"role": "user", "content": prompt})
|
|
||||||
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
|
|
||||||
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
|
|
||||||
print(completion)
|
|
||||||
messages.append({"role": "assistant", "content": completion})
|
|
||||||
else:
|
|
||||||
with open(input_file) as f:
|
|
||||||
prompts = [line.strip() for line in f.readlines()]
|
|
||||||
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
|
|
||||||
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
|
|
||||||
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
|
|
||||||
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
|
|
||||||
for prompt, completion in zip(prompts, completions):
|
|
||||||
print("Prompt:", prompt)
|
|
||||||
print("Completion:", completion)
|
|
||||||
print()
|
|
||||||
|
|
||||||
if world_size > 1:
|
dist_env.cleanup()
|
||||||
dist.destroy_process_group()
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
"""
|
parser = ArgumentParser(description="Distributed text generation system")
|
||||||
Command-line interface for distributed text generation.
|
|
||||||
|
|
||||||
Arguments:
|
|
||||||
--ckpt-path (str): Path to the model checkpoint directory.
|
|
||||||
--config (str): Path to the model configuration file.
|
|
||||||
--input-file (str, optional): File containing prompts for batch processing.
|
|
||||||
--interactive (bool, optional): Enable interactive mode for generating text.
|
|
||||||
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
|
|
||||||
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
AssertionError: If neither input-file nor interactive mode is specified.
|
|
||||||
"""
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("--ckpt-path", type=str, required=True)
|
parser.add_argument("--ckpt-path", type=str, required=True)
|
||||||
parser.add_argument("--config", type=str, required=True)
|
parser.add_argument("--config", type=str, required=True)
|
||||||
parser.add_argument("--input-file", type=str, default="")
|
parser.add_argument("--input-file", type=str, default="")
|
||||||
@ -181,5 +304,15 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--max-new-tokens", type=int, default=200)
|
parser.add_argument("--max-new-tokens", type=int, default=200)
|
||||||
parser.add_argument("--temperature", type=float, default=0.2)
|
parser.add_argument("--temperature", type=float, default=0.2)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
|
|
||||||
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
|
if not args.input_file and not args.interactive:
|
||||||
|
raise ValueError("Either input-file or interactive mode must be specified")
|
||||||
|
|
||||||
|
main(
|
||||||
|
args.ckpt_path,
|
||||||
|
args.config,
|
||||||
|
args.input_file,
|
||||||
|
args.interactive,
|
||||||
|
args.max_new_tokens,
|
||||||
|
args.temperature
|
||||||
|
)
|
@ -1,4 +1,5 @@
|
|||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import triton
|
import triton
|
||||||
@ -6,19 +7,29 @@ import triton.language as tl
|
|||||||
from triton import Config
|
from triton import Config
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
@dataclass
|
||||||
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
class BlockConfig:
|
||||||
|
"""Configuration for block sizes in tensor operations."""
|
||||||
|
size: int = 128
|
||||||
|
size_m: int = 64
|
||||||
|
size_n: int = 64
|
||||||
|
size_k: int = 128
|
||||||
|
|
||||||
|
|
||||||
|
class QuantizationKernels:
|
||||||
|
"""Collection of Triton kernels for quantization operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@triton.jit
|
||||||
|
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
|
Quantizes activation values using block-wise scaling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x_ptr (triton.Pointer): Pointer to the input tensor.
|
x_ptr: Input tensor pointer
|
||||||
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
|
y_ptr: Output quantized tensor pointer
|
||||||
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
|
s_ptr: Output scaling factors pointer
|
||||||
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
|
BLOCK_SIZE: Size of processing block
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
pid = tl.program_id(axis=0)
|
pid = tl.program_id(axis=0)
|
||||||
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||||
@ -29,44 +40,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
|
|||||||
tl.store(y_ptr + offs, y)
|
tl.store(y_ptr + offs, y)
|
||||||
tl.store(s_ptr + pid, s)
|
tl.store(s_ptr + pid, s)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
@triton.jit
|
||||||
|
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
||||||
"""
|
"""
|
||||||
Quantizes the input tensor `x` using block-wise quantization.
|
Dequantizes weights using block-wise scaling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
|
x_ptr: Quantized weights pointer
|
||||||
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
|
s_ptr: Scaling factors pointer
|
||||||
|
y_ptr: Output dequantized tensor pointer
|
||||||
Returns:
|
M: Number of rows
|
||||||
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
N: Number of columns
|
||||||
- The quantized tensor with dtype `torch.float8_e4m3fn`.
|
BLOCK_SIZE: Size of processing block
|
||||||
- A tensor of scaling factors with dtype `torch.float32`.
|
|
||||||
"""
|
|
||||||
assert x.is_contiguous(), 'Input tensor must be contiguous'
|
|
||||||
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
|
|
||||||
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
|
||||||
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
|
||||||
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
|
|
||||||
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
|
||||||
return y, s
|
|
||||||
|
|
||||||
|
|
||||||
@triton.jit
|
|
||||||
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|
||||||
"""
|
|
||||||
Dequantizes weights using the provided scaling factors and stores the result.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
x_ptr (tl.pointer): Pointer to the quantized weights.
|
|
||||||
s_ptr (tl.pointer): Pointer to the scaling factors.
|
|
||||||
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
|
|
||||||
M (int): Number of rows in the weight matrix.
|
|
||||||
N (int): Number of columns in the weight matrix.
|
|
||||||
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
pid_n = tl.program_id(axis=1)
|
pid_n = tl.program_id(axis=1)
|
||||||
@ -81,84 +67,80 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
|
|||||||
tl.store(y_ptr + offs, y, mask=mask)
|
tl.store(y_ptr + offs, y, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
class MatrixMultKernels:
|
||||||
"""
|
"""Collection of Triton kernels for matrix multiplication operations."""
|
||||||
Dequantizes the given weight tensor using the provided scale tensor.
|
|
||||||
|
|
||||||
Args:
|
@staticmethod
|
||||||
x (torch.Tensor): The quantized weight tensor of shape (M, N).
|
def get_configs():
|
||||||
s (torch.Tensor): The scale tensor of shape (M, N).
|
"""Generate configurations for FP8 GEMM autotuning."""
|
||||||
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
|
return [
|
||||||
|
Config({
|
||||||
|
'BLOCK_SIZE_M': block_m,
|
||||||
|
'BLOCK_SIZE_N': block_n,
|
||||||
|
'BLOCK_SIZE_K': 128
|
||||||
|
}, num_stages=num_stages, num_warps=8)
|
||||||
|
for block_m in [16, 32, 64]
|
||||||
|
for block_n in [32, 64, 128]
|
||||||
|
for num_stages in [3, 4, 5, 6]
|
||||||
|
]
|
||||||
|
|
||||||
Returns:
|
@staticmethod
|
||||||
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
|
@triton.autotune(configs=get_configs(), key=['N', 'K'])
|
||||||
|
@triton.jit
|
||||||
Raises:
|
def fp8_gemm_kernel(
|
||||||
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
|
a_ptr, b_ptr, c_ptr,
|
||||||
"""
|
|
||||||
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
|
|
||||||
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
|
|
||||||
M, N = x.size()
|
|
||||||
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
|
||||||
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
|
|
||||||
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
|
||||||
return y
|
|
||||||
|
|
||||||
|
|
||||||
fp8_gemm_configs = [
|
|
||||||
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
|
|
||||||
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
|
|
||||||
]
|
|
||||||
|
|
||||||
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
|
|
||||||
@triton.jit
|
|
||||||
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|
||||||
a_s_ptr, b_s_ptr,
|
a_s_ptr, b_s_ptr,
|
||||||
M, N: tl.constexpr, K: tl.constexpr,
|
M, N: tl.constexpr, K: tl.constexpr,
|
||||||
BLOCK_SIZE_M: tl.constexpr,
|
BLOCK_SIZE_M: tl.constexpr,
|
||||||
BLOCK_SIZE_N: tl.constexpr,
|
BLOCK_SIZE_N: tl.constexpr,
|
||||||
BLOCK_SIZE_K: tl.constexpr):
|
BLOCK_SIZE_K: tl.constexpr
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
|
Performs FP8 matrix multiplication with scaling factors.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a_ptr (tl.tensor): Pointer to the first input matrix A.
|
a_ptr: First input matrix pointer
|
||||||
b_ptr (tl.tensor): Pointer to the second input matrix B.
|
b_ptr: Second input matrix pointer
|
||||||
c_ptr (tl.tensor): Pointer to the output matrix C.
|
c_ptr: Output matrix pointer
|
||||||
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
|
a_s_ptr: First matrix scaling factors pointer
|
||||||
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
|
b_s_ptr: Second matrix scaling factors pointer
|
||||||
M (int): Number of rows in matrix A and C.
|
M: First matrix rows
|
||||||
N (tl.constexpr): Number of columns in matrix B and C.
|
N: Second matrix columns
|
||||||
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
|
K: Inner dimension
|
||||||
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
|
BLOCK_SIZE_M/N/K: Block sizes for tiling
|
||||||
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
|
|
||||||
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
None
|
|
||||||
"""
|
"""
|
||||||
pid_m = tl.program_id(axis=0)
|
pid_m = tl.program_id(axis=0)
|
||||||
pid_n = tl.program_id(axis=1)
|
pid_n = tl.program_id(axis=1)
|
||||||
k = tl.cdiv(K, BLOCK_SIZE_K)
|
k = tl.cdiv(K, BLOCK_SIZE_K)
|
||||||
|
|
||||||
|
# Calculate offsets
|
||||||
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
|
||||||
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
|
||||||
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
||||||
|
|
||||||
|
# Initialize pointers
|
||||||
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
|
||||||
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
|
||||||
a_s_ptrs = a_s_ptr + offs_m * k
|
a_s_ptrs = a_s_ptr + offs_m * k
|
||||||
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
|
||||||
|
|
||||||
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
||||||
|
|
||||||
|
# Main computation loop
|
||||||
for i in range(k):
|
for i in range(k):
|
||||||
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||||
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
|
||||||
a_s = tl.load(a_s_ptrs)
|
a_s = tl.load(a_s_ptrs)
|
||||||
b_s = tl.load(b_s_ptrs)
|
b_s = tl.load(b_s_ptrs)
|
||||||
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
|
||||||
|
|
||||||
|
# Update pointers
|
||||||
a_ptrs += BLOCK_SIZE_K
|
a_ptrs += BLOCK_SIZE_K
|
||||||
b_ptrs += BLOCK_SIZE_K
|
b_ptrs += BLOCK_SIZE_K
|
||||||
a_s_ptrs += 1
|
a_s_ptrs += 1
|
||||||
b_s_ptrs += 1
|
b_s_ptrs += 1
|
||||||
|
|
||||||
|
# Store results
|
||||||
c = accumulator.to(c_ptr.dtype.element_ty)
|
c = accumulator.to(c_ptr.dtype.element_ty)
|
||||||
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
||||||
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
||||||
@ -167,25 +149,86 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
|
|||||||
tl.store(c_ptrs, c, mask=mask)
|
tl.store(c_ptrs, c, mask=mask)
|
||||||
|
|
||||||
|
|
||||||
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
|
class TensorOps:
|
||||||
|
"""High-level interface for tensor operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
Perform a matrix multiplication using FP8 precision.
|
Quantize activations using block-wise scaling.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
a (torch.Tensor): The first input matrix, must be contiguous.
|
x: Input tensor
|
||||||
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
|
block_size: Block size for quantization
|
||||||
b (torch.Tensor): The second input matrix, must be contiguous.
|
|
||||||
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
|
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: The result of the matrix multiplication.
|
Tuple of quantized tensor and scaling factors
|
||||||
"""
|
"""
|
||||||
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
|
assert x.is_contiguous()
|
||||||
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
|
assert x.size(-1) % block_size == 0
|
||||||
|
|
||||||
|
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
|
||||||
|
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
|
||||||
|
|
||||||
|
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)
|
||||||
|
QuantizationKernels.act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
|
||||||
|
|
||||||
|
return y, s
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Dequantize weights using block-wise scaling.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
x: Quantized weight tensor
|
||||||
|
s: Scaling factors tensor
|
||||||
|
block_size: Block size for dequantization
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dequantized tensor
|
||||||
|
"""
|
||||||
|
assert x.is_contiguous() and s.is_contiguous()
|
||||||
|
assert x.dim() == 2 and s.dim() == 2
|
||||||
|
|
||||||
|
M, N = x.size()
|
||||||
|
y = torch.empty_like(x, dtype=torch.get_default_dtype())
|
||||||
|
|
||||||
|
grid = lambda meta: (
|
||||||
|
triton.cdiv(M, meta['BLOCK_SIZE']),
|
||||||
|
triton.cdiv(N, meta['BLOCK_SIZE'])
|
||||||
|
)
|
||||||
|
QuantizationKernels.weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
|
||||||
|
|
||||||
|
return y
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
|
||||||
|
"""
|
||||||
|
Perform FP8 matrix multiplication.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
a: First input matrix
|
||||||
|
a_s: First matrix scaling factors
|
||||||
|
b: Second input matrix
|
||||||
|
b_s: Second matrix scaling factors
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Result matrix
|
||||||
|
"""
|
||||||
|
assert a.is_contiguous() and b.is_contiguous()
|
||||||
|
assert a_s.is_contiguous() and b_s.is_contiguous()
|
||||||
|
|
||||||
K = a.size(-1)
|
K = a.size(-1)
|
||||||
M = a.numel() // K
|
M = a.numel() // K
|
||||||
N = b.size(0)
|
N = b.size(0)
|
||||||
|
|
||||||
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
|
||||||
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
|
|
||||||
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
grid = lambda META: (
|
||||||
|
triton.cdiv(M, META['BLOCK_SIZE_M']),
|
||||||
|
triton.cdiv(N, META['BLOCK_SIZE_N'])
|
||||||
|
)
|
||||||
|
MatrixMultKernels.fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
|
||||||
|
|
||||||
return c
|
return c
|
Loading…
Reference in New Issue
Block a user