From de7df86119dc7a2a5730169a4c67ee22437596f6 Mon Sep 17 00:00:00 2001 From: pratiyankkumar Date: Wed, 29 Jan 2025 10:37:09 +0530 Subject: [PATCH] Refactored the codebase by defining seperate classes for different operations and implemented better type safety --- inference/fp8_cast_bf16.py | 205 +++++++++++-------- inference/generate.py | 410 ++++++++++++++++++++++++------------- inference/kernel.py | 385 ++++++++++++++++++---------------- 3 files changed, 605 insertions(+), 395 deletions(-) diff --git a/inference/fp8_cast_bf16.py b/inference/fp8_cast_bf16.py index 4037342..966a8c9 100644 --- a/inference/fp8_cast_bf16.py +++ b/inference/fp8_cast_bf16.py @@ -2,6 +2,7 @@ import os import json from argparse import ArgumentParser from glob import glob +from typing import Dict, Any from tqdm import tqdm import torch @@ -9,104 +10,142 @@ from safetensors.torch import load_file, save_file from kernel import weight_dequant -def main(fp8_path, bf16_path): - """ - Converts FP8 weights to BF16 and saves the converted weights. - This function reads FP8 weights from the specified directory, converts them to BF16, - and saves the converted weights to another specified directory. It also updates the - model index file to reflect the changes. - - Args: - fp8_path (str): The path to the directory containing the FP8 weights and model index file. - bf16_path (str): The path to the directory where the converted BF16 weights will be saved. - - Raises: - KeyError: If a required scale_inv tensor is missing for a weight. - - Notes: - - The function assumes that the FP8 weights are stored in safetensor files. - - The function caches loaded safetensor files to optimize memory usage. - - The function updates the model index file to remove references to scale_inv tensors. - """ - torch.set_default_dtype(torch.bfloat16) - os.makedirs(bf16_path, exist_ok=True) - model_index_file = os.path.join(fp8_path, "model.safetensors.index.json") - with open(model_index_file, "r") as f: - model_index = json.load(f) - weight_map = model_index["weight_map"] - - # Cache for loaded safetensor files - loaded_files = {} - fp8_weight_names = [] - - # Helper function to get tensor from the correct file - def get_tensor(tensor_name): +class WeightConverter: + def __init__(self, fp8_path: str, bf16_path: str): """ - Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. + Initialize the weight converter with input and output paths. Args: - tensor_name (str): The name of the tensor to retrieve. + fp8_path (str): Path to the directory containing FP8 weights + bf16_path (str): Path to save the converted BF16 weights + """ + self.fp8_path = fp8_path + self.bf16_path = bf16_path + self.loaded_files: Dict[str, Dict[str, torch.Tensor]] = {} + self.fp8_weight_names: list = [] + self.weight_map: Dict[str, str] = self._load_model_index() + + def _load_model_index(self) -> Dict[str, str]: + """ + Load the model index file. Returns: - torch.Tensor: The retrieved tensor. + Dict[str, str]: Weight mapping from the index file + """ + model_index_file = os.path.join(self.fp8_path, "model.safetensors.index.json") + with open(model_index_file, "r") as f: + return json.load(f)["weight_map"] + + def _get_tensor(self, tensor_name: str) -> torch.Tensor: + """ + Get a tensor from cache or load it from disk. + + Args: + tensor_name (str): Name of the tensor to retrieve + + Returns: + torch.Tensor: The requested tensor Raises: - KeyError: If the tensor does not exist in the safetensor file. + KeyError: If tensor doesn't exist in the safetensor file """ - file_name = weight_map[tensor_name] - if file_name not in loaded_files: - file_path = os.path.join(fp8_path, file_name) - loaded_files[file_name] = load_file(file_path, device="cuda") - return loaded_files[file_name][tensor_name] + file_name = self.weight_map[tensor_name] + if file_name not in self.loaded_files: + file_path = os.path.join(self.fp8_path, file_name) + self.loaded_files[file_name] = load_file(file_path, device="cuda") + return self.loaded_files[file_name][tensor_name] - safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) - safetensor_files.sort() - for safetensor_file in tqdm(safetensor_files): - file_name = os.path.basename(safetensor_file) - current_state_dict = load_file(safetensor_file, device="cuda") - loaded_files[file_name] = current_state_dict - - new_state_dict = {} - for weight_name, weight in current_state_dict.items(): - if weight_name.endswith("_scale_inv"): - continue - elif weight.element_size() == 1: # FP8 weight - scale_inv_name = f"{weight_name}_scale_inv" - try: - # Get scale_inv from the correct file - scale_inv = get_tensor(scale_inv_name) - fp8_weight_names.append(weight_name) - new_state_dict[weight_name] = weight_dequant(weight, scale_inv) - except KeyError: - print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") - new_state_dict[weight_name] = weight - else: - new_state_dict[weight_name] = weight - - new_safetensor_file = os.path.join(bf16_path, file_name) - save_file(new_state_dict, new_safetensor_file) - - # Memory management: keep only the 2 most recently used files - if len(loaded_files) > 2: - oldest_file = next(iter(loaded_files)) - del loaded_files[oldest_file] + def _manage_memory(self): + """ + Keep only the 2 most recently used files in memory. + """ + if len(self.loaded_files) > 2: + oldest_file = next(iter(self.loaded_files)) + del self.loaded_files[oldest_file] torch.cuda.empty_cache() - - # Update model index - new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") - for weight_name in fp8_weight_names: - scale_inv_name = f"{weight_name}_scale_inv" - if scale_inv_name in weight_map: - weight_map.pop(scale_inv_name) - with open(new_model_index_file, "w") as f: - json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2) + + def _process_weight(self, weight_name: str, weight: torch.Tensor) -> torch.Tensor: + """ + Process a single weight tensor. + + Args: + weight_name (str): Name of the weight tensor + weight (torch.Tensor): The weight tensor to process + + Returns: + torch.Tensor: Processed weight tensor + """ + if weight_name.endswith("_scale_inv"): + return None + if weight.element_size() == 1: # FP8 weight + scale_inv_name = f"{weight_name}_scale_inv" + try: + scale_inv = self._get_tensor(scale_inv_name) + self.fp8_weight_names.append(weight_name) + return weight_dequant(weight, scale_inv) + except KeyError: + print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion") + return weight + return weight + + def _save_model_index(self): + """ + Save the updated model index file. + """ + new_model_index_file = os.path.join(self.bf16_path, "model.safetensors.index.json") + for weight_name in self.fp8_weight_names: + scale_inv_name = f"{weight_name}_scale_inv" + if scale_inv_name in self.weight_map: + self.weight_map.pop(scale_inv_name) + + with open(new_model_index_file, "w") as f: + json.dump({"metadata": {}, "weight_map": self.weight_map}, f, indent=2) + + def convert(self): + """ + Convert FP8 weights to BF16 format. + """ + torch.set_default_dtype(torch.bfloat16) + os.makedirs(self.bf16_path, exist_ok=True) + + safetensor_files = sorted(glob(os.path.join(self.fp8_path, "*.safetensors"))) + + for safetensor_file in tqdm(safetensor_files): + file_name = os.path.basename(safetensor_file) + current_state_dict = load_file(safetensor_file, device="cuda") + self.loaded_files[file_name] = current_state_dict + + new_state_dict = {} + for weight_name, weight in current_state_dict.items(): + processed_weight = self._process_weight(weight_name, weight) + if processed_weight is not None: + new_state_dict[weight_name] = processed_weight + + new_safetensor_file = os.path.join(self.bf16_path, file_name) + save_file(new_state_dict, new_safetensor_file) + + self._manage_memory() + + self._save_model_index() + + +def main(fp8_path: str, bf16_path: str): + """ + Main function to convert FP8 weights to BF16. + + Args: + fp8_path (str): Input directory containing FP8 weights + bf16_path (str): Output directory for BF16 weights + """ + converter = WeightConverter(fp8_path, bf16_path) + converter.convert() + if __name__ == "__main__": parser = ArgumentParser() parser.add_argument("--input-fp8-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True) args = parser.parse_args() - main(args.input_fp8_hf_path, args.output_bf16_hf_path) - + main(args.input_fp8_hf_path, args.output_bf16_hf_path) \ No newline at end of file diff --git a/inference/generate.py b/inference/generate.py index fbf3ab8..9432a5f 100644 --- a/inference/generate.py +++ b/inference/generate.py @@ -1,7 +1,8 @@ import os import json from argparse import ArgumentParser -from typing import List +from typing import List, Optional, Dict, Any, Tuple +from dataclasses import dataclass import torch import torch.distributed as dist @@ -11,71 +12,255 @@ from safetensors.torch import load_model from model import Transformer, ModelArgs -def sample(logits, temperature: float = 1.0): - """ - Samples a token from the logits using temperature scaling. - - Args: - logits (torch.Tensor): The logits tensor for token predictions. - temperature (float, optional): Temperature for scaling logits. Defaults to 1.0. - - Returns: - torch.Tensor: The sampled token. - """ - logits = logits / max(temperature, 1e-5) - probs = torch.softmax(logits, dim=-1) - return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) +@dataclass +class GenerationConfig: + max_new_tokens: int + temperature: float + eos_id: int -@torch.inference_mode() -def generate( - model: Transformer, - prompt_tokens: List[List[int]], - max_new_tokens: int, - eos_id: int, - temperature: float = 1.0 -) -> List[List[int]]: - """ - Generates new tokens based on the given prompt tokens using the specified model. +class TokenSampler: + @staticmethod + def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor: + """ + Samples a token from the logits using temperature scaling. - Args: - model (Transformer): The transformer model used for token generation. - prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. - max_new_tokens (int): The maximum number of new tokens to generate. - eos_id (int): The end-of-sequence token ID. - temperature (float, optional): The temperature value for sampling. Defaults to 1.0. + Args: + logits (torch.Tensor): The logits tensor for token predictions. + temperature (float): Temperature for scaling logits. - Returns: - List[List[int]]: A list of lists containing the generated tokens for each sequence. - """ - prompt_lens = [len(t) for t in prompt_tokens] - assert max(prompt_lens) <= model.max_seq_len - total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) - tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda") - for i, t in enumerate(prompt_tokens): - tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") - prev_pos = 0 - finished = torch.tensor([False] * len(prompt_tokens), device="cuda") - prompt_mask = tokens != -1 - for cur_pos in range(min(prompt_lens), total_len): - logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + Returns: + torch.Tensor: The sampled token. + """ + logits = logits / max(temperature, 1e-5) + probs = torch.softmax(logits, dim=-1) + return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1) + + +class TextGenerator: + def __init__(self, model: Transformer, tokenizer: Any): + self.model = model + self.tokenizer = tokenizer + + @torch.inference_mode() + def generate( + self, + prompt_tokens: List[List[int]], + config: GenerationConfig + ) -> List[List[int]]: + """ + Generates new tokens based on the given prompt tokens. + + Args: + prompt_tokens: A list of lists containing the prompt tokens for each sequence. + config: Generation configuration parameters. + + Returns: + List[List[int]]: Generated tokens for each sequence. + """ + prompt_lens = [len(t) for t in prompt_tokens] + assert max(prompt_lens) <= self.model.max_seq_len + + total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens)) + tokens = self._initialize_tokens(prompt_tokens, total_len) + + completion_tokens = self._generate_tokens( + tokens, prompt_lens, total_len, config + ) + return completion_tokens + + def _initialize_tokens( + self, prompt_tokens: List[List[int]], total_len: int + ) -> torch.Tensor: + tokens = torch.full( + (len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda" + ) + for i, t in enumerate(prompt_tokens): + tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") + return tokens + + def _generate_tokens( + self, + tokens: torch.Tensor, + prompt_lens: List[int], + total_len: int, + config: GenerationConfig + ) -> List[List[int]]: + prev_pos = 0 + finished = torch.tensor([False] * len(prompt_lens), device="cuda") + prompt_mask = tokens != -1 + + for cur_pos in range(min(prompt_lens), total_len): + logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos) + next_token = self._get_next_token(logits, config.temperature) + next_token = torch.where( + prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token + ) + + tokens[:, cur_pos] = next_token + finished |= torch.logical_and( + ~prompt_mask[:, cur_pos], next_token == config.eos_id + ) + prev_pos = cur_pos + + if finished.all(): + break + + return self._process_completion_tokens( + tokens, prompt_lens, config.max_new_tokens, config.eos_id + ) + + def _get_next_token( + self, logits: torch.Tensor, temperature: float + ) -> torch.Tensor: if temperature > 0: - next_token = sample(logits, temperature) + return TokenSampler.sample(logits, temperature) + return logits.argmax(dim=-1) + + def _process_completion_tokens( + self, + tokens: torch.Tensor, + prompt_lens: List[int], + max_new_tokens: int, + eos_id: int + ) -> List[List[int]]: + completion_tokens = [] + for i, toks in enumerate(tokens.tolist()): + toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens] + if eos_id in toks: + toks = toks[:toks.index(eos_id)] + completion_tokens.append(toks) + return completion_tokens + + +class DistributedEnvironment: + def __init__(self): + self.world_size = int(os.getenv("WORLD_SIZE", "1")) + self.rank = int(os.getenv("RANK", "0")) + self.local_rank = int(os.getenv("LOCAL_RANK", "0")) + + def setup(self): + if self.world_size > 1: + dist.init_process_group("nccl") + if self.rank != 0: + global print + print = lambda *_, **__: None + torch.cuda.set_device(self.local_rank) + + def cleanup(self): + if self.world_size > 1: + dist.destroy_process_group() + + def broadcast_prompt(self, prompt: Optional[str] = None) -> str: + if self.world_size == 1: + return input(">>> ") + elif self.rank == 0: + prompt = input(">>> ") + objects = [prompt] + dist.broadcast_object_list(objects, 0) + return prompt else: - next_token = logits.argmax(dim=-1) - next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) - tokens[:, cur_pos] = next_token - finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id) - prev_pos = cur_pos - if finished.all(): - break - completion_tokens = [] - for i, toks in enumerate(tokens.tolist()): - toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] - if eos_id in toks: - toks = toks[:toks.index(eos_id)] - completion_tokens.append(toks) - return completion_tokens + objects = [None] + dist.broadcast_object_list(objects, 0) + return objects[0] + + +class ChatSession: + def __init__( + self, + generator: TextGenerator, + config: GenerationConfig, + dist_env: DistributedEnvironment + ): + self.generator = generator + self.config = config + self.dist_env = dist_env + self.messages = [] + + def run_interactive(self): + while True: + prompt = self.dist_env.broadcast_prompt() + if prompt == "/exit": + break + elif prompt == "/clear": + self.messages.clear() + continue + + completion = self._process_message(prompt) + print(completion) + self.messages.append({"role": "assistant", "content": completion}) + + def run_batch(self, input_file: str): + with open(input_file) as f: + prompts = [line.strip() for line in f.readlines()] + assert len(prompts) <= self.generator.model.args.max_batch_size + + completions = self._process_batch(prompts) + for prompt, completion in zip(prompts, completions): + print("Prompt:", prompt) + print("Completion:", completion) + print() + + def _process_message(self, prompt: str) -> str: + self.messages.append({"role": "user", "content": prompt}) + prompt_tokens = self.generator.tokenizer.apply_chat_template( + self.messages, add_generation_prompt=True + ) + completion_tokens = self.generator.generate( + [prompt_tokens], self.config + ) + return self.generator.tokenizer.decode( + completion_tokens[0], skip_special_tokens=True + ) + + def _process_batch(self, prompts: List[str]) -> List[str]: + prompt_tokens = [ + self.generator.tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + add_generation_prompt=True + ) + for prompt in prompts + ] + completion_tokens = self.generator.generate( + prompt_tokens, self.config + ) + return self.generator.tokenizer.batch_decode( + completion_tokens, skip_special_tokens=True + ) + + +def initialize_model( + ckpt_path: str, config_path: str, dist_env: DistributedEnvironment +) -> Tuple[Transformer, Any]: + """Initialize the model and tokenizer.""" + torch.set_default_dtype(torch.bfloat16) + torch.set_num_threads(8) + torch.manual_seed(965) + + with open(config_path) as f: + args = ModelArgs(**json.load(f)) + print(args) + + with torch.device("cuda"): + model = Transformer(args) + tokenizer = AutoTokenizer.from_pretrained(ckpt_path) + + # Warmup + tokenizer.decode( + TextGenerator(model, tokenizer).generate( + [tokenizer.encode("DeepSeek")], + GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1) + )[0] + ) + + load_model( + model, + os.path.join( + ckpt_path, + f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors" + ) + ) + return model, tokenizer def main( @@ -86,94 +271,29 @@ def main( max_new_tokens: int = 100, temperature: float = 1.0, ) -> None: - """ - Main function to load the model and perform interactive or batch text generation. + dist_env = DistributedEnvironment() + dist_env.setup() - Args: - ckpt_path (str): Path to the model checkpoint directory. - config (str): Path to the model configuration file. - input_file (str, optional): Path to a file containing input prompts. Defaults to "". - interactive (bool, optional): Whether to run in interactive mode. Defaults to True. - max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. - temperature (float, optional): Temperature for sampling. Defaults to 1.0. - """ - world_size = int(os.getenv("WORLD_SIZE", "1")) - rank = int(os.getenv("RANK", "0")) - local_rank = int(os.getenv("LOCAL_RANK", "0")) - if world_size > 1: - dist.init_process_group("nccl") - global print - if rank != 0: - print = lambda *_, **__: None - torch.cuda.set_device(local_rank) - torch.set_default_dtype(torch.bfloat16) - torch.set_num_threads(8) - torch.manual_seed(965) - with open(config) as f: - args = ModelArgs(**json.load(f)) - print(args) - with torch.device("cuda"): - model = Transformer(args) - tokenizer = AutoTokenizer.from_pretrained(ckpt_path) - tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0]) - load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors")) + model, tokenizer = initialize_model(ckpt_path, config, dist_env) + generator = TextGenerator(model, tokenizer) + gen_config = GenerationConfig( + max_new_tokens=max_new_tokens, + temperature=temperature, + eos_id=tokenizer.eos_token_id + ) + session = ChatSession(generator, gen_config, dist_env) + if interactive: - messages = [] - while True: - if world_size == 1: - prompt = input(">>> ") - elif rank == 0: - prompt = input(">>> ") - objects = [prompt] - dist.broadcast_object_list(objects, 0) - else: - objects = [None] - dist.broadcast_object_list(objects, 0) - prompt = objects[0] - if prompt == "/exit": - break - elif prompt == "/clear": - messages.clear() - continue - messages.append({"role": "user", "content": prompt}) - prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature) - completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True) - print(completion) - messages.append({"role": "assistant", "content": completion}) + session.run_interactive() else: - with open(input_file) as f: - prompts = [line.strip() for line in f.readlines()] - assert len(prompts) <= args.max_batch_size - prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts] - completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature) - completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True) - for prompt, completion in zip(prompts, completions): - print("Prompt:", prompt) - print("Completion:", completion) - print() + session.run_batch(input_file) - if world_size > 1: - dist.destroy_process_group() + dist_env.cleanup() if __name__ == "__main__": - """ - Command-line interface for distributed text generation. - - Arguments: - --ckpt-path (str): Path to the model checkpoint directory. - --config (str): Path to the model configuration file. - --input-file (str, optional): File containing prompts for batch processing. - --interactive (bool, optional): Enable interactive mode for generating text. - --max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200. - --temperature (float, optional): Temperature for sampling. Defaults to 0.2. - - Raises: - AssertionError: If neither input-file nor interactive mode is specified. - """ - parser = ArgumentParser() + parser = ArgumentParser(description="Distributed text generation system") parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--config", type=str, required=True) parser.add_argument("--input-file", type=str, default="") @@ -181,5 +301,13 @@ if __name__ == "__main__": parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--temperature", type=float, default=0.2) args = parser.parse_args() + assert args.input_file or args.interactive - main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) + main( + args.ckpt_path, + args.config, + args.input_file, + args.interactive, + args.max_new_tokens, + args.temperature + ) \ No newline at end of file diff --git a/inference/kernel.py b/inference/kernel.py index dec8639..1eaec2a 100644 --- a/inference/kernel.py +++ b/inference/kernel.py @@ -1,4 +1,5 @@ from typing import Tuple +from dataclasses import dataclass import torch import triton @@ -6,186 +7,228 @@ import triton.language as tl from triton import Config -@triton.jit -def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): - """ - Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. - - Args: - x_ptr (triton.Pointer): Pointer to the input tensor. - y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored. - s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored. - BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance. - - Returns: - None - """ - pid = tl.program_id(axis=0) - offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - x = tl.load(x_ptr + offs).to(tl.float32) - s = tl.max(tl.abs(x)) / 448. - y = x / s - y = y.to(y_ptr.dtype.element_ty) - tl.store(y_ptr + offs, y) - tl.store(s_ptr + pid, s) +@dataclass +class BlockConfig: + """Configuration for block sizes in tensor operations.""" + size: int = 128 + size_m: int = 64 + size_n: int = 64 + size_k: int = 128 -def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: - """ - Quantizes the input tensor `x` using block-wise quantization. +class QuantizationKernels: + """Collection of Triton kernels for quantization operations.""" + + @staticmethod + @triton.jit + def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): + """ + Quantizes activation values using block-wise scaling. + + Args: + x_ptr: Input tensor pointer + y_ptr: Output quantized tensor pointer + s_ptr: Output scaling factors pointer + BLOCK_SIZE: Size of processing block + """ + pid = tl.program_id(axis=0) + offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + x = tl.load(x_ptr + offs).to(tl.float32) + s = tl.max(tl.abs(x)) / 448. + y = x / s + y = y.to(y_ptr.dtype.element_ty) + tl.store(y_ptr + offs, y) + tl.store(s_ptr + pid, s) - Args: - x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. - block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: A tuple containing: - - The quantized tensor with dtype `torch.float8_e4m3fn`. - - A tensor of scaling factors with dtype `torch.float32`. - """ - assert x.is_contiguous() - assert x.size(-1) % block_size == 0 - y = torch.empty_like(x, dtype=torch.float8_e4m3fn) - s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) - grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) - act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) - return y, s + @staticmethod + @triton.jit + def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): + """ + Dequantizes weights using block-wise scaling. + + Args: + x_ptr: Quantized weights pointer + s_ptr: Scaling factors pointer + y_ptr: Output dequantized tensor pointer + M: Number of rows + N: Number of columns + BLOCK_SIZE: Size of processing block + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + n = tl.cdiv(N, BLOCK_SIZE) + offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + offs = offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) + s = tl.load(s_ptr + pid_m * n + pid_n) + y = x * s + tl.store(y_ptr + offs, y, mask=mask) -@triton.jit -def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): - """ - Dequantizes weights using the provided scaling factors and stores the result. +class MatrixMultKernels: + """Collection of Triton kernels for matrix multiplication operations.""" + + @staticmethod + def get_configs(): + """Generate configurations for FP8 GEMM autotuning.""" + return [ + Config({ + 'BLOCK_SIZE_M': block_m, + 'BLOCK_SIZE_N': block_n, + 'BLOCK_SIZE_K': 128 + }, num_stages=num_stages, num_warps=8) + for block_m in [16, 32, 64] + for block_n in [32, 64, 128] + for num_stages in [3, 4, 5, 6] + ] - Args: - x_ptr (tl.pointer): Pointer to the quantized weights. - s_ptr (tl.pointer): Pointer to the scaling factors. - y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. - M (int): Number of rows in the weight matrix. - N (int): Number of columns in the weight matrix. - BLOCK_SIZE (tl.constexpr): Size of the block for tiling. + @staticmethod + @triton.autotune(configs=get_configs(), key=['N', 'K']) + @triton.jit + def fp8_gemm_kernel( + a_ptr, b_ptr, c_ptr, + a_s_ptr, b_s_ptr, + M, N: tl.constexpr, K: tl.constexpr, + BLOCK_SIZE_M: tl.constexpr, + BLOCK_SIZE_N: tl.constexpr, + BLOCK_SIZE_K: tl.constexpr + ): + """ + Performs FP8 matrix multiplication with scaling factors. + + Args: + a_ptr: First input matrix pointer + b_ptr: Second input matrix pointer + c_ptr: Output matrix pointer + a_s_ptr: First matrix scaling factors pointer + b_s_ptr: Second matrix scaling factors pointer + M: First matrix rows + N: Second matrix columns + K: Inner dimension + BLOCK_SIZE_M/N/K: Block sizes for tiling + """ + pid_m = tl.program_id(axis=0) + pid_n = tl.program_id(axis=1) + k = tl.cdiv(K, BLOCK_SIZE_K) + + # Calculate offsets + offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M + offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N + offs_k = tl.arange(0, BLOCK_SIZE_K) + + # Initialize pointers + a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] + b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] + a_s_ptrs = a_s_ptr + offs_m * k + b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k - Returns: - None - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - n = tl.cdiv(N, BLOCK_SIZE) - offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) - offs = offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) - s = tl.load(s_ptr + pid_m * n + pid_n) - y = x * s - tl.store(y_ptr + offs, y, mask=mask) + accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) + + # Main computation loop + for i in range(k): + a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) + b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) + a_s = tl.load(a_s_ptrs) + b_s = tl.load(b_s_ptrs) + accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] + + # Update pointers + a_ptrs += BLOCK_SIZE_K + b_ptrs += BLOCK_SIZE_K + a_s_ptrs += 1 + b_s_ptrs += 1 + + # Store results + c = accumulator.to(c_ptr.dtype.element_ty) + offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) + offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) + c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] + mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) + tl.store(c_ptrs, c, mask=mask) -def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: - """ - Dequantizes the given weight tensor using the provided scale tensor. +class TensorOps: + """High-level interface for tensor operations.""" - Args: - x (torch.Tensor): The quantized weight tensor of shape (M, N). - s (torch.Tensor): The scale tensor of shape (M, N). - block_size (int, optional): The block size to use for dequantization. Defaults to 128. + @staticmethod + def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Quantize activations using block-wise scaling. + + Args: + x: Input tensor + block_size: Block size for quantization + + Returns: + Tuple of quantized tensor and scaling factors + """ + assert x.is_contiguous() + assert x.size(-1) % block_size == 0 + + y = torch.empty_like(x, dtype=torch.float8_e4m3fn) + s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) + + grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),) + QuantizationKernels.act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) + + return y, s - Returns: - torch.Tensor: The dequantized weight tensor of the same shape as `x`. + @staticmethod + def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: + """ + Dequantize weights using block-wise scaling. + + Args: + x: Quantized weight tensor + s: Scaling factors tensor + block_size: Block size for dequantization + + Returns: + Dequantized tensor + """ + assert x.is_contiguous() and s.is_contiguous() + assert x.dim() == 2 and s.dim() == 2 + + M, N = x.size() + y = torch.empty_like(x, dtype=torch.get_default_dtype()) + + grid = lambda meta: ( + triton.cdiv(M, meta['BLOCK_SIZE']), + triton.cdiv(N, meta['BLOCK_SIZE']) + ) + QuantizationKernels.weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) + + return y - Raises: - AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. - """ - assert x.is_contiguous() and s.is_contiguous() - assert x.dim() == 2 and s.dim() == 2 - M, N = x.size() - y = torch.empty_like(x, dtype=torch.get_default_dtype()) - grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE'])) - weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size) - return y - - -fp8_gemm_configs = [ - Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) - for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6] -] - -@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) -@triton.jit -def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr, - a_s_ptr, b_s_ptr, - M, N: tl.constexpr, K: tl.constexpr, - BLOCK_SIZE_M: tl.constexpr, - BLOCK_SIZE_N: tl.constexpr, - BLOCK_SIZE_K: tl.constexpr): - """ - Performs a matrix multiplication operation on FP8 matrices with scaling factors. - - Args: - a_ptr (tl.tensor): Pointer to the first input matrix A. - b_ptr (tl.tensor): Pointer to the second input matrix B. - c_ptr (tl.tensor): Pointer to the output matrix C. - a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A. - b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B. - M (int): Number of rows in matrix A and C. - N (tl.constexpr): Number of columns in matrix B and C. - K (tl.constexpr): Number of columns in matrix A and rows in matrix B. - BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension. - BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension. - BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension. - - Returns: - None - """ - pid_m = tl.program_id(axis=0) - pid_n = tl.program_id(axis=1) - k = tl.cdiv(K, BLOCK_SIZE_K) - offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M - offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N - offs_k = tl.arange(0, BLOCK_SIZE_K) - a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :] - b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None] - a_s_ptrs = a_s_ptr + offs_m * k - b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k - - accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) - for i in range(k): - a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) - b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) - a_s = tl.load(a_s_ptrs) - b_s = tl.load(b_s_ptrs) - accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :] - a_ptrs += BLOCK_SIZE_K - b_ptrs += BLOCK_SIZE_K - a_s_ptrs += 1 - b_s_ptrs += 1 - c = accumulator.to(c_ptr.dtype.element_ty) - offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M) - offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) - c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :] - mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) - tl.store(c_ptrs, c, mask=mask) - - -def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): - """ - Perform a matrix multiplication using FP8 precision. - - Args: - a (torch.Tensor): The first input matrix, must be contiguous. - a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous. - b (torch.Tensor): The second input matrix, must be contiguous. - b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous. - - Returns: - torch.Tensor: The result of the matrix multiplication. - """ - assert a.is_contiguous() and b.is_contiguous() - assert a_s.is_contiguous() and b_s.is_contiguous() - K = a.size(-1) - M = a.numel() // K - N = b.size(0) - c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) - grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) - fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) - return c + @staticmethod + def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor: + """ + Perform FP8 matrix multiplication. + + Args: + a: First input matrix + a_s: First matrix scaling factors + b: Second input matrix + b_s: Second matrix scaling factors + + Returns: + Result matrix + """ + assert a.is_contiguous() and b.is_contiguous() + assert a_s.is_contiguous() and b_s.is_contiguous() + + K = a.size(-1) + M = a.numel() // K + N = b.size(0) + + c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) + + grid = lambda META: ( + triton.cdiv(M, META['BLOCK_SIZE_M']), + triton.cdiv(N, META['BLOCK_SIZE_N']) + ) + MatrixMultKernels.fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K) + + return c \ No newline at end of file