mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
1003 lines
39 KiB
Python
1003 lines
39 KiB
Python
|
import random
|
||
|
from contextlib import contextmanager
|
||
|
from dataclasses import dataclass, field
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from diffusers import (
|
||
|
DDIMScheduler,
|
||
|
DDPMScheduler,
|
||
|
DPMSolverMultistepScheduler,
|
||
|
StableDiffusionPipeline,
|
||
|
UNet2DConditionModel,
|
||
|
)
|
||
|
from diffusers.loaders import AttnProcsLayers
|
||
|
from diffusers.models.attention_processor import LoRAAttnProcessor
|
||
|
from diffusers.models.embeddings import TimestepEmbedding
|
||
|
from diffusers.utils.import_utils import is_xformers_available
|
||
|
|
||
|
import threestudio
|
||
|
from threestudio.models.prompt_processors.base import PromptProcessorOutput
|
||
|
from threestudio.utils.base import BaseModule
|
||
|
from threestudio.utils.misc import C, cleanup, parse_version
|
||
|
from threestudio.utils.perceptual import PerceptualLoss
|
||
|
from threestudio.utils.typing import *
|
||
|
|
||
|
|
||
|
class ToWeightsDType(nn.Module):
|
||
|
def __init__(self, module: nn.Module, dtype: torch.dtype):
|
||
|
super().__init__()
|
||
|
self.module = module
|
||
|
self.dtype = dtype
|
||
|
|
||
|
def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]:
|
||
|
return self.module(x).to(self.dtype)
|
||
|
|
||
|
|
||
|
@threestudio.register("stable-diffusion-vsd-guidance")
|
||
|
class StableDiffusionVSDGuidance(BaseModule):
|
||
|
@dataclass
|
||
|
class Config(BaseModule.Config):
|
||
|
cache_dir: Optional[str] = None
|
||
|
local_files_only: Optional[bool] = False
|
||
|
pretrained_model_name_or_path: str = "stabilityai/stable-diffusion-2-1-base"
|
||
|
pretrained_model_name_or_path_lora: str = "stabilityai/stable-diffusion-2-1"
|
||
|
enable_memory_efficient_attention: bool = False
|
||
|
enable_sequential_cpu_offload: bool = False
|
||
|
enable_attention_slicing: bool = False
|
||
|
enable_channels_last_format: bool = False
|
||
|
guidance_scale: float = 7.5
|
||
|
guidance_scale_lora: float = 1.0
|
||
|
grad_clip: Optional[
|
||
|
Any
|
||
|
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
|
||
|
half_precision_weights: bool = True
|
||
|
lora_cfg_training: bool = True
|
||
|
lora_n_timestamp_samples: int = 1
|
||
|
|
||
|
min_step_percent: float = 0.02
|
||
|
max_step_percent: float = 0.98
|
||
|
|
||
|
view_dependent_prompting: bool = True
|
||
|
camera_condition_type: str = "extrinsics"
|
||
|
|
||
|
use_du: bool = False
|
||
|
per_du_step: int = 10
|
||
|
start_du_step: int = 0
|
||
|
du_diffusion_steps: int = 20
|
||
|
|
||
|
use_bsd: bool = False
|
||
|
|
||
|
cfg: Config
|
||
|
|
||
|
def configure(self) -> None:
|
||
|
threestudio.info(f"Loading Stable Diffusion ...")
|
||
|
|
||
|
self.weights_dtype = (
|
||
|
torch.float16 if self.cfg.half_precision_weights else torch.float32
|
||
|
)
|
||
|
|
||
|
pipe_kwargs = {
|
||
|
"tokenizer": None,
|
||
|
"safety_checker": None,
|
||
|
"feature_extractor": None,
|
||
|
"requires_safety_checker": False,
|
||
|
"torch_dtype": self.weights_dtype,
|
||
|
"cache_dir": self.cfg.cache_dir,
|
||
|
"local_files_only": self.cfg.local_files_only
|
||
|
}
|
||
|
|
||
|
pipe_lora_kwargs = {
|
||
|
"tokenizer": None,
|
||
|
"safety_checker": None,
|
||
|
"feature_extractor": None,
|
||
|
"requires_safety_checker": False,
|
||
|
"torch_dtype": self.weights_dtype,
|
||
|
"cache_dir": self.cfg.cache_dir,
|
||
|
"local_files_only": self.cfg.local_files_only
|
||
|
}
|
||
|
|
||
|
@dataclass
|
||
|
class SubModules:
|
||
|
pipe: StableDiffusionPipeline
|
||
|
pipe_lora: StableDiffusionPipeline
|
||
|
|
||
|
pipe = StableDiffusionPipeline.from_pretrained(
|
||
|
self.cfg.pretrained_model_name_or_path,
|
||
|
**pipe_kwargs,
|
||
|
).to(self.device)
|
||
|
if (
|
||
|
self.cfg.pretrained_model_name_or_path
|
||
|
== self.cfg.pretrained_model_name_or_path_lora
|
||
|
):
|
||
|
self.single_model = True
|
||
|
pipe_lora = pipe
|
||
|
else:
|
||
|
self.single_model = False
|
||
|
pipe_lora = StableDiffusionPipeline.from_pretrained(
|
||
|
self.cfg.pretrained_model_name_or_path_lora,
|
||
|
**pipe_lora_kwargs,
|
||
|
).to(self.device)
|
||
|
del pipe_lora.vae
|
||
|
cleanup()
|
||
|
pipe_lora.vae = pipe.vae
|
||
|
self.submodules = SubModules(pipe=pipe, pipe_lora=pipe_lora)
|
||
|
|
||
|
if self.cfg.enable_memory_efficient_attention:
|
||
|
if parse_version(torch.__version__) >= parse_version("2"):
|
||
|
threestudio.info(
|
||
|
"PyTorch2.0 uses memory efficient attention by default."
|
||
|
)
|
||
|
elif not is_xformers_available():
|
||
|
threestudio.warn(
|
||
|
"xformers is not available, memory efficient attention is not enabled."
|
||
|
)
|
||
|
else:
|
||
|
self.pipe.enable_xformers_memory_efficient_attention()
|
||
|
self.pipe_lora.enable_xformers_memory_efficient_attention()
|
||
|
|
||
|
if self.cfg.enable_sequential_cpu_offload:
|
||
|
self.pipe.enable_sequential_cpu_offload()
|
||
|
self.pipe_lora.enable_sequential_cpu_offload()
|
||
|
|
||
|
if self.cfg.enable_attention_slicing:
|
||
|
self.pipe.enable_attention_slicing(1)
|
||
|
self.pipe_lora.enable_attention_slicing(1)
|
||
|
|
||
|
if self.cfg.enable_channels_last_format:
|
||
|
self.pipe.unet.to(memory_format=torch.channels_last)
|
||
|
self.pipe_lora.unet.to(memory_format=torch.channels_last)
|
||
|
|
||
|
del self.pipe.text_encoder
|
||
|
if not self.single_model:
|
||
|
del self.pipe_lora.text_encoder
|
||
|
cleanup()
|
||
|
|
||
|
for p in self.vae.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
for p in self.unet.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
for p in self.unet_lora.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
|
||
|
# FIXME: hard-coded dims
|
||
|
self.camera_embedding = ToWeightsDType(
|
||
|
TimestepEmbedding(16, 1280), self.weights_dtype
|
||
|
).to(self.device)
|
||
|
self.unet_lora.class_embedding = self.camera_embedding
|
||
|
|
||
|
# set up LoRA layers
|
||
|
lora_attn_procs = {}
|
||
|
for name in self.unet_lora.attn_processors.keys():
|
||
|
cross_attention_dim = (
|
||
|
None
|
||
|
if name.endswith("attn1.processor")
|
||
|
else self.unet_lora.config.cross_attention_dim
|
||
|
)
|
||
|
if name.startswith("mid_block"):
|
||
|
hidden_size = self.unet_lora.config.block_out_channels[-1]
|
||
|
elif name.startswith("up_blocks"):
|
||
|
block_id = int(name[len("up_blocks.")])
|
||
|
hidden_size = list(reversed(self.unet_lora.config.block_out_channels))[
|
||
|
block_id
|
||
|
]
|
||
|
elif name.startswith("down_blocks"):
|
||
|
block_id = int(name[len("down_blocks.")])
|
||
|
hidden_size = self.unet_lora.config.block_out_channels[block_id]
|
||
|
|
||
|
lora_attn_procs[name] = LoRAAttnProcessor(
|
||
|
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim
|
||
|
)
|
||
|
|
||
|
self.unet_lora.set_attn_processor(lora_attn_procs)
|
||
|
|
||
|
self.lora_layers = AttnProcsLayers(self.unet_lora.attn_processors).to(
|
||
|
self.device
|
||
|
)
|
||
|
self.lora_layers._load_state_dict_pre_hooks.clear()
|
||
|
self.lora_layers._state_dict_hooks.clear()
|
||
|
|
||
|
self.scheduler = DDIMScheduler.from_pretrained( # DDPM
|
||
|
self.cfg.pretrained_model_name_or_path,
|
||
|
subfolder="scheduler",
|
||
|
torch_dtype=self.weights_dtype,
|
||
|
cache_dir=self.cfg.cache_dir,
|
||
|
local_files_only=self.cfg.local_files_only,
|
||
|
)
|
||
|
|
||
|
self.scheduler_lora = DDPMScheduler.from_pretrained(
|
||
|
self.cfg.pretrained_model_name_or_path_lora,
|
||
|
subfolder="scheduler",
|
||
|
torch_dtype=self.weights_dtype,
|
||
|
cache_dir=self.cfg.cache_dir,
|
||
|
local_files_only=self.cfg.local_files_only,
|
||
|
)
|
||
|
|
||
|
self.scheduler_sample = DPMSolverMultistepScheduler.from_config(
|
||
|
self.pipe.scheduler.config
|
||
|
)
|
||
|
self.scheduler_lora_sample = DPMSolverMultistepScheduler.from_config(
|
||
|
self.pipe_lora.scheduler.config
|
||
|
)
|
||
|
|
||
|
self.pipe.scheduler = self.scheduler
|
||
|
self.pipe_lora.scheduler = self.scheduler_lora
|
||
|
|
||
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||
|
self.set_min_max_steps() # set to default value
|
||
|
|
||
|
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||
|
self.device
|
||
|
)
|
||
|
|
||
|
self.scheduler.alphas_cumprod = self.scheduler.alphas_cumprod.to(self.device)
|
||
|
|
||
|
self.grad_clip_val: Optional[float] = None
|
||
|
|
||
|
if self.cfg.use_du:
|
||
|
self.perceptual_loss = PerceptualLoss().eval().to(self.device)
|
||
|
for p in self.perceptual_loss.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
|
||
|
threestudio.info(f"Loaded Stable Diffusion!")
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||
|
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||
|
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||
|
|
||
|
@property
|
||
|
def pipe(self):
|
||
|
return self.submodules.pipe
|
||
|
|
||
|
@property
|
||
|
def pipe_lora(self):
|
||
|
return self.submodules.pipe_lora
|
||
|
|
||
|
@property
|
||
|
def unet(self):
|
||
|
return self.submodules.pipe.unet
|
||
|
|
||
|
@property
|
||
|
def unet_lora(self):
|
||
|
return self.submodules.pipe_lora.unet
|
||
|
|
||
|
@property
|
||
|
def vae(self):
|
||
|
return self.submodules.pipe.vae
|
||
|
|
||
|
@property
|
||
|
def vae_lora(self):
|
||
|
return self.submodules.pipe_lora.vae
|
||
|
|
||
|
@torch.no_grad()
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def _sample(
|
||
|
self,
|
||
|
pipe: StableDiffusionPipeline,
|
||
|
sample_scheduler: DPMSolverMultistepScheduler,
|
||
|
text_embeddings: Float[Tensor, "BB N Nf"],
|
||
|
num_inference_steps: int,
|
||
|
guidance_scale: float,
|
||
|
num_images_per_prompt: int = 1,
|
||
|
height: Optional[int] = None,
|
||
|
width: Optional[int] = None,
|
||
|
class_labels: Optional[Float[Tensor, "BB 16"]] = None,
|
||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||
|
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||
|
latents_inp: Optional[Float[Tensor, "..."]] = None,
|
||
|
) -> Float[Tensor, "B H W 3"]:
|
||
|
vae_scale_factor = 2 ** (len(pipe.vae.config.block_out_channels) - 1)
|
||
|
height = height or pipe.unet.config.sample_size * vae_scale_factor
|
||
|
width = width or pipe.unet.config.sample_size * vae_scale_factor
|
||
|
batch_size = text_embeddings.shape[0] // 2
|
||
|
device = self.device
|
||
|
|
||
|
sample_scheduler.set_timesteps(num_inference_steps, device=device)
|
||
|
timesteps = sample_scheduler.timesteps
|
||
|
num_channels_latents = pipe.unet.config.in_channels
|
||
|
|
||
|
if latents_inp is not None:
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step,
|
||
|
[1],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
noise = torch.randn_like(latents_inp)
|
||
|
init_timestep = max(1, min(int(num_inference_steps * t[0].item() / self.num_train_timesteps), num_inference_steps))
|
||
|
t_start = max(num_inference_steps - init_timestep, 0)
|
||
|
latent_timestep = sample_scheduler.timesteps[t_start : t_start + 1].repeat(batch_size)
|
||
|
latents = sample_scheduler.add_noise(latents_inp, noise, latent_timestep).to(self.weights_dtype)
|
||
|
|
||
|
else:
|
||
|
latents = pipe.prepare_latents(
|
||
|
batch_size * num_images_per_prompt,
|
||
|
num_channels_latents,
|
||
|
height,
|
||
|
width,
|
||
|
self.weights_dtype,
|
||
|
device,
|
||
|
generator,
|
||
|
)
|
||
|
t_start = 0
|
||
|
|
||
|
for i, t in enumerate(timesteps[t_start:]):
|
||
|
# expand the latents if we are doing classifier free guidance
|
||
|
latent_model_input = torch.cat([latents] * 2)
|
||
|
latent_model_input = sample_scheduler.scale_model_input(
|
||
|
latent_model_input, t
|
||
|
)
|
||
|
|
||
|
# predict the noise residual
|
||
|
if class_labels is None:
|
||
|
with self.disable_unet_class_embedding(pipe.unet) as unet:
|
||
|
noise_pred = unet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings.to(self.weights_dtype),
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
).sample
|
||
|
else:
|
||
|
noise_pred = pipe.unet(
|
||
|
latent_model_input,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings.to(self.weights_dtype),
|
||
|
class_labels=class_labels,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
).sample
|
||
|
|
||
|
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + guidance_scale * (
|
||
|
noise_pred_text - noise_pred_uncond
|
||
|
)
|
||
|
|
||
|
# compute the previous noisy sample x_t -> x_t-1
|
||
|
latents = sample_scheduler.step(noise_pred, t, latents).prev_sample
|
||
|
|
||
|
latents = 1 / pipe.vae.config.scaling_factor * latents
|
||
|
images = pipe.vae.decode(latents).sample
|
||
|
images = (images / 2 + 0.5).clamp(0, 1)
|
||
|
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
||
|
images = images.permute(0, 2, 3, 1).float()
|
||
|
return images
|
||
|
|
||
|
def sample(
|
||
|
self,
|
||
|
prompt_utils: PromptProcessorOutput,
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
seed: int = 0,
|
||
|
**kwargs,
|
||
|
) -> Float[Tensor, "N H W 3"]:
|
||
|
# view-dependent text embeddings
|
||
|
text_embeddings_vd = prompt_utils.get_text_embeddings(
|
||
|
elevation,
|
||
|
azimuth,
|
||
|
camera_distances,
|
||
|
view_dependent_prompting=self.cfg.view_dependent_prompting,
|
||
|
)
|
||
|
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
|
||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||
|
|
||
|
return self._sample(
|
||
|
pipe=self.pipe,
|
||
|
sample_scheduler=self.scheduler_sample,
|
||
|
text_embeddings=text_embeddings_vd,
|
||
|
num_inference_steps=25,
|
||
|
guidance_scale=self.cfg.guidance_scale,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
generator=generator,
|
||
|
)
|
||
|
|
||
|
def sample_img2img(
|
||
|
self,
|
||
|
rgb: Float[Tensor, "B H W C"],
|
||
|
prompt_utils: PromptProcessorOutput,
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
seed: int = 0,
|
||
|
mask = None,
|
||
|
**kwargs,
|
||
|
) -> Float[Tensor, "N H W 3"]:
|
||
|
|
||
|
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||
|
mask_BCHW = mask.permute(0, 3, 1, 2)
|
||
|
latents = self.get_latents(rgb_BCHW, rgb_as_latents=False) # TODO: 有部分概率是du或者ref image
|
||
|
|
||
|
# view-dependent text embeddings
|
||
|
text_embeddings_vd = prompt_utils.get_text_embeddings(
|
||
|
elevation,
|
||
|
azimuth,
|
||
|
camera_distances,
|
||
|
view_dependent_prompting=self.cfg.view_dependent_prompting,
|
||
|
)
|
||
|
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
|
||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||
|
|
||
|
# return self._sample(
|
||
|
# pipe=self.pipe,
|
||
|
# sample_scheduler=self.scheduler_sample,
|
||
|
# text_embeddings=text_embeddings_vd,
|
||
|
# num_inference_steps=25,
|
||
|
# guidance_scale=self.cfg.guidance_scale,
|
||
|
# cross_attention_kwargs=cross_attention_kwargs,
|
||
|
# generator=generator,
|
||
|
# latents_inp=latents
|
||
|
# )
|
||
|
|
||
|
return self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask_BCHW)
|
||
|
|
||
|
def sample_lora(
|
||
|
self,
|
||
|
prompt_utils: PromptProcessorOutput,
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
mvp_mtx: Float[Tensor, "B 4 4"],
|
||
|
c2w: Float[Tensor, "B 4 4"],
|
||
|
seed: int = 0,
|
||
|
**kwargs,
|
||
|
) -> Float[Tensor, "N H W 3"]:
|
||
|
# input text embeddings, view-independent
|
||
|
text_embeddings = prompt_utils.get_text_embeddings(
|
||
|
elevation, azimuth, camera_distances, view_dependent_prompting=False
|
||
|
)
|
||
|
|
||
|
if self.cfg.camera_condition_type == "extrinsics":
|
||
|
camera_condition = c2w
|
||
|
elif self.cfg.camera_condition_type == "mvp":
|
||
|
camera_condition = mvp_mtx
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
|
||
|
)
|
||
|
|
||
|
B = elevation.shape[0]
|
||
|
camera_condition_cfg = torch.cat(
|
||
|
[
|
||
|
camera_condition.view(B, -1),
|
||
|
torch.zeros_like(camera_condition.view(B, -1)),
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
|
||
|
generator = torch.Generator(device=self.device).manual_seed(seed)
|
||
|
return self._sample(
|
||
|
sample_scheduler=self.scheduler_lora_sample,
|
||
|
pipe=self.pipe_lora,
|
||
|
text_embeddings=text_embeddings,
|
||
|
num_inference_steps=25,
|
||
|
guidance_scale=self.cfg.guidance_scale_lora,
|
||
|
class_labels=camera_condition_cfg,
|
||
|
cross_attention_kwargs={"scale": 1.0},
|
||
|
generator=generator,
|
||
|
)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def forward_unet(
|
||
|
self,
|
||
|
unet: UNet2DConditionModel,
|
||
|
latents: Float[Tensor, "..."],
|
||
|
t: Float[Tensor, "..."],
|
||
|
encoder_hidden_states: Float[Tensor, "..."],
|
||
|
class_labels: Optional[Float[Tensor, "B 16"]] = None,
|
||
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||
|
) -> Float[Tensor, "..."]:
|
||
|
input_dtype = latents.dtype
|
||
|
return unet(
|
||
|
latents.to(self.weights_dtype),
|
||
|
t.to(self.weights_dtype),
|
||
|
encoder_hidden_states=encoder_hidden_states.to(self.weights_dtype),
|
||
|
class_labels=class_labels,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
).sample.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def encode_images(
|
||
|
self, imgs: Float[Tensor, "B 3 512 512"]
|
||
|
) -> Float[Tensor, "B 4 64 64"]:
|
||
|
input_dtype = imgs.dtype
|
||
|
imgs = imgs * 2.0 - 1.0
|
||
|
posterior = self.vae.encode(imgs.to(self.weights_dtype)).latent_dist
|
||
|
latents = posterior.sample() * self.vae.config.scaling_factor
|
||
|
return latents.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def decode_latents(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 H W"],
|
||
|
latent_height: int = 64,
|
||
|
latent_width: int = 64,
|
||
|
) -> Float[Tensor, "B 3 512 512"]:
|
||
|
input_dtype = latents.dtype
|
||
|
latents = F.interpolate(
|
||
|
latents, (latent_height, latent_width), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
latents = 1 / self.vae.config.scaling_factor * latents
|
||
|
image = self.vae.decode(latents.to(self.weights_dtype)).sample
|
||
|
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||
|
return image.to(input_dtype)
|
||
|
|
||
|
@contextmanager
|
||
|
def disable_unet_class_embedding(self, unet: UNet2DConditionModel):
|
||
|
class_embedding = unet.class_embedding
|
||
|
try:
|
||
|
unet.class_embedding = None
|
||
|
yield unet
|
||
|
finally:
|
||
|
unet.class_embedding = class_embedding
|
||
|
|
||
|
def compute_grad_du(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 64 64"],
|
||
|
rgb_BCHW_512: Float[Tensor, "B 3 512 512"],
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
mask = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
batch_size, _, _, _ = latents.shape
|
||
|
rgb_BCHW_512 = F.interpolate(rgb_BCHW_512, (512, 512), mode="bilinear")
|
||
|
assert batch_size == 1
|
||
|
need_diffusion = (
|
||
|
self.global_step % self.cfg.per_du_step == 0
|
||
|
and self.global_step > self.cfg.start_du_step
|
||
|
)
|
||
|
guidance_out = {}
|
||
|
|
||
|
if need_diffusion:
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step,
|
||
|
[1],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
self.scheduler.config.num_train_timesteps = t.item()
|
||
|
self.scheduler.set_timesteps(self.cfg.du_diffusion_steps)
|
||
|
|
||
|
if mask is not None:
|
||
|
mask = F.interpolate(mask, (64, 64), mode="bilinear", antialias=True)
|
||
|
with torch.no_grad():
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents)
|
||
|
latents = self.scheduler.add_noise(latents, noise, t) # type: ignore
|
||
|
for i, timestep in enumerate(self.scheduler.timesteps):
|
||
|
# predict the noise residual with unet, NO grad!
|
||
|
with torch.no_grad():
|
||
|
latent_model_input = torch.cat([latents] * 2)
|
||
|
with self.disable_unet_class_embedding(self.unet) as unet:
|
||
|
cross_attention_kwargs = (
|
||
|
{"scale": 0.0} if self.single_model else None
|
||
|
)
|
||
|
noise_pred = self.forward_unet(
|
||
|
unet,
|
||
|
latent_model_input,
|
||
|
timestep,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
)
|
||
|
# perform classifier-free guidance
|
||
|
noise_pred_text, noise_pred_uncond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_text - noise_pred_uncond
|
||
|
)
|
||
|
if mask is not None:
|
||
|
noise_pred = mask * noise_pred + (1 - mask) * noise
|
||
|
# get previous sample, continue loop
|
||
|
latents = self.scheduler.step(
|
||
|
noise_pred, timestep, latents
|
||
|
).prev_sample
|
||
|
edit_images = self.decode_latents(latents)
|
||
|
edit_images = F.interpolate(
|
||
|
edit_images, (512, 512), mode="bilinear"
|
||
|
).permute(0, 2, 3, 1)
|
||
|
gt_rgb = edit_images
|
||
|
# import cv2
|
||
|
# import numpy as np
|
||
|
# mask_temp = mask_BCHW_512.permute(0,2,3,1)
|
||
|
# # edit_images = edit_images * mask_temp + torch.rand(3)[None, None, None].to(self.device).repeat(*edit_images.shape[:-1],1) * (1 - mask_temp)
|
||
|
# temp = (edit_images.detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||
|
# cv2.imwrite(f".threestudio_cache/pig_sd_noise_500/test_{kwargs.get('name', 'none')}.jpg", temp[:, :, ::-1])
|
||
|
|
||
|
guidance_out.update(
|
||
|
{
|
||
|
"loss_l1": torch.nn.functional.l1_loss(
|
||
|
rgb_BCHW_512, gt_rgb.permute(0, 3, 1, 2), reduction="sum"
|
||
|
),
|
||
|
"loss_p": self.perceptual_loss(
|
||
|
rgb_BCHW_512.contiguous(),
|
||
|
gt_rgb.permute(0, 3, 1, 2).contiguous(),
|
||
|
).sum(),
|
||
|
"edit_image": edit_images.detach()
|
||
|
}
|
||
|
)
|
||
|
|
||
|
return guidance_out
|
||
|
|
||
|
def compute_grad_vsd(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 64 64"],
|
||
|
text_embeddings_vd: Float[Tensor, "BB 77 768"],
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
camera_condition: Float[Tensor, "B 4 4"],
|
||
|
):
|
||
|
B = latents.shape[0]
|
||
|
|
||
|
with torch.no_grad():
|
||
|
# random timestamp
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step + 1,
|
||
|
[B],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents)
|
||
|
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||
|
# pred noise
|
||
|
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||
|
if self.cfg.use_bsd:
|
||
|
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
|
||
|
noise_pred_pretrain = self.forward_unet(
|
||
|
self.unet,
|
||
|
latent_model_input,
|
||
|
torch.cat([t] * 2),
|
||
|
encoder_hidden_states=text_embeddings_vd,
|
||
|
class_labels=torch.cat(
|
||
|
[
|
||
|
camera_condition.view(B, -1),
|
||
|
torch.zeros_like(camera_condition.view(B, -1)),
|
||
|
],
|
||
|
dim=0,
|
||
|
),
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
)
|
||
|
else:
|
||
|
with self.disable_unet_class_embedding(self.unet) as unet:
|
||
|
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
|
||
|
noise_pred_pretrain = self.forward_unet(
|
||
|
unet,
|
||
|
latent_model_input,
|
||
|
torch.cat([t] * 2),
|
||
|
encoder_hidden_states=text_embeddings_vd,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
)
|
||
|
|
||
|
# use view-independent text embeddings in LoRA
|
||
|
text_embeddings_cond, _ = text_embeddings.chunk(2)
|
||
|
noise_pred_est = self.forward_unet(
|
||
|
self.unet_lora,
|
||
|
latent_model_input,
|
||
|
torch.cat([t] * 2),
|
||
|
encoder_hidden_states=torch.cat([text_embeddings_cond] * 2),
|
||
|
class_labels=torch.cat(
|
||
|
[
|
||
|
camera_condition.view(B, -1),
|
||
|
torch.zeros_like(camera_condition.view(B, -1)),
|
||
|
],
|
||
|
dim=0,
|
||
|
),
|
||
|
cross_attention_kwargs={"scale": 1.0},
|
||
|
)
|
||
|
|
||
|
(
|
||
|
noise_pred_pretrain_text,
|
||
|
noise_pred_pretrain_uncond,
|
||
|
) = noise_pred_pretrain.chunk(2)
|
||
|
|
||
|
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
|
||
|
noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_pretrain_text - noise_pred_pretrain_uncond
|
||
|
)
|
||
|
|
||
|
# TODO: more general cases
|
||
|
assert self.scheduler.config.prediction_type == "epsilon"
|
||
|
if self.scheduler_lora.config.prediction_type == "v_prediction":
|
||
|
alphas_cumprod = self.scheduler_lora.alphas_cumprod.to(
|
||
|
device=latents_noisy.device, dtype=latents_noisy.dtype
|
||
|
)
|
||
|
alpha_t = alphas_cumprod[t] ** 0.5
|
||
|
sigma_t = (1 - alphas_cumprod[t]) ** 0.5
|
||
|
|
||
|
noise_pred_est = latent_model_input * torch.cat([sigma_t] * 2, dim=0).view(
|
||
|
-1, 1, 1, 1
|
||
|
) + noise_pred_est * torch.cat([alpha_t] * 2, dim=0).view(-1, 1, 1, 1)
|
||
|
|
||
|
(
|
||
|
noise_pred_est_camera,
|
||
|
noise_pred_est_uncond,
|
||
|
) = noise_pred_est.chunk(2)
|
||
|
|
||
|
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
|
||
|
noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * (
|
||
|
noise_pred_est_camera - noise_pred_est_uncond
|
||
|
)
|
||
|
|
||
|
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||
|
|
||
|
grad = w * (noise_pred_pretrain - noise_pred_est)
|
||
|
return grad
|
||
|
|
||
|
def compute_grad_vsd_hifa(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 64 64"],
|
||
|
text_embeddings_vd: Float[Tensor, "BB 77 768"],
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
camera_condition: Float[Tensor, "B 4 4"],
|
||
|
mask=None,
|
||
|
):
|
||
|
B, _, DH, DW = latents.shape
|
||
|
rgb = self.decode_latents(latents)
|
||
|
self.name = "hifa"
|
||
|
|
||
|
if mask is not None:
|
||
|
mask = F.interpolate(mask, (DH, DW), mode="bilinear", antialias=True)
|
||
|
with torch.no_grad():
|
||
|
# random timestamp
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step + 1,
|
||
|
[B],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
w = (1 - self.alphas[t]).view(-1, 1, 1, 1)
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents)
|
||
|
latents_noisy = self.scheduler_sample.add_noise(latents, noise, t)
|
||
|
latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents, noise, t)
|
||
|
# pred noise
|
||
|
|
||
|
self.scheduler_sample.config.num_train_timesteps = t.item()
|
||
|
self.scheduler_sample.set_timesteps(t.item() // 50 + 1)
|
||
|
self.scheduler_lora_sample.config.num_train_timesteps = t.item()
|
||
|
self.scheduler_lora_sample.set_timesteps(t.item() // 50 + 1)
|
||
|
|
||
|
for i, timestep in enumerate(self.scheduler_sample.timesteps):
|
||
|
# for i, timestep in tqdm(enumerate(self.scheduler.timesteps)):
|
||
|
latent_model_input = torch.cat([latents_noisy] * 2, dim=0)
|
||
|
latent_model_input_lora = torch.cat([latents_noisy_lora] * 2, dim=0)
|
||
|
|
||
|
# print(latent_model_input.shape)
|
||
|
with self.disable_unet_class_embedding(self.unet) as unet:
|
||
|
cross_attention_kwargs = {"scale": 0.0} if self.single_model else None
|
||
|
noise_pred_pretrain = self.forward_unet(
|
||
|
unet,
|
||
|
latent_model_input,
|
||
|
timestep,
|
||
|
encoder_hidden_states=text_embeddings_vd,
|
||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||
|
)
|
||
|
|
||
|
# use view-independent text embeddings in LoRA
|
||
|
noise_pred_est = self.forward_unet(
|
||
|
self.unet_lora,
|
||
|
latent_model_input_lora,
|
||
|
timestep,
|
||
|
encoder_hidden_states=text_embeddings,
|
||
|
class_labels=torch.cat(
|
||
|
[
|
||
|
camera_condition.view(B, -1),
|
||
|
torch.zeros_like(camera_condition.view(B, -1)),
|
||
|
],
|
||
|
dim=0,
|
||
|
),
|
||
|
cross_attention_kwargs={"scale": 1.0},
|
||
|
)
|
||
|
|
||
|
(
|
||
|
noise_pred_pretrain_text,
|
||
|
noise_pred_pretrain_uncond,
|
||
|
) = noise_pred_pretrain.chunk(2)
|
||
|
|
||
|
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
|
||
|
noise_pred_pretrain = noise_pred_pretrain_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_pretrain_text - noise_pred_pretrain_uncond
|
||
|
)
|
||
|
if mask is not None:
|
||
|
noise_pred_pretrain = mask * noise_pred_pretrain + (1 - mask) * noise
|
||
|
|
||
|
(
|
||
|
noise_pred_est_text,
|
||
|
noise_pred_est_uncond,
|
||
|
) = noise_pred_est.chunk(2)
|
||
|
|
||
|
# NOTE: guidance scale definition here is aligned with diffusers, but different from other guidance
|
||
|
# noise_pred_est = noise_pred_est_uncond + self.cfg.guidance_scale_lora * (
|
||
|
# noise_pred_est_text - noise_pred_est_uncond
|
||
|
# )
|
||
|
noise_pred_est = noise_pred_est_text
|
||
|
if mask is not None:
|
||
|
noise_pred_est = mask * noise_pred_est + (1 - mask) * noise
|
||
|
|
||
|
latents_noisy = self.scheduler_sample.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample
|
||
|
latents_noisy_lora = self.scheduler_lora_sample.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample
|
||
|
|
||
|
# noise = torch.randn_like(latents)
|
||
|
# latents_noisy = self.scheduler.step(noise_pred_pretrain, timestep, latents_noisy).prev_sample
|
||
|
# latents_noisy = mask * latents_noisy + (1-mask) * latents
|
||
|
# latents_noisy = self.scheduler_sample.add_noise(latents_noisy, noise, timestep)
|
||
|
|
||
|
# latents_noisy_lora = self.scheduler_lora.step(noise_pred_est, timestep, latents_noisy_lora).prev_sample
|
||
|
# latents_noisy_lora = mask * latents_noisy_lora + (1-mask) * latents
|
||
|
# latents_noisy_lora = self.scheduler_lora_sample.add_noise(latents_noisy_lora, noise, timestep)
|
||
|
|
||
|
hifa_images = self.decode_latents(latents_noisy)
|
||
|
hifa_lora_images = self.decode_latents(latents_noisy_lora)
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
if mask is not None:
|
||
|
print('hifa mask!')
|
||
|
prefix = 'vsd_mask'
|
||
|
else:
|
||
|
prefix = ''
|
||
|
temp = (hifa_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||
|
cv2.imwrite(".threestudio_cache/%s%s_test.jpg" % (prefix, self.name), temp[:, :, ::-1])
|
||
|
temp = (hifa_lora_images.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||
|
cv2.imwrite(".threestudio_cache/%s%s_test_lora.jpg" % (prefix, self.name), temp[:, :, ::-1])
|
||
|
|
||
|
target = (latents_noisy - latents_noisy_lora + latents).detach()
|
||
|
# target = latents_noisy.detach()
|
||
|
targets_rgb = self.decode_latents(target)
|
||
|
# targets_rgb = (hifa_images - hifa_lora_images + rgb).detach()
|
||
|
temp = (targets_rgb.permute(0, 2, 3, 1).detach().cpu()[0].numpy() * 255).astype(np.uint8)
|
||
|
cv2.imwrite(".threestudio_cache/%s_target.jpg" % self.name, temp[:, :, ::-1])
|
||
|
|
||
|
return w * 0.5 * F.mse_loss(target, latents, reduction='sum')
|
||
|
|
||
|
def train_lora(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 64 64"],
|
||
|
text_embeddings: Float[Tensor, "BB 77 768"],
|
||
|
camera_condition: Float[Tensor, "B 4 4"],
|
||
|
):
|
||
|
B = latents.shape[0]
|
||
|
latents = latents.detach().repeat(self.cfg.lora_n_timestamp_samples, 1, 1, 1)
|
||
|
|
||
|
t = torch.randint(
|
||
|
int(self.num_train_timesteps * 0.0),
|
||
|
int(self.num_train_timesteps * 1.0),
|
||
|
[B * self.cfg.lora_n_timestamp_samples],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
noise = torch.randn_like(latents)
|
||
|
noisy_latents = self.scheduler_lora.add_noise(latents, noise, t)
|
||
|
if self.scheduler_lora.config.prediction_type == "epsilon":
|
||
|
target = noise
|
||
|
elif self.scheduler_lora.config.prediction_type == "v_prediction":
|
||
|
target = self.scheduler_lora.get_velocity(latents, noise, t)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown prediction type {self.scheduler_lora.config.prediction_type}"
|
||
|
)
|
||
|
# use view-independent text embeddings in LoRA
|
||
|
text_embeddings_cond, _ = text_embeddings.chunk(2)
|
||
|
if self.cfg.lora_cfg_training and random.random() < 0.1:
|
||
|
camera_condition = torch.zeros_like(camera_condition)
|
||
|
noise_pred = self.forward_unet(
|
||
|
self.unet_lora,
|
||
|
noisy_latents,
|
||
|
t,
|
||
|
encoder_hidden_states=text_embeddings_cond.repeat(
|
||
|
self.cfg.lora_n_timestamp_samples, 1, 1
|
||
|
),
|
||
|
class_labels=camera_condition.view(B, -1).repeat(
|
||
|
self.cfg.lora_n_timestamp_samples, 1
|
||
|
),
|
||
|
cross_attention_kwargs={"scale": 1.0},
|
||
|
)
|
||
|
return F.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||
|
|
||
|
def get_latents(
|
||
|
self, rgb_BCHW: Float[Tensor, "B C H W"], rgb_as_latents=False
|
||
|
) -> Float[Tensor, "B 4 64 64"]:
|
||
|
if rgb_as_latents:
|
||
|
latents = F.interpolate(
|
||
|
rgb_BCHW, (64, 64), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
else:
|
||
|
rgb_BCHW_512 = F.interpolate(
|
||
|
rgb_BCHW, (512, 512), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
# encode image into latents with vae
|
||
|
latents = self.encode_images(rgb_BCHW_512)
|
||
|
return latents
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
rgb: Float[Tensor, "B H W C"],
|
||
|
prompt_utils: PromptProcessorOutput,
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
mvp_mtx: Float[Tensor, "B 4 4"],
|
||
|
c2w: Float[Tensor, "B 4 4"],
|
||
|
rgb_as_latents=False,
|
||
|
mask: Float[Tensor, "B H W 1"] = None,
|
||
|
lora_prompt_utils = None,
|
||
|
**kwargs,
|
||
|
):
|
||
|
batch_size = rgb.shape[0]
|
||
|
|
||
|
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||
|
latents = self.get_latents(rgb_BCHW, rgb_as_latents=rgb_as_latents)
|
||
|
|
||
|
if mask is not None: mask = mask.permute(0, 3, 1, 2)
|
||
|
|
||
|
# view-dependent text embeddings
|
||
|
text_embeddings_vd = prompt_utils.get_text_embeddings(
|
||
|
elevation,
|
||
|
azimuth,
|
||
|
camera_distances,
|
||
|
view_dependent_prompting=self.cfg.view_dependent_prompting,
|
||
|
)
|
||
|
if lora_prompt_utils is not None:
|
||
|
# input text embeddings, view-independent
|
||
|
text_embeddings = lora_prompt_utils.get_text_embeddings(
|
||
|
elevation, azimuth, camera_distances, view_dependent_prompting=False
|
||
|
)
|
||
|
else:
|
||
|
# input text embeddings, view-independent
|
||
|
text_embeddings = prompt_utils.get_text_embeddings(
|
||
|
elevation, azimuth, camera_distances, view_dependent_prompting=False
|
||
|
)
|
||
|
|
||
|
if self.cfg.camera_condition_type == "extrinsics":
|
||
|
camera_condition = c2w
|
||
|
elif self.cfg.camera_condition_type == "mvp":
|
||
|
camera_condition = mvp_mtx
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown camera_condition_type {self.cfg.camera_condition_type}"
|
||
|
)
|
||
|
|
||
|
grad = self.compute_grad_vsd(
|
||
|
latents, text_embeddings_vd, text_embeddings, camera_condition
|
||
|
)
|
||
|
|
||
|
grad = torch.nan_to_num(grad)
|
||
|
# clip grad for stable training?
|
||
|
if self.grad_clip_val is not None:
|
||
|
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||
|
|
||
|
# reparameterization trick
|
||
|
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
||
|
target = (latents - grad).detach()
|
||
|
loss_vsd = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||
|
|
||
|
loss_lora = self.train_lora(latents, text_embeddings, camera_condition)
|
||
|
|
||
|
guidance_out = {
|
||
|
"loss_sd": loss_vsd,
|
||
|
"loss_lora": loss_lora,
|
||
|
"grad_norm": grad.norm(),
|
||
|
"min_step": self.min_step,
|
||
|
"max_step": self.max_step,
|
||
|
}
|
||
|
|
||
|
if self.cfg.use_du:
|
||
|
du_out = self.compute_grad_du(latents, rgb_BCHW, text_embeddings_vd, mask=mask)
|
||
|
guidance_out.update(du_out)
|
||
|
|
||
|
return guidance_out
|
||
|
|
||
|
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||
|
# clip grad for stable training as demonstrated in
|
||
|
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||
|
# http://arxiv.org/abs/2303.15413
|
||
|
if self.cfg.grad_clip is not None:
|
||
|
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||
|
self.global_step = global_step
|
||
|
self.set_min_max_steps(
|
||
|
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||
|
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||
|
)
|