import importlib import os from dataclasses import dataclass, field import cv2 import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline from diffusers.utils.import_utils import is_xformers_available from omegaconf import OmegaConf from tqdm import tqdm import threestudio from threestudio.utils.base import BaseObject from threestudio.utils.misc import C, parse_version from threestudio.utils.typing import * def get_obj_from_str(string, reload=False): module, cls = string.rsplit(".", 1) if reload: module_imp = importlib.import_module(module) importlib.reload(module_imp) return getattr(importlib.import_module(module, package=None), cls) def instantiate_from_config(config): if not "target" in config: if config == "__is_first_stage__": return None elif config == "__is_unconditional__": return None raise KeyError("Expected key `target` to instantiate.") return get_obj_from_str(config["target"])(**config.get("params", dict())) # load model def load_model_from_config(config, ckpt, device, vram_O=True, verbose=False): pl_sd = torch.load(ckpt, map_location="cpu") if "global_step" in pl_sd and verbose: print(f'[INFO] Global Step: {pl_sd["global_step"]}') sd = pl_sd["state_dict"] model = instantiate_from_config(config.model) m, u = model.load_state_dict(sd, strict=False) if len(m) > 0 and verbose: print("[INFO] missing keys: \n", m) if len(u) > 0 and verbose: print("[INFO] unexpected keys: \n", u) # manually load ema and delete it to save GPU memory if model.use_ema: if verbose: print("[INFO] loading EMA...") model.model_ema.copy_to(model.model) del model.model_ema if vram_O: # we don't need decoder del model.first_stage_model.decoder torch.cuda.empty_cache() model.eval().to(device) return model @threestudio.register("zero123-guidance") class Zero123Guidance(BaseObject): @dataclass class Config(BaseObject.Config): pretrained_model_name_or_path: str = "load/zero123/105000.ckpt" pretrained_config: str = "load/zero123/sd-objaverse-finetune-c_concat-256.yaml" vram_O: bool = True cond_image_path: str = "load/images/hamburger_rgba.png" cond_elevation_deg: float = 0.0 cond_azimuth_deg: float = 0.0 cond_camera_distance: float = 1.2 guidance_scale: float = 5.0 grad_clip: Optional[ Any ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) half_precision_weights: bool = False min_step_percent: float = 0.02 max_step_percent: float = 0.98 """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" max_items_eval: int = 4 cfg: Config def configure(self) -> None: threestudio.info(f"Loading Zero123 ...") self.config = OmegaConf.load(self.cfg.pretrained_config) # TODO: seems it cannot load into fp16... self.weights_dtype = torch.float32 self.model = load_model_from_config( self.config, self.cfg.pretrained_model_name_or_path, device=self.device, vram_O=self.cfg.vram_O, ) for p in self.model.parameters(): p.requires_grad_(False) # timesteps: use diffuser for convenience... hope it's alright. self.num_train_timesteps = self.config.model.params.timesteps self.scheduler = DDIMScheduler( self.num_train_timesteps, self.config.model.params.linear_start, self.config.model.params.linear_end, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, steps_offset=1, ) self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.set_min_max_steps() # set to default value self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( self.device ) self.grad_clip_val: Optional[float] = None self.prepare_embeddings(self.cfg.cond_image_path) threestudio.info(f"Loaded Zero123!") @torch.cuda.amp.autocast(enabled=False) def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): self.min_step = int(self.num_train_timesteps * min_step_percent) self.max_step = int(self.num_train_timesteps * max_step_percent) @torch.cuda.amp.autocast(enabled=False) def prepare_embeddings(self, image_path: str) -> None: # load cond image for zero123 assert os.path.exists(image_path) rgba = cv2.cvtColor( cv2.imread(image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA ) rgba = ( cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype( np.float32 ) / 255.0 ) rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:]) self.rgb_256: Float[Tensor, "1 3 H W"] = ( torch.from_numpy(rgb) .unsqueeze(0) .permute(0, 3, 1, 2) .contiguous() .to(self.device) ) self.c_crossattn, self.c_concat = self.get_img_embeds(self.rgb_256) @torch.cuda.amp.autocast(enabled=False) @torch.no_grad() def get_img_embeds( self, img: Float[Tensor, "B 3 256 256"], ) -> Tuple[Float[Tensor, "B 1 768"], Float[Tensor, "B 4 32 32"]]: img = img * 2.0 - 1.0 c_crossattn = self.model.get_learned_conditioning(img.to(self.weights_dtype)) c_concat = self.model.encode_first_stage(img.to(self.weights_dtype)).mode() return c_crossattn, c_concat @torch.cuda.amp.autocast(enabled=False) def encode_images( self, imgs: Float[Tensor, "B 3 256 256"] ) -> Float[Tensor, "B 4 32 32"]: input_dtype = imgs.dtype imgs = imgs * 2.0 - 1.0 latents = self.model.get_first_stage_encoding( self.model.encode_first_stage(imgs.to(self.weights_dtype)) ) return latents.to(input_dtype) # [B, 4, 32, 32] Latent space image @torch.cuda.amp.autocast(enabled=False) def decode_latents( self, latents: Float[Tensor, "B 4 H W"], ) -> Float[Tensor, "B 3 512 512"]: input_dtype = latents.dtype image = self.model.decode_first_stage(latents) image = (image * 0.5 + 0.5).clamp(0, 1) return image.to(input_dtype) @torch.cuda.amp.autocast(enabled=False) @torch.no_grad() def get_cond( self, elevation: Float[Tensor, "B"], azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], c_crossattn=None, c_concat=None, **kwargs, ) -> dict: T = torch.stack( [ torch.deg2rad( (90 - elevation) - (90 - self.cfg.cond_elevation_deg) ), # Zero123 polar is 90-elevation torch.sin(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), torch.cos(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)), camera_distances - self.cfg.cond_camera_distance, ], dim=-1, )[:, None, :].to(self.device) cond = {} clip_emb = self.model.cc_projection( torch.cat( [ (self.c_crossattn if c_crossattn is None else c_crossattn).repeat( len(T), 1, 1 ), T, ], dim=-1, ) ) cond["c_crossattn"] = [ torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0) ] cond["c_concat"] = [ torch.cat( [ torch.zeros_like(self.c_concat) .repeat(len(T), 1, 1, 1) .to(self.device), (self.c_concat if c_concat is None else c_concat).repeat( len(T), 1, 1, 1 ), ], dim=0, ) ] return cond def __call__( self, rgb: Float[Tensor, "B H W C"], elevation: Float[Tensor, "B"], azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], rgb_as_latents=False, guidance_eval=False, **kwargs, ): batch_size = rgb.shape[0] rgb_BCHW = rgb.permute(0, 3, 1, 2) latents: Float[Tensor, "B 4 64 64"] if rgb_as_latents: latents = ( F.interpolate(rgb_BCHW, (32, 32), mode="bilinear", align_corners=False) * 2 - 1 ) else: rgb_BCHW_512 = F.interpolate( rgb_BCHW, (256, 256), mode="bilinear", align_corners=False ) # encode image into latents with vae latents = self.encode_images(rgb_BCHW_512) cond = self.get_cond(elevation, azimuth, camera_distances) # timestep ~ U(0.02, 0.98) to avoid very high/low noise level t = torch.randint( self.min_step, self.max_step + 1, [batch_size], dtype=torch.long, device=self.device, ) # predict the noise residual with unet, NO grad! with torch.no_grad(): # add noise noise = torch.randn_like(latents) # TODO: use torch generator latents_noisy = self.scheduler.add_noise(latents, noise, t) # pred noise x_in = torch.cat([latents_noisy] * 2) t_in = torch.cat([t] * 2) noise_pred = self.model.apply_model(x_in, t_in, cond) # perform guidance noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( noise_pred_cond - noise_pred_uncond ) w = (1 - self.alphas[t]).reshape(-1, 1, 1, 1) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # clip grad for stable training? if self.grad_clip_val is not None: grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) # loss = SpecifyGradient.apply(latents, grad) # SpecifyGradient is not straghtforward, use a reparameterization trick instead target = (latents - grad).detach() # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size guidance_out = { "loss_sd": loss_sds, # loss_sds "grad_norm": grad.norm(), "min_step": self.min_step, "max_step": self.max_step, } if guidance_eval: guidance_eval_utils = { "cond": cond, "t_orig": t, "latents_noisy": latents_noisy, "noise_pred": noise_pred, } guidance_eval_out = self.guidance_eval(**guidance_eval_utils) texts = [] for n, e, a, c in zip( guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances ): texts.append( f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" ) guidance_eval_out.update({"texts": texts}) guidance_out.update({"eval": guidance_eval_out}) return guidance_out @torch.cuda.amp.autocast(enabled=False) @torch.no_grad() def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred): # use only 50 timesteps, and find nearest of those to t self.scheduler.set_timesteps(50) self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) bs = ( min(self.cfg.max_items_eval, latents_noisy.shape[0]) if self.cfg.max_items_eval > 0 else latents_noisy.shape[0] ) # batch size large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ :bs ].unsqueeze( -1 ) # sized [bs,50] > [bs,1] idxs = torch.min(large_enough_idxs, dim=1)[1] t = self.scheduler.timesteps_gpu[idxs] fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1) # get prev latent latents_1step = [] pred_1orig = [] for b in range(bs): step_output = self.scheduler.step( noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1 ) latents_1step.append(step_output["prev_sample"]) pred_1orig.append(step_output["pred_original_sample"]) latents_1step = torch.cat(latents_1step) pred_1orig = torch.cat(pred_1orig) imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1) imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1) latents_final = [] for b, i in enumerate(idxs): latents = latents_1step[b : b + 1] c = { "c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]], "c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]], } for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): # pred noise x_in = torch.cat([latents] * 2) t_in = torch.cat([t.reshape(1)] * 2).to(self.device) noise_pred = self.model.apply_model(x_in, t_in, c) # perform guidance noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( noise_pred_cond - noise_pred_uncond ) # get prev latent latents = self.scheduler.step(noise_pred, t, latents, eta=1)[ "prev_sample" ] latents_final.append(latents) latents_final = torch.cat(latents_final) imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1) return { "bs": bs, "noise_levels": fracs, "imgs_noisy": imgs_noisy, "imgs_1step": imgs_1step, "imgs_1orig": imgs_1orig, "imgs_final": imgs_final, } def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): # clip grad for stable training as demonstrated in # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation # http://arxiv.org/abs/2303.15413 if self.cfg.grad_clip is not None: self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) self.set_min_max_steps( min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), ) # verification - requires `vram_O = False` in load_model_from_config @torch.no_grad() def generate( self, image, # image tensor [1, 3, H, W] in [0, 1] elevation=0, azimuth=0, camera_distances=0, # new view params c_crossattn=None, c_concat=None, scale=3, ddim_steps=50, post_process=True, ddim_eta=1, ): if c_crossattn is None: c_crossattn, c_concat = self.get_img_embeds(image) cond = self.get_cond( elevation, azimuth, camera_distances, c_crossattn, c_concat ) imgs = self.gen_from_cond(cond, scale, ddim_steps, post_process, ddim_eta) return imgs # verification - requires `vram_O = False` in load_model_from_config @torch.no_grad() def gen_from_cond( self, cond, scale=3, ddim_steps=50, post_process=True, ddim_eta=1, ): # produce latents loop B = cond["c_crossattn"][0].shape[0] // 2 latents = torch.randn((B, 4, 32, 32), device=self.device) self.scheduler.set_timesteps(ddim_steps) for t in self.scheduler.timesteps: x_in = torch.cat([latents] * 2) t_in = torch.cat([t.reshape(1).repeat(B)] * 2).to(self.device) noise_pred = self.model.apply_model(x_in, t_in, cond) noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) noise_pred = noise_pred_uncond + scale * ( noise_pred_cond - noise_pred_uncond ) latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)[ "prev_sample" ] imgs = self.decode_latents(latents) imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs return imgs if __name__ == '__main__': from threestudio.utils.config import load_config import pytorch_lightning as pl import numpy as np import os import cv2 cfg = load_config("configs/experimental/zero123.yaml") guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) elevations = [0, 20, -20] azimuths = [45, 90, 135, -45, -90] radius = torch.tensor([3.8]).to(guidance.device) outdir = ".threestudio_cache/saiyan" os.makedirs(outdir, exist_ok=True) # rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy() # os.makedirs('.threestudio_cache', exist_ok=True) # cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image) rgb_image = cv2.imread(cfg.system.guidance.cond_image_path)[:, :, ::-1].copy() / 255 rgb_image = cv2.resize(rgb_image, (256, 256)) rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device).permute(0,3,1,2) for elevation in elevations: for azimuth in azimuths: output1 = guidance.generate( rgb_image, torch.tensor([elevation]).to(guidance.device), torch.tensor([azimuth]).to(guidance.device), radius, c_crossattn=guidance.c_crossattn, c_concat=guidance.c_concat ) from torchvision.utils import save_image save_image(torch.tensor(output1).float().permute(0,3,1,2), f"{outdir}/result_e_{elevation}_a_{azimuth}.png", normalize=True, value_range=(0,1))