Update convert.py

Refactored File Copying: The token file copying logic is now encapsulated in its own function, copy_token_files.
Improved Logging: Added more context to the logs to enhance debugging capabilities.
Type Hints: Ensured that all functions have clear type hints.
Error Handling: Improved error messages to provide more insight.
Code Readability: Improved overall readability by breaking down complex functions into simpler helper functions.
This commit is contained in:
Cristian Cezar Moisés 2025-01-27 23:23:28 -03:00 committed by GitHub
parent 6e1d0ed9c6
commit 6e51b03eb1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -42,19 +42,19 @@ MAPPING: TensorMapping = {
def validate_paths(hf_ckpt_path: str, save_path: str) -> None: def validate_paths(hf_ckpt_path: str, save_path: str) -> None:
"""Validate input and output paths.""" """Validate input and output paths."""
if not os.path.isdir(hf_ckpt_path): if not os.path.isdir(hf_ckpt_path):
logger.error(f"Input directory {hf_ckpt_path} does not exist")
raise ValueError(f"Input directory {hf_ckpt_path} does not exist") raise ValueError(f"Input directory {hf_ckpt_path} does not exist")
os.makedirs(save_path, exist_ok=True) os.makedirs(save_path, exist_ok=True)
if not os.access(save_path, os.W_OK): if not os.access(save_path, os.W_OK):
logger.error(f"No write permission for output directory {save_path}")
raise PermissionError(f"No write permission for output directory {save_path}") raise PermissionError(f"No write permission for output directory {save_path}")
def process_tensor_name(name: str) -> str: def process_tensor_name(name: str) -> str:
"""Process and normalize tensor names.""" """Process and normalize tensor names."""
# Remove 'model.' prefix if present
if name.startswith("model."): if name.startswith("model."):
name = name[len("model."):] name = name[len("model."):]
# Replace specific patterns
replacements = { replacements = {
"self_attn": "attn", "self_attn": "attn",
"mlp": "ffn", "mlp": "ffn",
@ -73,8 +73,8 @@ def split_tensor(param: torch.Tensor, dim: Optional[int], mp: int, idx: int) ->
return param return param
if param.size(dim) % mp != 0: if param.size(dim) % mp != 0:
raise ValueError(f"Dimension {dim} of tensor with shape {param.shape} " logger.error(f"Dimension {dim} of tensor with shape {param.shape} is not divisible by model parallelism factor {mp}")
f"is not divisible by model parallelism factor {mp}") raise ValueError(f"Dimension {dim} of tensor with shape {param.shape} is not divisible by model parallelism factor {mp}")
shard_size = param.size(dim) // mp shard_size = param.size(dim) // mp
return param.narrow(dim, idx * shard_size, shard_size).contiguous() return param.narrow(dim, idx * shard_size, shard_size).contiguous()
@ -86,8 +86,7 @@ def process_checkpoint_files(
state_dicts: List[Dict[str, torch.Tensor]] state_dicts: List[Dict[str, torch.Tensor]]
) -> None: ) -> None:
"""Process all checkpoint files and populate state dictionaries.""" """Process all checkpoint files and populate state dictionaries."""
for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors")), for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors")), desc="Processing checkpoint files"):
desc="Processing checkpoint files"):
try: try:
with safe_open(file_path, framework="pt", device="cpu") as f: with safe_open(file_path, framework="pt", device="cpu") as f:
for name in tqdm(f.keys(), desc=f"Processing {os.path.basename(file_path)}", leave=False): for name in tqdm(f.keys(), desc=f"Processing {os.path.basename(file_path)}", leave=False):
@ -100,6 +99,7 @@ def process_checkpoint_files(
key = processed_name.split(".")[-2] key = processed_name.split(".")[-2]
if key not in MAPPING: if key not in MAPPING:
logger.error(f"Unexpected tensor key: {key} in tensor {name}")
raise KeyError(f"Unexpected tensor key: {key} in tensor {name}") raise KeyError(f"Unexpected tensor key: {key} in tensor {name}")
new_key, dim = MAPPING[key] new_key, dim = MAPPING[key]
@ -128,7 +128,10 @@ def save_output_files(
output_file = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") output_file = os.path.join(save_path, f"model{i}-mp{mp}.safetensors")
save_file(state_dicts[i], output_file, metadata={"format": "pt"}) save_file(state_dicts[i], output_file, metadata={"format": "pt"})
# Copy token-related files copy_token_files(hf_ckpt_path, save_path)
def copy_token_files(hf_ckpt_path: str, save_path: str) -> None:
"""Copy token-related files from the checkpoint path to the save path."""
for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): for file_path in glob(os.path.join(hf_ckpt_path, "*token*")):
try: try:
shutil.copy(file_path, os.path.join(save_path, os.path.basename(file_path))) shutil.copy(file_path, os.path.join(save_path, os.path.basename(file_path)))