mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
517 lines
20 KiB
Python
517 lines
20 KiB
Python
|
import os
|
||
|
from dataclasses import dataclass
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn.functional as F
|
||
|
from controlnet_aux import CannyDetector, NormalBaeDetector
|
||
|
from diffusers import ControlNetModel, DDIMScheduler, StableDiffusionControlNetPipeline
|
||
|
from diffusers.utils.import_utils import is_xformers_available
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import threestudio
|
||
|
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||
|
from threestudio.utils.base import BaseObject
|
||
|
from threestudio.utils.misc import C, parse_version
|
||
|
from threestudio.utils.perceptual import PerceptualLoss
|
||
|
from threestudio.utils.typing import *
|
||
|
|
||
|
|
||
|
@threestudio.register("stable-diffusion-controlnet-guidance")
|
||
|
class ControlNetGuidance(BaseObject):
|
||
|
@dataclass
|
||
|
class Config(BaseObject.Config):
|
||
|
cache_dir: Optional[str] = None
|
||
|
pretrained_model_name_or_path: str = "SG161222/Realistic_Vision_V2.0"
|
||
|
ddim_scheduler_name_or_path: str = "runwayml/stable-diffusion-v1-5"
|
||
|
control_type: str = "normal" # normal/canny
|
||
|
|
||
|
enable_memory_efficient_attention: bool = False
|
||
|
enable_sequential_cpu_offload: bool = False
|
||
|
enable_attention_slicing: bool = False
|
||
|
enable_channels_last_format: bool = False
|
||
|
guidance_scale: float = 7.5
|
||
|
condition_scale: float = 1.5
|
||
|
grad_clip: Optional[Any] = None
|
||
|
half_precision_weights: bool = True
|
||
|
|
||
|
fixed_size: int = -1
|
||
|
|
||
|
min_step_percent: float = 0.02
|
||
|
max_step_percent: float = 0.98
|
||
|
|
||
|
diffusion_steps: int = 20
|
||
|
|
||
|
use_sds: bool = False
|
||
|
|
||
|
use_du: bool = False
|
||
|
per_du_step: int = 10
|
||
|
start_du_step: int = 1000
|
||
|
cache_du: bool = False
|
||
|
|
||
|
# Canny threshold
|
||
|
canny_lower_bound: int = 50
|
||
|
canny_upper_bound: int = 100
|
||
|
|
||
|
cfg: Config
|
||
|
|
||
|
def configure(self) -> None:
|
||
|
threestudio.info(f"Loading ControlNet ...")
|
||
|
|
||
|
controlnet_name_or_path: str
|
||
|
if self.cfg.control_type in ("normal", "input_normal"):
|
||
|
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_normalbae"
|
||
|
elif self.cfg.control_type == "canny":
|
||
|
controlnet_name_or_path = "lllyasviel/control_v11p_sd15_canny"
|
||
|
|
||
|
self.weights_dtype = (
|
||
|
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||
|
)
|
||
|
|
||
|
pipe_kwargs = {
|
||
|
"safety_checker": None,
|
||
|
"feature_extractor": None,
|
||
|
"requires_safety_checker": False,
|
||
|
"torch_dtype": self.weights_dtype,
|
||
|
"cache_dir": self.cfg.cache_dir,
|
||
|
}
|
||
|
|
||
|
controlnet = ControlNetModel.from_pretrained(
|
||
|
controlnet_name_or_path,
|
||
|
torch_dtype=self.weights_dtype,
|
||
|
cache_dir=self.cfg.cache_dir,
|
||
|
)
|
||
|
self.pipe = StableDiffusionControlNetPipeline.from_pretrained(
|
||
|
self.cfg.pretrained_model_name_or_path, controlnet=controlnet, **pipe_kwargs
|
||
|
).to(self.device)
|
||
|
self.scheduler = DDIMScheduler.from_pretrained(
|
||
|
self.cfg.ddim_scheduler_name_or_path,
|
||
|
subfolder="scheduler",
|
||
|
torch_dtype=self.weights_dtype,
|
||
|
cache_dir=self.cfg.cache_dir,
|
||
|
)
|
||
|
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
|
||
|
|
||
|
if self.cfg.enable_memory_efficient_attention:
|
||
|
if parse_version(torch.__version__) >= parse_version("2"):
|
||
|
threestudio.info(
|
||
|
"PyTorch2.0 uses memory efficient attention by default."
|
||
|
)
|
||
|
elif not is_xformers_available():
|
||
|
threestudio.warn(
|
||
|
"xformers is not available, memory efficient attention is not enabled."
|
||
|
)
|
||
|
else:
|
||
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||
|
|
||
|
if self.cfg.enable_sequential_cpu_offload:
|
||
|
self.pipe.enable_sequential_cpu_offload()
|
||
|
|
||
|
if self.cfg.enable_attention_slicing:
|
||
|
self.pipe.enable_attention_slicing(1)
|
||
|
|
||
|
if self.cfg.enable_channels_last_format:
|
||
|
self.pipe.unet.to(memory_format=torch.channels_last)
|
||
|
|
||
|
# Create model
|
||
|
self.vae = self.pipe.vae.eval()
|
||
|
self.unet = self.pipe.unet.eval()
|
||
|
self.controlnet = self.pipe.controlnet.eval()
|
||
|
|
||
|
if self.cfg.control_type == "normal":
|
||
|
self.preprocessor = NormalBaeDetector.from_pretrained(
|
||
|
"lllyasviel/Annotators"
|
||
|
)
|
||
|
self.preprocessor.model.to(self.device)
|
||
|
elif self.cfg.control_type == "canny":
|
||
|
self.preprocessor = CannyDetector()
|
||
|
|
||
|
for p in self.vae.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
for p in self.unet.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
|
||
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||
|
self.set_min_max_steps() # set to default value
|
||
|
|
||
|
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||
|
self.device
|
||
|
)
|
||
|
|
||
|
self.grad_clip_val: Optional[float] = None
|
||
|
|
||
|
if self.cfg.use_du:
|
||
|
if self.cfg.cache_du:
|
||
|
self.edit_frames = {}
|
||
|
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
|
||
|
|
||
|
threestudio.info(f"Loaded ControlNet!")
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||
|
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||
|
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def forward_controlnet(
|
||
|
self,
|
||
|
latents: Float[Tensor, "..."],
|
||
|
t: Float[Tensor, "..."],
|
||
|
image_cond: Float[Tensor, "..."],
|
||
|
condition_scale: float,
|
||
|
encoder_hidden_states: Float[Tensor, "..."],
|
||
|
) -> Float[Tensor, "..."]:
|
||
|
return self.controlnet(
|
||
|
latents.to(self.weights_dtype),
|
||
|
t.to(self.weights_dtype),
|
||
|
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||
|
controlnet_cond=image_cond.to(self.weights_dtype),
|
||
|
conditioning_scale=condition_scale,
|
||
|
return_dict=False,
|
||
|
)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def forward_control_unet(
|
||
|
self,
|
||
|
latents: Float[Tensor, "..."],
|
||
|
t: Float[Tensor, "..."],
|
||
|
encoder_hidden_states: Float[Tensor, "..."],
|
||
|
cross_attention_kwargs,
|
||
|
down_block_additional_residuals,
|
||
|
mid_block_additional_residual,
|
||
|
) -> Float[Tensor, "..."]:
|
||
|
input_dtype = latents.dtype
|
||
|
return self.unet(
|
||
|
latents.to(self.weights_dtype),
|
||
|
t.to(self.weights_dtype),
|
||
|
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
down_block_additional_residuals=down_block_additional_residuals,
|
||
|
mid_block_additional_residual=mid_block_additional_residual,
|
||
|
).sample.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def encode_images(
|
||
|
self, imgs: Float[Tensor, "B 3 H W"]
|
||
|
) -> Float[Tensor, "B 4 DH DW"]:
|
||
|
input_dtype = imgs.dtype
|
||
|
imgs = imgs * 2.0 - 1.0
|
||
|
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||
|
latents = posterior.sample() * self.vae.config.scaling_factor
|
||
|
return latents.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def encode_cond_images(
|
||
|
self, imgs: Float[Tensor, "B 3 H W"]
|
||
|
) -> Float[Tensor, "B 4 DH DW"]:
|
||
|
input_dtype = imgs.dtype
|
||
|
imgs = imgs * 2.0 - 1.0
|
||
|
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||
|
latents = posterior.mode()
|
||
|
uncond_image_latents = torch.zeros_like(latents)
|
||
|
latents = torch.cat([latents, latents, uncond_image_latents], dim=0)
|
||
|
return latents.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def decode_latents(
|
||
|
self, latents: Float[Tensor, "B 4 DH DW"]
|
||
|
) -> Float[Tensor, "B 3 H W"]:
|
||
|
input_dtype = latents.dtype
|
||
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||
|
image = self.vae.decode(latents.to(self.weights_dtype)).sample
|
||
|
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||
|
return image.to(input_dtype)
|
||
|
|
||
|
def edit_latents(
|
||
|
self,
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
latents: Float[Tensor, "B 4 DH DW"],
|
||
|
image_cond: Float[Tensor, "B 3 H W"],
|
||
|
t: Int[Tensor, "B"],
|
||
|
mask = None
|
||
|
) -> Float[Tensor, "B 4 DH DW"]:
|
||
|
self.scheduler.config.num_train_timesteps = t.item()
|
||
|
self.scheduler.set_timesteps(self.cfg.diffusion_steps)
|
||
|
if mask is not None:
|
||
|
mask = F.interpolate(mask, (latents.shape[-2], latents.shape[-1]), mode='bilinear')
|
||
|
with torch.no_grad():
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents)
|
||
|
latents = self.scheduler.add_noise(latents, noise, t) # type: ignore
|
||
|
|
||
|
# sections of code used from https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py
|
||
|
threestudio.debug("Start editing...")
|
||
|
for i, t in enumerate(self.scheduler.timesteps):
|
||
|
# predict the noise residual with unet, NO grad!
|
||
|
with torch.no_grad():
|
||
|
# pred noise
|
||
|
latent_model_input = torch.cat([latents] * 2)
|
||
|
(
|
||
|
down_block_res_samples,
|
||
|
mid_block_res_sample,
|
||
|
) = self.forward_controlnet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
image_cond=image_cond,
|
||
|
condition_scale=self.cfg.condition_scale,
|
||
|
)
|
||
|
|
||
|
noise_pred = self.forward_control_unet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
cross_attention_kwargs=None,
|
||
|
down_block_additional_residuals=down_block_res_samples,
|
||
|
mid_block_additional_residual=mid_block_res_sample,
|
||
|
)
|
||
|
# perform classifier-free guidance
|
||
|
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_text - noise_pred_uncond
|
||
|
)
|
||
|
if mask is not None:
|
||
|
noise_pred = mask * noise_pred + (1 - mask) * noise
|
||
|
# get previous sample, continue loop
|
||
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
||
|
threestudio.debug("Editing finished.")
|
||
|
return latents
|
||
|
|
||
|
def prepare_image_cond(self, cond_rgb: Float[Tensor, "B H W C"]):
|
||
|
if self.cfg.control_type == "normal":
|
||
|
cond_rgb = (
|
||
|
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||
|
)
|
||
|
detected_map = self.preprocessor(cond_rgb)
|
||
|
control = (
|
||
|
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||
|
)
|
||
|
control = control.unsqueeze(0)
|
||
|
control = control.permute(0, 3, 1, 2)
|
||
|
elif self.cfg.control_type == "canny":
|
||
|
cond_rgb = (
|
||
|
(cond_rgb[0].detach().cpu().numpy() * 255).astype(np.uint8).copy()
|
||
|
)
|
||
|
blurred_img = cv2.blur(cond_rgb, ksize=(5, 5))
|
||
|
detected_map = self.preprocessor(
|
||
|
blurred_img, self.cfg.canny_lower_bound, self.cfg.canny_upper_bound
|
||
|
)
|
||
|
control = (
|
||
|
torch.from_numpy(np.array(detected_map)).float().to(self.device) / 255.0
|
||
|
)
|
||
|
# control = control.unsqueeze(-1).repeat(1, 1, 3)
|
||
|
control = control.unsqueeze(0)
|
||
|
control = control.permute(0, 3, 1, 2)
|
||
|
elif self.cfg.control_type == "input_normal":
|
||
|
cond_rgb[..., 0] = (
|
||
|
1 - cond_rgb[..., 0]
|
||
|
) # Flip the sign on the x-axis to match bae system
|
||
|
control = cond_rgb.permute(0, 3, 1, 2)
|
||
|
else:
|
||
|
raise ValueError(f"Unknown control type: {self.cfg.control_type}")
|
||
|
|
||
|
return control
|
||
|
|
||
|
def compute_grad_sds(
|
||
|
self,
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
latents: Float[Tensor, "B 4 DH DW"],
|
||
|
image_cond: Float[Tensor, "B 3 H W"],
|
||
|
t: Int[Tensor, "B"],
|
||
|
):
|
||
|
with torch.no_grad():
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents) # TODO: use torch generator
|
||
|
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||
|
# pred noise
|
||
|
latent_model_input = torch.cat([latents_noisy] * 2)
|
||
|
down_block_res_samples, mid_block_res_sample = self.forward_controlnet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
image_cond=image_cond,
|
||
|
condition_scale=self.cfg.condition_scale,
|
||
|
)
|
||
|
|
||
|
noise_pred = self.forward_control_unet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
cross_attention_kwargs=None,
|
||
|
down_block_additional_residuals=down_block_res_samples,
|
||
|
mid_block_additional_residual=mid_block_res_sample,
|
||
|
)
|
||
|
|
||
|
# perform classifier-free guidance
|
||
|
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_text - noise_pred_uncond
|
||
|
)
|
||
|
|
||
|
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||
|
grad = w * (noise_pred - noise)
|
||
|
return grad
|
||
|
|
||
|
def compute_grad_du(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 H W"],
|
||
|
rgb_BCHW_HW8: Float[Tensor, "B 3 RH RW"],
|
||
|
cond_feature: Float[Tensor, "B 3 RH RW"],
|
||
|
cond_rgb: Float[Tensor, "B H W 3"],
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
mask = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
batch_size, _, RH, RW = cond_feature.shape
|
||
|
assert batch_size == 1
|
||
|
|
||
|
origin_gt_rgb = F.interpolate(
|
||
|
cond_rgb.permute(0, 3, 1, 2), (RH, RW), mode="bilinear"
|
||
|
).permute(0, 2, 3, 1)
|
||
|
need_diffusion = (
|
||
|
self.global_step % self.cfg.per_du_step == 0
|
||
|
and self.global_step > self.cfg.start_du_step
|
||
|
)
|
||
|
if self.cfg.cache_du:
|
||
|
if torch.is_tensor(kwargs["index"]):
|
||
|
batch_index = kwargs["index"].item()
|
||
|
else:
|
||
|
batch_index = kwargs["index"]
|
||
|
if (
|
||
|
not (batch_index in self.edit_frames)
|
||
|
) and self.global_step > self.cfg.start_du_step:
|
||
|
need_diffusion = True
|
||
|
need_loss = self.cfg.cache_du or need_diffusion
|
||
|
guidance_out = {}
|
||
|
|
||
|
if need_diffusion:
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step,
|
||
|
[1],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
print("t:", t)
|
||
|
edit_latents = self.edit_latents(text_embeddings, latents, cond_feature, t, mask)
|
||
|
edit_images = self.decode_latents(edit_latents)
|
||
|
edit_images = F.interpolate(
|
||
|
edit_images, (RH, RW), mode="bilinear"
|
||
|
).permute(0, 2, 3, 1)
|
||
|
self.edit_images = edit_images
|
||
|
if self.cfg.cache_du:
|
||
|
self.edit_frames[batch_index] = edit_images.detach().cpu()
|
||
|
|
||
|
if need_loss:
|
||
|
if self.cfg.cache_du:
|
||
|
if batch_index in self.edit_frames:
|
||
|
gt_rgb = self.edit_frames[batch_index].to(cond_feature.device)
|
||
|
else:
|
||
|
gt_rgb = origin_gt_rgb
|
||
|
else:
|
||
|
gt_rgb = edit_images
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
|
||
|
temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||
|
cv2.imwrite(".threestudio_cache/test.jpg", temp[:, :, ::-1])
|
||
|
|
||
|
guidance_out.update(
|
||
|
{
|
||
|
"loss_l1": torch.nn.functional.l1_loss(
|
||
|
rgb_BCHW_HW8, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
|
||
|
),
|
||
|
"loss_p": self.perceptual_loss(
|
||
|
rgb_BCHW_HW8.contiguous(),
|
||
|
gt_rgb.permute(0, 3, 1, 2).contiguous(),
|
||
|
).sum(),
|
||
|
}
|
||
|
)
|
||
|
|
||
|
return guidance_out
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
rgb: Float[Tensor, "B H W C"],
|
||
|
cond_rgb: Float[Tensor, "B H W C"],
|
||
|
prompt_utils: PromptProcessorOutput,
|
||
|
mask = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
batch_size, H, W, _ = rgb.shape
|
||
|
assert batch_size == 1
|
||
|
assert rgb.shape[:-1] == cond_rgb.shape[:-1]
|
||
|
|
||
|
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||
|
if mask is not None: mask = mask.permute(0, 3, 1, 2)
|
||
|
latents: Float[Tensor, "B 4 DH DW"]
|
||
|
if self.cfg.fixed_size > 0:
|
||
|
RH, RW = self.cfg.fixed_size, self.cfg.fixed_size
|
||
|
else:
|
||
|
RH, RW = H // 8 * 8, W // 8 * 8
|
||
|
rgb_BCHW_HW8 = F.interpolate(
|
||
|
rgb_BCHW, (RH, RW), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
latents = self.encode_images(rgb_BCHW_HW8)
|
||
|
|
||
|
image_cond = self.prepare_image_cond(cond_rgb)
|
||
|
image_cond = F.interpolate(
|
||
|
image_cond, (RH, RW), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
|
||
|
temp = torch.zeros(1).to(rgb.device)
|
||
|
azimuth = kwargs.get("azimuth", temp)
|
||
|
camera_distance = kwargs.get("camera_distance", temp)
|
||
|
view_dependent_prompt = kwargs.get("view_dependent_prompt", False)
|
||
|
text_embeddings = prompt_utils.get_text_embeddings(temp, azimuth, camera_distance, view_dependent_prompt) # FIXME: change to view-conditioned prompt
|
||
|
|
||
|
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step + 1,
|
||
|
[batch_size],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
|
||
|
guidance_out = {}
|
||
|
if self.cfg.use_sds:
|
||
|
grad = self.compute_grad_sds(text_embeddings, latents, image_cond, t)
|
||
|
grad = torch.nan_to_num(grad)
|
||
|
if self.grad_clip_val is not None:
|
||
|
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||
|
target = (latents - grad).detach()
|
||
|
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||
|
guidance_out.update(
|
||
|
{
|
||
|
"loss_sds": loss_sds,
|
||
|
"grad_norm": grad.norm(),
|
||
|
"min_step": self.min_step,
|
||
|
"max_step": self.max_step,
|
||
|
}
|
||
|
)
|
||
|
|
||
|
if self.cfg.use_du:
|
||
|
grad = self.compute_grad_du(
|
||
|
latents, rgb_BCHW_HW8, image_cond, cond_rgb, text_embeddings, mask, **kwargs
|
||
|
)
|
||
|
guidance_out.update(grad)
|
||
|
|
||
|
return guidance_out
|
||
|
|
||
|
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||
|
# clip grad for stable training as demonstrated in
|
||
|
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||
|
# http://arxiv.org/abs/2303.15413
|
||
|
if self.cfg.grad_clip is not None:
|
||
|
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||
|
|
||
|
self.set_min_max_steps(
|
||
|
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||
|
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||
|
)
|
||
|
|
||
|
self.global_step = global_step
|