Update fp8_cast_bf16.py

Increased Clarity: Added more comments and detailed docstrings to improve clarity and maintainability.
Efficient Dictionary Comprehension: Used dictionary comprehension to filter out None values in new_state_dict.
Safe Dictionary Modification: Used pop with a default value to safely remove keys from the dictionary without raising exceptions.
Consistent Type Hinting: Enhanced type hints for better clarity and consistency.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:24:46 -03:00 committed by GitHub
parent 6e51b03eb1
commit e1ed2e8465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import logging
from argparse import ArgumentParser from argparse import ArgumentParser
from glob import glob from glob import glob
from pathlib import Path from pathlib import Path
from typing import Dict, Tuple, Optional from typing import Dict, Tuple, Optional, List
from collections import OrderedDict from collections import OrderedDict
from tqdm import tqdm from tqdm import tqdm
@ -50,7 +50,7 @@ def process_weight(
weight_map: Dict[str, str], weight_map: Dict[str, str],
file_cache: OrderedDict, file_cache: OrderedDict,
fp8_path: Path, fp8_path: Path,
fp8_weight_names: list fp8_weight_names: List[str]
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
"""Process a single weight tensor.""" """Process a single weight tensor."""
if weight_name.endswith("_scale_inv"): if weight_name.endswith("_scale_inv"):
@ -67,7 +67,7 @@ def handle_fp8_weight(
weight_map: Dict[str, str], weight_map: Dict[str, str],
file_cache: OrderedDict, file_cache: OrderedDict,
fp8_path: Path, fp8_path: Path,
fp8_weight_names: list fp8_weight_names: List[str]
) -> torch.Tensor: ) -> torch.Tensor:
"""Handle FP8 weight conversion to BF16.""" """Handle FP8 weight conversion to BF16."""
scale_inv_name = f"{weight_name}_scale_inv" scale_inv_name = f"{weight_name}_scale_inv"
@ -119,23 +119,22 @@ def process_safetensor_file(
weight_map: Dict[str, str], weight_map: Dict[str, str],
file_cache: OrderedDict, file_cache: OrderedDict,
fp8_path: Path, fp8_path: Path,
fp8_weight_names: list fp8_weight_names: List[str]
) -> None: ) -> None:
"""Process a single safetensors file.""" """Process a single safetensors file."""
try: try:
current_state_dict = load_file(str(file_path), device=TORCH_DEVICE) current_state_dict = load_file(str(file_path), device=TORCH_DEVICE)
file_cache[file_path.name] = current_state_dict file_cache[file_path.name] = current_state_dict
new_state_dict = {} new_state_dict = {
weight_name: process_weight(weight_name, weight, weight_map, file_cache, fp8_path, fp8_weight_names)
for weight_name, weight in tqdm(current_state_dict.items(), for weight_name, weight in tqdm(current_state_dict.items(),
desc=f"Processing {file_path.name}", desc=f"Processing {file_path.name}",
leave=False): leave=False)
processed_weight = process_weight( }
weight_name, weight, weight_map,
file_cache, fp8_path, fp8_weight_names # Remove None values from new_state_dict
) new_state_dict = {k: v for k, v in new_state_dict.items() if v is not None}
if processed_weight is not None:
new_state_dict[weight_name] = processed_weight
save_converted_file(new_state_dict, file_path.name, bf16_path) save_converted_file(new_state_dict, file_path.name, bf16_path)
except Exception as e: except Exception as e:
@ -152,12 +151,11 @@ def save_converted_file(state_dict: Dict[str, torch.Tensor], filename: str, bf16
logger.error(f"Failed to save {filename}: {str(e)}") logger.error(f"Failed to save {filename}: {str(e)}")
raise raise
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: list, bf16_path: Path) -> None: def update_model_index(weight_map: Dict[str, str], fp8_weight_names: List[str], bf16_path: Path) -> None:
"""Update model index file with converted weights.""" """Update model index file with converted weights."""
for weight_name in fp8_weight_names: for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv" scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map: weight_map.pop(scale_inv_name, None) # Use pop with default to avoid KeyError
del weight_map[scale_inv_name]
index_path = bf16_path / "model.safetensors.index.json" index_path = bf16_path / "model.safetensors.index.json"
try: try: