mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 06:18:56 -05:00
154 lines
4.6 KiB
Python
154 lines
4.6 KiB
Python
|
import hashlib
|
||
|
import os
|
||
|
|
||
|
import requests
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
URL_MAP = {"vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"}
|
||
|
|
||
|
CKPT_MAP = {"vgg_lpips": "vgg.pth"}
|
||
|
|
||
|
MD5_MAP = {"vgg_lpips": "d507d7349b931f0638a25a48a722f98a"}
|
||
|
|
||
|
|
||
|
def download(url, local_path, chunk_size=1024):
|
||
|
os.makedirs(os.path.split(local_path)[0], exist_ok=True)
|
||
|
with requests.get(url, stream=True) as r:
|
||
|
total_size = int(r.headers.get("content-length", 0))
|
||
|
with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
|
||
|
with open(local_path, "wb") as f:
|
||
|
for data in r.iter_content(chunk_size=chunk_size):
|
||
|
if data:
|
||
|
f.write(data)
|
||
|
pbar.update(chunk_size)
|
||
|
|
||
|
|
||
|
def md5_hash(path):
|
||
|
with open(path, "rb") as f:
|
||
|
content = f.read()
|
||
|
return hashlib.md5(content).hexdigest()
|
||
|
|
||
|
|
||
|
def get_ckpt_path(name, root, check=False):
|
||
|
assert name in URL_MAP
|
||
|
path = os.path.join(root, CKPT_MAP[name])
|
||
|
if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
|
||
|
print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
|
||
|
download(URL_MAP[name], path)
|
||
|
md5 = md5_hash(path)
|
||
|
assert md5 == MD5_MAP[name], md5
|
||
|
return path
|
||
|
|
||
|
|
||
|
class KeyNotFoundError(Exception):
|
||
|
def __init__(self, cause, keys=None, visited=None):
|
||
|
self.cause = cause
|
||
|
self.keys = keys
|
||
|
self.visited = visited
|
||
|
messages = list()
|
||
|
if keys is not None:
|
||
|
messages.append("Key not found: {}".format(keys))
|
||
|
if visited is not None:
|
||
|
messages.append("Visited: {}".format(visited))
|
||
|
messages.append("Cause:\n{}".format(cause))
|
||
|
message = "\n".join(messages)
|
||
|
super().__init__(message)
|
||
|
|
||
|
|
||
|
def retrieve(
|
||
|
list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
|
||
|
):
|
||
|
"""Given a nested list or dict return the desired value at key expanding
|
||
|
callable nodes if necessary and :attr:`expand` is ``True``. The expansion
|
||
|
is done in-place.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
list_or_dict : list or dict
|
||
|
Possibly nested list or dictionary.
|
||
|
key : str
|
||
|
key/to/value, path like string describing all keys necessary to
|
||
|
consider to get to the desired value. List indices can also be
|
||
|
passed here.
|
||
|
splitval : str
|
||
|
String that defines the delimiter between keys of the
|
||
|
different depth levels in `key`.
|
||
|
default : obj
|
||
|
Value returned if :attr:`key` is not found.
|
||
|
expand : bool
|
||
|
Whether to expand callable nodes on the path or not.
|
||
|
|
||
|
Returns
|
||
|
-------
|
||
|
The desired value or if :attr:`default` is not ``None`` and the
|
||
|
:attr:`key` is not found returns ``default``.
|
||
|
|
||
|
Raises
|
||
|
------
|
||
|
Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
|
||
|
``None``.
|
||
|
"""
|
||
|
|
||
|
keys = key.split(splitval)
|
||
|
|
||
|
success = True
|
||
|
try:
|
||
|
visited = []
|
||
|
parent = None
|
||
|
last_key = None
|
||
|
for key in keys:
|
||
|
if callable(list_or_dict):
|
||
|
if not expand:
|
||
|
raise KeyNotFoundError(
|
||
|
ValueError(
|
||
|
"Trying to get past callable node with expand=False."
|
||
|
),
|
||
|
keys=keys,
|
||
|
visited=visited,
|
||
|
)
|
||
|
list_or_dict = list_or_dict()
|
||
|
parent[last_key] = list_or_dict
|
||
|
|
||
|
last_key = key
|
||
|
parent = list_or_dict
|
||
|
|
||
|
try:
|
||
|
if isinstance(list_or_dict, dict):
|
||
|
list_or_dict = list_or_dict[key]
|
||
|
else:
|
||
|
list_or_dict = list_or_dict[int(key)]
|
||
|
except (KeyError, IndexError, ValueError) as e:
|
||
|
raise KeyNotFoundError(e, keys=keys, visited=visited)
|
||
|
|
||
|
visited += [key]
|
||
|
# final expansion of retrieved value
|
||
|
if expand and callable(list_or_dict):
|
||
|
list_or_dict = list_or_dict()
|
||
|
parent[last_key] = list_or_dict
|
||
|
except KeyNotFoundError as e:
|
||
|
if default is None:
|
||
|
raise e
|
||
|
else:
|
||
|
list_or_dict = default
|
||
|
success = False
|
||
|
|
||
|
if not pass_success:
|
||
|
return list_or_dict
|
||
|
else:
|
||
|
return list_or_dict, success
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
config = {
|
||
|
"keya": "a",
|
||
|
"keyb": "b",
|
||
|
"keyc": {
|
||
|
"cc1": 1,
|
||
|
"cc2": 2,
|
||
|
},
|
||
|
}
|
||
|
from omegaconf import OmegaConf
|
||
|
|
||
|
config = OmegaConf.create(config)
|
||
|
print(config)
|
||
|
retrieve(config, "keya")
|