Update fp8_cast_bf16.py

Increased Clarity: Added more comments and detailed docstrings to improve clarity and maintainability.
Efficient Dictionary Comprehension: Used dictionary comprehension to filter out None values in new_state_dict.
Safe Dictionary Modification: Used pop with a default value to safely remove keys from the dictionary without raising exceptions.
Consistent Type Hinting: Enhanced type hints for better clarity and consistency.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:24:46 -03:00 committed by GitHub
parent 6e51b03eb1
commit e1ed2e8465
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -4,7 +4,7 @@ import logging
from argparse import ArgumentParser
from glob import glob
from pathlib import Path
from typing import Dict, Tuple, Optional
from typing import Dict, Tuple, Optional, List
from collections import OrderedDict
from tqdm import tqdm
@ -50,7 +50,7 @@ def process_weight(
weight_map: Dict[str, str],
file_cache: OrderedDict,
fp8_path: Path,
fp8_weight_names: list
fp8_weight_names: List[str]
) -> Optional[torch.Tensor]:
"""Process a single weight tensor."""
if weight_name.endswith("_scale_inv"):
@ -67,7 +67,7 @@ def handle_fp8_weight(
weight_map: Dict[str, str],
file_cache: OrderedDict,
fp8_path: Path,
fp8_weight_names: list
fp8_weight_names: List[str]
) -> torch.Tensor:
"""Handle FP8 weight conversion to BF16."""
scale_inv_name = f"{weight_name}_scale_inv"
@ -119,24 +119,23 @@ def process_safetensor_file(
weight_map: Dict[str, str],
file_cache: OrderedDict,
fp8_path: Path,
fp8_weight_names: list
fp8_weight_names: List[str]
) -> None:
"""Process a single safetensors file."""
try:
current_state_dict = load_file(str(file_path), device=TORCH_DEVICE)
file_cache[file_path.name] = current_state_dict
new_state_dict = {}
for weight_name, weight in tqdm(current_state_dict.items(),
desc=f"Processing {file_path.name}",
leave=False):
processed_weight = process_weight(
weight_name, weight, weight_map,
file_cache, fp8_path, fp8_weight_names
)
if processed_weight is not None:
new_state_dict[weight_name] = processed_weight
new_state_dict = {
weight_name: process_weight(weight_name, weight, weight_map, file_cache, fp8_path, fp8_weight_names)
for weight_name, weight in tqdm(current_state_dict.items(),
desc=f"Processing {file_path.name}",
leave=False)
}
# Remove None values from new_state_dict
new_state_dict = {k: v for k, v in new_state_dict.items() if v is not None}
save_converted_file(new_state_dict, file_path.name, bf16_path)
except Exception as e:
logger.error(f"Failed to process {file_path.name}: {str(e)}")
@ -152,12 +151,11 @@ def save_converted_file(state_dict: Dict[str, torch.Tensor], filename: str, bf16
logger.error(f"Failed to save {filename}: {str(e)}")
raise
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: list, bf16_path: Path) -> None:
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: List[str], bf16_path: Path) -> None:
"""Update model index file with converted weights."""
for weight_name in fp8_weight_names:
scale_inv_name = f"{weight_name}_scale_inv"
if scale_inv_name in weight_map:
del weight_map[scale_inv_name]
weight_map.pop(scale_inv_name, None) # Use pop with default to avoid KeyError
index_path = bf16_path / "model.safetensors.index.json"
try: