DreamCraft3D/threestudio/models/geometry/implicit_volume.py

325 lines
13 KiB
Python
Raw Normal View History

2023-12-12 11:17:53 -05:00
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import (
BaseGeometry,
BaseImplicitGeometry,
contract_to_unisphere,
)
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("implicit-volume")
class ImplicitVolume(BaseImplicitGeometry):
@dataclass
class Config(BaseImplicitGeometry.Config):
n_input_dims: int = 3
n_feature_dims: int = 3
density_activation: Optional[str] = "softplus"
density_bias: Union[float, str] = "blob_magic3d"
density_blob_scale: float = 10.0
density_blob_std: float = 0.5
pos_encoding_config: dict = field(
default_factory=lambda: {
"otype": "HashGrid",
"n_levels": 16,
"n_features_per_level": 2,
"log2_hashmap_size": 19,
"base_resolution": 16,
"per_level_scale": 1.447269237440378,
}
)
mlp_network_config: dict = field(
default_factory=lambda: {
"otype": "VanillaMLP",
"activation": "ReLU",
"output_activation": "none",
"n_neurons": 64,
"n_hidden_layers": 1,
}
)
normal_type: Optional[
str
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
finite_difference_normal_eps: Union[
float, str
] = 0.01 # in [float, "progressive"]
# automatically determine the threshold
isosurface_threshold: Union[float, str] = 25.0
# 4D Gaussian Annealing
anneal_density_blob_std_config: Optional[dict] = None
cfg: Config
def configure(self) -> None:
super().configure()
self.encoding = get_encoding(
self.cfg.n_input_dims, self.cfg.pos_encoding_config
)
self.density_network = get_mlp(
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
)
if self.cfg.n_feature_dims > 0:
self.feature_network = get_mlp(
self.encoding.n_output_dims,
self.cfg.n_feature_dims,
self.cfg.mlp_network_config,
)
if self.cfg.normal_type == "pred":
self.normal_network = get_mlp(
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
)
self.finite_difference_normal_eps: Optional[float] = None
def get_activated_density(
self, points: Float[Tensor, "*N Di"], density: Float[Tensor, "*N 1"]
) -> Tuple[Float[Tensor, "*N 1"], Float[Tensor, "*N 1"]]:
density_bias: Union[float, Float[Tensor, "*N 1"]]
if self.cfg.density_bias == "blob_dreamfusion":
# pre-activation density bias
density_bias = (
self.cfg.density_blob_scale
* torch.exp(
-0.5 * (points**2).sum(dim=-1) / self.cfg.density_blob_std**2
)[..., None]
)
elif self.cfg.density_bias == "blob_magic3d":
# pre-activation density bias
density_bias = (
self.cfg.density_blob_scale
* (
1
- torch.sqrt((points**2).sum(dim=-1)) / self.cfg.density_blob_std
)[..., None]
)
elif isinstance(self.cfg.density_bias, float):
density_bias = self.cfg.density_bias
else:
raise ValueError(f"Unknown density bias {self.cfg.density_bias}")
raw_density: Float[Tensor, "*N 1"] = density + density_bias
density = get_activation(self.cfg.density_activation)(raw_density)
return raw_density, density
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
grad_enabled = torch.is_grad_enabled()
if output_normal and self.cfg.normal_type == "analytic":
torch.set_grad_enabled(True)
points.requires_grad_(True)
points_unscaled = points # points in the original scale
points = contract_to_unisphere(
points, self.bbox, self.unbounded
) # points normalized to (0, 1)
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
density = self.density_network(enc).view(*points.shape[:-1], 1)
raw_density, density = self.get_activated_density(points_unscaled, density)
output = {
"density": density,
}
if self.cfg.n_feature_dims > 0:
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
output.update({"features": features})
if output_normal:
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
# TODO: use raw density
assert self.finite_difference_normal_eps is not None
eps: float = self.finite_difference_normal_eps
if self.cfg.normal_type == "finite_difference_laplacian":
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
[
[eps, 0.0, 0.0],
[-eps, 0.0, 0.0],
[0.0, eps, 0.0],
[0.0, -eps, 0.0],
[0.0, 0.0, eps],
[0.0, 0.0, -eps],
]
).to(points_unscaled)
points_offset: Float[Tensor, "... 6 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
points_offset
)
normal = (
-0.5
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
/ eps
)
else:
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(points_unscaled)
points_offset: Float[Tensor, "... 3 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
points_offset
)
normal = -(density_offset[..., 0::1, 0] - density) / eps
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "pred":
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "analytic":
normal = -torch.autograd.grad(
density,
points_unscaled,
grad_outputs=torch.ones_like(density),
create_graph=True,
)[0]
normal = F.normalize(normal, dim=-1)
if not grad_enabled:
normal = normal.detach()
else:
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
output.update({"normal": normal, "shading_normal": normal})
torch.set_grad_enabled(grad_enabled)
return output
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
density = self.density_network(
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
).reshape(*points.shape[:-1], 1)
_, density = self.get_activated_density(points_unscaled, density)
return density
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
if self.cfg.isosurface_deformable_grid:
threestudio.warn(
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
)
density = self.forward_density(points)
return density, None
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
return -(field - threshold)
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
features = self.feature_network(enc).view(
*points.shape[:-1], self.cfg.n_feature_dims
)
out.update(
{
"features": features,
}
)
return out
@staticmethod
@torch.no_grad()
def create_from(
other: BaseGeometry,
cfg: Optional[Union[dict, DictConfig]] = None,
copy_net: bool = True,
**kwargs,
) -> "ImplicitVolume":
if isinstance(other, ImplicitVolume):
instance = ImplicitVolume(cfg, **kwargs)
instance.encoding.load_state_dict(other.encoding.state_dict())
instance.density_network.load_state_dict(other.density_network.state_dict())
if copy_net:
if (
instance.cfg.n_feature_dims > 0
and other.cfg.n_feature_dims == instance.cfg.n_feature_dims
):
instance.feature_network.load_state_dict(
other.feature_network.state_dict()
)
if (
instance.cfg.normal_type == "pred"
and other.cfg.normal_type == "pred"
):
instance.normal_network.load_state_dict(
other.normal_network.state_dict()
)
return instance
else:
raise TypeError(
f"Cannot create {ImplicitVolume.__name__} from {other.__class__.__name__}"
)
# FIXME: use progressive normal eps
def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
if self.cfg.anneal_density_blob_std_config is not None:
min_step = self.cfg.anneal_density_blob_std_config.min_anneal_step
max_step = self.cfg.anneal_density_blob_std_config.max_anneal_step
if global_step >= min_step and global_step <= max_step:
end_val = self.cfg.anneal_density_blob_std_config.end_val
start_val = self.cfg.anneal_density_blob_std_config.start_val
self.density_blob_std = start_val + (global_step - min_step) * (
end_val - start_val
) / (max_step - min_step)
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
if isinstance(self.cfg.finite_difference_normal_eps, float):
self.finite_difference_normal_eps = (
self.cfg.finite_difference_normal_eps
)
elif self.cfg.finite_difference_normal_eps == "progressive":
# progressive finite difference eps from Neuralangelo
# https://arxiv.org/abs/2306.03092
hg_conf: Any = self.cfg.pos_encoding_config
assert (
hg_conf.otype == "ProgressiveBandHashGrid"
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
current_level = min(
hg_conf.start_level
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
hg_conf.n_levels,
)
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
current_level - 1
)
grid_size = 2 * self.cfg.radius / grid_res
if grid_size != self.finite_difference_normal_eps:
threestudio.info(
f"Update finite_difference_normal_eps to {grid_size}"
)
self.finite_difference_normal_eps = grid_size
else:
raise ValueError(
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
)