Refactored the codebase by defining seperate classes for different operations and implemented better type safety

This commit is contained in:
pratiyankkumar 2025-01-29 10:37:09 +05:30
parent 70ff909fdc
commit de7df86119
3 changed files with 605 additions and 395 deletions

View File

@ -2,6 +2,7 @@ import os
import json import json
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from typing import Dict, Any
from tqdm import tqdm from tqdm import tqdm
import torch import torch
@ -9,98 +10,137 @@ from safetensors.torch import load_file, save_file
from kernel import weight_dequant from kernel import weight_dequant
def main(fp8_path, bf16_path):
"""
Converts FP8 weights to BF16 and saves the converted weights.
This function reads FP8 weights from the specified directory, converts them to BF16, class WeightConverter:
and saves the converted weights to another specified directory. It also updates the def __init__(self, fp8_path: str, bf16_path: str):
model index file to reflect the changes.
Args:
fp8_path (str): The path to the directory containing the FP8 weights and model index file.
bf16_path (str): The path to the directory where the converted BF16 weights will be saved.
Raises:
KeyError: If a required scale_inv tensor is missing for a weight.
Notes:
- The function assumes that the FP8 weights are stored in safetensor files.
- The function caches loaded safetensor files to optimize memory usage.
- The function updates the model index file to remove references to scale_inv tensors.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(bf16_path, exist_ok=True)
model_index_file = os.path.join(fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
model_index = json.load(f)
weight_map = model_index["weight_map"]
# Cache for loaded safetensor files
loaded_files = {}
fp8_weight_names = []
# Helper function to get tensor from the correct file
def get_tensor(tensor_name):
""" """
Retrieves a tensor from the cached safetensor files or loads it from disk if not cached. Initialize the weight converter with input and output paths.
Args: Args:
tensor_name (str): The name of the tensor to retrieve. fp8_path (str): Path to the directory containing FP8 weights
bf16_path (str): Path to save the converted BF16 weights
"""
self.fp8_path = fp8_path
self.bf16_path = bf16_path
self.loaded_files: Dict[str, Dict[str, torch.Tensor]] = {}
self.fp8_weight_names: list = []
self.weight_map: Dict[str, str] = self._load_model_index()
def _load_model_index(self) -> Dict[str, str]:
"""
Load the model index file.
Returns: Returns:
torch.Tensor: The retrieved tensor. Dict[str, str]: Weight mapping from the index file
"""
model_index_file = os.path.join(self.fp8_path, "model.safetensors.index.json")
with open(model_index_file, "r") as f:
return json.load(f)["weight_map"]
def _get_tensor(self, tensor_name: str) -> torch.Tensor:
"""
Get a tensor from cache or load it from disk.
Args:
tensor_name (str): Name of the tensor to retrieve
Returns:
torch.Tensor: The requested tensor
Raises: Raises:
KeyError: If the tensor does not exist in the safetensor file. KeyError: If tensor doesn't exist in the safetensor file
""" """
file_name = weight_map[tensor_name] file_name = self.weight_map[tensor_name]
if file_name not in loaded_files: if file_name not in self.loaded_files:
file_path = os.path.join(fp8_path, file_name) file_path = os.path.join(self.fp8_path, file_name)
loaded_files[file_name] = load_file(file_path, device="cuda") self.loaded_files[file_name] = load_file(file_path, device="cuda")
return loaded_files[file_name][tensor_name] return self.loaded_files[file_name][tensor_name]
safetensor_files = list(glob(os.path.join(fp8_path, "*.safetensors"))) def _manage_memory(self):
safetensor_files.sort() """
for safetensor_file in tqdm(safetensor_files): Keep only the 2 most recently used files in memory.
file_name = os.path.basename(safetensor_file) """
current_state_dict = load_file(safetensor_file, device="cuda") if len(self.loaded_files) > 2:
loaded_files[file_name] = current_state_dict oldest_file = next(iter(self.loaded_files))
del self.loaded_files[oldest_file]
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
if weight_name.endswith("_scale_inv"):
continue
elif weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
# Get scale_inv from the correct file
scale_inv = get_tensor(scale_inv_name)
fp8_weight_names.append(weight_name)
new_state_dict[weight_name] = weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
new_state_dict[weight_name] = weight
else:
new_state_dict[weight_name] = weight
new_safetensor_file = os.path.join(bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
# Memory management: keep only the 2 most recently used files
if len(loaded_files) > 2:
oldest_file = next(iter(loaded_files))
del loaded_files[oldest_file]
torch.cuda.empty_cache() torch.cuda.empty_cache()
# Update model index def _process_weight(self, weight_name: str, weight: torch.Tensor) -> torch.Tensor:
new_model_index_file = os.path.join(bf16_path, "model.safetensors.index.json") """
for weight_name in fp8_weight_names: Process a single weight tensor.
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map: Args:
weight_map.pop(scale_inv_name) weight_name (str): Name of the weight tensor
with open(new_model_index_file, "w") as f: weight (torch.Tensor): The weight tensor to process
json.dump({"metadata": {}, "weight_map": weight_map}, f, indent=2)
Returns:
torch.Tensor: Processed weight tensor
"""
if weight_name.endswith("_scale_inv"):
return None
if weight.element_size() == 1: # FP8 weight
scale_inv_name = f"{weight_name}_scale_inv"
try:
scale_inv = self._get_tensor(scale_inv_name)
self.fp8_weight_names.append(weight_name)
return weight_dequant(weight, scale_inv)
except KeyError:
print(f"Warning: Missing scale_inv tensor for {weight_name}, skipping conversion")
return weight
return weight
def _save_model_index(self):
"""
Save the updated model index file.
"""
new_model_index_file = os.path.join(self.bf16_path, "model.safetensors.index.json")
for weight_name in self.fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in self.weight_map:
self.weight_map.pop(scale_inv_name)
with open(new_model_index_file, "w") as f:
json.dump({"metadata": {}, "weight_map": self.weight_map}, f, indent=2)
def convert(self):
"""
Convert FP8 weights to BF16 format.
"""
torch.set_default_dtype(torch.bfloat16)
os.makedirs(self.bf16_path, exist_ok=True)
safetensor_files = sorted(glob(os.path.join(self.fp8_path, "*.safetensors")))
for safetensor_file in tqdm(safetensor_files):
file_name = os.path.basename(safetensor_file)
current_state_dict = load_file(safetensor_file, device="cuda")
self.loaded_files[file_name] = current_state_dict
new_state_dict = {}
for weight_name, weight in current_state_dict.items():
processed_weight = self._process_weight(weight_name, weight)
if processed_weight is not None:
new_state_dict[weight_name] = processed_weight
new_safetensor_file = os.path.join(self.bf16_path, file_name)
save_file(new_state_dict, new_safetensor_file)
self._manage_memory()
self._save_model_index()
def main(fp8_path: str, bf16_path: str):
"""
Main function to convert FP8 weights to BF16.
Args:
fp8_path (str): Input directory containing FP8 weights
bf16_path (str): Output directory for BF16 weights
"""
converter = WeightConverter(fp8_path, bf16_path)
converter.convert()
if __name__ == "__main__": if __name__ == "__main__":
@ -109,4 +149,3 @@ if __name__ == "__main__":
parser.add_argument("--output-bf16-hf-path", type=str, required=True) parser.add_argument("--output-bf16-hf-path", type=str, required=True)
args = parser.parse_args() args = parser.parse_args()
main(args.input_fp8_hf_path, args.output_bf16_hf_path) main(args.input_fp8_hf_path, args.output_bf16_hf_path)

View File

@ -1,7 +1,8 @@
import os import os
import json import json
from argparse import ArgumentParser from argparse import ArgumentParser
from typing import List from typing import List, Optional, Dict, Any, Tuple
from dataclasses import dataclass
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -11,71 +12,255 @@ from safetensors.torch import load_model
from model import Transformer, ModelArgs from model import Transformer, ModelArgs
def sample(logits, temperature: float = 1.0): @dataclass
""" class GenerationConfig:
Samples a token from the logits using temperature scaling. max_new_tokens: int
temperature: float
Args: eos_id: int
logits (torch.Tensor): The logits tensor for token predictions.
temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.
Returns:
torch.Tensor: The sampled token.
"""
logits = logits / max(temperature, 1e-5)
probs = torch.softmax(logits, dim=-1)
return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
@torch.inference_mode() class TokenSampler:
def generate( @staticmethod
model: Transformer, def sample(logits: torch.Tensor, temperature: float = 1.0) -> torch.Tensor:
prompt_tokens: List[List[int]], """
max_new_tokens: int, Samples a token from the logits using temperature scaling.
eos_id: int,
temperature: float = 1.0
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens using the specified model.
Args: Args:
model (Transformer): The transformer model used for token generation. logits (torch.Tensor): The logits tensor for token predictions.
prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence. temperature (float): Temperature for scaling logits.
max_new_tokens (int): The maximum number of new tokens to generate.
eos_id (int): The end-of-sequence token ID.
temperature (float, optional): The temperature value for sampling. Defaults to 1.0.
Returns: Returns:
List[List[int]]: A list of lists containing the generated tokens for each sequence. torch.Tensor: The sampled token.
""" """
prompt_lens = [len(t) for t in prompt_tokens] logits = logits / max(temperature, 1e-5)
assert max(prompt_lens) <= model.max_seq_len probs = torch.softmax(logits, dim=-1)
total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens)) return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)
tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda") class TextGenerator:
prev_pos = 0 def __init__(self, model: Transformer, tokenizer: Any):
finished = torch.tensor([False] * len(prompt_tokens), device="cuda") self.model = model
prompt_mask = tokens != -1 self.tokenizer = tokenizer
for cur_pos in range(min(prompt_lens), total_len):
logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos) @torch.inference_mode()
def generate(
self,
prompt_tokens: List[List[int]],
config: GenerationConfig
) -> List[List[int]]:
"""
Generates new tokens based on the given prompt tokens.
Args:
prompt_tokens: A list of lists containing the prompt tokens for each sequence.
config: Generation configuration parameters.
Returns:
List[List[int]]: Generated tokens for each sequence.
"""
prompt_lens = [len(t) for t in prompt_tokens]
assert max(prompt_lens) <= self.model.max_seq_len
total_len = min(self.model.max_seq_len, config.max_new_tokens + max(prompt_lens))
tokens = self._initialize_tokens(prompt_tokens, total_len)
completion_tokens = self._generate_tokens(
tokens, prompt_lens, total_len, config
)
return completion_tokens
def _initialize_tokens(
self, prompt_tokens: List[List[int]], total_len: int
) -> torch.Tensor:
tokens = torch.full(
(len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda"
)
for i, t in enumerate(prompt_tokens):
tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
return tokens
def _generate_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
total_len: int,
config: GenerationConfig
) -> List[List[int]]:
prev_pos = 0
finished = torch.tensor([False] * len(prompt_lens), device="cuda")
prompt_mask = tokens != -1
for cur_pos in range(min(prompt_lens), total_len):
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
next_token = self._get_next_token(logits, config.temperature)
next_token = torch.where(
prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token
)
tokens[:, cur_pos] = next_token
finished |= torch.logical_and(
~prompt_mask[:, cur_pos], next_token == config.eos_id
)
prev_pos = cur_pos
if finished.all():
break
return self._process_completion_tokens(
tokens, prompt_lens, config.max_new_tokens, config.eos_id
)
def _get_next_token(
self, logits: torch.Tensor, temperature: float
) -> torch.Tensor:
if temperature > 0: if temperature > 0:
next_token = sample(logits, temperature) return TokenSampler.sample(logits, temperature)
return logits.argmax(dim=-1)
def _process_completion_tokens(
self,
tokens: torch.Tensor,
prompt_lens: List[int],
max_new_tokens: int,
eos_id: int
) -> List[List[int]]:
completion_tokens = []
for i, toks in enumerate(tokens.tolist()):
toks = toks[prompt_lens[i]:prompt_lens[i] + max_new_tokens]
if eos_id in toks:
toks = toks[:toks.index(eos_id)]
completion_tokens.append(toks)
return completion_tokens
class DistributedEnvironment:
def __init__(self):
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
self.rank = int(os.getenv("RANK", "0"))
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
def setup(self):
if self.world_size > 1:
dist.init_process_group("nccl")
if self.rank != 0:
global print
print = lambda *_, **__: None
torch.cuda.set_device(self.local_rank)
def cleanup(self):
if self.world_size > 1:
dist.destroy_process_group()
def broadcast_prompt(self, prompt: Optional[str] = None) -> str:
if self.world_size == 1:
return input(">>> ")
elif self.rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
return prompt
else: else:
next_token = logits.argmax(dim=-1) objects = [None]
next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token) dist.broadcast_object_list(objects, 0)
tokens[:, cur_pos] = next_token return objects[0]
finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
prev_pos = cur_pos
if finished.all(): class ChatSession:
break def __init__(
completion_tokens = [] self,
for i, toks in enumerate(tokens.tolist()): generator: TextGenerator,
toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens] config: GenerationConfig,
if eos_id in toks: dist_env: DistributedEnvironment
toks = toks[:toks.index(eos_id)] ):
completion_tokens.append(toks) self.generator = generator
return completion_tokens self.config = config
self.dist_env = dist_env
self.messages = []
def run_interactive(self):
while True:
prompt = self.dist_env.broadcast_prompt()
if prompt == "/exit":
break
elif prompt == "/clear":
self.messages.clear()
continue
completion = self._process_message(prompt)
print(completion)
self.messages.append({"role": "assistant", "content": completion})
def run_batch(self, input_file: str):
with open(input_file) as f:
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= self.generator.model.args.max_batch_size
completions = self._process_batch(prompts)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
def _process_message(self, prompt: str) -> str:
self.messages.append({"role": "user", "content": prompt})
prompt_tokens = self.generator.tokenizer.apply_chat_template(
self.messages, add_generation_prompt=True
)
completion_tokens = self.generator.generate(
[prompt_tokens], self.config
)
return self.generator.tokenizer.decode(
completion_tokens[0], skip_special_tokens=True
)
def _process_batch(self, prompts: List[str]) -> List[str]:
prompt_tokens = [
self.generator.tokenizer.apply_chat_template(
[{"role": "user", "content": prompt}],
add_generation_prompt=True
)
for prompt in prompts
]
completion_tokens = self.generator.generate(
prompt_tokens, self.config
)
return self.generator.tokenizer.batch_decode(
completion_tokens, skip_special_tokens=True
)
def initialize_model(
ckpt_path: str, config_path: str, dist_env: DistributedEnvironment
) -> Tuple[Transformer, Any]:
"""Initialize the model and tokenizer."""
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config_path) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
# Warmup
tokenizer.decode(
TextGenerator(model, tokenizer).generate(
[tokenizer.encode("DeepSeek")],
GenerationConfig(max_new_tokens=2, temperature=1.0, eos_id=-1)
)[0]
)
load_model(
model,
os.path.join(
ckpt_path,
f"model{dist_env.rank}-mp{dist_env.world_size}.safetensors"
)
)
return model, tokenizer
def main( def main(
@ -86,94 +271,29 @@ def main(
max_new_tokens: int = 100, max_new_tokens: int = 100,
temperature: float = 1.0, temperature: float = 1.0,
) -> None: ) -> None:
""" dist_env = DistributedEnvironment()
Main function to load the model and perform interactive or batch text generation. dist_env.setup()
Args: model, tokenizer = initialize_model(ckpt_path, config, dist_env)
ckpt_path (str): Path to the model checkpoint directory. generator = TextGenerator(model, tokenizer)
config (str): Path to the model configuration file. gen_config = GenerationConfig(
input_file (str, optional): Path to a file containing input prompts. Defaults to "". max_new_tokens=max_new_tokens,
interactive (bool, optional): Whether to run in interactive mode. Defaults to True. temperature=temperature,
max_new_tokens (int, optional): Maximum number of new tokens to generate. Defaults to 100. eos_id=tokenizer.eos_token_id
temperature (float, optional): Temperature for sampling. Defaults to 1.0. )
"""
world_size = int(os.getenv("WORLD_SIZE", "1")) session = ChatSession(generator, gen_config, dist_env)
rank = int(os.getenv("RANK", "0"))
local_rank = int(os.getenv("LOCAL_RANK", "0"))
if world_size > 1:
dist.init_process_group("nccl")
global print
if rank != 0:
print = lambda *_, **__: None
torch.cuda.set_device(local_rank)
torch.set_default_dtype(torch.bfloat16)
torch.set_num_threads(8)
torch.manual_seed(965)
with open(config) as f:
args = ModelArgs(**json.load(f))
print(args)
with torch.device("cuda"):
model = Transformer(args)
tokenizer = AutoTokenizer.from_pretrained(ckpt_path)
tokenizer.decode(generate(model, [tokenizer.encode("DeepSeek")], 2, -1, 1.)[0])
load_model(model, os.path.join(ckpt_path, f"model{rank}-mp{world_size}.safetensors"))
if interactive: if interactive:
messages = [] session.run_interactive()
while True:
if world_size == 1:
prompt = input(">>> ")
elif rank == 0:
prompt = input(">>> ")
objects = [prompt]
dist.broadcast_object_list(objects, 0)
else:
objects = [None]
dist.broadcast_object_list(objects, 0)
prompt = objects[0]
if prompt == "/exit":
break
elif prompt == "/clear":
messages.clear()
continue
messages.append({"role": "user", "content": prompt})
prompt_tokens = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
completion_tokens = generate(model, [prompt_tokens], max_new_tokens, tokenizer.eos_token_id, temperature)
completion = tokenizer.decode(completion_tokens[0], skip_special_tokens=True)
print(completion)
messages.append({"role": "assistant", "content": completion})
else: else:
with open(input_file) as f: session.run_batch(input_file)
prompts = [line.strip() for line in f.readlines()]
assert len(prompts) <= args.max_batch_size
prompt_tokens = [tokenizer.apply_chat_template([{"role": "user", "content": prompt}], add_generation_prompt=True) for prompt in prompts]
completion_tokens = generate(model, prompt_tokens, max_new_tokens, tokenizer.eos_token_id, temperature)
completions = tokenizer.batch_decode(completion_tokens, skip_special_tokens=True)
for prompt, completion in zip(prompts, completions):
print("Prompt:", prompt)
print("Completion:", completion)
print()
if world_size > 1: dist_env.cleanup()
dist.destroy_process_group()
if __name__ == "__main__": if __name__ == "__main__":
""" parser = ArgumentParser(description="Distributed text generation system")
Command-line interface for distributed text generation.
Arguments:
--ckpt-path (str): Path to the model checkpoint directory.
--config (str): Path to the model configuration file.
--input-file (str, optional): File containing prompts for batch processing.
--interactive (bool, optional): Enable interactive mode for generating text.
--max-new-tokens (int, optional): Maximum number of new tokens to generate. Defaults to 200.
--temperature (float, optional): Temperature for sampling. Defaults to 0.2.
Raises:
AssertionError: If neither input-file nor interactive mode is specified.
"""
parser = ArgumentParser()
parser.add_argument("--ckpt-path", type=str, required=True) parser.add_argument("--ckpt-path", type=str, required=True)
parser.add_argument("--config", type=str, required=True) parser.add_argument("--config", type=str, required=True)
parser.add_argument("--input-file", type=str, default="") parser.add_argument("--input-file", type=str, default="")
@ -181,5 +301,13 @@ if __name__ == "__main__":
parser.add_argument("--max-new-tokens", type=int, default=200) parser.add_argument("--max-new-tokens", type=int, default=200)
parser.add_argument("--temperature", type=float, default=0.2) parser.add_argument("--temperature", type=float, default=0.2)
args = parser.parse_args() args = parser.parse_args()
assert args.input_file or args.interactive assert args.input_file or args.interactive
main(args.ckpt_path, args.config, args.input_file, args.interactive, args.max_new_tokens, args.temperature) main(
args.ckpt_path,
args.config,
args.input_file,
args.interactive,
args.max_new_tokens,
args.temperature
)

View File

@ -1,4 +1,5 @@
from typing import Tuple from typing import Tuple
from dataclasses import dataclass
import torch import torch
import triton import triton
@ -6,186 +7,228 @@ import triton.language as tl
from triton import Config from triton import Config
@triton.jit @dataclass
def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr): class BlockConfig:
""" """Configuration for block sizes in tensor operations."""
Quantizes the input tensor `x_ptr` and stores the result in `y_ptr` and the scaling factor in `s_ptr`. size: int = 128
size_m: int = 64
Args: size_n: int = 64
x_ptr (triton.Pointer): Pointer to the input tensor. size_k: int = 128
y_ptr (triton.Pointer): Pointer to the output tensor where quantized values will be stored.
s_ptr (triton.Pointer): Pointer to the output tensor where scaling factors will be stored.
BLOCK_SIZE (tl.constexpr): The size of the block to be processed by each program instance.
Returns:
None
"""
pid = tl.program_id(axis=0)
offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
x = tl.load(x_ptr + offs).to(tl.float32)
s = tl.max(tl.abs(x)) / 448.
y = x / s
y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]: class QuantizationKernels:
""" """Collection of Triton kernels for quantization operations."""
Quantizes the input tensor `x` using block-wise quantization.
Args: @staticmethod
x (torch.Tensor): The input tensor to be quantized. Must be contiguous and its last dimension size must be divisible by `block_size`. @triton.jit
block_size (int, optional): The size of the blocks to be used for quantization. Default is 128. def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
"""
Quantizes activation values using block-wise scaling.
Returns: Args:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing: x_ptr: Input tensor pointer
- The quantized tensor with dtype `torch.float8_e4m3fn`. y_ptr: Output quantized tensor pointer
- A tensor of scaling factors with dtype `torch.float32`. s_ptr: Output scaling factors pointer
""" BLOCK_SIZE: Size of processing block
assert x.is_contiguous() """
assert x.size(-1) % block_size == 0 pid = tl.program_id(axis=0)
y = torch.empty_like(x, dtype=torch.float8_e4m3fn) offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32) x = tl.load(x_ptr + offs).to(tl.float32)
grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']), ) s = tl.max(tl.abs(x)) / 448.
act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size) y = x / s
return y, s y = y.to(y_ptr.dtype.element_ty)
tl.store(y_ptr + offs, y)
tl.store(s_ptr + pid, s)
@staticmethod
@triton.jit
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr):
"""
Dequantizes weights using block-wise scaling.
Args:
x_ptr: Quantized weights pointer
s_ptr: Scaling factors pointer
y_ptr: Output dequantized tensor pointer
M: Number of rows
N: Number of columns
BLOCK_SIZE: Size of processing block
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
n = tl.cdiv(N, BLOCK_SIZE)
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
offs = offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32)
s = tl.load(s_ptr + pid_m * n + pid_n)
y = x * s
tl.store(y_ptr + offs, y, mask=mask)
@triton.jit class MatrixMultKernels:
def weight_dequant_kernel(x_ptr, s_ptr, y_ptr, M, N, BLOCK_SIZE: tl.constexpr): """Collection of Triton kernels for matrix multiplication operations."""
"""
Dequantizes weights using the provided scaling factors and stores the result.
Args: @staticmethod
x_ptr (tl.pointer): Pointer to the quantized weights. def get_configs():
s_ptr (tl.pointer): Pointer to the scaling factors. """Generate configurations for FP8 GEMM autotuning."""
y_ptr (tl.pointer): Pointer to the output buffer for dequantized weights. return [
M (int): Number of rows in the weight matrix. Config({
N (int): Number of columns in the weight matrix. 'BLOCK_SIZE_M': block_m,
BLOCK_SIZE (tl.constexpr): Size of the block for tiling. 'BLOCK_SIZE_N': block_n,
'BLOCK_SIZE_K': 128
}, num_stages=num_stages, num_warps=8)
for block_m in [16, 32, 64]
for block_n in [32, 64, 128]
for num_stages in [3, 4, 5, 6]
]
Returns: @staticmethod
None @triton.autotune(configs=get_configs(), key=['N', 'K'])
""" @triton.jit
pid_m = tl.program_id(axis=0) def fp8_gemm_kernel(
pid_n = tl.program_id(axis=1) a_ptr, b_ptr, c_ptr,
n = tl.cdiv(N, BLOCK_SIZE) a_s_ptr, b_s_ptr,
offs_m = pid_m * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) M, N: tl.constexpr, K: tl.constexpr,
offs_n = pid_n * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) BLOCK_SIZE_M: tl.constexpr,
offs = offs_m[:, None] * N + offs_n[None, :] BLOCK_SIZE_N: tl.constexpr,
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N) BLOCK_SIZE_K: tl.constexpr
x = tl.load(x_ptr + offs, mask=mask).to(tl.float32) ):
s = tl.load(s_ptr + pid_m * n + pid_n) """
y = x * s Performs FP8 matrix multiplication with scaling factors.
tl.store(y_ptr + offs, y, mask=mask)
Args:
a_ptr: First input matrix pointer
b_ptr: Second input matrix pointer
c_ptr: Output matrix pointer
a_s_ptr: First matrix scaling factors pointer
b_s_ptr: Second matrix scaling factors pointer
M: First matrix rows
N: Second matrix columns
K: Inner dimension
BLOCK_SIZE_M/N/K: Block sizes for tiling
"""
pid_m = tl.program_id(axis=0)
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
# Calculate offsets
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
# Initialize pointers
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
# Main computation loop
for i in range(k):
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0)
a_s = tl.load(a_s_ptrs)
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
# Update pointers
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
# Store results
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor: class TensorOps:
""" """High-level interface for tensor operations."""
Dequantizes the given weight tensor using the provided scale tensor.
Args: @staticmethod
x (torch.Tensor): The quantized weight tensor of shape (M, N). def act_quant(x: torch.Tensor, block_size: int = 128) -> Tuple[torch.Tensor, torch.Tensor]:
s (torch.Tensor): The scale tensor of shape (M, N). """
block_size (int, optional): The block size to use for dequantization. Defaults to 128. Quantize activations using block-wise scaling.
Returns: Args:
torch.Tensor: The dequantized weight tensor of the same shape as `x`. x: Input tensor
block_size: Block size for quantization
Raises: Returns:
AssertionError: If `x` or `s` are not contiguous or if their dimensions are not 2. Tuple of quantized tensor and scaling factors
""" """
assert x.is_contiguous() and s.is_contiguous() assert x.is_contiguous()
assert x.dim() == 2 and s.dim() == 2 assert x.size(-1) % block_size == 0
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
grid = lambda meta: (triton.cdiv(M, meta['BLOCK_SIZE']), triton.cdiv(N, meta['BLOCK_SIZE']))
weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
return y
y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
fp8_gemm_configs = [ grid = lambda meta: (triton.cdiv(x.numel(), meta['BLOCK_SIZE']),)
Config({'BLOCK_SIZE_M': block_m, 'BLOCK_SIZE_N': block_n, 'BLOCK_SIZE_K': 128}, num_stages=num_stages, num_warps=8) QuantizationKernels.act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
for block_m in [16, 32, 64] for block_n in [32, 64, 128] for num_stages in [3, 4, 5, 6]
]
@triton.autotune(configs=fp8_gemm_configs, key=['N', 'K']) return y, s
@triton.jit
def fp8_gemm_kernel(a_ptr, b_ptr, c_ptr,
a_s_ptr, b_s_ptr,
M, N: tl.constexpr, K: tl.constexpr,
BLOCK_SIZE_M: tl.constexpr,
BLOCK_SIZE_N: tl.constexpr,
BLOCK_SIZE_K: tl.constexpr):
"""
Performs a matrix multiplication operation on FP8 matrices with scaling factors.
Args: @staticmethod
a_ptr (tl.tensor): Pointer to the first input matrix A. def weight_dequant(x: torch.Tensor, s: torch.Tensor, block_size: int = 128) -> torch.Tensor:
b_ptr (tl.tensor): Pointer to the second input matrix B. """
c_ptr (tl.tensor): Pointer to the output matrix C. Dequantize weights using block-wise scaling.
a_s_ptr (tl.tensor): Pointer to the scaling factors for matrix A.
b_s_ptr (tl.tensor): Pointer to the scaling factors for matrix B.
M (int): Number of rows in matrix A and C.
N (tl.constexpr): Number of columns in matrix B and C.
K (tl.constexpr): Number of columns in matrix A and rows in matrix B.
BLOCK_SIZE_M (tl.constexpr): Block size for the M dimension.
BLOCK_SIZE_N (tl.constexpr): Block size for the N dimension.
BLOCK_SIZE_K (tl.constexpr): Block size for the K dimension.
Returns: Args:
None x: Quantized weight tensor
""" s: Scaling factors tensor
pid_m = tl.program_id(axis=0) block_size: Block size for dequantization
pid_n = tl.program_id(axis=1)
k = tl.cdiv(K, BLOCK_SIZE_K)
offs_m = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
offs_n = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + offs_m[:, None] * K + offs_k[None, :]
b_ptrs = b_ptr + offs_n[None, :] * K + offs_k[:, None]
a_s_ptrs = a_s_ptr + offs_m * k
b_s_ptrs = b_s_ptr + (offs_n // BLOCK_SIZE_K) * k
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32) Returns:
for i in range(k): Dequantized tensor
a = tl.load(a_ptrs, mask=offs_k[None, :] < K - i * BLOCK_SIZE_K, other=0.0) """
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - i * BLOCK_SIZE_K, other=0.0) assert x.is_contiguous() and s.is_contiguous()
a_s = tl.load(a_s_ptrs) assert x.dim() == 2 and s.dim() == 2
b_s = tl.load(b_s_ptrs)
accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
a_ptrs += BLOCK_SIZE_K
b_ptrs += BLOCK_SIZE_K
a_s_ptrs += 1
b_s_ptrs += 1
c = accumulator.to(c_ptr.dtype.element_ty)
offs_m = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
offs_n = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + offs_m[:, None] * N + offs_n[None, :]
mask = (offs_m[:, None] < M) & (offs_n[None, :] < N)
tl.store(c_ptrs, c, mask=mask)
M, N = x.size()
y = torch.empty_like(x, dtype=torch.get_default_dtype())
def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor): grid = lambda meta: (
""" triton.cdiv(M, meta['BLOCK_SIZE']),
Perform a matrix multiplication using FP8 precision. triton.cdiv(N, meta['BLOCK_SIZE'])
)
QuantizationKernels.weight_dequant_kernel[grid](x, s, y, M, N, BLOCK_SIZE=block_size)
Args: return y
a (torch.Tensor): The first input matrix, must be contiguous.
a_s (torch.Tensor): The scaling factor for the first input matrix, must be contiguous.
b (torch.Tensor): The second input matrix, must be contiguous.
b_s (torch.Tensor): The scaling factor for the second input matrix, must be contiguous.
Returns: @staticmethod
torch.Tensor: The result of the matrix multiplication. def fp8_gemm(a: torch.Tensor, a_s: torch.Tensor, b: torch.Tensor, b_s: torch.Tensor) -> torch.Tensor:
""" """
assert a.is_contiguous() and b.is_contiguous() Perform FP8 matrix multiplication.
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1) Args:
M = a.numel() // K a: First input matrix
N = b.size(0) a_s: First matrix scaling factors
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype()) b: Second input matrix
grid = lambda META: (triton.cdiv(M, META['BLOCK_SIZE_M']), triton.cdiv(N, META['BLOCK_SIZE_N'])) b_s: Second matrix scaling factors
fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c Returns:
Result matrix
"""
assert a.is_contiguous() and b.is_contiguous()
assert a_s.is_contiguous() and b_s.is_contiguous()
K = a.size(-1)
M = a.numel() // K
N = b.size(0)
c = a.new_empty(*a.size()[:-1], N, dtype=torch.get_default_dtype())
grid = lambda META: (
triton.cdiv(M, META['BLOCK_SIZE_M']),
triton.cdiv(N, META['BLOCK_SIZE_N'])
)
MatrixMultKernels.fp8_gemm_kernel[grid](a, b, c, a_s, b_s, M, N, K)
return c