mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 06:18:56 -05:00
325 lines
13 KiB
Python
325 lines
13 KiB
Python
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import threestudio
|
|
from threestudio.models.geometry.base import (
|
|
BaseGeometry,
|
|
BaseImplicitGeometry,
|
|
contract_to_unisphere,
|
|
)
|
|
from threestudio.models.networks import get_encoding, get_mlp
|
|
from threestudio.utils.ops import get_activation
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
@threestudio.register("implicit-volume")
|
|
class ImplicitVolume(BaseImplicitGeometry):
|
|
@dataclass
|
|
class Config(BaseImplicitGeometry.Config):
|
|
n_input_dims: int = 3
|
|
n_feature_dims: int = 3
|
|
density_activation: Optional[str] = "softplus"
|
|
density_bias: Union[float, str] = "blob_magic3d"
|
|
density_blob_scale: float = 10.0
|
|
density_blob_std: float = 0.5
|
|
pos_encoding_config: dict = field(
|
|
default_factory=lambda: {
|
|
"otype": "HashGrid",
|
|
"n_levels": 16,
|
|
"n_features_per_level": 2,
|
|
"log2_hashmap_size": 19,
|
|
"base_resolution": 16,
|
|
"per_level_scale": 1.447269237440378,
|
|
}
|
|
)
|
|
mlp_network_config: dict = field(
|
|
default_factory=lambda: {
|
|
"otype": "VanillaMLP",
|
|
"activation": "ReLU",
|
|
"output_activation": "none",
|
|
"n_neurons": 64,
|
|
"n_hidden_layers": 1,
|
|
}
|
|
)
|
|
normal_type: Optional[
|
|
str
|
|
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
|
|
finite_difference_normal_eps: Union[
|
|
float, str
|
|
] = 0.01 # in [float, "progressive"]
|
|
|
|
# automatically determine the threshold
|
|
isosurface_threshold: Union[float, str] = 25.0
|
|
|
|
# 4D Gaussian Annealing
|
|
anneal_density_blob_std_config: Optional[dict] = None
|
|
|
|
cfg: Config
|
|
|
|
def configure(self) -> None:
|
|
super().configure()
|
|
self.encoding = get_encoding(
|
|
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
|
)
|
|
self.density_network = get_mlp(
|
|
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
|
|
)
|
|
if self.cfg.n_feature_dims > 0:
|
|
self.feature_network = get_mlp(
|
|
self.encoding.n_output_dims,
|
|
self.cfg.n_feature_dims,
|
|
self.cfg.mlp_network_config,
|
|
)
|
|
if self.cfg.normal_type == "pred":
|
|
self.normal_network = get_mlp(
|
|
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
|
)
|
|
|
|
self.finite_difference_normal_eps: Optional[float] = None
|
|
|
|
def get_activated_density(
|
|
self, points: Float[Tensor, "*N Di"], density: Float[Tensor, "*N 1"]
|
|
) -> Tuple[Float[Tensor, "*N 1"], Float[Tensor, "*N 1"]]:
|
|
density_bias: Union[float, Float[Tensor, "*N 1"]]
|
|
if self.cfg.density_bias == "blob_dreamfusion":
|
|
# pre-activation density bias
|
|
density_bias = (
|
|
self.cfg.density_blob_scale
|
|
* torch.exp(
|
|
-0.5 * (points**2).sum(dim=-1) / self.cfg.density_blob_std**2
|
|
)[..., None]
|
|
)
|
|
elif self.cfg.density_bias == "blob_magic3d":
|
|
# pre-activation density bias
|
|
density_bias = (
|
|
self.cfg.density_blob_scale
|
|
* (
|
|
1
|
|
- torch.sqrt((points**2).sum(dim=-1)) / self.cfg.density_blob_std
|
|
)[..., None]
|
|
)
|
|
elif isinstance(self.cfg.density_bias, float):
|
|
density_bias = self.cfg.density_bias
|
|
else:
|
|
raise ValueError(f"Unknown density bias {self.cfg.density_bias}")
|
|
raw_density: Float[Tensor, "*N 1"] = density + density_bias
|
|
density = get_activation(self.cfg.density_activation)(raw_density)
|
|
return raw_density, density
|
|
|
|
def forward(
|
|
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
|
) -> Dict[str, Float[Tensor, "..."]]:
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
if output_normal and self.cfg.normal_type == "analytic":
|
|
torch.set_grad_enabled(True)
|
|
points.requires_grad_(True)
|
|
|
|
points_unscaled = points # points in the original scale
|
|
points = contract_to_unisphere(
|
|
points, self.bbox, self.unbounded
|
|
) # points normalized to (0, 1)
|
|
|
|
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
|
density = self.density_network(enc).view(*points.shape[:-1], 1)
|
|
raw_density, density = self.get_activated_density(points_unscaled, density)
|
|
|
|
output = {
|
|
"density": density,
|
|
}
|
|
|
|
if self.cfg.n_feature_dims > 0:
|
|
features = self.feature_network(enc).view(
|
|
*points.shape[:-1], self.cfg.n_feature_dims
|
|
)
|
|
output.update({"features": features})
|
|
|
|
if output_normal:
|
|
if (
|
|
self.cfg.normal_type == "finite_difference"
|
|
or self.cfg.normal_type == "finite_difference_laplacian"
|
|
):
|
|
# TODO: use raw density
|
|
assert self.finite_difference_normal_eps is not None
|
|
eps: float = self.finite_difference_normal_eps
|
|
if self.cfg.normal_type == "finite_difference_laplacian":
|
|
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
|
|
[
|
|
[eps, 0.0, 0.0],
|
|
[-eps, 0.0, 0.0],
|
|
[0.0, eps, 0.0],
|
|
[0.0, -eps, 0.0],
|
|
[0.0, 0.0, eps],
|
|
[0.0, 0.0, -eps],
|
|
]
|
|
).to(points_unscaled)
|
|
points_offset: Float[Tensor, "... 6 3"] = (
|
|
points_unscaled[..., None, :] + offsets
|
|
).clamp(-self.cfg.radius, self.cfg.radius)
|
|
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
|
|
points_offset
|
|
)
|
|
normal = (
|
|
-0.5
|
|
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
|
|
/ eps
|
|
)
|
|
else:
|
|
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
|
|
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
|
).to(points_unscaled)
|
|
points_offset: Float[Tensor, "... 3 3"] = (
|
|
points_unscaled[..., None, :] + offsets
|
|
).clamp(-self.cfg.radius, self.cfg.radius)
|
|
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
|
|
points_offset
|
|
)
|
|
normal = -(density_offset[..., 0::1, 0] - density) / eps
|
|
normal = F.normalize(normal, dim=-1)
|
|
elif self.cfg.normal_type == "pred":
|
|
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
|
|
normal = F.normalize(normal, dim=-1)
|
|
elif self.cfg.normal_type == "analytic":
|
|
normal = -torch.autograd.grad(
|
|
density,
|
|
points_unscaled,
|
|
grad_outputs=torch.ones_like(density),
|
|
create_graph=True,
|
|
)[0]
|
|
normal = F.normalize(normal, dim=-1)
|
|
if not grad_enabled:
|
|
normal = normal.detach()
|
|
else:
|
|
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
|
|
output.update({"normal": normal, "shading_normal": normal})
|
|
|
|
torch.set_grad_enabled(grad_enabled)
|
|
return output
|
|
|
|
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
|
|
points_unscaled = points
|
|
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
|
|
|
density = self.density_network(
|
|
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
|
).reshape(*points.shape[:-1], 1)
|
|
|
|
_, density = self.get_activated_density(points_unscaled, density)
|
|
return density
|
|
|
|
def forward_field(
|
|
self, points: Float[Tensor, "*N Di"]
|
|
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
|
if self.cfg.isosurface_deformable_grid:
|
|
threestudio.warn(
|
|
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
|
|
)
|
|
density = self.forward_density(points)
|
|
return density, None
|
|
|
|
def forward_level(
|
|
self, field: Float[Tensor, "*N 1"], threshold: float
|
|
) -> Float[Tensor, "*N 1"]:
|
|
return -(field - threshold)
|
|
|
|
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
|
out: Dict[str, Any] = {}
|
|
if self.cfg.n_feature_dims == 0:
|
|
return out
|
|
points_unscaled = points
|
|
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
|
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
|
features = self.feature_network(enc).view(
|
|
*points.shape[:-1], self.cfg.n_feature_dims
|
|
)
|
|
out.update(
|
|
{
|
|
"features": features,
|
|
}
|
|
)
|
|
return out
|
|
|
|
@staticmethod
|
|
@torch.no_grad()
|
|
def create_from(
|
|
other: BaseGeometry,
|
|
cfg: Optional[Union[dict, DictConfig]] = None,
|
|
copy_net: bool = True,
|
|
**kwargs,
|
|
) -> "ImplicitVolume":
|
|
if isinstance(other, ImplicitVolume):
|
|
instance = ImplicitVolume(cfg, **kwargs)
|
|
instance.encoding.load_state_dict(other.encoding.state_dict())
|
|
instance.density_network.load_state_dict(other.density_network.state_dict())
|
|
if copy_net:
|
|
if (
|
|
instance.cfg.n_feature_dims > 0
|
|
and other.cfg.n_feature_dims == instance.cfg.n_feature_dims
|
|
):
|
|
instance.feature_network.load_state_dict(
|
|
other.feature_network.state_dict()
|
|
)
|
|
if (
|
|
instance.cfg.normal_type == "pred"
|
|
and other.cfg.normal_type == "pred"
|
|
):
|
|
instance.normal_network.load_state_dict(
|
|
other.normal_network.state_dict()
|
|
)
|
|
return instance
|
|
else:
|
|
raise TypeError(
|
|
f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}"
|
|
)
|
|
|
|
# FIXME: use progressive normal eps
|
|
def update_step(
|
|
self, epoch: int, global_step: int, on_load_weights: bool = False
|
|
) -> None:
|
|
if self.cfg.anneal_density_blob_std_config is not None:
|
|
min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step
|
|
max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step
|
|
if global_step >= min_step and global_step <= max_step:
|
|
end_val = self.cfg.anneal_density_blob_std_config.end_val
|
|
start_val = self.cfg.anneal_density_blob_std_config.start_val
|
|
self.density_blob_std = start_val + (global_step - min_step) * (
|
|
end_val - start_val
|
|
) / (max_step - min_step)
|
|
|
|
if (
|
|
self.cfg.normal_type == "finite_difference"
|
|
or self.cfg.normal_type == "finite_difference_laplacian"
|
|
):
|
|
if isinstance(self.cfg.finite_difference_normal_eps, float):
|
|
self.finite_difference_normal_eps = (
|
|
self.cfg.finite_difference_normal_eps
|
|
)
|
|
elif self.cfg.finite_difference_normal_eps == "progressive":
|
|
# progressive finite difference eps from Neuralangelo
|
|
# https://arxiv.org/abs/2306.03092
|
|
hg_conf: Any = self.cfg.pos_encoding_config
|
|
assert (
|
|
hg_conf.otype == "ProgressiveBandHashGrid"
|
|
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
|
|
current_level = min(
|
|
hg_conf.start_level
|
|
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
|
|
hg_conf.n_levels,
|
|
)
|
|
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
|
|
current_level - 1
|
|
)
|
|
grid_size = 2 * self.cfg.radius / grid_res
|
|
if grid_size != self.finite_difference_normal_eps:
|
|
threestudio.info(
|
|
f"Update finite_difference_normal_eps to {grid_size}"
|
|
)
|
|
self.finite_difference_normal_eps = grid_size
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
|
|
) |