DreamCraft3D/threestudio/models/prompt_processors/base.py
2023-12-15 17:44:44 +08:00

517 lines
20 KiB
Python

import json
import os
from dataclasses import dataclass, field
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.nn.functional as F
from pytorch_lightning.utilities.rank_zero import rank_zero_only
from transformers import AutoTokenizer, BertForMaskedLM
import threestudio
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import barrier, cleanup, get_rank
from threestudio.utils.ops import shifted_cosine_decay, shifted_expotional_decay
from threestudio.utils.typing import *
def hash_prompt(model: str, prompt: str) -> str:
import hashlib
identifier = f"{model}-{prompt}"
return hashlib.md5(identifier.encode()).hexdigest()
@dataclass
class DirectionConfig:
name: str
prompt: Callable[[str], str]
negative_prompt: Callable[[str], str]
condition: Callable[
[Float[Tensor, "B"], Float[Tensor, "B"], Float[Tensor, "B"]],
Float[Tensor, "B"],
]
@dataclass
class PromptProcessorOutput:
text_embeddings: Float[Tensor, "N Nf"]
uncond_text_embeddings: Float[Tensor, "N Nf"]
text_embeddings_vd: Float[Tensor, "Nv N Nf"]
uncond_text_embeddings_vd: Float[Tensor, "Nv N Nf"]
directions: List[DirectionConfig]
direction2idx: Dict[str, int]
use_perp_neg: bool
perp_neg_f_sb: Tuple[float, float, float]
perp_neg_f_fsb: Tuple[float, float, float]
perp_neg_f_fs: Tuple[float, float, float]
perp_neg_f_sf: Tuple[float, float, float]
def get_text_embeddings(
self,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
view_dependent_prompting: bool = True,
) -> Float[Tensor, "BB N Nf"]:
batch_size = elevation.shape[0]
if view_dependent_prompting:
# Get direction
direction_idx = torch.zeros_like(elevation, dtype=torch.long)
for d in self.directions:
direction_idx[
d.condition(elevation, azimuth, camera_distances)
] = self.direction2idx[d.name]
# Get text embeddings
text_embeddings = self.text_embeddings_vd[direction_idx] # type: ignore
uncond_text_embeddings = self.uncond_text_embeddings_vd[direction_idx] # type: ignore
else:
text_embeddings = self.text_embeddings.expand(batch_size, -1, -1) # type: ignore
uncond_text_embeddings = self.uncond_text_embeddings.expand( # type: ignore
batch_size, -1, -1
)
# IMPORTANT: we return (cond, uncond), which is in different order than other implementations!
return torch.cat([text_embeddings, uncond_text_embeddings], dim=0)
def get_text_embeddings_perp_neg(
self,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
view_dependent_prompting: bool = True,
) -> Tuple[Float[Tensor, "BBBB N Nf"], Float[Tensor, "B 2"]]:
assert (
view_dependent_prompting
), "Perp-Neg only works with view-dependent prompting"
batch_size = elevation.shape[0]
direction_idx = torch.zeros_like(elevation, dtype=torch.long)
for d in self.directions:
direction_idx[
d.condition(elevation, azimuth, camera_distances)
] = self.direction2idx[d.name]
# 0 - side view
# 1 - front view
# 2 - back view
# 3 - overhead view
pos_text_embeddings = []
neg_text_embeddings = []
neg_guidance_weights = []
uncond_text_embeddings = []
side_emb = self.text_embeddings_vd[0]
front_emb = self.text_embeddings_vd[1]
back_emb = self.text_embeddings_vd[2]
overhead_emb = self.text_embeddings_vd[3]
for idx, ele, azi, dis in zip(
direction_idx, elevation, azimuth, camera_distances
):
azi = shift_azimuth_deg(azi) # to (-180, 180)
uncond_text_embeddings.append(
self.uncond_text_embeddings_vd[idx]
) # should be ""
if idx.item() == 3: # overhead view
pos_text_embeddings.append(overhead_emb) # side view
# dummy
neg_text_embeddings += [
self.uncond_text_embeddings_vd[idx],
self.uncond_text_embeddings_vd[idx],
]
neg_guidance_weights += [0.0, 0.0]
else: # interpolating views
if torch.abs(azi) < 90:
# front-side interpolation
# 0 - complete side, 1 - complete front
r_inter = 1 - torch.abs(azi) / 90
pos_text_embeddings.append(
r_inter * front_emb + (1 - r_inter) * side_emb
)
neg_text_embeddings += [front_emb, side_emb]
neg_guidance_weights += [
-shifted_expotional_decay(*self.perp_neg_f_fs, r_inter),
-shifted_expotional_decay(*self.perp_neg_f_sf, 1 - r_inter),
]
else:
# side-back interpolation
# 0 - complete back, 1 - complete side
r_inter = 2.0 - torch.abs(azi) / 90
pos_text_embeddings.append(
r_inter * side_emb + (1 - r_inter) * back_emb
)
neg_text_embeddings += [side_emb, front_emb]
neg_guidance_weights += [
-shifted_expotional_decay(*self.perp_neg_f_sb, r_inter),
-shifted_expotional_decay(*self.perp_neg_f_fsb, r_inter),
]
text_embeddings = torch.cat(
[
torch.stack(pos_text_embeddings, dim=0),
torch.stack(uncond_text_embeddings, dim=0),
torch.stack(neg_text_embeddings, dim=0),
],
dim=0,
)
return text_embeddings, torch.as_tensor(
neg_guidance_weights, device=elevation.device
).reshape(batch_size, 2)
def shift_azimuth_deg(azimuth: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
# shift azimuth angle (in degrees), to [-180, 180]
return (azimuth + 180) % 360 - 180
class PromptProcessor(BaseObject):
@dataclass
class Config(BaseObject.Config):
prompt: str = "a hamburger"
# manually assigned view-dependent prompts
prompt_front: Optional[str] = None
prompt_side: Optional[str] = None
prompt_back: Optional[str] = None
prompt_overhead: Optional[str] = None
negative_prompt: str = ""
pretrained_model_name_or_path: str = "runwayml/stable-diffusion-v1-5"
overhead_threshold: float = 60.0
front_threshold: float = 45.0
back_threshold: float = 45.0
view_dependent_prompt_front: bool = False
use_cache: bool = True
spawn: bool = True
# perp neg
use_perp_neg: bool = False
# a*e(-b*r) + c
# a * e(-b) + c = 0
perp_neg_f_sb: Tuple[float, float, float] = (1, 0.5, -0.606)
perp_neg_f_fsb: Tuple[float, float, float] = (1, 0.5, +0.967)
perp_neg_f_fs: Tuple[float, float, float] = (
4,
0.5,
-2.426,
) # f_fs(1) = 0, a, b > 0
perp_neg_f_sf: Tuple[float, float, float] = (4, 0.5, -2.426)
# prompt debiasing
use_prompt_debiasing: bool = False
pretrained_model_name_or_path_prompt_debiasing: str = "bert-base-uncased"
# index of words that can potentially be removed
prompt_debiasing_mask_ids: Optional[List[int]] = None
cfg: Config
@rank_zero_only
def configure_text_encoder(self) -> None:
raise NotImplementedError
@rank_zero_only
def destroy_text_encoder(self) -> None:
raise NotImplementedError
def configure(self) -> None:
self._cache_dir = ".threestudio_cache/text_embeddings" # FIXME: hard-coded path
# view-dependent text embeddings
self.directions: List[DirectionConfig]
if self.cfg.view_dependent_prompt_front:
self.directions = [
DirectionConfig(
"side",
lambda s: f"side view of {s}",
lambda s: s,
lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool),
),
DirectionConfig(
"front",
lambda s: f"front view of {s}",
lambda s: s,
lambda ele, azi, dis: (
shift_azimuth_deg(azi) > -self.cfg.front_threshold
)
& (shift_azimuth_deg(azi) < self.cfg.front_threshold),
),
DirectionConfig(
"back",
lambda s: f"backside view of {s}",
lambda s: s,
lambda ele, azi, dis: (
shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold
)
| (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold),
),
DirectionConfig(
"overhead",
lambda s: f"overhead view of {s}",
lambda s: s,
lambda ele, azi, dis: ele > self.cfg.overhead_threshold,
),
]
else:
self.directions = [
DirectionConfig(
"side",
lambda s: f"{s}, side view",
lambda s: s,
lambda ele, azi, dis: torch.ones_like(ele, dtype=torch.bool),
),
DirectionConfig(
"front",
lambda s: f"{s}, front view",
lambda s: s,
lambda ele, azi, dis: (
shift_azimuth_deg(azi) > -self.cfg.front_threshold
)
& (shift_azimuth_deg(azi) < self.cfg.front_threshold),
),
DirectionConfig(
"back",
lambda s: f"{s}, back view",
lambda s: s,
lambda ele, azi, dis: (
shift_azimuth_deg(azi) > 180 - self.cfg.back_threshold
)
| (shift_azimuth_deg(azi) < -180 + self.cfg.back_threshold),
),
DirectionConfig(
"overhead",
lambda s: f"{s}, overhead view",
lambda s: s,
lambda ele, azi, dis: ele > self.cfg.overhead_threshold,
),
]
self.direction2idx = {d.name: i for i, d in enumerate(self.directions)}
with open(os.path.join("load/prompt_library.json"), "r") as f:
self.prompt_library = json.load(f)
# use provided prompt or find prompt in library
self.prompt = self.preprocess_prompt(self.cfg.prompt)
# use provided negative prompt
self.negative_prompt = self.cfg.negative_prompt
threestudio.info(
f"Using prompt [{self.prompt}] and negative prompt [{self.negative_prompt}]"
)
# view-dependent prompting
if self.cfg.use_prompt_debiasing:
assert (
self.cfg.prompt_side is None
and self.cfg.prompt_back is None
and self.cfg.prompt_overhead is None
), "Do not manually assign prompt_side, prompt_back or prompt_overhead when using prompt debiasing"
prompts = self.get_debiased_prompt(self.prompt)
self.prompts_vd = [
d.prompt(prompt) for d, prompt in zip(self.directions, prompts)
]
else:
self.prompts_vd = [
self.cfg.get(f"prompt_{d.name}", None) or d.prompt(self.prompt) # type: ignore
for d in self.directions
]
prompts_vd_display = " ".join(
[
f"[{d.name}]:[{prompt}]"
for prompt, d in zip(self.prompts_vd, self.directions)
]
)
threestudio.info(f"Using view-dependent prompts {prompts_vd_display}")
self.negative_prompts_vd = [
d.negative_prompt(self.negative_prompt) for d in self.directions
]
self.prepare_text_embeddings()
self.load_text_embeddings()
@staticmethod
def spawn_func(pretrained_model_name_or_path, prompts, cache_dir, device):
raise NotImplementedError
@rank_zero_only
def prepare_text_embeddings(self):
os.makedirs(self._cache_dir, exist_ok=True)
all_prompts = (
[self.prompt]
+ [self.negative_prompt]
+ self.prompts_vd
+ self.negative_prompts_vd
)
prompts_to_process = []
for prompt in all_prompts:
if self.cfg.use_cache:
# some text embeddings are already in cache
# do not process them
cache_path = os.path.join(
self._cache_dir,
f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt",
)
if os.path.exists(cache_path):
threestudio.debug(
f"Text embeddings for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] are already in cache, skip processing."
)
continue
prompts_to_process.append(prompt)
if len(prompts_to_process) > 0:
if self.cfg.spawn:
ctx = mp.get_context("spawn")
subprocess = ctx.Process(
target=self.spawn_func,
args=(
self.cfg.pretrained_model_name_or_path,
prompts_to_process,
self._cache_dir,
self.device
),
)
subprocess.start()
subprocess.join()
else:
self.spawn_func(
self.cfg.pretrained_model_name_or_path,
prompts_to_process,
self._cache_dir,
self.device
)
cleanup()
def load_text_embeddings(self):
# synchronize, to ensure the text embeddings have been computed and saved to cache
barrier()
self.text_embeddings = self.load_from_cache(self.prompt)[None, ...]
self.uncond_text_embeddings = self.load_from_cache(self.negative_prompt)[
None, ...
]
self.text_embeddings_vd = torch.stack(
[self.load_from_cache(prompt) for prompt in self.prompts_vd], dim=0
)
self.uncond_text_embeddings_vd = torch.stack(
[self.load_from_cache(prompt) for prompt in self.negative_prompts_vd], dim=0
)
threestudio.debug(f"Loaded text embeddings.")
def load_from_cache(self, prompt):
cache_path = os.path.join(
self._cache_dir,
f"{hash_prompt(self.cfg.pretrained_model_name_or_path, prompt)}.pt",
)
if not os.path.exists(cache_path):
raise FileNotFoundError(
f"Text embedding file {cache_path} for model {self.cfg.pretrained_model_name_or_path} and prompt [{prompt}] not found."
)
return torch.load(cache_path, map_location=self.device)
def preprocess_prompt(self, prompt: str) -> str:
if prompt.startswith("lib:"):
# find matches in the library
candidate = None
keywords = prompt[4:].lower().split("_")
for prompt in self.prompt_library["dreamfusion"]:
if all([k in prompt.lower() for k in keywords]):
if candidate is not None:
raise ValueError(
f"Multiple prompts matched with keywords {keywords} in library"
)
candidate = prompt
if candidate is None:
raise ValueError(
f"Cannot find prompt with keywords {keywords} in library"
)
threestudio.info("Find matched prompt in library: " + candidate)
return candidate
else:
return prompt
def get_text_embeddings(
self, prompt: Union[str, List[str]], negative_prompt: Union[str, List[str]]
) -> Tuple[Float[Tensor, "B ..."], Float[Tensor, "B ..."]]:
raise NotImplementedError
def get_debiased_prompt(self, prompt: str) -> List[str]:
os.environ["TOKENIZERS_PARALLELISM"] = "false"
tokenizer = AutoTokenizer.from_pretrained(
self.cfg.pretrained_model_name_or_path_prompt_debiasing
)
model = BertForMaskedLM.from_pretrained(
self.cfg.pretrained_model_name_or_path_prompt_debiasing
)
views = [d.name for d in self.directions]
view_ids = tokenizer(" ".join(views), return_tensors="pt").input_ids[0]
view_ids = view_ids[1:5]
def modulate(prompt):
prompt_vd = f"This image is depicting a [MASK] view of {prompt}"
tokens = tokenizer(
prompt_vd,
padding="max_length",
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
mask_idx = torch.where(tokens.input_ids == tokenizer.mask_token_id)[1]
logits = model(**tokens).logits
logits = F.softmax(logits[0, mask_idx], dim=-1)
logits = logits[0, view_ids]
probes = logits / logits.sum()
return probes
prompts = [prompt.split(" ") for _ in range(4)]
full_probe = modulate(prompt)
n_words = len(prompt.split(" "))
prompt_debiasing_mask_ids = (
self.cfg.prompt_debiasing_mask_ids
if self.cfg.prompt_debiasing_mask_ids is not None
else list(range(n_words))
)
words_to_debias = [prompt.split(" ")[idx] for idx in prompt_debiasing_mask_ids]
threestudio.info(f"Words that can potentially be removed: {words_to_debias}")
for idx in prompt_debiasing_mask_ids:
words = prompt.split(" ")
prompt_ = " ".join(words[:idx] + words[(idx + 1) :])
part_probe = modulate(prompt_)
pmi = full_probe / torch.lerp(part_probe, full_probe, 0.5)
for i in range(pmi.shape[0]):
if pmi[i].item() < 0.95:
prompts[i][idx] = ""
debiased_prompts = [" ".join([word for word in p if word]) for p in prompts]
for d, debiased_prompt in zip(views, debiased_prompts):
threestudio.info(f"Debiased prompt of the {d} view is [{debiased_prompt}]")
del tokenizer, model
cleanup()
return debiased_prompts
def __call__(self) -> PromptProcessorOutput:
return PromptProcessorOutput(
text_embeddings=self.text_embeddings,
uncond_text_embeddings=self.uncond_text_embeddings,
text_embeddings_vd=self.text_embeddings_vd,
uncond_text_embeddings_vd=self.uncond_text_embeddings_vd,
directions=self.directions,
direction2idx=self.direction2idx,
use_perp_neg=self.cfg.use_perp_neg,
perp_neg_f_sb=self.cfg.perp_neg_f_sb,
perp_neg_f_fsb=self.cfg.perp_neg_f_fsb,
perp_neg_f_fs=self.cfg.perp_neg_f_fs,
perp_neg_f_sf=self.cfg.perp_neg_f_sf,
)