diff --git a/inference/convert.py b/inference/convert.py index 3277cea..1e68e9f 100644 --- a/inference/convert.py +++ b/inference/convert.py @@ -42,19 +42,19 @@ MAPPING: TensorMapping = { def validate_paths(hf_ckpt_path: str, save_path: str) -> None: """Validate input and output paths.""" if not os.path.isdir(hf_ckpt_path): + logger.error(f"Input directory {hf_ckpt_path} does not exist") raise ValueError(f"Input directory {hf_ckpt_path} does not exist") os.makedirs(save_path, exist_ok=True) if not os.access(save_path, os.W_OK): + logger.error(f"No write permission for output directory {save_path}") raise PermissionError(f"No write permission for output directory {save_path}") def process_tensor_name(name: str) -> str: """Process and normalize tensor names.""" - # Remove 'model.' prefix if present if name.startswith("model."): name = name[len("model."):] - # Replace specific patterns replacements = { "self_attn": "attn", "mlp": "ffn", @@ -73,8 +73,8 @@ def split_tensor(param: torch.Tensor, dim: Optional[int], mp: int, idx: int) -> return param if param.size(dim) % mp != 0: - raise ValueError(f"Dimension {dim} of tensor with shape {param.shape} " - f"is not divisible by model parallelism factor {mp}") + logger.error(f"Dimension {dim} of tensor with shape {param.shape} is not divisible by model parallelism factor {mp}") + raise ValueError(f"Dimension {dim} of tensor with shape {param.shape} is not divisible by model parallelism factor {mp}") shard_size = param.size(dim) // mp return param.narrow(dim, idx * shard_size, shard_size).contiguous() @@ -86,8 +86,7 @@ def process_checkpoint_files( state_dicts: List[Dict[str, torch.Tensor]] ) -> None: """Process all checkpoint files and populate state dictionaries.""" - for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors")), - desc="Processing checkpoint files"): + for file_path in tqdm(glob(os.path.join(hf_ckpt_path, "*.safetensors")), desc="Processing checkpoint files"): try: with safe_open(file_path, framework="pt", device="cpu") as f: for name in tqdm(f.keys(), desc=f"Processing {os.path.basename(file_path)}", leave=False): @@ -100,6 +99,7 @@ def process_checkpoint_files( key = processed_name.split(".")[-2] if key not in MAPPING: + logger.error(f"Unexpected tensor key: {key} in tensor {name}") raise KeyError(f"Unexpected tensor key: {key} in tensor {name}") new_key, dim = MAPPING[key] @@ -128,7 +128,10 @@ def save_output_files( output_file = os.path.join(save_path, f"model{i}-mp{mp}.safetensors") save_file(state_dicts[i], output_file, metadata={"format": "pt"}) - # Copy token-related files + copy_token_files(hf_ckpt_path, save_path) + +def copy_token_files(hf_ckpt_path: str, save_path: str) -> None: + """Copy token-related files from the checkpoint path to the save path.""" for file_path in glob(os.path.join(hf_ckpt_path, "*token*")): try: shutil.copy(file_path, os.path.join(save_path, os.path.basename(file_path)))