This commit is contained in:
Pratiyank Kumar 2025-04-08 22:12:19 +08:00 committed by GitHub
commit 7c6911dce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 733 additions and 436 deletions

View File

@ -2,13 +2,19 @@ import os
import shutil
from argparse import ArgumentParser
from glob import glob
from tqdm import tqdm, trange
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import torch
from safetensors.torch import safe_open, save_file
from tqdm import tqdm, trange
# Constants and type definitions
TensorMapping = Dict[str, Tuple[str, Optional[int]]]
StateDict = Dict[str, torch.Tensor]
mapping = {
# Define mapping as a constant at module level
TENSOR_MAPPING: TensorMapping = {
"embed_tokens": ("embed", 0),
"input_layernorm": ("attn_norm", None),
"post_attention_layernorm": ("ffn_norm", None),
@ -29,68 +35,144 @@ mapping = {
"scale": ("scale", None),
}
def main(hf_ckpt_path, save_path, n_experts, mp):
def process_tensor_name(name: str) -> str:
"""
Converts and saves model checkpoint files into a specified format.
Process tensor name by removing prefixes and replacing common patterns.
Args:
hf_ckpt_path (str): Path to the directory containing the input checkpoint files.
save_path (str): Path to the directory where the converted checkpoint files will be saved.
n_experts (int): Total number of experts in the model.
mp (int): Model parallelism factor.
name: Original tensor name
Returns:
None
Processed tensor name
"""
if name.startswith("model."):
name = name[len("model."):]
replacements = {
"self_attn": "attn",
"mlp": "ffn",
"weight_scale_inv": "scale",
"e_score_correction_bias": "bias"
}
for old, new in replacements.items():
name = name.replace(old, new)
return name
def shard_tensor(param: torch.Tensor, mp_idx: int, mp_count: int, dim: int) -> torch.Tensor:
"""
Shard a tensor along specified dimension for model parallelism.
Args:
param: Input tensor to shard
mp_idx: Index of current model parallel rank
mp_count: Total number of model parallel ranks
dim: Dimension along which to shard
Returns:
Sharded tensor slice
"""
if param.size(dim) % mp_count != 0:
raise ValueError(f"Tensor size {param.size(dim)} not divisible by mp_count {mp_count}")
shard_size = param.size(dim) // mp_count
return param.narrow(dim, mp_idx * shard_size, shard_size).contiguous()
def convert_checkpoint(
hf_ckpt_path: Union[str, Path],
save_path: Union[str, Path],
n_experts: int,
mp: int
) -> None:
"""
Convert and save model checkpoint files into a specified format.
Args:
hf_ckpt_path: Path to input checkpoint directory
save_path: Path to output directory for converted checkpoints
n_experts: Total number of experts in model
mp: Model parallelism factor
Raises:
ValueError: If n_experts is not divisible by mp
FileNotFoundError: If input path doesn't exist or contain safetensors
"""
if n_experts % mp != 0:
raise ValueError(f"Number of experts ({n_experts}) must be divisible by model parallel size ({mp})")
hf_ckpt_path = Path(hf_ckpt_path)
save_path = Path(save_path)
if not hf_ckpt_path.exists():
raise FileNotFoundError(f"Checkpoint path {hf_ckpt_path} does not exist")
safetensor_files = list(hf_ckpt_path.glob("*.safetensors"))
if not safetensor_files:
raise FileNotFoundError(f"No safetensor files found in {hf_ckpt_path}")
torch.set_num_threads(8)
n_local_experts = n_experts // mp
state_dicts = [{} for _ in range(mp)]
state_dicts: List[StateDict] = [{} for _ in range(mp)]
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors"))):
# Process each checkpoint file
for file_path in tqdm(safetensor_files, desc="Processing checkpoint files"):
with safe_open(file_path, framework="pt", device="cpu") as f:
for name in f.keys():
if "model.layers.61" in name:
continue
param: torch.Tensor = f.get_tensor(name)
if name.startswith("model."):
name = name[len("model."):]
name = name.replace("self_attn", "attn")
name = name.replace("mlp", "ffn")
name = name.replace("weight_scale_inv", "scale")
name = name.replace("e_score_correction_bias", "bias")
name = process_tensor_name(name)
key = name.split(".")[-2]
assert key in mapping, f"Key {key} not found in mapping"
new_key, dim = mapping[key]
if key not in TENSOR_MAPPING:
raise ValueError(f"Unknown tensor key: {key}")
new_key, dim = TENSOR_MAPPING[key]
name = name.replace(key, new_key)
# Distribute tensors across model parallel ranks
for i in range(mp):
new_param = param
if "experts" in name and "shared_experts" not in name:
idx = int(name.split(".")[-3])
if idx < i * n_local_experts or idx >= (i + 1) * n_local_experts:
if not (i * n_local_experts <= idx < (i + 1) * n_local_experts):
continue
elif dim is not None:
assert param.size(dim) % mp == 0, f"Dimension {dim} must be divisible by {mp}"
shard_size = param.size(dim) // mp
new_param = param.narrow(dim, i * shard_size, shard_size).contiguous()
new_param = shard_tensor(param, i, mp, dim)
state_dicts[i][name] = new_param
os.makedirs(save_path, exist_ok=True)
# Save converted checkpoints
save_path.mkdir(parents=True, exist_ok=True)
for i in trange(mp):
save_file(state_dicts[i], os.path.join(save_path, f"model{i}-mp{mp}.safetensors"))
for i in trange(mp, desc="Saving converted checkpoints"):
output_file = save_path / f"model{i}-mp{mp}.safetensors"
save_file(state_dicts[i], str(output_file))
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
new_file_path = os.path.join(save_path, os.path.basename(file_path))
shutil.copyfile(file_path, new_file_path)
# Copy tokenizer files
for file_path in hf_ckpt_path.glob("*token*"):
shutil.copyfile(file_path, save_path / file_path.name)
def main():
"""Parse command line arguments and run the conversion."""
parser = ArgumentParser(description="Convert HuggingFace checkpoints to custom format")
parser.add_argument("--hf-ckpt-path", type=str, required=True,
help="Path to input HuggingFace checkpoint directory")
parser.add_argument("--save-path", type=str, required=True,
help="Path to output directory for converted checkpoints")
parser.add_argument("--n-experts", type=int, required=True,
help="Total number of experts in the model")
parser.add_argument("--model-parallel", type=int, required=True,
help="Model parallelism factor")
args = parser.parse_args()
try:
convert_checkpoint(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
except Exception as e:
print(f"Error during conversion: {str(e)}")
raise
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--hf-ckpt-path", type=str, required=True)
parser.add_argument("--save-path", type=str, required=True)
parser.add_argument("--n-experts", type=int, required=True)
parser.add_argument("--model-parallel", type=int, required=True)
args = parser.parse_args()
assert args.n_experts % args.model_parallel == 0, "Number of experts must be divisible by model parallelism"
main(args.hf_ckpt_path, args.save_path, args.n_experts, args.model_parallel)
main()

View File

@ -2,6 +2,7 @@ import os
import json
from argparse import ArgumentParser
from glob import glob
from typing import Dict, Any
from tqdm import tqdm
import torch
@ -9,98 +10,137 @@ from safetensors.torch import load_file, save_file
from kernel import weight_dequant
def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16,
and saves the converted weights to another specified directory. It also updates the
model index file to reflect the changes.
class WeightConverter:
def __init__(self, fp8_path: str, bf16_path: str):
"""
Initialize the weight converter with input and output paths.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
fp8_path (str): Path to the directory containing FP8 weights
bf16_path (str): Path to save the converted BF16 weights
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
self.fp8_path = fp8_path
self.bf16_path = bf16_path
self.loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
self.fp8_weight_names: list = []
self.weight_map: Dict[str, str] = self._load_model_index()
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
def _load_model_index(self) -> Dict[str, str]:
"""
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached.
Args:
tensor_name (str): The name of the tensor to retrieve.
Load the model index file.
Returns:
torch.Tensor: The retrieved tensor.
Dict[str, str]: Weight mapping from the index file
"""
model_index_file = os.path.join(self.fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
return json.load(f)["weight_map"]
def _get_tensor(self, tensor_name: str) -> torch.Tensor:
"""
Get a tensor from cache or load it from disk.
Args:
tensor_name (str): Name of the tensor to retrieve
Returns:
torch.Tensor: The requested tensor
Raises:
KeyError: If the tensor does not exist in the safetensor file.
KeyError: If tensor doesn't exist in the safetensor file
"""
file_name = weight_map[tensor_name]
if file_name not in loaded_files:
file_path = os.path.join(fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name]
file_name = self.weight_map[tensor_name]
if file_name not in self.loaded_files:
file_path = os.path.join(self.fp8_path, file_name)
self.loaded_files[file_name] = load_file(file_path, device="cuda")
return self.loaded_files[file_name][tensor_name]
def _manage_memory(self):
"""
Keep only the 2 most recently used files in memory.
"""
if len(self.loaded_files) > 2:
oldest_file = next(iter(self.loaded_files))
del self.loaded_files[oldest_file]
torch.cuda.empty_cache()
def _process_weight(self, weight_name: str, weight: torch.Tensor) -> torch.Tensor:
"""
Process a single weight tensor.
Args:
weight_name (str): Name of the weight tensor
weight (torch.Tensor): The weight tensor to process
Returns:
torch.Tensor: Processed weight tensor
"""
if weight_name.endswith("_scale_inv"):
return None
if weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
scale_inv = self._get_tensor(scale_inv_name)
self.fp8_weight_names.append(weight_name)
return weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
return weight
return weight
def _save_model_index(self):
"""
Save the updated model index file.
"""
new_model_index_file = os.path.join(self.bf16_path, "model.safetensors.index.json")
for weight_name in self.fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in self.weight_map:
self.weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": self.weight_map}, f, indent=2)
def convert(self):
"""
Convert FP8 weights to BF16 format.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(self.bf16_path, exist_ok=True)
safetensor_files = sorted(glob(os.path.join(self.fp8_path, "*.safetensors")))
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors")))
safetensor_files.sort()
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
loaded_files[file_name] = current_state_dict
self.loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
processed_weight = self._process_weight(weight_name, weight)
if processed_weight is not None:
new_state_dict[weight_name] = processed_weight
new_safetensor_file = os.path.join(bf16_path, file_name)
new_safetensor_file = os.path.join(self.bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache()
self._manage_memory()
# Update model index
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json")
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
self._save_model_index()
def main(fp8_path: str, bf16_path: str):
"""
Main function to convert FP8 weights to BF16.
Args:
fp8_path (str): Input directory containing FP8 weights
bf16_path (str): Output directory for BF16 weights
"""
converter = WeightConverter(fp8_path, bf16_path)
converter.convert()
if __name__ == "__main__":
@ -109,4 +149,3 @@ if __name__ == "__main__":
parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path)

View File

@ -1,7 +1,8 @@
import os
import json
from argparse import ArgumentParser
from typing import List
from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
import torch
import torch.distributed as dist
@ -11,13 +12,22 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0):
@dataclass
class GenerationConfig:
max_new_tokens: int
temperature: float
eos_id: int
class TokenSampler:
@staticmethod
def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
"""
Samples a token from the logits using temperature scaling.
Args:
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
temperature (float): Temperature for scaling logits.
Returns:
torch.Tensor: The sampled token.
@ -27,48 +37,94 @@ def sample(logits, temperature: float = 1.0):
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
class TextGenerator:
def __init__(self, model: Transformer, tokenizer: Any):
self.model = model
self.tokenizer = tokenizer
@torch.inference_mode()
def generate(
model: Transformer,
self,
prompt_tokens: List[List[int]],
max_new_tokens: int,
eos_id: int,
temperature: float = 1.0
config: GenerationConfig
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Generates new tokens based on the given prompt tokens.
Args:
model (Transformer): The transformer model used for token generation.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
config: Generation configuration parameters.
Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence.
List[List[int]]: Generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= model.max_seq_len, f"Prompt length exceeds model maximum sequence length (max_seq_len={model.max_seq_len})"
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
if max(prompt_lens) > self.model.max_seq_len:
raise ValueError(f"Prompt length exceeds model maximum sequence length (max_seq_len={self.model.max_seq_len})")
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
tokens = self._initialize_tokens(prompt_tokens, total_len)
completion_tokens = self._generate_tokens(
tokens, prompt_lens, total_len, config
)
return completion_tokens
def _initialize_tokens(
self, prompt_tokens: List[List[int]], total_len: int
) -> torch.Tensor:
tokens = torch.full(
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
return tokens
def _generate_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
total_len: int,
config: GenerationConfig
) -> List[List[int]]:
prev_pos = 0
finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
if temperature > 0:
next_token = sample(logits, temperature)
else:
next_token = logits.argmax(dim=-1)
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
next_token = self._get_next_token(logits, config.temperature)
next_token = torch.where(
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
finished |= torch.logical_and(
~prompt_mask[:, cur_pos], next_token == config.eos_id
)
prev_pos = cur_pos
if finished.all():
break
return self._process_completion_tokens(
tokens, prompt_lens, config.max_new_tokens, config.eos_id
)
def _get_next_token(
self, logits: torch.Tensor, temperature: float
) -> torch.Tensor:
if temperature > 0:
return TokenSampler.sample(logits, temperature)
return logits.argmax(dim=-1)
def _process_completion_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
max_new_tokens: int,
eos_id: int
) -> List[List[int]]:
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
@ -78,6 +134,138 @@ def generate(
return completion_tokens
class DistributedEnvironment:
def __init__(self):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.rank = int(os.getenv("RANK", "0"))
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
def setup(self):
if self.world_size > 1:
dist.init_process_group("nccl")
if self.rank != 0:
global print
print = lambda *_, **__: None
torch.cuda.set_device(self.local_rank)
def cleanup(self):
if self.world_size > 1:
dist.destroy_process_group()
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
if self.world_size == 1:
return input(">>> ")
elif self.rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
return prompt
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
return objects[0]
class ChatSession:
def __init__(
self,
generator: TextGenerator,
config: GenerationConfig,
dist_env: DistributedEnvironment
):
self.generator = generator
self.config = config
self.dist_env = dist_env
self.messages = []
def run_interactive(self):
while True:
prompt = self.dist_env.broadcast_prompt()
if prompt == "/exit":
break
elif prompt == "/clear":
self.messages.clear()
continue
completion = self._process_message(prompt)
print(completion)
self.messages.append({"role": "assistant", "content": completion})
def run_batch(self, input_file: str):
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
if len(prompts) > self.generator.model.args.max_batch_size:
raise ValueError(f"Number of prompts exceeds maximum batch size ({self.generator.model.args.max_batch_size})")
completions = self._process_batch(prompts)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
def _process_message(self, prompt: str) -> str:
self.messages.append({"role": "user", "content": prompt})
prompt_tokens = self.generator.tokenizer.apply_chat_template(
self.messages, add_generation_prompt=True
)
completion_tokens = self.generator.generate(
[prompt_tokens], self.config
)
return self.generator.tokenizer.decode(
completion_tokens[0], skip_special_tokens=True
)
def _process_batch(self, prompts: List[str]) -> List[str]:
prompt_tokens = [
self.generator.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True
)
for prompt in prompts
]
completion_tokens = self.generator.generate(
prompt_tokens, self.config
)
return self.generator.tokenizer.batch_decode(
completion_tokens, skip_special_tokens=True
)
def initialize_model(
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
) -> Tuple[Transformer, Any]:
"""Initialize the model and tokenizer."""
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config_path) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Warmup
tokenizer.decode(
TextGenerator(model, tokenizer).generate(
[tokenizer.encode("DeepSeek")],
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
)[0]
)
load_model(
model,
os.path.join(
ckpt_path,
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
)
)
return model, tokenizer
def main(
ckpt_path: str,
config: str,
@ -86,94 +274,29 @@ def main(
max_new_tokens: int = 100,
temperature: float = 1.0,
) -> None:
"""
Main function to load the model and perform interactive or batch text generation.
dist_env = DistributedEnvironment()
dist_env.setup()
Args:
ckpt_path (str): Path to the model checkpoint directory.
config (str): Path to the model configuration file.
input_file (str, optional): Path to a file containing input prompts. Defaults to "".
interactive (bool, optional): Whether to run in interactive mode. Defaults to True.
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100.
temperature (float, optional): Temperature for sampling. Defaults to 1.0.
"""
world_size = int(os.getenv("WORLD_SIZE", "1"))
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
model, tokenizer = initialize_model(ckpt_path, config, dist_env)
generator = TextGenerator(model, tokenizer)
gen_config = GenerationConfig(
max_new_tokens=max_new_tokens,
temperature=temperature,
eos_id=tokenizer.eos_token_id
)
session = ChatSession(generator, gen_config, dist_env)
if interactive:
messages = []
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
session.run_interactive()
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else:
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size, f"Number of prompts exceeds maximum batch size ({args.max_batch_size})"
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
session.run_batch(input_file)
if world_size > 1:
dist.destroy_process_group()
dist_env.cleanup()
if __name__ == "__main__":
"""
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser = ArgumentParser(description="Distributed text generation system")
parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="")
@ -181,5 +304,15 @@ if __name__ == "__main__":
parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args()
assert args.input_file or args.interactive, "Either input-file or interactive mode must be specified"
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature)
if not args.input_file and not args.interactive:
raise ValueError("Either input-file or interactive mode must be specified")
main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature
)

View File

@ -1,4 +1,5 @@
from typing import Tuple
from dataclasses import dataclass
import torch
import triton
@ -6,19 +7,29 @@ import triton.language as tl
from triton import Config
@dataclass
class BlockConfig:
"""Configuration for block sizes in tensor operations."""
size: int = 128
size_m: int = 64
size_n: int = 64
size_k: int = 128
class QuantizationKernels:
"""Collection of Triton kernels for quantization operations."""
@staticmethod
@triton.jit
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`.
Quantizes activation values using block-wise scaling.
Args:
x_ptr (triton.Pointer): Pointer to the input tensor.
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
x_ptr: Input tensor pointer
y_ptr: Output quantized tensor pointer
s_ptr: Output scaling factors pointer
BLOCK_SIZE: Size of processing block
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
@ -29,44 +40,19 @@ def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Quantizes the input tensor `x` using block-wise quantization.
Args:
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`.
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128.
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- The quantized tensor with dtype `torch.float8_e4m3fn`.
- A tensor of scaling factors with dtype `torch.float32`.
"""
assert x.is_contiguous(), 'Input tensor must be contiguous'
assert x.size(-1) % block_size == 0, f'Last dimension size must be divisible by block_size (block_size={block_size})'
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), )
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@staticmethod
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using the provided scaling factors and stores the result.
Dequantizes weights using block-wise scaling.
Args:
x_ptr (tl.pointer): Pointer to the quantized weights.
s_ptr (tl.pointer): Pointer to the scaling factors.
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights.
M (int): Number of rows in the weight matrix.
N (int): Number of columns in the weight matrix.
BLOCK_SIZE (tl.constexpr): Size of the block for tiling.
Returns:
None
x_ptr: Quantized weights pointer
s_ptr: Scaling factors pointer
y_ptr: Output dequantized tensor pointer
M: Number of rows
N: Number of columns
BLOCK_SIZE: Size of processing block
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
@ -81,84 +67,80 @@ def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
tl.store(y_ptr + offs, y, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantizes the given weight tensor using the provided scale tensor.
class MatrixMultKernels:
"""Collection of Triton kernels for matrix multiplication operations."""
Args:
x (torch.Tensor): The quantized weight tensor of shape (M, N).
s (torch.Tensor): The scale tensor of shape (M, N).
block_size (int, optional): The block size to use for dequantization. Defaults to 128.
Returns:
torch.Tensor: The dequantized weight tensor of the same shape as `x`.
Raises:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2.
"""
assert x.is_contiguous() and s.is_contiguous(), 'Input tensors must be contiguous'
assert x.dim() == 2 and s.dim() == 2, 'Input tensors must have 2 dimensions'
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
fp8_gemm_configs = [
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
@staticmethod
def get_configs():
"""Generate configurations for FP8 GEMM autotuning."""
return [
Config({
'BLOCK_SIZE_M': block_m,
'BLOCK_SIZE_N': block_n,
'BLOCK_SIZE_K': 128
}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K'])
@staticmethod
@triton.autotune(configs=get_configs(), key=['N', 'K'])
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
def fp8_gemm_kernel(
a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
BLOCK_SIZE_K: tl.constexpr
):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Performs FP8 matrix multiplication with scaling factors.
Args:
a_ptr (tl.tensor): Pointer to the first input matrix A.
b_ptr (tl.tensor): Pointer to the second input matrix B.
c_ptr (tl.tensor): Pointer to the output matrix C.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns:
None
a_ptr: First input matrix pointer
b_ptr: Second input matrix pointer
c_ptr: Output matrix pointer
a_s_ptr: First matrix scaling factors pointer
b_s_ptr: Second matrix scaling factors pointer
M: First matrix rows
N: Second matrix columns
K: Inner dimension
BLOCK_SIZE_M/N/K: Block sizes for tiling
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
# Calculate offsets
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
# Initialize pointers
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Main computation loop
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
# Update pointers
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
# Store results
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
@ -167,25 +149,86 @@ def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
tl.store(c_ptrs, c, mask=mask)
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor):
class TensorOps:
"""High-level interface for tensor operations."""
@staticmethod
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Perform a matrix multiplication using FP8 precision.
Quantize activations using block-wise scaling.
Args:
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
x: Input tensor
block_size: Block size for quantization
Returns:
torch.Tensor: The result of the matrix multiplication.
Tuple of quantized tensor and scaling factors
"""
assert a.is_contiguous() and b.is_contiguous(), 'Input tensors must be contiguous'
assert a_s.is_contiguous() and b_s.is_contiguous(), 'Scaling factor tensors must be contiguous'
assert x.is_contiguous()
assert x.size(-1) % block_size == 0
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)
QuantizationKernels.act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
return y, s
@staticmethod
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
"""
Dequantize weights using block-wise scaling.
Args:
x: Quantized weight tensor
s: Scaling factors tensor
block_size: Block size for dequantization
Returns:
Dequantized tensor
"""
assert x.is_contiguous() and s.is_contiguous()
assert x.dim() == 2 and s.dim() == 2
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (
triton.cdiv(M, meta['BLOCK_SIZE']),
triton.cdiv(N, meta['BLOCK_SIZE'])
)
QuantizationKernels.weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
@staticmethod
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
"""
Perform FP8 matrix multiplication.
Args:
a: First input matrix
a_s: First matrix scaling factors
b: Second input matrix
b_s: Second matrix scaling factors
Returns:
Result matrix
"""
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N']))
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']),
triton.cdiv(N, META['BLOCK_SIZE_N'])
)
MatrixMultKernels.fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c