mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
517 lines
20 KiB
Python
517 lines
20 KiB
Python
import json
|
|
import os
|
|
from dataclasses import dataclass, field
|
|
|
|
import torch
|
|
import torch.multiprocessing as mp
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from pytorch_lightning.utilities.rank_zero import rank_zero_only
|
|
from transformers import AutoTokenizer, BertForMaskedLM
|
|
|
|
import threestudio
|
|
from threestudio.utils.base import BaseObject
|
|
from threestudio.utils.misc import barrier, cleanup, get_rank
|
|
from threestudio.utils.ops import shifted_cosine_decay, shifted_expotional_decay
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
def hash_prompt(model: str, prompt: str) -> str:
|
|
import hashlib
|
|
|
|
identifier = f"{model}-{prompt}"
|
|
return hashlib.md5(identifier.encode()).hexdigest()
|
|
|
|
|
|
@dataclass
|
|
class DirectionConfig:
|
|
name: str
|
|
prompt: Callable[[str], str]
|
|
negative_prompt: Callable[[str], str]
|
|
condition: Callable[
|
|
[Float[Tensor, "B"], Float[Tensor, "B"], Float[Tensor, "B"]],
|
|
Float[Tensor, "B"],
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class PromptProcessorOutput:
|
|
text_embeddings: Float[Tensor, "N Nf"]
|
|
uncond_text_embeddings: Float[Tensor, "N Nf"]
|
|
text_embeddings_vd: Float[Tensor, "Nv N Nf"]
|
|
uncond_text_embeddings_vd: Float[Tensor, "Nv N Nf"]
|
|
directions: List[DirectionConfig]
|
|
direction2idx: Dict[str, int]
|
|
use_perp_neg: bool
|
|
perp_neg_f_sb: Tuple[float, float, float]
|
|
perp_neg_f_fsb: Tuple[float, float, float]
|
|
perp_neg_f_fs: Tuple[float, float, float]
|
|
perp_neg_f_sf: Tuple[float, float, float]
|
|
|
|
def get_text_embeddings(
|
|
self,
|
|
elevation: Float[Tensor, "B"],
|
|
azimuth: Float[Tensor, "B"],
|
|
camera_distances: Float[Tensor, "B"],
|
|
view_dependent_prompting: bool = True,
|
|
) -> Float[Tensor, "BB N Nf"]:
|
|
batch_size = elevation.shape[0]
|
|
|
|
if view_dependent_prompting:
|
|
# Get direction
|
|
direction_idx = torch.zeros_like(elevation, dtype=torch.long)
|
|
for d in self.directions:
|
|
direction_idx[
|
|
d.condition(elevation, azimuth, camera_distances)
|
|
] = self.direction2idx[d.name]
|
|
|
|
# Get text embeddings
|
|
text_embeddings = self.text_embeddings_vd[direction_idx] # type: ignore
|
|
uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore
|
|
else:
|
|
text_embeddings = self.text_embeddings.expand(batch_size, -1, -1) # type: ignore
|
|
uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore
|
|
batch_size, -1, -1
|
|
)
|
|
|
|
# IMPORTANT: we return (cond, uncond), which is in different order than other implementations!
|
|
return torch.cat([text_embeddings, uncond_text_embeddings], dim=0)
|
|
|
|
def get_text_embeddings_perp_neg(
|
|
self,
|
|
elevation: Float[Tensor, "B"],
|
|
azimuth: Float[Tensor, "B"],
|
|
camera_distances: Float[Tensor, "B"],
|
|
view_dependent_prompting: bool = True,
|
|
) -> Tuple[Float[Tensor, "BBBB N Nf"], Float[Tensor, "B 2"]]:
|
|
assert (
|
|
view_dependent_prompting
|
|
), "Perp-Neg only works with view-dependent prompting"
|
|
|
|
batch_size = elevation.shape[0]
|
|
|
|
direction_idx = torch.zeros_like(elevation, dtype=torch.long)
|
|
for d in self.directions:
|
|
direction_idx[
|
|
d.condition(elevation, azimuth, camera_distances)
|
|
] = self.direction2idx[d.name]
|
|
# 0 - side view
|
|
# 1 - front view
|
|
# 2 - back view
|
|
# 3 - overhead view
|
|
|
|
pos_text_embeddings = []
|
|
neg_text_embeddings = []
|
|
neg_guidance_weights = []
|
|
uncond_text_embeddings = []
|
|
|
|
side_emb = self.text_embeddings_vd[0]
|
|
front_emb = self.text_embeddings_vd[1]
|
|
back_emb = self.text_embeddings_vd[2]
|
|
overhead_emb = self.text_embeddings_vd[3]
|
|
|
|
for idx, ele, azi, dis in zip(
|
|
direction_idx, elevation, azimuth, camera_distances
|
|
):
|
|
azi = shift_azimuth_deg(azi) # to (-180, 180)
|
|
uncond_text_embeddings.append(
|
|
self.uncond_text_embeddings_vd[idx]
|
|
) # should be ""
|
|
if idx.item() == 3: # overhead view
|
|
pos_text_embeddings.append(overhead_emb) # side view
|
|
# dummy
|
|
neg_text_embeddings += [
|
|
self.uncond_text_embeddings_vd[idx],
|
|
self.uncond_text_embeddings_vd[idx],
|
|
]
|
|
neg_guidance_weights += [0.0, 0.0]
|
|
else: # interpolating views
|
|
if torch.abs(azi) < 90:
|
|
# front-side interpolation
|
|
# 0 - complete side, 1 - complete front
|
|
r_inter = 1 - torch.abs(azi) / 90
|
|
pos_text_embeddings.append(
|
|
r_inter * front_emb + (1 - r_inter) * side_emb
|
|
)
|
|
neg_text_embeddings += [front_emb, side_emb]
|
|
neg_guidance_weights += [
|
|
-shifted_expotional_decay(*self.perp_neg_f_fs, r_inter),
|
|
-shifted_expotional_decay(*self.perp_neg_f_sf, 1 - r_inter),
|
|
]
|
|
else:
|
|
# side-back interpolation
|
|
# 0 - complete back, 1 - complete side
|
|
r_inter = 2.0 - torch.abs(azi) / 90
|
|
pos_text_embeddings.append(
|
|
r_inter * side_emb + (1 - r_inter) * back_emb
|
|
)
|
|
neg_text_embeddings += [side_emb, front_emb]
|
|
neg_guidance_weights += [
|
|
-shifted_expotional_decay(*self.perp_neg_f_sb, r_inter),
|
|
-shifted_expotional_decay(*self.perp_neg_f_fsb, r_inter),
|
|
]
|
|
|
|
text_embeddings = torch.cat(
|
|
[
|
|
torch.stack(pos_text_embeddings, dim=0),
|
|
torch.stack(uncond_text_embeddings, dim=0),
|
|
torch.stack(neg_text_embeddings, dim=0),
|
|
],
|
|
dim=0,
|
|
)
|
|
|
|
return text_embeddings, torch.as_tensor(
|
|
neg_guidance_weights, device=elevation.device
|
|
).reshape(batch_size, 2)
|
|
|
|
|
|
def shift_azimuth_deg(azimuth: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
|
|
# shift azimuth angle (in degrees), to [-180, 180]
|
|
return (azimuth + 180) % 360 - 180
|
|
|
|
|
|
class PromptProcessor(BaseObject):
|
|
@dataclass
|
|
class Config(BaseObject.Config):
|
|
prompt: str = "a hamburger"
|
|
|
|
# manually assigned view-dependent prompts
|
|
prompt_front: Optional[str] = None
|
|
prompt_side: Optional[str] = None
|
|
prompt_back: Optional[str] = None
|
|
prompt_overhead: Optional[str] = None
|
|
|
|
negative_prompt: str = ""
|
|
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
|
overhead_threshold: float = 60.0
|
|
front_threshold: float = 45.0
|
|
back_threshold: float = 45.0
|
|
view_dependent_prompt_front: bool = False
|
|
use_cache: bool = True
|
|
spawn: bool = True
|
|
|
|
# perp neg
|
|
use_perp_neg: bool = False
|
|
# a*e(-b*r) + c
|
|
# a * e(-b) + c = 0
|
|
perp_neg_f_sb: Tuple[float, float, float] = (1, 0.5, -0.606)
|
|
perp_neg_f_fsb: Tuple[float, float, float] = (1, 0.5, +0.967)
|
|
perp_neg_f_fs: Tuple[float, float, float] = (
|
|
4,
|
|
0.5,
|
|
-2.426,
|
|
) # f_fs(1) = 0, a, b > 0
|
|
perp_neg_f_sf: Tuple[float, float, float] = (4, 0.5, -2.426)
|
|
|
|
# prompt debiasing
|
|
use_prompt_debiasing: bool = False
|
|
pretrained_model_name_or_path_prompt_debiasing: str = "bert-base-uncased"
|
|
# index of words that can potentially be removed
|
|
prompt_debiasing_mask_ids: Optional[List[int]] = None
|
|
|
|
cfg: Config
|
|
|
|
@rank_zero_only
|
|
def configure_text_encoder(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
@rank_zero_only
|
|
def destroy_text_encoder(self) -> None:
|
|
raise NotImplementedError
|
|
|
|
def configure(self) -> None:
|
|
self._cache_dir = ".threestudio_cache/text_embeddings" # FIXME: hard-coded path
|
|
|
|
# view-dependent text embeddings
|
|
self.directions: List[DirectionConfig]
|
|
if self.cfg.view_dependent_prompt_front:
|
|
self.directions = [
|
|
DirectionConfig(
|
|
"side",
|
|
lambda s: f"side view of {s}",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool),
|
|
),
|
|
DirectionConfig(
|
|
"front",
|
|
lambda s: f"front view of {s}",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: (
|
|
shift_azimuth_deg(azi) > -self.cfg.front_threshold
|
|
)
|
|
& (shift_azimuth_deg(azi) < self.cfg.front_threshold),
|
|
),
|
|
DirectionConfig(
|
|
"back",
|
|
lambda s: f"backside view of {s}",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: (
|
|
shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold
|
|
)
|
|
| (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold),
|
|
),
|
|
DirectionConfig(
|
|
"overhead",
|
|
lambda s: f"overhead view of {s}",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: ele > self.cfg.overhead_threshold,
|
|
),
|
|
]
|
|
else:
|
|
self.directions = [
|
|
DirectionConfig(
|
|
"side",
|
|
lambda s: f"{s}, side view",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool),
|
|
),
|
|
DirectionConfig(
|
|
"front",
|
|
lambda s: f"{s}, front view",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: (
|
|
shift_azimuth_deg(azi) > -self.cfg.front_threshold
|
|
)
|
|
& (shift_azimuth_deg(azi) < self.cfg.front_threshold),
|
|
),
|
|
DirectionConfig(
|
|
"back",
|
|
lambda s: f"{s}, back view",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: (
|
|
shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold
|
|
)
|
|
| (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold),
|
|
),
|
|
DirectionConfig(
|
|
"overhead",
|
|
lambda s: f"{s}, overhead view",
|
|
lambda s: s,
|
|
lambda ele, azi, dis: ele > self.cfg.overhead_threshold,
|
|
),
|
|
]
|
|
|
|
self.direction2idx = {d.name: i for i, d in enumerate(self.directions)}
|
|
|
|
with open(os.path.join("load/prompt_library.json"), "r") as f:
|
|
self.prompt_library = json.load(f)
|
|
# use provided prompt or find prompt in library
|
|
self.prompt = self.preprocess_prompt(self.cfg.prompt)
|
|
# use provided negative prompt
|
|
self.negative_prompt = self.cfg.negative_prompt
|
|
|
|
threestudio.info(
|
|
f"Using prompt [{self.prompt}] and negative prompt [{self.negative_prompt}]"
|
|
)
|
|
|
|
# view-dependent prompting
|
|
if self.cfg.use_prompt_debiasing:
|
|
assert (
|
|
self.cfg.prompt_side is None
|
|
and self.cfg.prompt_back is None
|
|
and self.cfg.prompt_overhead is None
|
|
), "Do not manually assign prompt_side, prompt_back or prompt_overhead when using prompt debiasing"
|
|
prompts = self.get_debiased_prompt(self.prompt)
|
|
self.prompts_vd = [
|
|
d.prompt(prompt) for d, prompt in zip(self.directions, prompts)
|
|
]
|
|
else:
|
|
self.prompts_vd = [
|
|
self.cfg.get(f"prompt_{d.name}", None) or d.prompt(self.prompt) # type: ignore
|
|
for d in self.directions
|
|
]
|
|
|
|
prompts_vd_display = " ".join(
|
|
[
|
|
f"[{d.name}]:[{prompt}]"
|
|
for prompt, d in zip(self.prompts_vd, self.directions)
|
|
]
|
|
)
|
|
threestudio.info(f"Using view-dependent prompts {prompts_vd_display}")
|
|
|
|
self.negative_prompts_vd = [
|
|
d.negative_prompt(self.negative_prompt) for d in self.directions
|
|
]
|
|
|
|
self.prepare_text_embeddings()
|
|
self.load_text_embeddings()
|
|
|
|
@staticmethod
|
|
def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device):
|
|
raise NotImplementedError
|
|
|
|
@rank_zero_only
|
|
def prepare_text_embeddings(self):
|
|
os.makedirs(self._cache_dir, exist_ok=True)
|
|
|
|
all_prompts = (
|
|
[self.prompt]
|
|
+ [self.negative_prompt]
|
|
+ self.prompts_vd
|
|
+ self.negative_prompts_vd
|
|
)
|
|
prompts_to_process = []
|
|
for prompt in all_prompts:
|
|
if self.cfg.use_cache:
|
|
# some text embeddings are already in cache
|
|
# do not process them
|
|
cache_path = os.path.join(
|
|
self._cache_dir,
|
|
f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt",
|
|
)
|
|
if os.path.exists(cache_path):
|
|
threestudio.debug(
|
|
f"Text embeddings for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] are already in cache, skip processing."
|
|
)
|
|
continue
|
|
prompts_to_process.append(prompt)
|
|
|
|
if len(prompts_to_process) > 0:
|
|
if self.cfg.spawn:
|
|
ctx = mp.get_context("spawn")
|
|
subprocess = ctx.Process(
|
|
target=self.spawn_func,
|
|
args=(
|
|
self.cfg.pretrained_model_name_or_path,
|
|
prompts_to_process,
|
|
self._cache_dir,
|
|
self.device
|
|
),
|
|
)
|
|
subprocess.start()
|
|
subprocess.join()
|
|
else:
|
|
self.spawn_func(
|
|
self.cfg.pretrained_model_name_or_path,
|
|
prompts_to_process,
|
|
self._cache_dir,
|
|
self.device
|
|
)
|
|
cleanup()
|
|
|
|
def load_text_embeddings(self):
|
|
# synchronize, to ensure the text embeddings have been computed and saved to cache
|
|
barrier()
|
|
self.text_embeddings = self.load_from_cache(self.prompt)[None, ...]
|
|
self.uncond_text_embeddings = self.load_from_cache(self.negative_prompt)[
|
|
None, ...
|
|
]
|
|
self.text_embeddings_vd = torch.stack(
|
|
[self.load_from_cache(prompt) for prompt in self.prompts_vd], dim=0
|
|
)
|
|
self.uncond_text_embeddings_vd = torch.stack(
|
|
[self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0
|
|
)
|
|
threestudio.debug(f"Loaded text embeddings.")
|
|
|
|
def load_from_cache(self, prompt):
|
|
cache_path = os.path.join(
|
|
self._cache_dir,
|
|
f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt",
|
|
)
|
|
if not os.path.exists(cache_path):
|
|
raise FileNotFoundError(
|
|
f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found."
|
|
)
|
|
return torch.load(cache_path, map_location=self.device)
|
|
|
|
def preprocess_prompt(self, prompt: str) -> str:
|
|
if prompt.startswith("lib:"):
|
|
# find matches in the library
|
|
candidate = None
|
|
keywords = prompt[4:].lower().split("_")
|
|
for prompt in self.prompt_library["dreamfusion"]:
|
|
if all([k in prompt.lower() for k in keywords]):
|
|
if candidate is not None:
|
|
raise ValueError(
|
|
f"Multiple prompts matched with keywords {keywords} in library"
|
|
)
|
|
candidate = prompt
|
|
if candidate is None:
|
|
raise ValueError(
|
|
f"Cannot find prompt with keywords {keywords} in library"
|
|
)
|
|
threestudio.info("Find matched prompt in library: " + candidate)
|
|
return candidate
|
|
else:
|
|
return prompt
|
|
|
|
def get_text_embeddings(
|
|
self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]]
|
|
) -> Tuple[Float[Tensor, "B ..."], Float[Tensor, "B ..."]]:
|
|
raise NotImplementedError
|
|
|
|
def get_debiased_prompt(self, prompt: str) -> List[str]:
|
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(
|
|
self.cfg.pretrained_model_name_or_path_prompt_debiasing
|
|
)
|
|
model = BertForMaskedLM.from_pretrained(
|
|
self.cfg.pretrained_model_name_or_path_prompt_debiasing
|
|
)
|
|
|
|
views = [d.name for d in self.directions]
|
|
view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0]
|
|
view_ids = view_ids[1:5]
|
|
|
|
def modulate(prompt):
|
|
prompt_vd = f"This image is depicting a [MASK] view of {prompt}"
|
|
tokens = tokenizer(
|
|
prompt_vd,
|
|
padding="max_length",
|
|
truncation=True,
|
|
add_special_tokens=True,
|
|
return_tensors="pt",
|
|
)
|
|
mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1]
|
|
|
|
logits = model(**tokens).logits
|
|
logits = F.softmax(logits[0, mask_idx], dim=-1)
|
|
logits = logits[0, view_ids]
|
|
probes = logits / logits.sum()
|
|
return probes
|
|
|
|
prompts = [prompt.split(" ") for _ in range(4)]
|
|
full_probe = modulate(prompt)
|
|
n_words = len(prompt.split(" "))
|
|
prompt_debiasing_mask_ids = (
|
|
self.cfg.prompt_debiasing_mask_ids
|
|
if self.cfg.prompt_debiasing_mask_ids is not None
|
|
else list(range(n_words))
|
|
)
|
|
words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids]
|
|
threestudio.info(f"Words that can potentially be removed: {words_to_debias}")
|
|
for idx in prompt_debiasing_mask_ids:
|
|
words = prompt.split(" ")
|
|
prompt_ = " ".join(words[:idx] + words[(idx + 1) :])
|
|
part_probe = modulate(prompt_)
|
|
|
|
pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5)
|
|
for i in range(pmi.shape[0]):
|
|
if pmi[i].item() < 0.95:
|
|
prompts[i][idx] = ""
|
|
|
|
debiased_prompts = [" ".join([word for word in p if word]) for p in prompts]
|
|
for d, debiased_prompt in zip(views, debiased_prompts):
|
|
threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]")
|
|
|
|
del tokenizer, model
|
|
cleanup()
|
|
|
|
return debiased_prompts
|
|
|
|
def __call__(self) -> PromptProcessorOutput:
|
|
return PromptProcessorOutput(
|
|
text_embeddings=self.text_embeddings,
|
|
uncond_text_embeddings=self.uncond_text_embeddings,
|
|
text_embeddings_vd=self.text_embeddings_vd,
|
|
uncond_text_embeddings_vd=self.uncond_text_embeddings_vd,
|
|
directions=self.directions,
|
|
direction2idx=self.direction2idx,
|
|
use_perp_neg=self.cfg.use_perp_neg,
|
|
perp_neg_f_sb=self.cfg.perp_neg_f_sb,
|
|
perp_neg_f_fsb=self.cfg.perp_neg_f_fsb,
|
|
perp_neg_f_fs=self.cfg.perp_neg_f_fs,
|
|
perp_neg_f_sf=self.cfg.perp_neg_f_sf,
|
|
) |