mirror of
https://github.com/deepseek-ai/DeepSeek-V3.git
synced 2025-04-19 18:18:57 -04:00
Update fp8_cast_bf16.py
Increased Clarity: Added more comments and detailed docstrings to improve clarity and maintainability. Efficient Dictionary Comprehension: Used dictionary comprehension to filter out None values in new_state_dict. Safe Dictionary Modification: Used pop with a default value to safely remove keys from the dictionary without raising exceptions. Consistent Type Hinting: Enhanced type hints for better clarity and consistency.
This commit is contained in:
parent
6e51b03eb1
commit
e1ed2e8465
@ -4,7 +4,7 @@ import logging
|
|||||||
from argparse import ArgumentParser
|
from argparse import ArgumentParser
|
||||||
from glob import glob
|
from glob import glob
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Tuple, Optional
|
from typing import Dict, Tuple, Optional, List
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
@ -50,7 +50,7 @@ def process_weight(
|
|||||||
weight_map: Dict[str, str],
|
weight_map: Dict[str, str],
|
||||||
file_cache: OrderedDict,
|
file_cache: OrderedDict,
|
||||||
fp8_path: Path,
|
fp8_path: Path,
|
||||||
fp8_weight_names: list
|
fp8_weight_names: List[str]
|
||||||
) -> Optional[torch.Tensor]:
|
) -> Optional[torch.Tensor]:
|
||||||
"""Process a single weight tensor."""
|
"""Process a single weight tensor."""
|
||||||
if weight_name.endswith("_scale_inv"):
|
if weight_name.endswith("_scale_inv"):
|
||||||
@ -67,7 +67,7 @@ def handle_fp8_weight(
|
|||||||
weight_map: Dict[str, str],
|
weight_map: Dict[str, str],
|
||||||
file_cache: OrderedDict,
|
file_cache: OrderedDict,
|
||||||
fp8_path: Path,
|
fp8_path: Path,
|
||||||
fp8_weight_names: list
|
fp8_weight_names: List[str]
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Handle FP8 weight conversion to BF16."""
|
"""Handle FP8 weight conversion to BF16."""
|
||||||
scale_inv_name = f"{weight_name}_scale_inv"
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
@ -119,23 +119,22 @@ def process_safetensor_file(
|
|||||||
weight_map: Dict[str, str],
|
weight_map: Dict[str, str],
|
||||||
file_cache: OrderedDict,
|
file_cache: OrderedDict,
|
||||||
fp8_path: Path,
|
fp8_path: Path,
|
||||||
fp8_weight_names: list
|
fp8_weight_names: List[str]
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Process a single safetensors file."""
|
"""Process a single safetensors file."""
|
||||||
try:
|
try:
|
||||||
current_state_dict = load_file(str(file_path), device=TORCH_DEVICE)
|
current_state_dict = load_file(str(file_path), device=TORCH_DEVICE)
|
||||||
file_cache[file_path.name] = current_state_dict
|
file_cache[file_path.name] = current_state_dict
|
||||||
|
|
||||||
new_state_dict = {}
|
new_state_dict = {
|
||||||
|
weight_name: process_weight(weight_name, weight, weight_map, file_cache, fp8_path, fp8_weight_names)
|
||||||
for weight_name, weight in tqdm(current_state_dict.items(),
|
for weight_name, weight in tqdm(current_state_dict.items(),
|
||||||
desc=f"Processing {file_path.name}",
|
desc=f"Processing {file_path.name}",
|
||||||
leave=False):
|
leave=False)
|
||||||
processed_weight = process_weight(
|
}
|
||||||
weight_name, weight, weight_map,
|
|
||||||
file_cache, fp8_path, fp8_weight_names
|
# Remove None values from new_state_dict
|
||||||
)
|
new_state_dict = {k: v for k, v in new_state_dict.items() if v is not None}
|
||||||
if processed_weight is not None:
|
|
||||||
new_state_dict[weight_name] = processed_weight
|
|
||||||
|
|
||||||
save_converted_file(new_state_dict, file_path.name, bf16_path)
|
save_converted_file(new_state_dict, file_path.name, bf16_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -152,12 +151,11 @@ def save_converted_file(state_dict: Dict[str, torch.Tensor], filename: str, bf16
|
|||||||
logger.error(f"Failed to save {filename}: {str(e)}")
|
logger.error(f"Failed to save {filename}: {str(e)}")
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: list, bf16_path: Path) -> None:
|
def update_model_index(weight_map: Dict[str, str], fp8_weight_names: List[str], bf16_path: Path) -> None:
|
||||||
"""Update model index file with converted weights."""
|
"""Update model index file with converted weights."""
|
||||||
for weight_name in fp8_weight_names:
|
for weight_name in fp8_weight_names:
|
||||||
scale_inv_name = f"{weight_name}_scale_inv"
|
scale_inv_name = f"{weight_name}_scale_inv"
|
||||||
if scale_inv_name in weight_map:
|
weight_map.pop(scale_inv_name, None) # Use pop with default to avoid KeyError
|
||||||
del weight_map[scale_inv_name]
|
|
||||||
|
|
||||||
index_path = bf16_path / "model.safetensors.index.json"
|
index_path = bf16_path / "model.safetensors.index.json"
|
||||||
try:
|
try:
|
||||||
|
Loading…
Reference in New Issue
Block a user