mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
413 lines
16 KiB
Python
413 lines
16 KiB
Python
import os
|
|
from dataclasses import dataclass, field
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import threestudio
|
|
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
|
|
from threestudio.models.mesh import Mesh
|
|
from threestudio.models.networks import get_encoding, get_mlp
|
|
from threestudio.utils.misc import broadcast, get_rank
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
@threestudio.register("implicit-sdf")
|
|
class ImplicitSDF(BaseImplicitGeometry):
|
|
@dataclass
|
|
class Config(BaseImplicitGeometry.Config):
|
|
n_input_dims: int = 3
|
|
n_feature_dims: int = 3
|
|
pos_encoding_config: dict = field(
|
|
default_factory=lambda: {
|
|
"otype": "HashGrid",
|
|
"n_levels": 16,
|
|
"n_features_per_level": 2,
|
|
"log2_hashmap_size": 19,
|
|
"base_resolution": 16,
|
|
"per_level_scale": 1.447269237440378,
|
|
}
|
|
)
|
|
mlp_network_config: dict = field(
|
|
default_factory=lambda: {
|
|
"otype": "VanillaMLP",
|
|
"activation": "ReLU",
|
|
"output_activation": "none",
|
|
"n_neurons": 64,
|
|
"n_hidden_layers": 1,
|
|
}
|
|
)
|
|
normal_type: Optional[
|
|
str
|
|
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
|
|
finite_difference_normal_eps: Union[
|
|
float, str
|
|
] = 0.01 # in [float, "progressive"]
|
|
shape_init: Optional[str] = None
|
|
shape_init_params: Optional[Any] = None
|
|
shape_init_mesh_up: str = "+z"
|
|
shape_init_mesh_front: str = "+x"
|
|
force_shape_init: bool = False
|
|
sdf_bias: Union[float, str] = 0.0
|
|
sdf_bias_params: Optional[Any] = None
|
|
|
|
# no need to removal outlier for SDF
|
|
isosurface_remove_outliers: bool = False
|
|
|
|
cfg: Config
|
|
|
|
def configure(self) -> None:
|
|
super().configure()
|
|
self.encoding = get_encoding(
|
|
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
|
)
|
|
self.sdf_network = get_mlp(
|
|
self.encoding.n_output_dims, 1, self.cfg.mlp_network_config
|
|
)
|
|
|
|
if self.cfg.n_feature_dims > 0:
|
|
self.feature_network = get_mlp(
|
|
self.encoding.n_output_dims,
|
|
self.cfg.n_feature_dims,
|
|
self.cfg.mlp_network_config,
|
|
)
|
|
|
|
if self.cfg.normal_type == "pred":
|
|
self.normal_network = get_mlp(
|
|
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
|
)
|
|
if self.cfg.isosurface_deformable_grid:
|
|
assert (
|
|
self.cfg.isosurface_method == "mt"
|
|
), "isosurface_deformable_grid only works with mt"
|
|
self.deformation_network = get_mlp(
|
|
self.encoding.n_output_dims, 3, self.cfg.mlp_network_config
|
|
)
|
|
|
|
self.finite_difference_normal_eps: Optional[float] = None
|
|
|
|
def initialize_shape(self) -> None:
|
|
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
|
|
return
|
|
|
|
# do not initialize shape if weights are provided
|
|
if self.cfg.weights is not None and not self.cfg.force_shape_init:
|
|
return
|
|
|
|
if self.cfg.sdf_bias != 0.0:
|
|
threestudio.warn(
|
|
"shape_init and sdf_bias are both specified, which may lead to unexpected results."
|
|
)
|
|
|
|
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
|
|
assert isinstance(self.cfg.shape_init, str)
|
|
if self.cfg.shape_init == "ellipsoid":
|
|
assert (
|
|
isinstance(self.cfg.shape_init_params, Sized)
|
|
and len(self.cfg.shape_init_params) == 3
|
|
)
|
|
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
|
|
|
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
|
return ((points_rand / size) ** 2).sum(
|
|
dim=-1, keepdim=True
|
|
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
|
|
|
get_gt_sdf = func
|
|
elif self.cfg.shape_init == "sphere":
|
|
assert isinstance(self.cfg.shape_init_params, float)
|
|
radius = self.cfg.shape_init_params
|
|
|
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
|
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
|
|
|
get_gt_sdf = func
|
|
elif self.cfg.shape_init.startswith("mesh:"):
|
|
assert isinstance(self.cfg.shape_init_params, float)
|
|
mesh_path = self.cfg.shape_init[5:]
|
|
if not os.path.exists(mesh_path):
|
|
raise ValueError(f"Mesh file {mesh_path} does not exist.")
|
|
|
|
import trimesh
|
|
|
|
scene = trimesh.load(mesh_path)
|
|
if isinstance(scene, trimesh.Trimesh):
|
|
mesh = scene
|
|
elif isinstance(scene, trimesh.scene.Scene):
|
|
mesh = trimesh.Trimesh()
|
|
for obj in scene.geometry.values():
|
|
mesh = trimesh.util.concatenate([mesh, obj])
|
|
else:
|
|
raise ValueError(f"Unknown mesh type at {mesh_path}.")
|
|
|
|
# move to center
|
|
centroid = mesh.vertices.mean(0)
|
|
mesh.vertices = mesh.vertices - centroid
|
|
|
|
# align to up-z and front-x
|
|
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
|
dir2vec = {
|
|
"+x": np.array([1, 0, 0]),
|
|
"+y": np.array([0, 1, 0]),
|
|
"+z": np.array([0, 0, 1]),
|
|
"-x": np.array([-1, 0, 0]),
|
|
"-y": np.array([0, -1, 0]),
|
|
"-z": np.array([0, 0, -1]),
|
|
}
|
|
if (
|
|
self.cfg.shape_init_mesh_up not in dirs
|
|
or self.cfg.shape_init_mesh_front not in dirs
|
|
):
|
|
raise ValueError(
|
|
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
|
)
|
|
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
|
|
raise ValueError(
|
|
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
|
)
|
|
z_, x_ = (
|
|
dir2vec[self.cfg.shape_init_mesh_up],
|
|
dir2vec[self.cfg.shape_init_mesh_front],
|
|
)
|
|
y_ = np.cross(z_, x_)
|
|
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
|
mesh2std = np.linalg.inv(std2mesh)
|
|
|
|
# scaling
|
|
scale = np.abs(mesh.vertices).max()
|
|
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
|
|
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
|
|
|
from pysdf import SDF
|
|
|
|
sdf = SDF(mesh.vertices, mesh.faces)
|
|
|
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
|
# add a negative signed here
|
|
# as in pysdf the inside of the shape has positive signed distance
|
|
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
|
|
points_rand
|
|
)[..., None]
|
|
|
|
get_gt_sdf = func
|
|
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown shape initialization type: {self.cfg.shape_init}"
|
|
)
|
|
|
|
# Initialize SDF to a given shape when no weights are provided or force_shape_init is True
|
|
optim = torch.optim.Adam(self.parameters(), lr=1e-3)
|
|
from tqdm import tqdm
|
|
|
|
for _ in tqdm(
|
|
range(1000),
|
|
desc=f"Initializing SDF to a(n) {self.cfg.shape_init}:",
|
|
disable=get_rank() != 0,
|
|
):
|
|
points_rand = (
|
|
torch.rand((10000, 3), dtype=torch.float32).to(self.device) * 2.0 - 1.0
|
|
)
|
|
sdf_gt = get_gt_sdf(points_rand)
|
|
sdf_pred = self.forward_sdf(points_rand)
|
|
loss = F.mse_loss(sdf_pred, sdf_gt)
|
|
optim.zero_grad()
|
|
loss.backward()
|
|
optim.step()
|
|
|
|
# explicit broadcast to ensure param consistency across ranks
|
|
for param in self.parameters():
|
|
broadcast(param, src=0)
|
|
|
|
def get_shifted_sdf(
|
|
self, points: Float[Tensor, "*N Di"], sdf: Float[Tensor, "*N 1"]
|
|
) -> Float[Tensor, "*N 1"]:
|
|
sdf_bias: Union[float, Float[Tensor, "*N 1"]]
|
|
if self.cfg.sdf_bias == "ellipsoid":
|
|
assert (
|
|
isinstance(self.cfg.sdf_bias_params, Sized)
|
|
and len(self.cfg.sdf_bias_params) == 3
|
|
)
|
|
size = torch.as_tensor(self.cfg.sdf_bias_params).to(points)
|
|
sdf_bias = ((points / size) ** 2).sum(
|
|
dim=-1, keepdim=True
|
|
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
|
elif self.cfg.sdf_bias == "sphere":
|
|
assert isinstance(self.cfg.sdf_bias_params, float)
|
|
radius = self.cfg.sdf_bias_params
|
|
sdf_bias = (points**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
|
elif isinstance(self.cfg.sdf_bias, float):
|
|
sdf_bias = self.cfg.sdf_bias
|
|
else:
|
|
raise ValueError(f"Unknown sdf bias {self.cfg.sdf_bias}")
|
|
return sdf + sdf_bias
|
|
|
|
def forward(
|
|
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
|
) -> Dict[str, Float[Tensor, "..."]]:
|
|
grad_enabled = torch.is_grad_enabled()
|
|
|
|
if output_normal and self.cfg.normal_type == "analytic":
|
|
torch.set_grad_enabled(True)
|
|
points.requires_grad_(True)
|
|
|
|
points_unscaled = points # points in the original scale
|
|
points = contract_to_unisphere(
|
|
points, self.bbox, self.unbounded
|
|
) # points normalized to (0, 1)
|
|
|
|
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
|
sdf = self.sdf_network(enc).view(*points.shape[:-1], 1)
|
|
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
|
output = {"sdf": sdf}
|
|
|
|
if self.cfg.n_feature_dims > 0:
|
|
features = self.feature_network(enc).view(
|
|
*points.shape[:-1], self.cfg.n_feature_dims
|
|
)
|
|
output.update({"features": features})
|
|
|
|
if output_normal:
|
|
if (
|
|
self.cfg.normal_type == "finite_difference"
|
|
or self.cfg.normal_type == "finite_difference_laplacian"
|
|
):
|
|
assert self.finite_difference_normal_eps is not None
|
|
eps: float = self.finite_difference_normal_eps
|
|
if self.cfg.normal_type == "finite_difference_laplacian":
|
|
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
|
|
[
|
|
[eps, 0.0, 0.0],
|
|
[-eps, 0.0, 0.0],
|
|
[0.0, eps, 0.0],
|
|
[0.0, -eps, 0.0],
|
|
[0.0, 0.0, eps],
|
|
[0.0, 0.0, -eps],
|
|
]
|
|
).to(points_unscaled)
|
|
points_offset: Float[Tensor, "... 6 3"] = (
|
|
points_unscaled[..., None, :] + offsets
|
|
).clamp(-self.cfg.radius, self.cfg.radius)
|
|
sdf_offset: Float[Tensor, "... 6 1"] = self.forward_sdf(
|
|
points_offset
|
|
)
|
|
sdf_grad = (
|
|
0.5
|
|
* (sdf_offset[..., 0::2, 0] - sdf_offset[..., 1::2, 0])
|
|
/ eps
|
|
)
|
|
else:
|
|
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
|
|
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
|
|
).to(points_unscaled)
|
|
points_offset: Float[Tensor, "... 3 3"] = (
|
|
points_unscaled[..., None, :] + offsets
|
|
).clamp(-self.cfg.radius, self.cfg.radius)
|
|
sdf_offset: Float[Tensor, "... 3 1"] = self.forward_sdf(
|
|
points_offset
|
|
)
|
|
sdf_grad = (sdf_offset[..., 0::1, 0] - sdf) / eps
|
|
normal = F.normalize(sdf_grad, dim=-1)
|
|
elif self.cfg.normal_type == "pred":
|
|
normal = self.normal_network(enc).view(*points.shape[:-1], 3)
|
|
normal = F.normalize(normal, dim=-1)
|
|
sdf_grad = normal
|
|
elif self.cfg.normal_type == "analytic":
|
|
sdf_grad = -torch.autograd.grad(
|
|
sdf,
|
|
points_unscaled,
|
|
grad_outputs=torch.ones_like(sdf),
|
|
create_graph=True,
|
|
)[0]
|
|
normal = F.normalize(sdf_grad, dim=-1)
|
|
if not grad_enabled:
|
|
sdf_grad = sdf_grad.detach()
|
|
normal = normal.detach()
|
|
else:
|
|
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
|
|
output.update(
|
|
{"normal": normal, "shading_normal": normal, "sdf_grad": sdf_grad}
|
|
)
|
|
return output
|
|
|
|
def forward_sdf(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
|
|
points_unscaled = points
|
|
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
|
|
|
sdf = self.sdf_network(
|
|
self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
|
).reshape(*points.shape[:-1], 1)
|
|
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
|
return sdf
|
|
|
|
def forward_field(
|
|
self, points: Float[Tensor, "*N Di"]
|
|
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
|
|
points_unscaled = points
|
|
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
|
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
|
sdf = self.sdf_network(enc).reshape(*points.shape[:-1], 1)
|
|
sdf = self.get_shifted_sdf(points_unscaled, sdf)
|
|
deformation: Optional[Float[Tensor, "*N 3"]] = None
|
|
if self.cfg.isosurface_deformable_grid:
|
|
deformation = self.deformation_network(enc).reshape(*points.shape[:-1], 3)
|
|
return sdf, deformation
|
|
|
|
def forward_level(
|
|
self, field: Float[Tensor, "*N 1"], threshold: float
|
|
) -> Float[Tensor, "*N 1"]:
|
|
return field - threshold
|
|
|
|
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
|
out: Dict[str, Any] = {}
|
|
if self.cfg.n_feature_dims == 0:
|
|
return out
|
|
points_unscaled = points
|
|
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
|
|
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
|
features = self.feature_network(enc).view(
|
|
*points.shape[:-1], self.cfg.n_feature_dims
|
|
)
|
|
out.update(
|
|
{
|
|
"features": features,
|
|
}
|
|
)
|
|
return out
|
|
|
|
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
|
|
if (
|
|
self.cfg.normal_type == "finite_difference"
|
|
or self.cfg.normal_type == "finite_difference_laplacian"
|
|
):
|
|
if isinstance(self.cfg.finite_difference_normal_eps, float):
|
|
self.finite_difference_normal_eps = (
|
|
self.cfg.finite_difference_normal_eps
|
|
)
|
|
elif self.cfg.finite_difference_normal_eps == "progressive":
|
|
# progressive finite difference eps from Neuralangelo
|
|
# https://arxiv.org/abs/2306.03092
|
|
hg_conf: Any = self.cfg.pos_encoding_config
|
|
assert (
|
|
hg_conf.otype == "ProgressiveBandHashGrid"
|
|
), "finite_difference_normal_eps=progressive only works with ProgressiveBandHashGrid"
|
|
current_level = min(
|
|
hg_conf.start_level
|
|
+ max(global_step - hg_conf.start_step, 0) // hg_conf.update_steps,
|
|
hg_conf.n_levels,
|
|
)
|
|
grid_res = hg_conf.base_resolution * hg_conf.per_level_scale ** (
|
|
current_level - 1
|
|
)
|
|
grid_size = 2 * self.cfg.radius / grid_res
|
|
if grid_size != self.finite_difference_normal_eps:
|
|
threestudio.info(
|
|
f"Update finite_difference_normal_eps to {grid_size}"
|
|
)
|
|
self.finite_difference_normal_eps = grid_size
|
|
else:
|
|
raise ValueError(
|
|
f"Unknown finite_difference_normal_eps={self.cfg.finite_difference_normal_eps}"
|
|
) |