DreamCraft3D/threestudio/models/guidance/stable_diffusion_vsd_guidance.py
2023-12-15 17:44:44 +08:00

1003 lines
39 KiB
Python

import random
from contextlib import contextmanager
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import (
DDIMScheduler,
DDPMScheduler,
DPMSolverMultistepScheduler,
StableDiffusionPipeline,
UNet2DConditionModel,
)
from diffusers.loaders import AttnProcsLayers
from diffusers.models.attention_processor import LoRAAttnProcessor
from diffusers.models.embeddings import TimestepEmbedding
from diffusers.utils.import_utils import is_xformers_available
import threestudio
from threestudio.models.prompt_processors.base import PromptProcessorOutput
from threestudio.utils.base import BaseModule
from threestudio.utils.misc import C, cleanup, parse_version
from threestudio.utils.perceptual import PerceptualLoss
from threestudio.utils.typing import *
class ToWeightsDType(nn.Module):
def __init__(self, module: nn.Module, dtype: torch.dtype):
super().__init__()
self.module = module
self.dtype = dtype
def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
return self.module(x).to(self.dtype)
@threestudio.register("stable-diffusion-vsd-guidance")
class StableDiffusionVSDGuidance(BaseModule):
@dataclass
class Config(BaseModule.Config):
cache_dir: Optional[str] = None
local_files_only: Optional[bool] = False
pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
pretrained_model_name_or_path_lora: str = "stabilityai/stable-diffusion-2-1"
enable_memory_efficient_attention: bool = False
enable_sequential_cpu_offload: bool = False
enable_attention_slicing: bool = False
enable_channels_last_format: bool = False
guidance_scale: float = 7.5
guidance_scale_lora: float = 1.0
grad_clip: Optional[
Any
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
half_precision_weights: bool = True
lora_cfg_training: bool = True
lora_n_timestamp_samples: int = 1
min_step_percent: float = 0.02
max_step_percent: float = 0.98
view_dependent_prompting: bool = True
camera_condition_type: str = "extrinsics"
use_du: bool = False
per_du_step: int = 10
start_du_step: int = 0
du_diffusion_steps: int = 20
use_bsd: bool = False
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading Stable Diffusion ...")
self.weights_dtype = (
torch.float16 if self.cfg.half_precision_weights else torch.float32
)
pipe_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
pipe_lora_kwargs = {
"tokenizer": None,
"safety_checker": None,
"feature_extractor": None,
"requires_safety_checker": False,
"torch_dtype": self.weights_dtype,
"cache_dir": self.cfg.cache_dir,
"local_files_only": self.cfg.local_files_only
}
@dataclass
class SubModules:
pipe: StableDiffusionPipeline
pipe_lora: StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path,
**pipe_kwargs,
).to(self.device)
if (
self.cfg.pretrained_model_name_or_path
== self.cfg.pretrained_model_name_or_path_lora
):
self.single_model = True
pipe_lora = pipe
else:
self.single_model = False
pipe_lora = StableDiffusionPipeline.from_pretrained(
self.cfg.pretrained_model_name_or_path_lora,
**pipe_lora_kwargs,
).to(self.device)
del pipe_lora.vae
cleanup()
pipe_lora.vae = pipe.vae
self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora)
if self.cfg.enable_memory_efficient_attention:
if parse_version(torch.__version__) >= parse_version("2"):
threestudio.info(
"PyTorch2.0 uses memory efficient attention by default."
)
elif not is_xformers_available():
threestudio.warn(
"xformers is not available, memory efficient attention is not enabled."
)
else:
self.pipe.enable_xformers_memory_efficient_attention()
self.pipe_lora.enable_xformers_memory_efficient_attention()
if self.cfg.enable_sequential_cpu_offload:
self.pipe.enable_sequential_cpu_offload()
self.pipe_lora.enable_sequential_cpu_offload()
if self.cfg.enable_attention_slicing:
self.pipe.enable_attention_slicing(1)
self.pipe_lora.enable_attention_slicing(1)
if self.cfg.enable_channels_last_format:
self.pipe.unet.to(memory_format=torch.channels_last)
self.pipe_lora.unet.to(memory_format=torch.channels_last)
del self.pipe.text_encoder
if not self.single_model:
del self.pipe_lora.text_encoder
cleanup()
for p in self.vae.parameters():
p.requires_grad_(False)
for p in self.unet.parameters():
p.requires_grad_(False)
for p in self.unet_lora.parameters():
p.requires_grad_(False)
# FIXME: hard-coded dims
self.camera_embedding = ToWeightsDType(
TimestepEmbedding(16, 1280), self.weights_dtype
).to(self.device)
self.unet_lora.class_embedding = self.camera_embedding
# set up LoRA layers
lora_attn_procs = {}
for name in self.unet_lora.attn_processors.keys():
cross_attention_dim = (
None
if name.endswith("attn1.processor")
else self.unet_lora.config.cross_attention_dim
)
if name.startswith("mid_block"):
hidden_size = self.unet_lora.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
block_id
]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = self.unet_lora.config.block_out_channels[block_id]
lora_attn_procs[name] = LoRAAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
)
self.unet_lora.set_attn_processor(lora_attn_procs)
self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
self.device
)
self.lora_layers._load_state_dict_pre_hooks.clear()
self.lora_layers._state_dict_hooks.clear()
self.scheduler = DDIMScheduler.from_pretrained( # DDPM
self.cfg.pretrained_model_name_or_path,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only,
)
self.scheduler_lora = DDPMScheduler.from_pretrained(
self.cfg.pretrained_model_name_or_path_lora,
subfolder="scheduler",
torch_dtype=self.weights_dtype,
cache_dir=self.cfg.cache_dir,
local_files_only=self.cfg.local_files_only,
)
self.scheduler_sample = DPMSolverMultistepScheduler.from_config(
self.pipe.scheduler.config
)
self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config(
self.pipe_lora.scheduler.config
)
self.pipe.scheduler = self.scheduler
self.pipe_lora.scheduler = self.scheduler_lora
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
self.grad_clip_val: Optional[float] = None
if self.cfg.use_du:
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
for p in self.perceptual_loss.parameters():
p.requires_grad_(False)
threestudio.info(f"Loaded Stable Diffusion!")
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@property
def pipe(self):
return self.submodules.pipe
@property
def pipe_lora(self):
return self.submodules.pipe_lora
@property
def unet(self):
return self.submodules.pipe.unet
@property
def unet_lora(self):
return self.submodules.pipe_lora.unet
@property
def vae(self):
return self.submodules.pipe.vae
@property
def vae_lora(self):
return self.submodules.pipe_lora.vae
@torch.no_grad()
@torch.cuda.amp.autocast(enabled=False)
def _sample(
self,
pipe: StableDiffusionPipeline,
sample_scheduler: DPMSolverMultistepScheduler,
text_embeddings: Float[Tensor, "BB N Nf"],
num_inference_steps: int,
guidance_scale: float,
num_images_per_prompt: int = 1,
height: Optional[int] = None,
width: Optional[int] = None,
class_labels: Optional[Float[Tensor, "BB 16"]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents_inp: Optional[Float[Tensor, "..."]] = None,
) -> Float[Tensor, "B H W 3"]:
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
height = height or pipe.unet.config.sample_size * vae_scale_factor
width = width or pipe.unet.config.sample_size * vae_scale_factor
batch_size = text_embeddings.shape[0] // 2
device = self.device
sample_scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = sample_scheduler.timesteps
num_channels_latents = pipe.unet.config.in_channels
if latents_inp is not None:
t = torch.randint(
self.min_step,
self.max_step,
[1],
dtype=torch.long,
device=self.device,
)
noise = torch.randn_like(latents_inp)
init_timestep = max(1, min(int(num_inference_steps * t[0].item() / self.num_train_timesteps), num_inference_steps))
t_start = max(num_inference_steps - init_timestep, 0)
latent_timestep = sample_scheduler.timesteps[t_start : t_start + 1].repeat(batch_size)
latents = sample_scheduler.add_noise(latents_inp, noise, latent_timestep).to(self.weights_dtype)
else:
latents = pipe.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
self.weights_dtype,
device,
generator,
)
t_start = 0
for i, t in enumerate(timesteps[t_start:]):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2)
latent_model_input = sample_scheduler.scale_model_input(
latent_model_input, t
)
# predict the noise residual
if class_labels is None:
with self.disable_unet_class_embedding(pipe.unet) as unet:
noise_pred = unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings.to(self.weights_dtype),
cross_attention_kwargs=cross_attention_kwargs,
).sample
else:
noise_pred = pipe.unet(
latent_model_input,
t,
encoder_hidden_states=text_embeddings.to(self.weights_dtype),
class_labels=class_labels,
cross_attention_kwargs=cross_attention_kwargs,
).sample
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (
noise_pred_text - noise_pred_uncond
)
# compute the previous noisy sample x_t -> x_t-1
latents = sample_scheduler.step(noise_pred, t, latents).prev_sample
latents = 1 / pipe.vae.config.scaling_factor * latents
images = pipe.vae.decode(latents).sample
images = (images / 2 + 0.5).clamp(0, 1)
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
images = images.permute(0, 2, 3, 1).float()
return images
def sample(
self,
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
seed: int = 0,
**kwargs,
) -> Float[Tensor, "N H W 3"]:
# view-dependent text embeddings
text_embeddings_vd = prompt_utils.get_text_embeddings(
elevation,
azimuth,
camera_distances,
view_dependent_prompting=self.cfg.view_dependent_prompting,
)
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
generator = torch.Generator(device=self.device).manual_seed(seed)
return self._sample(
pipe=self.pipe,
sample_scheduler=self.scheduler_sample,
text_embeddings=text_embeddings_vd,
num_inference_steps=25,
guidance_scale=self.cfg.guidance_scale,
cross_attention_kwargs=cross_attention_kwargs,
generator=generator,
)
def sample_img2img(
self,
rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
seed: int = 0,
mask = None,
**kwargs,
) -> Float[Tensor, "N H W 3"]:
rgb_BCHW = rgb.permute(0, 3, 1, 2)
mask_BCHW = mask.permute(0, 3, 1, 2)
latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image
# view-dependent text embeddings
text_embeddings_vd = prompt_utils.get_text_embeddings(
elevation,
azimuth,
camera_distances,
view_dependent_prompting=self.cfg.view_dependent_prompting,
)
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
generator = torch.Generator(device=self.device).manual_seed(seed)
# return self._sample(
# pipe=self.pipe,
# sample_scheduler=self.scheduler_sample,
# text_embeddings=text_embeddings_vd,
# num_inference_steps=25,
# guidance_scale=self.cfg.guidance_scale,
# cross_attention_kwargs=cross_attention_kwargs,
# generator=generator,
# latents_inp=latents
# )
return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW)
def sample_lora(
self,
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
mvp_mtx: Float[Tensor, "B 4 4"],
c2w: Float[Tensor, "B 4 4"],
seed: int = 0,
**kwargs,
) -> Float[Tensor, "N H W 3"]:
# input text embeddings, view-independent
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, view_dependent_prompting=False
)
if self.cfg.camera_condition_type == "extrinsics":
camera_condition = c2w
elif self.cfg.camera_condition_type == "mvp":
camera_condition = mvp_mtx
else:
raise ValueError(
f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
)
B = elevation.shape[0]
camera_condition_cfg = torch.cat(
[
camera_condition.view(B, -1),
torch.zeros_like(camera_condition.view(B, -1)),
],
dim=0,
)
generator = torch.Generator(device=self.device).manual_seed(seed)
return self._sample(
sample_scheduler=self.scheduler_lora_sample,
pipe=self.pipe_lora,
text_embeddings=text_embeddings,
num_inference_steps=25,
guidance_scale=self.cfg.guidance_scale_lora,
class_labels=camera_condition_cfg,
cross_attention_kwargs={"scale": 1.0},
generator=generator,
)
@torch.cuda.amp.autocast(enabled=False)
def forward_unet(
self,
unet: UNet2DConditionModel,
latents: Float[Tensor, "..."],
t: Float[Tensor, "..."],
encoder_hidden_states: Float[Tensor, "..."],
class_labels: Optional[Float[Tensor, "B 16"]] = None,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
) -> Float[Tensor, "..."]:
input_dtype = latents.dtype
return unet(
latents.to(self.weights_dtype),
t.to(self.weights_dtype),
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
class_labels=class_labels,
cross_attention_kwargs=cross_attention_kwargs,
).sample.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def encode_images(
self, imgs: Float[Tensor, "B 3 512 512"]
) -> Float[Tensor, "B 4 64 64"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
latents = posterior.sample() * self.vae.config.scaling_factor
return latents.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
def decode_latents(
self,
latents: Float[Tensor, "B 4 H W"],
latent_height: int = 64,
latent_width: int = 64,
) -> Float[Tensor, "B 3 512 512"]:
input_dtype = latents.dtype
latents = F.interpolate(
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
)
latents = 1 / self.vae.config.scaling_factor * latents
image = self.vae.decode(latents.to(self.weights_dtype)).sample
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
@contextmanager
def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
class_embedding = unet.class_embedding
try:
unet.class_embedding = None
yield unet
finally:
unet.class_embedding = class_embedding
def compute_grad_du(
self,
latents: Float[Tensor, "B 4 64 64"],
rgb_BCHW_512: Float[Tensor, "B 3 512 512"],
text_embeddings: Float[Tensor, "BB 77 768"],
mask = None,
**kwargs,
):
batch_size, _, _, _ = latents.shape
rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear")
assert batch_size == 1
need_diffusion = (
self.global_step % self.cfg.per_du_step == 0
and self.global_step > self.cfg.start_du_step
)
guidance_out = {}
if need_diffusion:
t = torch.randint(
self.min_step,
self.max_step,
[1],
dtype=torch.long,
device=self.device,
)
self.scheduler.config.num_train_timesteps = t.item()
self.scheduler.set_timesteps(self.cfg.du_diffusion_steps)
if mask is not None:
mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True)
with torch.no_grad():
# add noise
noise = torch.randn_like(latents)
latents = self.scheduler.add_noise(latents, noise, t) # type: ignore
for i, timestep in enumerate(self.scheduler.timesteps):
# predict the noise residual with unet, NO grad!
with torch.no_grad():
latent_model_input = torch.cat([latents] * 2)
with self.disable_unet_class_embedding(self.unet) as unet:
cross_attention_kwargs = (
{"scale": 0.0} if self.single_model else None
)
noise_pred = self.forward_unet(
unet,
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings,
cross_attention_kwargs=cross_attention_kwargs,
)
# perform classifier-free guidance
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_text - noise_pred_uncond
)
if mask is not None:
noise_pred = mask * noise_pred + (1 - mask) * noise
# get previous sample, continue loop
latents = self.scheduler.step(
noise_pred, timestep, latents
).prev_sample
edit_images = self.decode_latents(latents)
edit_images = F.interpolate(
edit_images, (512, 512), mode="bilinear"
).permute(0, 2, 3, 1)
gt_rgb = edit_images
# import cv2
# import numpy as np
# mask_temp = mask_BCHW_512.permute(0,2,3,1)
# # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp)
# temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
# cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1])
guidance_out.update(
{
"loss_l1": torch.nn.functional.l1_loss(
rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
),
"loss_p": self.perceptual_loss(
rgb_BCHW_512.contiguous(),
gt_rgb.permute(0, 3, 1, 2).contiguous(),
).sum(),
"edit_image": edit_images.detach()
}
)
return guidance_out
def compute_grad_vsd(
self,
latents: Float[Tensor, "B 4 64 64"],
text_embeddings_vd: Float[Tensor, "BB 77 768"],
text_embeddings: Float[Tensor, "BB 77 768"],
camera_condition: Float[Tensor, "B 4 4"],
):
B = latents.shape[0]
with torch.no_grad():
# random timestamp
t = torch.randint(
self.min_step,
self.max_step + 1,
[B],
dtype=torch.long,
device=self.device,
)
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
if self.cfg.use_bsd:
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
noise_pred_pretrain = self.forward_unet(
self.unet,
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=text_embeddings_vd,
class_labels=torch.cat(
[
camera_condition.view(B, -1),
torch.zeros_like(camera_condition.view(B, -1)),
],
dim=0,
),
cross_attention_kwargs=cross_attention_kwargs,
)
else:
with self.disable_unet_class_embedding(self.unet) as unet:
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
noise_pred_pretrain = self.forward_unet(
unet,
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=text_embeddings_vd,
cross_attention_kwargs=cross_attention_kwargs,
)
# use view-independent text embeddings in LoRA
text_embeddings_cond, _ = text_embeddings.chunk(2)
noise_pred_est = self.forward_unet(
self.unet_lora,
latent_model_input,
torch.cat([t] * 2),
encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
class_labels=torch.cat(
[
camera_condition.view(B, -1),
torch.zeros_like(camera_condition.view(B, -1)),
],
dim=0,
),
cross_attention_kwargs={"scale": 1.0},
)
(
noise_pred_pretrain_text,
noise_pred_pretrain_uncond,
) = noise_pred_pretrain.chunk(2)
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * (
noise_pred_pretrain_text - noise_pred_pretrain_uncond
)
# TODO: more general cases
assert self.scheduler.config.prediction_type == "epsilon"
if self.scheduler_lora.config.prediction_type == "v_prediction":
alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
device=latents_noisy.device, dtype=latents_noisy.dtype
)
alpha_t = alphas_cumprod[t] ** 0.5
sigma_t = (1 - alphas_cumprod[t]) ** 0.5
noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view(
-1, 1, 1, 1
) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1)
(
noise_pred_est_camera,
noise_pred_est_uncond,
) = noise_pred_est.chunk(2)
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * (
noise_pred_est_camera - noise_pred_est_uncond
)
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
grad = w * (noise_pred_pretrain - noise_pred_est)
return grad
def compute_grad_vsd_hifa(
self,
latents: Float[Tensor, "B 4 64 64"],
text_embeddings_vd: Float[Tensor, "BB 77 768"],
text_embeddings: Float[Tensor, "BB 77 768"],
camera_condition: Float[Tensor, "B 4 4"],
mask=None,
):
B, _, DH, DW = latents.shape
rgb = self.decode_latents(latents)
self.name = "hifa"
if mask is not None:
mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True)
with torch.no_grad():
# random timestamp
t = torch.randint(
self.min_step,
self.max_step + 1,
[B],
dtype=torch.long,
device=self.device,
)
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
# add noise
noise = torch.randn_like(latents)
latents_noisy = self.scheduler_sample.add_noise(latents, noise, t)
latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents, noise, t)
# pred noise
self.scheduler_sample.config.num_train_timesteps = t.item()
self.scheduler_sample.set_timesteps(t.item() // 50 + 1)
self.scheduler_lora_sample.config.num_train_timesteps = t.item()
self.scheduler_lora_sample.set_timesteps(t.item() // 50 + 1)
for i, timestep in enumerate(self.scheduler_sample.timesteps):
# for i, timestep in tqdm(enumerate(self.scheduler.timesteps)):
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
latent_model_input_lora = torch.cat([latents_noisy_lora] * 2, dim=0)
# print(latent_model_input.shape)
with self.disable_unet_class_embedding(self.unet) as unet:
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
noise_pred_pretrain = self.forward_unet(
unet,
latent_model_input,
timestep,
encoder_hidden_states=text_embeddings_vd,
cross_attention_kwargs=cross_attention_kwargs,
)
# use view-independent text embeddings in LoRA
noise_pred_est = self.forward_unet(
self.unet_lora,
latent_model_input_lora,
timestep,
encoder_hidden_states=text_embeddings,
class_labels=torch.cat(
[
camera_condition.view(B, -1),
torch.zeros_like(camera_condition.view(B, -1)),
],
dim=0,
),
cross_attention_kwargs={"scale": 1.0},
)
(
noise_pred_pretrain_text,
noise_pred_pretrain_uncond,
) = noise_pred_pretrain.chunk(2)
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * (
noise_pred_pretrain_text - noise_pred_pretrain_uncond
)
if mask is not None:
noise_pred_pretrain = mask * noise_pred_pretrain + (1 - mask) * noise
(
noise_pred_est_text,
noise_pred_est_uncond,
) = noise_pred_est.chunk(2)
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
# noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * (
# noise_pred_est_text - noise_pred_est_uncond
# )
noise_pred_est = noise_pred_est_text
if mask is not None:
noise_pred_est = mask * noise_pred_est + (1 - mask) * noise
latents_noisy = self.scheduler_sample.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample
latents_noisy_lora = self.scheduler_lora_sample.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample
# noise = torch.randn_like(latents)
# latents_noisy = self.scheduler.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample
# latents_noisy = mask * latents_noisy + (1-mask) * latents
# latents_noisy = self.scheduler_sample.add_noise(latents_noisy, noise, timestep)
# latents_noisy_lora = self.scheduler_lora.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample
# latents_noisy_lora = mask * latents_noisy_lora + (1-mask) * latents
# latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents_noisy_lora, noise, timestep)
hifa_images = self.decode_latents(latents_noisy)
hifa_lora_images = self.decode_latents(latents_noisy_lora)
import cv2
import numpy as np
if mask is not None:
print('hifa mask!')
prefix = 'vsd_mask'
else:
prefix = ''
temp = (hifa_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
cv2.imwrite(".threestudio_cache/%s%s_test.jpg" % (prefix, self.name), temp[:, :, ::-1])
temp = (hifa_lora_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
cv2.imwrite(".threestudio_cache/%s%s_test_lora.jpg" % (prefix, self.name), temp[:, :, ::-1])
target = (latents_noisy - latents_noisy_lora + latents).detach()
# target = latents_noisy.detach()
targets_rgb = self.decode_latents(target)
# targets_rgb = (hifa_images - hifa_lora_images + rgb).detach()
temp = (targets_rgb.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
cv2.imwrite(".threestudio_cache/%s_target.jpg" % self.name, temp[:, :, ::-1])
return w * 0.5 * F.mse_loss(target, latents, reduction='sum')
def train_lora(
self,
latents: Float[Tensor, "B 4 64 64"],
text_embeddings: Float[Tensor, "BB 77 768"],
camera_condition: Float[Tensor, "B 4 4"],
):
B = latents.shape[0]
latents = latents.detach().repeat(self.cfg.lora_n_timestamp_samples, 1, 1, 1)
t = torch.randint(
int(self.num_train_timesteps * 0.0),
int(self.num_train_timesteps * 1.0),
[B * self.cfg.lora_n_timestamp_samples],
dtype=torch.long,
device=self.device,
)
noise = torch.randn_like(latents)
noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
if self.scheduler_lora.config.prediction_type == "epsilon":
target = noise
elif self.scheduler_lora.config.prediction_type == "v_prediction":
target = self.scheduler_lora.get_velocity(latents, noise, t)
else:
raise ValueError(
f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
)
# use view-independent text embeddings in LoRA
text_embeddings_cond, _ = text_embeddings.chunk(2)
if self.cfg.lora_cfg_training and random.random() < 0.1:
camera_condition = torch.zeros_like(camera_condition)
noise_pred = self.forward_unet(
self.unet_lora,
noisy_latents,
t,
encoder_hidden_states=text_embeddings_cond.repeat(
self.cfg.lora_n_timestamp_samples, 1, 1
),
class_labels=camera_condition.view(B, -1).repeat(
self.cfg.lora_n_timestamp_samples, 1
),
cross_attention_kwargs={"scale": 1.0},
)
return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
def get_latents(
self, rgb_BCHW: Float[Tensor, "B C H W"], rgb_as_latents=False
) -> Float[Tensor, "B 4 64 64"]:
if rgb_as_latents:
latents = F.interpolate(
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
)
else:
rgb_BCHW_512 = F.interpolate(
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
)
# encode image into latents with vae
latents = self.encode_images(rgb_BCHW_512)
return latents
def forward(
self,
rgb: Float[Tensor, "B H W C"],
prompt_utils: PromptProcessorOutput,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
mvp_mtx: Float[Tensor, "B 4 4"],
c2w: Float[Tensor, "B 4 4"],
rgb_as_latents=False,
mask: Float[Tensor, "B H W 1"] = None,
lora_prompt_utils = None,
**kwargs,
):
batch_size = rgb.shape[0]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents = self.get_latents(rgb_BCHW, rgb_as_latents=rgb_as_latents)
if mask is not None: mask = mask.permute(0, 3, 1, 2)
# view-dependent text embeddings
text_embeddings_vd = prompt_utils.get_text_embeddings(
elevation,
azimuth,
camera_distances,
view_dependent_prompting=self.cfg.view_dependent_prompting,
)
if lora_prompt_utils is not None:
# input text embeddings, view-independent
text_embeddings = lora_prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, view_dependent_prompting=False
)
else:
# input text embeddings, view-independent
text_embeddings = prompt_utils.get_text_embeddings(
elevation, azimuth, camera_distances, view_dependent_prompting=False
)
if self.cfg.camera_condition_type == "extrinsics":
camera_condition = c2w
elif self.cfg.camera_condition_type == "mvp":
camera_condition = mvp_mtx
else:
raise ValueError(
f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
)
grad = self.compute_grad_vsd(
latents, text_embeddings_vd, text_embeddings, camera_condition
)
grad = torch.nan_to_num(grad)
# clip grad for stable training?
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
# reparameterization trick
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
target = (latents - grad).detach()
loss_vsd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
loss_lora = self.train_lora(latents, text_embeddings, camera_condition)
guidance_out = {
"loss_sd": loss_vsd,
"loss_lora": loss_lora,
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
if self.cfg.use_du:
du_out = self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask)
guidance_out.update(du_out)
return guidance_out
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.global_step = global_step
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)