from dataclasses import dataclass, field import torch import torch.nn as nn import torch.nn.functional as F from diffusers import IFPipeline, DDPMScheduler from diffusers.utils.import_utils import is_xformers_available from tqdm import tqdm import threestudio from threestudio.models.prompt_processors.base import PromptProcessorOutput from threestudio.utils.base import BaseObject from threestudio.utils.misc import C, parse_version from threestudio.utils.ops import perpendicular_component from threestudio.utils.typing import * @threestudio.register("deep-floyd-guidance") class DeepFloydGuidance(BaseObject): @dataclass class Config(BaseObject.Config): cache_dir: Optional[str] = None local_files_only: Optional[bool] = False pretrained_model_name_or_path: str = "DeepFloyd/IF-I-XL-v1.0" # FIXME: xformers error enable_memory_efficient_attention: bool = False enable_sequential_cpu_offload: bool = False enable_attention_slicing: bool = False enable_channels_last_format: bool = True guidance_scale: float = 20.0 grad_clip: Optional[ Any ] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000]) time_prior: Optional[Any] = None # [w1,w2,s1,s2] half_precision_weights: bool = True min_step_percent: float = 0.02 max_step_percent: float = 0.98 weighting_strategy: str = "sds" view_dependent_prompting: bool = True """Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items.""" max_items_eval: int = 4 lora_weights_path: Optional[str] = None cfg: Config def configure(self) -> None: threestudio.info(f"Loading Deep Floyd ...") self.weights_dtype = ( torch.float16 if self.cfg.half_precision_weights else torch.float32 ) # Create model self.pipe = IFPipeline.from_pretrained( self.cfg.pretrained_model_name_or_path, text_encoder=None, safety_checker=None, watermarker=None, feature_extractor=None, requires_safety_checker=False, variant="fp16" if self.cfg.half_precision_weights else None, torch_dtype=self.weights_dtype, cache_dir=self.cfg.cache_dir, local_files_only=self.cfg.local_files_only ).to(self.device) # Load lora weights if self.cfg.lora_weights_path is not None: self.pipe.load_lora_weights(self.cfg.lora_weights_path) self.pipe.scheduler = self.pipe.scheduler.__class__.from_config(self.pipe.scheduler.config, variance_type="fixed_small") if self.cfg.enable_memory_efficient_attention: if parse_version(torch.__version__) >= parse_version("2"): threestudio.info( "PyTorch2.0 uses memory efficient attention by default." ) elif not is_xformers_available(): threestudio.warn( "xformers is not available, memory efficient attention is not enabled." ) else: threestudio.warn( f"Use DeepFloyd with xformers may raise error, see https://github.com/deep-floyd/IF/issues/52 to track this problem." ) self.pipe.enable_xformers_memory_efficient_attention() if self.cfg.enable_sequential_cpu_offload: self.pipe.enable_sequential_cpu_offload() if self.cfg.enable_attention_slicing: self.pipe.enable_attention_slicing(1) if self.cfg.enable_channels_last_format: self.pipe.unet.to(memory_format=torch.channels_last) self.unet = self.pipe.unet.eval() for p in self.unet.parameters(): p.requires_grad_(False) self.scheduler = self.pipe.scheduler self.num_train_timesteps = self.scheduler.config.num_train_timesteps self.set_min_max_steps() # set to default value if self.cfg.time_prior is not None: m1, m2, s1, s2 = self.cfg.time_prior weights = torch.cat( ( torch.exp( -((torch.arange(self.num_train_timesteps, m1, -1) - m1) ** 2) / (2 * s1**2) ), torch.ones(m1 - m2 + 1), torch.exp( -((torch.arange(m2 - 1, 0, -1) - m2) ** 2) / (2 * s2**2) ), ) ) weights = weights / torch.sum(weights) self.time_prior_acc_weights = torch.cumsum(weights, dim=0) self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to( self.device ) self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to( self.device ) self.grad_clip_val: Optional[float] = None threestudio.info(f"Loaded Deep Floyd!") @torch.cuda.amp.autocast(enabled=False) def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98): self.min_step = int(self.num_train_timesteps * min_step_percent) self.max_step = int(self.num_train_timesteps * max_step_percent) @torch.cuda.amp.autocast(enabled=False) def forward_unet( self, latents: Float[Tensor, "..."], t: Float[Tensor, "..."], encoder_hidden_states: Float[Tensor, "..."], ) -> Float[Tensor, "..."]: input_dtype = latents.dtype return self.unet( latents.to(self.weights_dtype), t.to(self.weights_dtype), encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype), ).sample.to(input_dtype) def __call__( self, rgb: Float[Tensor, "B H W C"], prompt_utils: PromptProcessorOutput, elevation: Float[Tensor, "B"], azimuth: Float[Tensor, "B"], camera_distances: Float[Tensor, "B"], current_step_ratio=None, mask: Float[Tensor, "B H W 1"] = None, rgb_as_latents=False, guidance_eval=False, **kwargs, ): batch_size = rgb.shape[0] rgb_BCHW = rgb.permute(0, 3, 1, 2) if mask is not None: mask = mask.permute(0, 3, 1, 2) mask = F.interpolate( mask, (64, 64), mode="bilinear", align_corners=False ) assert rgb_as_latents == False, f"No latent space in {self.__class__.__name__}" rgb_BCHW = rgb_BCHW * 2.0 - 1.0 # scale to [-1, 1] to match the diffusion range latents = F.interpolate( rgb_BCHW, (64, 64), mode="bilinear", align_corners=False ) if self.cfg.time_prior is not None: time_index = torch.where( (self.time_prior_acc_weights - current_step_ratio) > 0 )[0][0] if time_index == 0 or torch.abs( self.time_prior_acc_weights[time_index] - current_step_ratio ) < torch.abs( self.time_prior_acc_weights[time_index - 1] - current_step_ratio ): t = self.num_train_timesteps - time_index else: t = self.num_train_timesteps - time_index + 1 t = torch.clip(t, self.min_step, self.max_step + 1) t = torch.full((batch_size,), t, dtype=torch.long, device=self.device) else: # timestep ~ U(0.02, 0.98) to avoid very high/low noise level t = torch.randint( self.min_step, self.max_step + 1, [batch_size], dtype=torch.long, device=self.device, ) if prompt_utils.use_perp_neg: ( text_embeddings, neg_guidance_weights, ) = prompt_utils.get_text_embeddings_perp_neg( elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting ) with torch.no_grad(): noise = torch.randn_like(latents) latents_noisy = self.scheduler.add_noise(latents, noise, t) if mask is not None: latents_noisy = (1 - mask) * latents + mask * latents_noisy latent_model_input = torch.cat([latents_noisy] * 4, dim=0) noise_pred = self.forward_unet( latent_model_input, torch.cat([t] * 4), encoder_hidden_states=text_embeddings, ) # (4B, 6, 64, 64) noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1) noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split( 3, dim=1 ) noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1) e_pos = noise_pred_text - noise_pred_uncond accum_grad = 0 n_negative_prompts = neg_guidance_weights.shape[-1] for i in range(n_negative_prompts): e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond accum_grad += neg_guidance_weights[:, i].view( -1, 1, 1, 1 ) * perpendicular_component(e_i_neg, e_pos) noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( e_pos + accum_grad ) else: neg_guidance_weights = None text_embeddings = prompt_utils.get_text_embeddings( elevation, azimuth, camera_distances, self.cfg.view_dependent_prompting ) # predict the noise residual with unet, NO grad! with torch.no_grad(): # add noise noise = torch.randn_like(latents) # TODO: use torch generator latents_noisy = self.scheduler.add_noise(latents, noise, t) if mask is not None: latents_noisy = (1 - mask) * latents + mask * latents_noisy # pred noise latent_model_input = torch.cat([latents_noisy] * 2, dim=0) noise_pred = self.forward_unet( latent_model_input, torch.cat([t] * 2), encoder_hidden_states=text_embeddings, ) # (2B, 6, 64, 64) # perform guidance (high scale from paper!) noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1) noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1) noise_pred = noise_pred_text + self.cfg.guidance_scale * ( noise_pred_text - noise_pred_uncond ) """ # thresholding, experimental if self.cfg.thresholding: assert batch_size == 1 noise_pred = torch.cat([noise_pred, predicted_variance], dim=1) noise_pred = custom_ddpm_step(self.scheduler, noise_pred, int(t.item()), latents_noisy, **self.pipe.prepare_extra_step_kwargs(None, 0.0) ) """ if self.cfg.weighting_strategy == "sds": # w(t), sigma_t^2 w = (1 - self.alphas[t]).view(-1, 1, 1, 1) elif self.cfg.weighting_strategy == "uniform": w = 1 elif self.cfg.weighting_strategy == "fantasia3d": w = (self.alphas[t] ** 0.5 * (1 - self.alphas[t])).view(-1, 1, 1, 1) else: raise ValueError( f"Unknown weighting strategy: {self.cfg.weighting_strategy}" ) grad = w * (noise_pred - noise) grad = torch.nan_to_num(grad) # clip grad for stable training? if self.grad_clip_val is not None: grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val) # loss = SpecifyGradient.apply(latents, grad) # SpecifyGradient is not straghtforward, use a reparameterization trick instead target = (latents - grad).detach() # d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad loss_sd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size guidance_out = { "loss_sd": loss_sd, "grad_norm": grad.norm(), "min_step": self.min_step, "max_step": self.max_step, } # # FIXME: Visualize inpainting results # self.scheduler.set_timesteps(20) # latents = latents_noisy # for t in tqdm(self.scheduler.timesteps): # # pred noise # noise_pred = self.get_noise_pred( # latents, t, text_embeddings, prompt_utils.use_perp_neg, None # ) # # get prev latent # prev_latents = latents # latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"] # if mask is not None: # latents = (1 - mask) * prev_latents + mask * latents # denoised_img = (latents / 2 + 0.5).permute(0, 2, 3, 1) # guidance_out.update( # {"denoised_img": denoised_img} # ) if guidance_eval: guidance_eval_utils = { "use_perp_neg": prompt_utils.use_perp_neg, "neg_guidance_weights": neg_guidance_weights, "text_embeddings": text_embeddings, "t_orig": t, "latents_noisy": latents_noisy, "noise_pred": torch.cat([noise_pred, predicted_variance], dim=1), } guidance_eval_out = self.guidance_eval(**guidance_eval_utils) texts = [] for n, e, a, c in zip( guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances ): texts.append( f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}" ) guidance_eval_out.update({"texts": texts}) guidance_out.update({"eval": guidance_eval_out}) return guidance_out @torch.cuda.amp.autocast(enabled=False) @torch.no_grad() def get_noise_pred( self, latents_noisy, t, text_embeddings, use_perp_neg=False, neg_guidance_weights=None, ): batch_size = latents_noisy.shape[0] if use_perp_neg: latent_model_input = torch.cat([latents_noisy] * 4, dim=0) noise_pred = self.forward_unet( latent_model_input, torch.cat([t.reshape(1)] * 4).to(self.device), encoder_hidden_states=text_embeddings, ) # (4B, 6, 64, 64) noise_pred_text, _ = noise_pred[:batch_size].split(3, dim=1) noise_pred_uncond, _ = noise_pred[batch_size : batch_size * 2].split( 3, dim=1 ) noise_pred_neg, _ = noise_pred[batch_size * 2 :].split(3, dim=1) e_pos = noise_pred_text - noise_pred_uncond accum_grad = 0 n_negative_prompts = neg_guidance_weights.shape[-1] for i in range(n_negative_prompts): e_i_neg = noise_pred_neg[i::n_negative_prompts] - noise_pred_uncond accum_grad += neg_guidance_weights[:, i].view( -1, 1, 1, 1 ) * perpendicular_component(e_i_neg, e_pos) noise_pred = noise_pred_uncond + self.cfg.guidance_scale * ( e_pos + accum_grad ) else: latent_model_input = torch.cat([latents_noisy] * 2, dim=0) noise_pred = self.forward_unet( latent_model_input, torch.cat([t.reshape(1)] * 2).to(self.device), encoder_hidden_states=text_embeddings, ) # (2B, 6, 64, 64) # perform guidance (high scale from paper!) noise_pred_text, noise_pred_uncond = noise_pred.chunk(2) noise_pred_text, predicted_variance = noise_pred_text.split(3, dim=1) noise_pred_uncond, _ = noise_pred_uncond.split(3, dim=1) noise_pred = noise_pred_text + self.cfg.guidance_scale * ( noise_pred_text - noise_pred_uncond ) return torch.cat([noise_pred, predicted_variance], dim=1) @torch.cuda.amp.autocast(enabled=False) @torch.no_grad() def guidance_eval( self, t_orig, text_embeddings, latents_noisy, noise_pred, use_perp_neg=False, neg_guidance_weights=None, ): # use only 50 timesteps, and find nearest of those to t self.scheduler.set_timesteps(50) self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device) bs = ( min(self.cfg.max_items_eval, latents_noisy.shape[0]) if self.cfg.max_items_eval > 0 else latents_noisy.shape[0] ) # batch size large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[ :bs ].unsqueeze( -1 ) # sized [bs,50] > [bs,1] idxs = torch.min(large_enough_idxs, dim=1)[1] t = self.scheduler.timesteps_gpu[idxs] fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy()) imgs_noisy = (latents_noisy[:bs] / 2 + 0.5).permute(0, 2, 3, 1) # get prev latent latents_1step = [] pred_1orig = [] for b in range(bs): step_output = self.scheduler.step( noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1] ) latents_1step.append(step_output["prev_sample"]) pred_1orig.append(step_output["pred_original_sample"]) latents_1step = torch.cat(latents_1step) pred_1orig = torch.cat(pred_1orig) imgs_1step = (latents_1step / 2 + 0.5).permute(0, 2, 3, 1) imgs_1orig = (pred_1orig / 2 + 0.5).permute(0, 2, 3, 1) latents_final = [] for b, i in enumerate(idxs): latents = latents_1step[b : b + 1] text_emb = ( text_embeddings[ [b, b + len(idxs), b + 2 * len(idxs), b + 3 * len(idxs)], ... ] if use_perp_neg else text_embeddings[[b, b + len(idxs)], ...] ) neg_guid = neg_guidance_weights[b : b + 1] if use_perp_neg else None for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False): # pred noise noise_pred = self.get_noise_pred( latents, t, text_emb, use_perp_neg, neg_guid ) # get prev latent latents = self.scheduler.step(noise_pred, t, latents)["prev_sample"] latents_final.append(latents) latents_final = torch.cat(latents_final) imgs_final = (latents_final / 2 + 0.5).permute(0, 2, 3, 1) return { "bs": bs, "noise_levels": fracs, "imgs_noisy": imgs_noisy, "imgs_1step": imgs_1step, "imgs_1orig": imgs_1orig, "imgs_final": imgs_final, } def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): # clip grad for stable training as demonstrated in # Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation # http://arxiv.org/abs/2303.15413 if self.cfg.grad_clip is not None: self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step) self.set_min_max_steps( min_step_percent=C(self.cfg.min_step_percent, epoch, global_step), max_step_percent=C(self.cfg.max_step_percent, epoch, global_step), ) """ # used by thresholding, experimental def custom_ddpm_step(ddpm, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor, generator=None, return_dict: bool = True): self = ddpm t = timestep prev_t = self.previous_timestep(t) if model_output.shape[1] == sample.shape[1] * 2 and self.variance_type in ["learned", "learned_range"]: model_output, predicted_variance = torch.split(model_output, sample.shape[1], dim=1) else: predicted_variance = None # 1. compute alphas, betas alpha_prod_t = self.alphas_cumprod[t].item() alpha_prod_t_prev = self.alphas_cumprod[prev_t].item() if prev_t >= 0 else 1.0 beta_prod_t = 1 - alpha_prod_t beta_prod_t_prev = 1 - alpha_prod_t_prev current_alpha_t = alpha_prod_t / alpha_prod_t_prev current_beta_t = 1 - current_alpha_t # 2. compute predicted original sample from predicted noise also called # "predicted x_0" of formula (15) from https://arxiv.org/pdf/2006.11239.pdf if self.config.prediction_type == "epsilon": pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) elif self.config.prediction_type == "sample": pred_original_sample = model_output elif self.config.prediction_type == "v_prediction": pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output else: raise ValueError( f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample` or" " `v_prediction` for the DDPMScheduler." ) # 3. Clip or threshold "predicted x_0" if self.config.thresholding: pred_original_sample = self._threshold_sample(pred_original_sample) elif self.config.clip_sample: pred_original_sample = pred_original_sample.clamp( -self.config.clip_sample_range, self.config.clip_sample_range ) noise_thresholded = (sample - (alpha_prod_t ** 0.5) * pred_original_sample) / (beta_prod_t ** 0.5) return noise_thresholded """ if __name__ == '__main__': from threestudio.utils.config import load_config import pytorch_lightning as pl import numpy as np import os import cv2 cfg = load_config("configs/debugging/deepfloyd.yaml") guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance) prompt_processor = threestudio.find(cfg.system.prompt_processor_type)(cfg.system.prompt_processor) prompt_utils = prompt_processor() temp = torch.zeros(1).to(guidance.device) # rgb_image = guidance.sample(prompt_utils, temp, temp, temp, seed=cfg.seed) # rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy() # os.makedirs('.threestudio_cache', exist_ok=True) # cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image) ### inpaint rgb_image = cv2.imread("assets/test.jpg")[:, :, ::-1].copy() / 255 mask_image = cv2.imread("assets/mask.png")[:, :, :1].copy() / 255 rgb_image = cv2.resize(rgb_image, (512, 512)) mask_image = cv2.resize(mask_image, (512, 512)).reshape(512, 512, 1) rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device) mask_image = torch.FloatTensor(mask_image).unsqueeze(0).to(guidance.device) guidance_out = guidance(rgb_image, prompt_utils, temp, temp, temp, mask=mask_image) edit_image = ( (guidance_out["denoised_img"][0].detach().cpu().clip(0, 1).numpy() * 255) .astype(np.uint8)[:, :, ::-1] .copy() ) os.makedirs(".threestudio_cache", exist_ok=True) cv2.imwrite(".threestudio_cache/edit_image.jpg", edit_image)