mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
106 lines
3.7 KiB
Python
106 lines
3.7 KiB
Python
from dataclasses import dataclass
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
import threestudio
|
|
from threestudio.models.background.base import BaseBackground
|
|
from threestudio.models.geometry.base import BaseImplicitGeometry
|
|
from threestudio.models.materials.base import BaseMaterial
|
|
from threestudio.models.renderers.base import VolumeRenderer
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
@threestudio.register("patch-renderer")
|
|
class PatchRenderer(VolumeRenderer):
|
|
@dataclass
|
|
class Config(VolumeRenderer.Config):
|
|
patch_size: int = 128
|
|
base_renderer_type: str = ""
|
|
base_renderer: Optional[VolumeRenderer.Config] = None
|
|
global_detach: bool = False
|
|
global_downsample: int = 4
|
|
|
|
cfg: Config
|
|
|
|
def configure(
|
|
self,
|
|
geometry: BaseImplicitGeometry,
|
|
material: BaseMaterial,
|
|
background: BaseBackground,
|
|
) -> None:
|
|
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
|
|
self.cfg.base_renderer,
|
|
geometry=geometry,
|
|
material=material,
|
|
background=background,
|
|
)
|
|
|
|
def forward(
|
|
self,
|
|
rays_o: Float[Tensor, "B H W 3"],
|
|
rays_d: Float[Tensor, "B H W 3"],
|
|
light_positions: Float[Tensor, "B 3"],
|
|
bg_color: Optional[Tensor] = None,
|
|
**kwargs
|
|
) -> Dict[str, Float[Tensor, "..."]]:
|
|
B, H, W, _ = rays_o.shape
|
|
|
|
if self.base_renderer.training:
|
|
downsample = self.cfg.global_downsample
|
|
global_rays_o = torch.nn.functional.interpolate(
|
|
rays_o.permute(0, 3, 1, 2),
|
|
(H // downsample, W // downsample),
|
|
mode="bilinear",
|
|
).permute(0, 2, 3, 1)
|
|
global_rays_d = torch.nn.functional.interpolate(
|
|
rays_d.permute(0, 3, 1, 2),
|
|
(H // downsample, W // downsample),
|
|
mode="bilinear",
|
|
).permute(0, 2, 3, 1)
|
|
out_global = self.base_renderer(
|
|
global_rays_o, global_rays_d, light_positions, bg_color, **kwargs
|
|
)
|
|
|
|
PS = self.cfg.patch_size
|
|
patch_x = torch.randint(0, W - PS, (1,)).item()
|
|
patch_y = torch.randint(0, H - PS, (1,)).item()
|
|
patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
|
|
patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
|
|
out = self.base_renderer(
|
|
patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs
|
|
)
|
|
|
|
valid_patch_key = []
|
|
for key in out:
|
|
if torch.is_tensor(out[key]):
|
|
if len(out[key].shape) == len(out["comp_rgb"].shape):
|
|
if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape:
|
|
valid_patch_key.append(key)
|
|
for key in valid_patch_key:
|
|
out_global[key] = F.interpolate(
|
|
out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear"
|
|
).permute(0, 2, 3, 1)
|
|
if self.cfg.global_detach:
|
|
out_global[key] = out_global[key].detach()
|
|
out_global[key][
|
|
:, patch_y : patch_y + PS, patch_x : patch_x + PS
|
|
] = out[key]
|
|
out = out_global
|
|
else:
|
|
out = self.base_renderer(
|
|
rays_o, rays_d, light_positions, bg_color, **kwargs
|
|
)
|
|
|
|
return out
|
|
|
|
def update_step(
|
|
self, epoch: int, global_step: int, on_load_weights: bool = False
|
|
) -> None:
|
|
self.base_renderer.update_step(epoch, global_step, on_load_weights)
|
|
|
|
def train(self, mode=True):
|
|
return self.base_renderer.train(mode)
|
|
|
|
def eval(self):
|
|
return self.base_renderer.eval() |