import os import json import torch from torch import nn from safetensors.torch import load_file from transformers import AutoModelForCausalLM, AutoTokenizer def to_buffer(module, mark_param=True): """Turns all parameters of a module into buffers.""" if module is None: return modules = module.modules() module = next(modules) delattrs = [] for name, param in module.named_parameters(recurse=False): delattrs.append([module, name, param]) if mark_param and delattrs: old_param_list = getattr(module, 'param_list', []) module.param_list = old_param_list + [name for _, name, _ in delattrs] for module, name, _ in delattrs: delattr(module, name) # Unregister parameter for module, name, param in delattrs: module.register_buffer(name, param.data, persistent=False) for module in modules: to_buffer(module, mark_param=mark_param) def to_param(module): """Turns all buffers of a module into parameterss.""" if module is None: return modules = module.modules() module = next(modules) param_list = getattr(module, 'param_list', []) for name in param_list: buffer = getattr(module, name) delattr(module, name) # Delete buffer setattr(module, name, nn.Parameter(buffer)) for module in modules: to_param(module) def recursive_getattr(model, module_name): split_list = module_name.split('.') output = model for name in split_list: output = getattr(output, name) return output def recursive_setattr(model, module_name, module): split_list = module_name.split('.') output = model for name in split_list[:-1]: output = getattr(output, name) output.__setattr__(split_list[-1], module) def to_esft(model, adapter_config): if not adapter_config.get('non_expert_modules', False): to_buffer(model) else: to_param(model) for idx, layer in enumerate(model.model.layers): if type(layer.mlp).__name__ != "DeepseekV2MoE": continue if adapter_config.get('shared_experts', False): to_param(layer.mlp.shared_experts) else: to_buffer(layer.mlp.shared_experts) trainable_experts = adapter_config['experts'][str(idx)] for expert_id in range(len(layer.mlp.experts)): if expert_id in trainable_experts: to_param(layer.mlp.experts[expert_id]) else: to_buffer(layer.mlp.experts[expert_id]) return model def load_state_dict(folder_path): # 初始化空的 state_dict combined_state_dict = {} # 遍历文件夹中的所有文件 for file_name in os.listdir(folder_path): if file_name.endswith('.safetensors'): file_path = os.path.join(folder_path, file_name) state_dict = load_file(file_path) combined_state_dict.update(state_dict) # legacy for loading v1 checkpoints: add prefix "model." for parameters for k in list(combined_state_dict.keys()): if k.startswith("layers"): k_new = "model." + k combined_state_dict[k_new] = combined_state_dict[k] del combined_state_dict[k] return combined_state_dict def load_esft_model(base_model_path, adapter_dir): adapter_config = json.load(open(adapter_dir + "/expert_cfg.json")) adapter_state_dict = load_state_dict(adapter_dir) # load pretrained model: model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path) to_esft(model, adapter_config) model.load_state_dict(adapter_state_dict) return model, tokenizer def load_base_model(base_model_path): # load pretrained model: model, tokenizer = AutoModelForCausalLM.from_pretrained(base_model_path, trust_remote_code=True, torch_dtype=torch.bfloat16), AutoTokenizer.from_pretrained(base_model_path) return model, tokenizer def _get_expert_id(key): "input: model.layers.25.mlp.experts.10.xx.xx, output: ('25', 10)" return str(key.split(".")[2]), int(key.split(".")[5]) def _build_expert_dict(expert_list): expert_dict = {} for layer, expert in expert_list: if layer not in expert_dict: expert_dict[layer] = [] expert_dict[layer].append(expert) for layer in expert_dict: expert_dict[layer] = sorted(list(set(expert_dict[layer]))) return expert_dict def _dict_equal(dict1, dict2): if set(dict1.keys()) != set(dict2.keys()): return False else: keys = set(dict1.keys()) return all([sorted(tuple(dict1[k])) == sorted(tuple(dict2[k])) for k in keys]) def add_adapter(base_model, adapter_dir, return_original_states=False, expert_config=None): if expert_config is not None: adapter_config = json.load(open(expert_config)) else: adapter_config = json.load(open(adapter_dir + "/expert_cfg.json")) adapter_state_dict = load_state_dict(adapter_dir) expert_in_param = list([_get_expert_id(i) for i in adapter_state_dict.keys() if "expert" in i]) # expert_in_param: [('1', 8), ('1', 9), ('2', 0), ('2', 1), ('2', 2), ('2', 3), ('2', 4), ('2', 5), ('2', 6), ('2', 7)] # expert_in_param_dict: {'1': [8, 9], '2': [0, 1, 2, 3, 4, 5, 6, 7]} expert_in_param_dict = _build_expert_dict(expert_in_param) if not _dict_equal(adapter_config['experts'], expert_in_param_dict): print(adapter_config['experts']) print(expert_in_param_dict) raise ValueError("expert_config and expert_in_param_dict are not consistent") to_esft(base_model, adapter_config) if return_original_states: original_state_dict = {k:v.cpu() for k, v in base_model.state_dict().items()} base_model.load_state_dict(adapter_state_dict, strict=False) return base_model, original_state_dict else: base_model.load_state_dict(adapter_state_dict) return base_model