import gc import os import re import tinycudann as tcnn import torch from packaging import version from threestudio.utils.config import config_to_primitive from threestudio.utils.typing import * def parse_version(ver: str): return version.parse(ver) def get_rank(): # SLURM_PROCID can be set even if SLURM is not managing the multiprocessing, # therefore LOCAL_RANK needs to be checked first rank_keys = ("LOCAL_RANK", "RANK", "SLURM_PROCID", "JSM_NAMESPACE_RANK") for key in rank_keys: rank = os.environ.get(key) if rank is not None: return int(rank) return 0 def get_device(): return torch.device(f"cuda:{get_rank()}") def load_module_weights( path, module_name=None, ignore_modules=None, map_location=None ) -> Tuple[dict, int, int]: if module_name is not None and ignore_modules is not None: raise ValueError("module_name and ignore_modules cannot be both set") if map_location is None: map_location = get_device() ckpt = torch.load(path, map_location=map_location) state_dict = ckpt["state_dict"] state_dict_to_load = state_dict if ignore_modules is not None: state_dict_to_load = {} for k, v in state_dict.items(): ignore = any( [k.startswith(ignore_module + ".") for ignore_module in ignore_modules] ) if ignore: continue state_dict_to_load[k] = v if module_name is not None: state_dict_to_load = {} for k, v in state_dict.items(): m = re.match(rf"^{module_name}\.(.*)$", k) if m is None: continue state_dict_to_load[m.group(1)] = v return state_dict_to_load, ckpt["epoch"], ckpt["global_step"] def C(value: Any, epoch: int, global_step: int) -> float: if isinstance(value, int) or isinstance(value, float): pass else: value = config_to_primitive(value) if not isinstance(value, list): raise TypeError("Scalar specification only supports list, got", type(value)) if len(value) == 3: value = [0] + value if len(value) >= 6: select_i = 3 for i in range(3, len(value) - 2, 2): if global_step >= value[i]: select_i = i + 2 if select_i != 3: start_value, start_step = value[select_i - 3], value[select_i - 2] else: start_step, start_value = value[:2] end_value, end_step = value[select_i - 1], value[select_i] value = [start_step, start_value, end_value, end_step] assert len(value) == 4 start_step, start_value, end_value, end_step = value if isinstance(end_step, int): current_step = global_step value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) elif isinstance(end_step, float): current_step = epoch value = start_value + (end_value - start_value) * max( min(1.0, (current_step - start_step) / (end_step - start_step)), 0.0 ) return value def cleanup(): gc.collect() torch.cuda.empty_cache() tcnn.free_temporary_memory() def finish_with_cleanup(func: Callable): def wrapper(*args, **kwargs): out = func(*args, **kwargs) cleanup() return out return wrapper def _distributed_available(): return torch.distributed.is_available() and torch.distributed.is_initialized() def barrier(): if not _distributed_available(): return else: torch.distributed.barrier() def broadcast(tensor, src=0): if not _distributed_available(): return tensor else: torch.distributed.broadcast(tensor, src=src) return tensor def enable_gradient(model, enabled: bool = True) -> None: for param in model.parameters(): param.requires_grad_(enabled) def find_last_path(path: str): if (path is not None) and ("LAST" in path): path = path.replace(" ", "_") base_dir_prefix, suffix = path.split("LAST", 1) base_dir = os.path.dirname(base_dir_prefix) prefix = os.path.split(base_dir_prefix)[-1] base_dir_prefix = os.path.join(base_dir, prefix) all_path = os.listdir(base_dir) all_path = [os.path.join(base_dir, dir) for dir in all_path] filtered_path = [dir for dir in all_path if dir.startswith(base_dir_prefix)] filtered_path.sort(reverse=True) last_path = filtered_path[0] new_path = last_path + suffix if os.path.exists(new_path): return new_path else: raise FileNotFoundError(new_path) else: return path