mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
159 lines
5.9 KiB
Python
159 lines
5.9 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import threestudio
|
|
from threestudio.models.background.base import BaseBackground
|
|
from threestudio.models.geometry.base import BaseImplicitGeometry
|
|
from threestudio.models.materials.base import BaseMaterial
|
|
from threestudio.models.renderers.base import VolumeRenderer
|
|
from threestudio.utils.GAN.discriminator import NLayerDiscriminator, weights_init
|
|
from threestudio.utils.GAN.distribution import DiagonalGaussianDistribution
|
|
from threestudio.utils.GAN.mobilenet import MobileNetV3 as GlobalEncoder
|
|
from threestudio.utils.GAN.vae import Decoder as Generator
|
|
from threestudio.utils.GAN.vae import Encoder as LocalEncoder
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
@threestudio.register("gan-volume-renderer")
|
|
class GANVolumeRenderer(VolumeRenderer):
|
|
@dataclass
|
|
class Config(VolumeRenderer.Config):
|
|
base_renderer_type: str = ""
|
|
base_renderer: Optional[VolumeRenderer.Config] = None
|
|
|
|
cfg: Config
|
|
|
|
def configure(
|
|
self,
|
|
geometry: BaseImplicitGeometry,
|
|
material: BaseMaterial,
|
|
background: BaseBackground,
|
|
) -> None:
|
|
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
|
|
self.cfg.base_renderer,
|
|
geometry=geometry,
|
|
material=material,
|
|
background=background,
|
|
)
|
|
self.ch_mult = [1, 2, 4]
|
|
self.generator = Generator(
|
|
ch=64,
|
|
out_ch=3,
|
|
ch_mult=self.ch_mult,
|
|
num_res_blocks=1,
|
|
attn_resolutions=[],
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
in_channels=7,
|
|
resolution=512,
|
|
z_channels=4,
|
|
)
|
|
self.local_encoder = LocalEncoder(
|
|
ch=32,
|
|
out_ch=3,
|
|
ch_mult=self.ch_mult,
|
|
num_res_blocks=1,
|
|
attn_resolutions=[],
|
|
dropout=0.0,
|
|
resamp_with_conv=True,
|
|
in_channels=3,
|
|
resolution=512,
|
|
z_channels=4,
|
|
)
|
|
self.global_encoder = GlobalEncoder(n_class=64)
|
|
self.discriminator = NLayerDiscriminator(
|
|
input_nc=3, n_layers=3, use_actnorm=False, ndf=64
|
|
).apply(weights_init)
|
|
|
|
def forward(
|
|
self,
|
|
rays_o: Float[Tensor, "B H W 3"],
|
|
rays_d: Float[Tensor, "B H W 3"],
|
|
light_positions: Float[Tensor, "B 3"],
|
|
bg_color: Optional[Tensor] = None,
|
|
gt_rgb: Float[Tensor, "B H W 3"] = None,
|
|
multi_level_guidance: Bool = False,
|
|
**kwargs
|
|
) -> Dict[str, Float[Tensor, "..."]]:
|
|
B, H, W, _ = rays_o.shape
|
|
if gt_rgb is not None and multi_level_guidance:
|
|
generator_level = torch.randint(0, 3, (1,)).item()
|
|
interval_x = torch.randint(0, 8, (1,)).item()
|
|
interval_y = torch.randint(0, 8, (1,)).item()
|
|
int_rays_o = rays_o[:, interval_y::8, interval_x::8]
|
|
int_rays_d = rays_d[:, interval_y::8, interval_x::8]
|
|
out = self.base_renderer(
|
|
int_rays_o, int_rays_d, light_positions, bg_color, **kwargs
|
|
)
|
|
comp_int_rgb = out["comp_rgb"][..., :3]
|
|
comp_gt_rgb = gt_rgb[:, interval_y::8, interval_x::8]
|
|
else:
|
|
generator_level = 0
|
|
scale_ratio = 2 ** (len(self.ch_mult) - 1)
|
|
rays_o = torch.nn.functional.interpolate(
|
|
rays_o.permute(0, 3, 1, 2),
|
|
(H // scale_ratio, W // scale_ratio),
|
|
mode="bilinear",
|
|
).permute(0, 2, 3, 1)
|
|
rays_d = torch.nn.functional.interpolate(
|
|
rays_d.permute(0, 3, 1, 2),
|
|
(H // scale_ratio, W // scale_ratio),
|
|
mode="bilinear",
|
|
).permute(0, 2, 3, 1)
|
|
|
|
out = self.base_renderer(rays_o, rays_d, light_positions, bg_color, **kwargs)
|
|
comp_rgb = out["comp_rgb"][..., :3]
|
|
latent = out["comp_rgb"][..., 3:]
|
|
out["comp_lr_rgb"] = comp_rgb.clone()
|
|
|
|
posterior = DiagonalGaussianDistribution(latent.permute(0, 3, 1, 2))
|
|
if multi_level_guidance:
|
|
z_map = posterior.sample()
|
|
else:
|
|
z_map = posterior.mode()
|
|
lr_rgb = comp_rgb.permute(0, 3, 1, 2)
|
|
|
|
if generator_level == 0:
|
|
g_code_rgb = self.global_encoder(F.interpolate(lr_rgb, (224, 224)))
|
|
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
|
|
elif generator_level == 1:
|
|
g_code_rgb = self.global_encoder(
|
|
F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
|
|
)
|
|
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
|
|
elif generator_level == 2:
|
|
g_code_rgb = self.global_encoder(
|
|
F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
|
|
)
|
|
l_code_rgb = self.local_encoder(gt_rgb.permute(0, 3, 1, 2))
|
|
posterior = DiagonalGaussianDistribution(l_code_rgb)
|
|
z_map = posterior.sample()
|
|
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
|
|
|
|
comp_rgb = F.interpolate(comp_rgb.permute(0, 3, 1, 2), (H, W), mode="bilinear")
|
|
comp_gan_rgb = F.interpolate(comp_gan_rgb, (H, W), mode="bilinear")
|
|
out.update(
|
|
{
|
|
"posterior": posterior,
|
|
"comp_gan_rgb": comp_gan_rgb.permute(0, 2, 3, 1),
|
|
"comp_rgb": comp_rgb.permute(0, 2, 3, 1),
|
|
"generator_level": generator_level,
|
|
}
|
|
)
|
|
|
|
if gt_rgb is not None and multi_level_guidance:
|
|
out.update({"comp_int_rgb": comp_int_rgb, "comp_gt_rgb": comp_gt_rgb})
|
|
return out
|
|
|
|
def update_step(
|
|
self, epoch: int, global_step: int, on_load_weights: bool = False
|
|
) -> None:
|
|
self.base_renderer.update_step(epoch, global_step, on_load_weights)
|
|
|
|
def train(self, mode=True):
|
|
return self.base_renderer.train(mode)
|
|
|
|
def eval(self):
|
|
return self.base_renderer.eval() |