mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 10:08:59 -04:00
Update fp8_cast_bf16.py
Increased Clarity: Added more comments and detailed docstrings to improve clarity and maintainability. Efficient Dictionary Comprehension: Used dictionary comprehension to filter out None values in new_state_dict. Safe Dictionary Modification: Used pop with a default value to safely remove keys from the dictionary without raising exceptions. Consistent Type Hinting: Enhanced type hints for better clarity and consistency.
This commit is contained in:
parent
6e51b03eb1
commit
e1ed2e8465
@ -4,7 +4,7 @@ import logging
|
||||
from argparse import ArgumentParser
|
||||
from glob import glob
|
||||
from pathlib import Path
|
||||
from typing import Dict, Tuple, Optional
|
||||
from typing import Dict, Tuple, Optional, List
|
||||
from collections import OrderedDict
|
||||
|
||||
from tqdm import tqdm
|
||||
@ -50,7 +50,7 @@ def process_weight(
|
||||
weight_map: Dict[str, str],
|
||||
file_cache: OrderedDict,
|
||||
fp8_path: Path,
|
||||
fp8_weight_names: list
|
||||
fp8_weight_names: List[str]
|
||||
) -> Optional[torch.Tensor]:
|
||||
"""Process a single weight tensor."""
|
||||
if weight_name.endswith("_scale_inv"):
|
||||
@ -67,7 +67,7 @@ def handle_fp8_weight(
|
||||
weight_map: Dict[str, str],
|
||||
file_cache: OrderedDict,
|
||||
fp8_path: Path,
|
||||
fp8_weight_names: list
|
||||
fp8_weight_names: List[str]
|
||||
) -> torch.Tensor:
|
||||
"""Handle FP8 weight conversion to BF16."""
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
@ -119,24 +119,23 @@ def process_safetensor_file(
|
||||
weight_map: Dict[str, str],
|
||||
file_cache: OrderedDict,
|
||||
fp8_path: Path,
|
||||
fp8_weight_names: list
|
||||
fp8_weight_names: List[str]
|
||||
) -> None:
|
||||
"""Process a single safetensors file."""
|
||||
try:
|
||||
current_state_dict = load_file(str(file_path), device=TORCH_DEVICE)
|
||||
file_cache[file_path.name] = current_state_dict
|
||||
|
||||
new_state_dict = {}
|
||||
for weight_name, weight in tqdm(current_state_dict.items(),
|
||||
desc=f"Processing {file_path.name}",
|
||||
leave=False):
|
||||
processed_weight = process_weight(
|
||||
weight_name, weight, weight_map,
|
||||
file_cache, fp8_path, fp8_weight_names
|
||||
)
|
||||
if processed_weight is not None:
|
||||
new_state_dict[weight_name] = processed_weight
|
||||
|
||||
new_state_dict = {
|
||||
weight_name: process_weight(weight_name, weight, weight_map, file_cache, fp8_path, fp8_weight_names)
|
||||
for weight_name, weight in tqdm(current_state_dict.items(),
|
||||
desc=f"Processing {file_path.name}",
|
||||
leave=False)
|
||||
}
|
||||
|
||||
# Remove None values from new_state_dict
|
||||
new_state_dict = {k: v for k, v in new_state_dict.items() if v is not None}
|
||||
|
||||
save_converted_file(new_state_dict, file_path.name, bf16_path)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process {file_path.name}: {str(e)}")
|
||||
@ -152,12 +151,11 @@ def save_converted_file(state_dict: Dict[str, torch.Tensor], filename: str, bf16
|
||||
logger.error(f"Failed to save {filename}: {str(e)}")
|
||||
raise
|
||||
|
||||
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: list, bf16_path: Path) -> None:
|
||||
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: List[str], bf16_path: Path) -> None:
|
||||
"""Update model index file with converted weights."""
|
||||
for weight_name in fp8_weight_names:
|
||||
scale_inv_name = f"{weight_name}_scale_inv"
|
||||
if scale_inv_name in weight_map:
|
||||
del weight_map[scale_inv_name]
|
||||
weight_map.pop(scale_inv_name, None) # Use pop with default to avoid KeyError
|
||||
|
||||
index_path = bf16_path / "model.safetensors.index.json"
|
||||
try:
|
||||
|
Loading…
Reference in New Issue
Block a user