DreamCraft3D/threestudio/models/guidance/zero123_guidance.py
2023-12-15 17:44:44 +08:00

528 lines
18 KiB
Python

import importlib
import os
from dataclasses import dataclass, field
import cv2
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from omegaconf import OmegaConf
from tqdm import tqdm
import threestudio
from threestudio.utils.base import BaseObject
from threestudio.utils.misc import C, parse_version
from threestudio.utils.typing import *
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def instantiate_from_config(config):
if not "target" in config:
if config == "__is_first_stage__":
return None
elif config == "__is_unconditional__":
return None
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
# load model
def load_model_from_config(config, ckpt, device, vram_O=True, verbose=False):
pl_sd = torch.load(ckpt, map_location="cpu")
if "global_step" in pl_sd and verbose:
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
sd = pl_sd["state_dict"]
model = instantiate_from_config(config.model)
m, u = model.load_state_dict(sd, strict=False)
if len(m) > 0 and verbose:
print("[INFO] missing keys: \n", m)
if len(u) > 0 and verbose:
print("[INFO] unexpected keys: \n", u)
# manually load ema and delete it to save GPU memory
if model.use_ema:
if verbose:
print("[INFO] loading EMA...")
model.model_ema.copy_to(model.model)
del model.model_ema
if vram_O:
# we don't need decoder
del model.first_stage_model.decoder
torch.cuda.empty_cache()
model.eval().to(device)
return model
@threestudio.register("zero123-guidance")
class Zero123Guidance(BaseObject):
@dataclass
class Config(BaseObject.Config):
pretrained_model_name_or_path: str = "load/zero123/105000.ckpt"
pretrained_config: str = "load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
vram_O: bool = True
cond_image_path: str = "load/images/hamburger_rgba.png"
cond_elevation_deg: float = 0.0
cond_azimuth_deg: float = 0.0
cond_camera_distance: float = 1.2
guidance_scale: float = 5.0
grad_clip: Optional[
Any
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
half_precision_weights: bool = False
min_step_percent: float = 0.02
max_step_percent: float = 0.98
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
max_items_eval: int = 4
cfg: Config
def configure(self) -> None:
threestudio.info(f"Loading Zero123 ...")
self.config = OmegaConf.load(self.cfg.pretrained_config)
# TODO: seems it cannot load into fp16...
self.weights_dtype = torch.float32
self.model = load_model_from_config(
self.config,
self.cfg.pretrained_model_name_or_path,
device=self.device,
vram_O=self.cfg.vram_O,
)
for p in self.model.parameters():
p.requires_grad_(False)
# timesteps: use diffuser for convenience... hope it's alright.
self.num_train_timesteps = self.config.model.params.timesteps
self.scheduler = DDIMScheduler(
self.num_train_timesteps,
self.config.model.params.linear_start,
self.config.model.params.linear_end,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
self.set_min_max_steps() # set to default value
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
self.device
)
self.grad_clip_val: Optional[float] = None
self.prepare_embeddings(self.cfg.cond_image_path)
threestudio.info(f"Loaded Zero123!")
@torch.cuda.amp.autocast(enabled=False)
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
self.min_step = int(self.num_train_timesteps * min_step_percent)
self.max_step = int(self.num_train_timesteps * max_step_percent)
@torch.cuda.amp.autocast(enabled=False)
def prepare_embeddings(self, image_path: str) -> None:
# load cond image for zero123
assert os.path.exists(image_path)
rgba = cv2.cvtColor(
cv2.imread(image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
)
rgba = (
cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(
np.float32
)
/ 255.0
)
rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:])
self.rgb_256: Float[Tensor, "1 3 H W"] = (
torch.from_numpy(rgb)
.unsqueeze(0)
.permute(0, 3, 1, 2)
.contiguous()
.to(self.device)
)
self.c_crossattn, self.c_concat = self.get_img_embeds(self.rgb_256)
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def get_img_embeds(
self,
img: Float[Tensor, "B 3 256 256"],
) -> Tuple[Float[Tensor, "B 1 768"], Float[Tensor, "B 4 32 32"]]:
img = img * 2.0 - 1.0
c_crossattn = self.model.get_learned_conditioning(img.to(self.weights_dtype))
c_concat = self.model.encode_first_stage(img.to(self.weights_dtype)).mode()
return c_crossattn, c_concat
@torch.cuda.amp.autocast(enabled=False)
def encode_images(
self, imgs: Float[Tensor, "B 3 256 256"]
) -> Float[Tensor, "B 4 32 32"]:
input_dtype = imgs.dtype
imgs = imgs * 2.0 - 1.0
latents = self.model.get_first_stage_encoding(
self.model.encode_first_stage(imgs.to(self.weights_dtype))
)
return latents.to(input_dtype) # [B, 4, 32, 32] Latent space image
@torch.cuda.amp.autocast(enabled=False)
def decode_latents(
self,
latents: Float[Tensor, "B 4 H W"],
) -> Float[Tensor, "B 3 512 512"]:
input_dtype = latents.dtype
image = self.model.decode_first_stage(latents)
image = (image * 0.5 + 0.5).clamp(0, 1)
return image.to(input_dtype)
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def get_cond(
self,
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
c_crossattn=None,
c_concat=None,
**kwargs,
) -> dict:
T = torch.stack(
[
torch.deg2rad(
(90 - elevation) - (90 - self.cfg.cond_elevation_deg)
), # Zero123 polar is 90-elevation
torch.sin(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)),
torch.cos(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)),
camera_distances - self.cfg.cond_camera_distance,
],
dim=-1,
)[:, None, :].to(self.device)
cond = {}
clip_emb = self.model.cc_projection(
torch.cat(
[
(self.c_crossattn if c_crossattn is None else c_crossattn).repeat(
len(T), 1, 1
),
T,
],
dim=-1,
)
)
cond["c_crossattn"] = [
torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)
]
cond["c_concat"] = [
torch.cat(
[
torch.zeros_like(self.c_concat)
.repeat(len(T), 1, 1, 1)
.to(self.device),
(self.c_concat if c_concat is None else c_concat).repeat(
len(T), 1, 1, 1
),
],
dim=0,
)
]
return cond
def __call__(
self,
rgb: Float[Tensor, "B H W C"],
elevation: Float[Tensor, "B"],
azimuth: Float[Tensor, "B"],
camera_distances: Float[Tensor, "B"],
rgb_as_latents=False,
guidance_eval=False,
**kwargs,
):
batch_size = rgb.shape[0]
rgb_BCHW = rgb.permute(0, 3, 1, 2)
latents: Float[Tensor, "B 4 64 64"]
if rgb_as_latents:
latents = (
F.interpolate(rgb_BCHW, (32, 32), mode="bilinear", align_corners=False)
* 2
- 1
)
else:
rgb_BCHW_512 = F.interpolate(
rgb_BCHW, (256, 256), mode="bilinear", align_corners=False
)
# encode image into latents with vae
latents = self.encode_images(rgb_BCHW_512)
cond = self.get_cond(elevation, azimuth, camera_distances)
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
t = torch.randint(
self.min_step,
self.max_step + 1,
[batch_size],
dtype=torch.long,
device=self.device,
)
# predict the noise residual with unet, NO grad!
with torch.no_grad():
# add noise
noise = torch.randn_like(latents) # TODO: use torch generator
latents_noisy = self.scheduler.add_noise(latents, noise, t)
# pred noise
x_in = torch.cat([latents_noisy] * 2)
t_in = torch.cat([t] * 2)
noise_pred = self.model.apply_model(x_in, t_in, cond)
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
w = (1 - self.alphas[t]).reshape(-1, 1, 1, 1)
grad = w * (noise_pred - noise)
grad = torch.nan_to_num(grad)
# clip grad for stable training?
if self.grad_clip_val is not None:
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
# loss = SpecifyGradient.apply(latents, grad)
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
target = (latents - grad).detach()
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
guidance_out = {
"loss_sd": loss_sds, # loss_sds
"grad_norm": grad.norm(),
"min_step": self.min_step,
"max_step": self.max_step,
}
if guidance_eval:
guidance_eval_utils = {
"cond": cond,
"t_orig": t,
"latents_noisy": latents_noisy,
"noise_pred": noise_pred,
}
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
texts = []
for n, e, a, c in zip(
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
):
texts.append(
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
)
guidance_eval_out.update({"texts": texts})
guidance_out.update({"eval": guidance_eval_out})
return guidance_out
@torch.cuda.amp.autocast(enabled=False)
@torch.no_grad()
def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred):
# use only 50 timesteps, and find nearest of those to t
self.scheduler.set_timesteps(50)
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
bs = (
min(self.cfg.max_items_eval, latents_noisy.shape[0])
if self.cfg.max_items_eval > 0
else latents_noisy.shape[0]
) # batch size
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
:bs
].unsqueeze(
-1
) # sized [bs,50] > [bs,1]
idxs = torch.min(large_enough_idxs, dim=1)[1]
t = self.scheduler.timesteps_gpu[idxs]
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)
# get prev latent
latents_1step = []
pred_1orig = []
for b in range(bs):
step_output = self.scheduler.step(
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
)
latents_1step.append(step_output["prev_sample"])
pred_1orig.append(step_output["pred_original_sample"])
latents_1step = torch.cat(latents_1step)
pred_1orig = torch.cat(pred_1orig)
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)
latents_final = []
for b, i in enumerate(idxs):
latents = latents_1step[b : b + 1]
c = {
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
}
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
# pred noise
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.reshape(1)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, c)
# perform guidance
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
noise_pred_cond - noise_pred_uncond
)
# get prev latent
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
"prev_sample"
]
latents_final.append(latents)
latents_final = torch.cat(latents_final)
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)
return {
"bs": bs,
"noise_levels": fracs,
"imgs_noisy": imgs_noisy,
"imgs_1step": imgs_1step,
"imgs_1orig": imgs_1orig,
"imgs_final": imgs_final,
}
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
# clip grad for stable training as demonstrated in
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
# http://arxiv.org/abs/2303.15413
if self.cfg.grad_clip is not None:
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
self.set_min_max_steps(
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
)
# verification - requires `vram_O = False` in load_model_from_config
@torch.no_grad()
def generate(
self,
image, # image tensor [1, 3, H, W] in [0, 1]
elevation=0,
azimuth=0,
camera_distances=0, # new view params
c_crossattn=None,
c_concat=None,
scale=3,
ddim_steps=50,
post_process=True,
ddim_eta=1,
):
if c_crossattn is None:
c_crossattn, c_concat = self.get_img_embeds(image)
cond = self.get_cond(
elevation, azimuth, camera_distances, c_crossattn, c_concat
)
imgs = self.gen_from_cond(cond, scale, ddim_steps, post_process, ddim_eta)
return imgs
# verification - requires `vram_O = False` in load_model_from_config
@torch.no_grad()
def gen_from_cond(
self,
cond,
scale=3,
ddim_steps=50,
post_process=True,
ddim_eta=1,
):
# produce latents loop
B = cond["c_crossattn"][0].shape[0] // 2
latents = torch.randn((B, 4, 32, 32), device=self.device)
self.scheduler.set_timesteps(ddim_steps)
for t in self.scheduler.timesteps:
x_in = torch.cat([latents] * 2)
t_in = torch.cat([t.reshape(1).repeat(B)] * 2).to(self.device)
noise_pred = self.model.apply_model(x_in, t_in, cond)
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + scale * (
noise_pred_cond - noise_pred_uncond
)
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)[
"prev_sample"
]
imgs = self.decode_latents(latents)
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
return imgs
if __name__ == '__main__':
from threestudio.utils.config import load_config
import pytorch_lightning as pl
import numpy as np
import os
import cv2
cfg = load_config("configs/experimental/zero123.yaml")
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
elevations = [0, 20, -20]
azimuths = [45, 90, 135, -45, -90]
radius = torch.tensor([3.8]).to(guidance.device)
outdir = ".threestudio_cache/saiyan"
os.makedirs(outdir, exist_ok=True)
# rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy()
# os.makedirs('.threestudio_cache', exist_ok=True)
# cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image)
rgb_image = cv2.imread(cfg.system.guidance.cond_image_path)[:, :, ::-1].copy() / 255
rgb_image = cv2.resize(rgb_image, (256, 256))
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device).permute(0,3,1,2)
for elevation in elevations:
for azimuth in azimuths:
output1 = guidance.generate(
rgb_image,
torch.tensor([elevation]).to(guidance.device),
torch.tensor([azimuth]).to(guidance.device),
radius,
c_crossattn=guidance.c_crossattn,
c_concat=guidance.c_concat
)
from torchvision.utils import save_image
save_image(torch.tensor(output1).float().permute(0,3,1,2), f"{outdir}/result_e_{elevation}_a_{azimuth}.png", normalize=True, value_range=(0,1))