mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
528 lines
18 KiB
Python
528 lines
18 KiB
Python
|
import importlib
|
||
|
import os
|
||
|
from dataclasses import dataclass, field
|
||
|
|
||
|
import cv2
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
from diffusers import DDIMScheduler, DDPMScheduler, StableDiffusionPipeline
|
||
|
from diffusers.utils.import_utils import is_xformers_available
|
||
|
from omegaconf import OmegaConf
|
||
|
from tqdm import tqdm
|
||
|
|
||
|
import threestudio
|
||
|
from threestudio.utils.base import BaseObject
|
||
|
from threestudio.utils.misc import C, parse_version
|
||
|
from threestudio.utils.typing import *
|
||
|
|
||
|
|
||
|
def get_obj_from_str(string, reload=False):
|
||
|
module, cls = string.rsplit(".", 1)
|
||
|
if reload:
|
||
|
module_imp = importlib.import_module(module)
|
||
|
importlib.reload(module_imp)
|
||
|
return getattr(importlib.import_module(module, package=None), cls)
|
||
|
|
||
|
|
||
|
def instantiate_from_config(config):
|
||
|
if not "target" in config:
|
||
|
if config == "__is_first_stage__":
|
||
|
return None
|
||
|
elif config == "__is_unconditional__":
|
||
|
return None
|
||
|
raise KeyError("Expected key `target` to instantiate.")
|
||
|
return get_obj_from_str(config["target"])(**config.get("params", dict()))
|
||
|
|
||
|
|
||
|
# load model
|
||
|
def load_model_from_config(config, ckpt, device, vram_O=True, verbose=False):
|
||
|
pl_sd = torch.load(ckpt, map_location="cpu")
|
||
|
|
||
|
if "global_step" in pl_sd and verbose:
|
||
|
print(f'[INFO] Global Step: {pl_sd["global_step"]}')
|
||
|
|
||
|
sd = pl_sd["state_dict"]
|
||
|
|
||
|
model = instantiate_from_config(config.model)
|
||
|
m, u = model.load_state_dict(sd, strict=False)
|
||
|
|
||
|
if len(m) > 0 and verbose:
|
||
|
print("[INFO] missing keys: \n", m)
|
||
|
if len(u) > 0 and verbose:
|
||
|
print("[INFO] unexpected keys: \n", u)
|
||
|
|
||
|
# manually load ema and delete it to save GPU memory
|
||
|
if model.use_ema:
|
||
|
if verbose:
|
||
|
print("[INFO] loading EMA...")
|
||
|
model.model_ema.copy_to(model.model)
|
||
|
del model.model_ema
|
||
|
|
||
|
if vram_O:
|
||
|
# we don't need decoder
|
||
|
del model.first_stage_model.decoder
|
||
|
|
||
|
torch.cuda.empty_cache()
|
||
|
|
||
|
model.eval().to(device)
|
||
|
|
||
|
return model
|
||
|
|
||
|
|
||
|
@threestudio.register("zero123-guidance")
|
||
|
class Zero123Guidance(BaseObject):
|
||
|
@dataclass
|
||
|
class Config(BaseObject.Config):
|
||
|
pretrained_model_name_or_path: str = "load/zero123/105000.ckpt"
|
||
|
pretrained_config: str = "load/zero123/sd-objaverse-finetune-c_concat-256.yaml"
|
||
|
vram_O: bool = True
|
||
|
|
||
|
cond_image_path: str = "load/images/hamburger_rgba.png"
|
||
|
cond_elevation_deg: float = 0.0
|
||
|
cond_azimuth_deg: float = 0.0
|
||
|
cond_camera_distance: float = 1.2
|
||
|
|
||
|
guidance_scale: float = 5.0
|
||
|
|
||
|
grad_clip: Optional[
|
||
|
Any
|
||
|
] = None # field(default_factory=lambda: [0, 2.0, 8.0, 1000])
|
||
|
half_precision_weights: bool = False
|
||
|
|
||
|
min_step_percent: float = 0.02
|
||
|
max_step_percent: float = 0.98
|
||
|
|
||
|
"""Maximum number of batch items to evaluate guidance for (for debugging) and to save on disk. -1 means save all items."""
|
||
|
max_items_eval: int = 4
|
||
|
|
||
|
cfg: Config
|
||
|
|
||
|
def configure(self) -> None:
|
||
|
threestudio.info(f"Loading Zero123 ...")
|
||
|
|
||
|
self.config = OmegaConf.load(self.cfg.pretrained_config)
|
||
|
# TODO: seems it cannot load into fp16...
|
||
|
self.weights_dtype = torch.float32
|
||
|
self.model = load_model_from_config(
|
||
|
self.config,
|
||
|
self.cfg.pretrained_model_name_or_path,
|
||
|
device=self.device,
|
||
|
vram_O=self.cfg.vram_O,
|
||
|
)
|
||
|
|
||
|
for p in self.model.parameters():
|
||
|
p.requires_grad_(False)
|
||
|
|
||
|
# timesteps: use diffuser for convenience... hope it's alright.
|
||
|
self.num_train_timesteps = self.config.model.params.timesteps
|
||
|
|
||
|
self.scheduler = DDIMScheduler(
|
||
|
self.num_train_timesteps,
|
||
|
self.config.model.params.linear_start,
|
||
|
self.config.model.params.linear_end,
|
||
|
beta_schedule="scaled_linear",
|
||
|
clip_sample=False,
|
||
|
set_alpha_to_one=False,
|
||
|
steps_offset=1,
|
||
|
)
|
||
|
|
||
|
self.num_train_timesteps = self.scheduler.config.num_train_timesteps
|
||
|
self.set_min_max_steps() # set to default value
|
||
|
|
||
|
self.alphas: Float[Tensor, "..."] = self.scheduler.alphas_cumprod.to(
|
||
|
self.device
|
||
|
)
|
||
|
|
||
|
self.grad_clip_val: Optional[float] = None
|
||
|
|
||
|
self.prepare_embeddings(self.cfg.cond_image_path)
|
||
|
|
||
|
threestudio.info(f"Loaded Zero123!")
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def set_min_max_steps(self, min_step_percent=0.02, max_step_percent=0.98):
|
||
|
self.min_step = int(self.num_train_timesteps * min_step_percent)
|
||
|
self.max_step = int(self.num_train_timesteps * max_step_percent)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def prepare_embeddings(self, image_path: str) -> None:
|
||
|
# load cond image for zero123
|
||
|
assert os.path.exists(image_path)
|
||
|
rgba = cv2.cvtColor(
|
||
|
cv2.imread(image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
|
||
|
)
|
||
|
rgba = (
|
||
|
cv2.resize(rgba, (256, 256), interpolation=cv2.INTER_AREA).astype(
|
||
|
np.float32
|
||
|
)
|
||
|
/ 255.0
|
||
|
)
|
||
|
rgb = rgba[..., :3] * rgba[..., 3:] + (1 - rgba[..., 3:])
|
||
|
self.rgb_256: Float[Tensor, "1 3 H W"] = (
|
||
|
torch.from_numpy(rgb)
|
||
|
.unsqueeze(0)
|
||
|
.permute(0, 3, 1, 2)
|
||
|
.contiguous()
|
||
|
.to(self.device)
|
||
|
)
|
||
|
self.c_crossattn, self.c_concat = self.get_img_embeds(self.rgb_256)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
@torch.no_grad()
|
||
|
def get_img_embeds(
|
||
|
self,
|
||
|
img: Float[Tensor, "B 3 256 256"],
|
||
|
) -> Tuple[Float[Tensor, "B 1 768"], Float[Tensor, "B 4 32 32"]]:
|
||
|
img = img * 2.0 - 1.0
|
||
|
c_crossattn = self.model.get_learned_conditioning(img.to(self.weights_dtype))
|
||
|
c_concat = self.model.encode_first_stage(img.to(self.weights_dtype)).mode()
|
||
|
return c_crossattn, c_concat
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def encode_images(
|
||
|
self, imgs: Float[Tensor, "B 3 256 256"]
|
||
|
) -> Float[Tensor, "B 4 32 32"]:
|
||
|
input_dtype = imgs.dtype
|
||
|
imgs = imgs * 2.0 - 1.0
|
||
|
latents = self.model.get_first_stage_encoding(
|
||
|
self.model.encode_first_stage(imgs.to(self.weights_dtype))
|
||
|
)
|
||
|
return latents.to(input_dtype) # [B, 4, 32, 32] Latent space image
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
def decode_latents(
|
||
|
self,
|
||
|
latents: Float[Tensor, "B 4 H W"],
|
||
|
) -> Float[Tensor, "B 3 512 512"]:
|
||
|
input_dtype = latents.dtype
|
||
|
image = self.model.decode_first_stage(latents)
|
||
|
image = (image * 0.5 + 0.5).clamp(0, 1)
|
||
|
return image.to(input_dtype)
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
@torch.no_grad()
|
||
|
def get_cond(
|
||
|
self,
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
c_crossattn=None,
|
||
|
c_concat=None,
|
||
|
**kwargs,
|
||
|
) -> dict:
|
||
|
T = torch.stack(
|
||
|
[
|
||
|
torch.deg2rad(
|
||
|
(90 - elevation) - (90 - self.cfg.cond_elevation_deg)
|
||
|
), # Zero123 polar is 90-elevation
|
||
|
torch.sin(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)),
|
||
|
torch.cos(torch.deg2rad(azimuth - self.cfg.cond_azimuth_deg)),
|
||
|
camera_distances - self.cfg.cond_camera_distance,
|
||
|
],
|
||
|
dim=-1,
|
||
|
)[:, None, :].to(self.device)
|
||
|
cond = {}
|
||
|
clip_emb = self.model.cc_projection(
|
||
|
torch.cat(
|
||
|
[
|
||
|
(self.c_crossattn if c_crossattn is None else c_crossattn).repeat(
|
||
|
len(T), 1, 1
|
||
|
),
|
||
|
T,
|
||
|
],
|
||
|
dim=-1,
|
||
|
)
|
||
|
)
|
||
|
cond["c_crossattn"] = [
|
||
|
torch.cat([torch.zeros_like(clip_emb).to(self.device), clip_emb], dim=0)
|
||
|
]
|
||
|
cond["c_concat"] = [
|
||
|
torch.cat(
|
||
|
[
|
||
|
torch.zeros_like(self.c_concat)
|
||
|
.repeat(len(T), 1, 1, 1)
|
||
|
.to(self.device),
|
||
|
(self.c_concat if c_concat is None else c_concat).repeat(
|
||
|
len(T), 1, 1, 1
|
||
|
),
|
||
|
],
|
||
|
dim=0,
|
||
|
)
|
||
|
]
|
||
|
return cond
|
||
|
|
||
|
def __call__(
|
||
|
self,
|
||
|
rgb: Float[Tensor, "B H W C"],
|
||
|
elevation: Float[Tensor, "B"],
|
||
|
azimuth: Float[Tensor, "B"],
|
||
|
camera_distances: Float[Tensor, "B"],
|
||
|
rgb_as_latents=False,
|
||
|
guidance_eval=False,
|
||
|
**kwargs,
|
||
|
):
|
||
|
batch_size = rgb.shape[0]
|
||
|
|
||
|
rgb_BCHW = rgb.permute(0, 3, 1, 2)
|
||
|
latents: Float[Tensor, "B 4 64 64"]
|
||
|
if rgb_as_latents:
|
||
|
latents = (
|
||
|
F.interpolate(rgb_BCHW, (32, 32), mode="bilinear", align_corners=False)
|
||
|
* 2
|
||
|
- 1
|
||
|
)
|
||
|
else:
|
||
|
rgb_BCHW_512 = F.interpolate(
|
||
|
rgb_BCHW, (256, 256), mode="bilinear", align_corners=False
|
||
|
)
|
||
|
# encode image into latents with vae
|
||
|
latents = self.encode_images(rgb_BCHW_512)
|
||
|
|
||
|
cond = self.get_cond(elevation, azimuth, camera_distances)
|
||
|
|
||
|
# timestep ~ U(0.02, 0.98) to avoid very high/low noise level
|
||
|
t = torch.randint(
|
||
|
self.min_step,
|
||
|
self.max_step + 1,
|
||
|
[batch_size],
|
||
|
dtype=torch.long,
|
||
|
device=self.device,
|
||
|
)
|
||
|
|
||
|
# predict the noise residual with unet, NO grad!
|
||
|
with torch.no_grad():
|
||
|
# add noise
|
||
|
noise = torch.randn_like(latents) # TODO: use torch generator
|
||
|
latents_noisy = self.scheduler.add_noise(latents, noise, t)
|
||
|
# pred noise
|
||
|
x_in = torch.cat([latents_noisy] * 2)
|
||
|
t_in = torch.cat([t] * 2)
|
||
|
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||
|
|
||
|
# perform guidance
|
||
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_cond - noise_pred_uncond
|
||
|
)
|
||
|
|
||
|
w = (1 - self.alphas[t]).reshape(-1, 1, 1, 1)
|
||
|
grad = w * (noise_pred - noise)
|
||
|
grad = torch.nan_to_num(grad)
|
||
|
# clip grad for stable training?
|
||
|
if self.grad_clip_val is not None:
|
||
|
grad = grad.clamp(-self.grad_clip_val, self.grad_clip_val)
|
||
|
|
||
|
# loss = SpecifyGradient.apply(latents, grad)
|
||
|
# SpecifyGradient is not straghtforward, use a reparameterization trick instead
|
||
|
target = (latents - grad).detach()
|
||
|
# d(loss)/d(latents) = latents - target = latents - (latents - grad) = grad
|
||
|
loss_sds = 0.5 * F.mse_loss(latents, target, reduction="sum") / batch_size
|
||
|
|
||
|
guidance_out = {
|
||
|
"loss_sd": loss_sds, # loss_sds
|
||
|
"grad_norm": grad.norm(),
|
||
|
"min_step": self.min_step,
|
||
|
"max_step": self.max_step,
|
||
|
}
|
||
|
|
||
|
if guidance_eval:
|
||
|
guidance_eval_utils = {
|
||
|
"cond": cond,
|
||
|
"t_orig": t,
|
||
|
"latents_noisy": latents_noisy,
|
||
|
"noise_pred": noise_pred,
|
||
|
}
|
||
|
guidance_eval_out = self.guidance_eval(**guidance_eval_utils)
|
||
|
texts = []
|
||
|
for n, e, a, c in zip(
|
||
|
guidance_eval_out["noise_levels"], elevation, azimuth, camera_distances
|
||
|
):
|
||
|
texts.append(
|
||
|
f"n{n:.02f}\ne{e.item():.01f}\na{a.item():.01f}\nc{c.item():.02f}"
|
||
|
)
|
||
|
guidance_eval_out.update({"texts": texts})
|
||
|
guidance_out.update({"eval": guidance_eval_out})
|
||
|
|
||
|
return guidance_out
|
||
|
|
||
|
@torch.cuda.amp.autocast(enabled=False)
|
||
|
@torch.no_grad()
|
||
|
def guidance_eval(self, cond, t_orig, latents_noisy, noise_pred):
|
||
|
# use only 50 timesteps, and find nearest of those to t
|
||
|
self.scheduler.set_timesteps(50)
|
||
|
self.scheduler.timesteps_gpu = self.scheduler.timesteps.to(self.device)
|
||
|
bs = (
|
||
|
min(self.cfg.max_items_eval, latents_noisy.shape[0])
|
||
|
if self.cfg.max_items_eval > 0
|
||
|
else latents_noisy.shape[0]
|
||
|
) # batch size
|
||
|
large_enough_idxs = self.scheduler.timesteps_gpu.expand([bs, -1]) > t_orig[
|
||
|
:bs
|
||
|
].unsqueeze(
|
||
|
-1
|
||
|
) # sized [bs,50] > [bs,1]
|
||
|
idxs = torch.min(large_enough_idxs, dim=1)[1]
|
||
|
t = self.scheduler.timesteps_gpu[idxs]
|
||
|
|
||
|
fracs = list((t / self.scheduler.config.num_train_timesteps).cpu().numpy())
|
||
|
imgs_noisy = self.decode_latents(latents_noisy[:bs]).permute(0, 2, 3, 1)
|
||
|
|
||
|
# get prev latent
|
||
|
latents_1step = []
|
||
|
pred_1orig = []
|
||
|
for b in range(bs):
|
||
|
step_output = self.scheduler.step(
|
||
|
noise_pred[b : b + 1], t[b], latents_noisy[b : b + 1], eta=1
|
||
|
)
|
||
|
latents_1step.append(step_output["prev_sample"])
|
||
|
pred_1orig.append(step_output["pred_original_sample"])
|
||
|
latents_1step = torch.cat(latents_1step)
|
||
|
pred_1orig = torch.cat(pred_1orig)
|
||
|
imgs_1step = self.decode_latents(latents_1step).permute(0, 2, 3, 1)
|
||
|
imgs_1orig = self.decode_latents(pred_1orig).permute(0, 2, 3, 1)
|
||
|
|
||
|
latents_final = []
|
||
|
for b, i in enumerate(idxs):
|
||
|
latents = latents_1step[b : b + 1]
|
||
|
c = {
|
||
|
"c_crossattn": [cond["c_crossattn"][0][[b, b + len(idxs)], ...]],
|
||
|
"c_concat": [cond["c_concat"][0][[b, b + len(idxs)], ...]],
|
||
|
}
|
||
|
for t in tqdm(self.scheduler.timesteps[i + 1 :], leave=False):
|
||
|
# pred noise
|
||
|
x_in = torch.cat([latents] * 2)
|
||
|
t_in = torch.cat([t.reshape(1)] * 2).to(self.device)
|
||
|
noise_pred = self.model.apply_model(x_in, t_in, c)
|
||
|
# perform guidance
|
||
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + self.cfg.guidance_scale * (
|
||
|
noise_pred_cond - noise_pred_uncond
|
||
|
)
|
||
|
# get prev latent
|
||
|
latents = self.scheduler.step(noise_pred, t, latents, eta=1)[
|
||
|
"prev_sample"
|
||
|
]
|
||
|
latents_final.append(latents)
|
||
|
|
||
|
latents_final = torch.cat(latents_final)
|
||
|
imgs_final = self.decode_latents(latents_final).permute(0, 2, 3, 1)
|
||
|
|
||
|
return {
|
||
|
"bs": bs,
|
||
|
"noise_levels": fracs,
|
||
|
"imgs_noisy": imgs_noisy,
|
||
|
"imgs_1step": imgs_1step,
|
||
|
"imgs_1orig": imgs_1orig,
|
||
|
"imgs_final": imgs_final,
|
||
|
}
|
||
|
|
||
|
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
||
|
# clip grad for stable training as demonstrated in
|
||
|
# Debiasing Scores and Prompts of 2D Diffusion for Robust Text-to-3D Generation
|
||
|
# http://arxiv.org/abs/2303.15413
|
||
|
if self.cfg.grad_clip is not None:
|
||
|
self.grad_clip_val = C(self.cfg.grad_clip, epoch, global_step)
|
||
|
|
||
|
self.set_min_max_steps(
|
||
|
min_step_percent=C(self.cfg.min_step_percent, epoch, global_step),
|
||
|
max_step_percent=C(self.cfg.max_step_percent, epoch, global_step),
|
||
|
)
|
||
|
|
||
|
# verification - requires `vram_O = False` in load_model_from_config
|
||
|
@torch.no_grad()
|
||
|
def generate(
|
||
|
self,
|
||
|
image, # image tensor [1, 3, H, W] in [0, 1]
|
||
|
elevation=0,
|
||
|
azimuth=0,
|
||
|
camera_distances=0, # new view params
|
||
|
c_crossattn=None,
|
||
|
c_concat=None,
|
||
|
scale=3,
|
||
|
ddim_steps=50,
|
||
|
post_process=True,
|
||
|
ddim_eta=1,
|
||
|
):
|
||
|
if c_crossattn is None:
|
||
|
c_crossattn, c_concat = self.get_img_embeds(image)
|
||
|
|
||
|
cond = self.get_cond(
|
||
|
elevation, azimuth, camera_distances, c_crossattn, c_concat
|
||
|
)
|
||
|
|
||
|
imgs = self.gen_from_cond(cond, scale, ddim_steps, post_process, ddim_eta)
|
||
|
|
||
|
return imgs
|
||
|
|
||
|
# verification - requires `vram_O = False` in load_model_from_config
|
||
|
@torch.no_grad()
|
||
|
def gen_from_cond(
|
||
|
self,
|
||
|
cond,
|
||
|
scale=3,
|
||
|
ddim_steps=50,
|
||
|
post_process=True,
|
||
|
ddim_eta=1,
|
||
|
):
|
||
|
# produce latents loop
|
||
|
B = cond["c_crossattn"][0].shape[0] // 2
|
||
|
latents = torch.randn((B, 4, 32, 32), device=self.device)
|
||
|
self.scheduler.set_timesteps(ddim_steps)
|
||
|
|
||
|
for t in self.scheduler.timesteps:
|
||
|
x_in = torch.cat([latents] * 2)
|
||
|
t_in = torch.cat([t.reshape(1).repeat(B)] * 2).to(self.device)
|
||
|
|
||
|
noise_pred = self.model.apply_model(x_in, t_in, cond)
|
||
|
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
||
|
noise_pred = noise_pred_uncond + scale * (
|
||
|
noise_pred_cond - noise_pred_uncond
|
||
|
)
|
||
|
|
||
|
latents = self.scheduler.step(noise_pred, t, latents, eta=ddim_eta)[
|
||
|
"prev_sample"
|
||
|
]
|
||
|
|
||
|
imgs = self.decode_latents(latents)
|
||
|
imgs = imgs.cpu().numpy().transpose(0, 2, 3, 1) if post_process else imgs
|
||
|
|
||
|
return imgs
|
||
|
|
||
|
|
||
|
if __name__ == '__main__':
|
||
|
from threestudio.utils.config import load_config
|
||
|
import pytorch_lightning as pl
|
||
|
import numpy as np
|
||
|
import os
|
||
|
import cv2
|
||
|
cfg = load_config("configs/experimental/zero123.yaml")
|
||
|
guidance = threestudio.find(cfg.system.guidance_type)(cfg.system.guidance)
|
||
|
elevations = [0, 20, -20]
|
||
|
azimuths = [45, 90, 135, -45, -90]
|
||
|
radius = torch.tensor([3.8]).to(guidance.device)
|
||
|
outdir = ".threestudio_cache/saiyan"
|
||
|
os.makedirs(outdir, exist_ok=True)
|
||
|
# rgb_image = (rgb_image[0].detach().cpu().clip(0, 1).numpy()*255).astype(np.uint8)[:, :, ::-1].copy()
|
||
|
# os.makedirs('.threestudio_cache', exist_ok=True)
|
||
|
# cv2.imwrite('.threestudio_cache/diffusion_image.jpg', rgb_image)
|
||
|
|
||
|
|
||
|
rgb_image = cv2.imread(cfg.system.guidance.cond_image_path)[:, :, ::-1].copy() / 255
|
||
|
rgb_image = cv2.resize(rgb_image, (256, 256))
|
||
|
rgb_image = torch.FloatTensor(rgb_image).unsqueeze(0).to(guidance.device).permute(0,3,1,2)
|
||
|
|
||
|
for elevation in elevations:
|
||
|
for azimuth in azimuths:
|
||
|
output1 = guidance.generate(
|
||
|
rgb_image,
|
||
|
torch.tensor([elevation]).to(guidance.device),
|
||
|
torch.tensor([azimuth]).to(guidance.device),
|
||
|
radius,
|
||
|
c_crossattn=guidance.c_crossattn,
|
||
|
c_concat=guidance.c_concat
|
||
|
)
|
||
|
from torchvision.utils import save_image
|
||
|
save_image(torch.tensor(output1).float().permute(0,3,1,2), f"{outdir}/result_e_{elevation}_a_{azimuth}.png", normalize=True, value_range=(0,1))
|
||
|
|