mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
369 lines
14 KiB
Python
369 lines
14 KiB
Python
|
import os
|
||
|
from dataclasses import dataclass, field
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import threestudio
|
||
|
from threestudio.models.geometry.base import (
|
||
|
BaseExplicitGeometry,
|
||
|
BaseGeometry,
|
||
|
contract_to_unisphere,
|
||
|
)
|
||
|
from threestudio.models.geometry.implicit_sdf import ImplicitSDF
|
||
|
from threestudio.models.geometry.implicit_volume import ImplicitVolume
|
||
|
from threestudio.models.isosurface import MarchingTetrahedraHelper
|
||
|
from threestudio.models.mesh import Mesh
|
||
|
from threestudio.models.networks import get_encoding, get_mlp
|
||
|
from threestudio.utils.misc import broadcast
|
||
|
from threestudio.utils.ops import scale_tensor
|
||
|
from threestudio.utils.typing import *
|
||
|
|
||
|
|
||
|
@threestudio.register("tetrahedra-sdf-grid")
|
||
|
class TetrahedraSDFGrid(BaseExplicitGeometry):
|
||
|
@dataclass
|
||
|
class Config(BaseExplicitGeometry.Config):
|
||
|
isosurface_resolution: int = 128
|
||
|
isosurface_deformable_grid: bool = True
|
||
|
isosurface_remove_outliers: bool = False
|
||
|
isosurface_outlier_n_faces_threshold: Union[int, float] = 0.01
|
||
|
|
||
|
n_input_dims: int = 3
|
||
|
n_feature_dims: int = 3
|
||
|
pos_encoding_config: dict = field(
|
||
|
default_factory=lambda: {
|
||
|
"otype": "HashGrid",
|
||
|
"n_levels": 16,
|
||
|
"n_features_per_level": 2,
|
||
|
"log2_hashmap_size": 19,
|
||
|
"base_resolution": 16,
|
||
|
"per_level_scale": 1.447269237440378,
|
||
|
}
|
||
|
)
|
||
|
mlp_network_config: dict = field(
|
||
|
default_factory=lambda: {
|
||
|
"otype": "VanillaMLP",
|
||
|
"activation": "ReLU",
|
||
|
"output_activation": "none",
|
||
|
"n_neurons": 64,
|
||
|
"n_hidden_layers": 1,
|
||
|
}
|
||
|
)
|
||
|
shape_init: Optional[str] = None
|
||
|
shape_init_params: Optional[Any] = None
|
||
|
shape_init_mesh_up: str = "+z"
|
||
|
shape_init_mesh_front: str = "+x"
|
||
|
force_shape_init: bool = False
|
||
|
geometry_only: bool = False
|
||
|
fix_geometry: bool = False
|
||
|
|
||
|
cfg: Config
|
||
|
|
||
|
def configure(self) -> None:
|
||
|
super().configure()
|
||
|
|
||
|
# this should be saved to state_dict, register as buffer
|
||
|
self.isosurface_bbox: Float[Tensor, "2 3"]
|
||
|
self.register_buffer("isosurface_bbox", self.bbox.clone())
|
||
|
|
||
|
self.isosurface_helper = MarchingTetrahedraHelper(
|
||
|
self.cfg.isosurface_resolution,
|
||
|
f"load/tets/{self.cfg.isosurface_resolution}_tets.npz",
|
||
|
)
|
||
|
|
||
|
self.sdf: Float[Tensor, "Nv 1"]
|
||
|
self.deformation: Optional[Float[Tensor, "Nv 3"]]
|
||
|
|
||
|
if not self.cfg.fix_geometry:
|
||
|
self.register_parameter(
|
||
|
"sdf",
|
||
|
nn.Parameter(
|
||
|
torch.zeros(
|
||
|
(self.isosurface_helper.grid_vertices.shape[0], 1),
|
||
|
dtype=torch.float32,
|
||
|
)
|
||
|
),
|
||
|
)
|
||
|
if self.cfg.isosurface_deformable_grid:
|
||
|
self.register_parameter(
|
||
|
"deformation",
|
||
|
nn.Parameter(
|
||
|
torch.zeros_like(self.isosurface_helper.grid_vertices)
|
||
|
),
|
||
|
)
|
||
|
else:
|
||
|
self.deformation = None
|
||
|
else:
|
||
|
self.register_buffer(
|
||
|
"sdf",
|
||
|
torch.zeros(
|
||
|
(self.isosurface_helper.grid_vertices.shape[0], 1),
|
||
|
dtype=torch.float32,
|
||
|
),
|
||
|
)
|
||
|
if self.cfg.isosurface_deformable_grid:
|
||
|
self.register_buffer(
|
||
|
"deformation",
|
||
|
torch.zeros_like(self.isosurface_helper.grid_vertices),
|
||
|
)
|
||
|
else:
|
||
|
self.deformation = None
|
||
|
|
||
|
if not self.cfg.geometry_only:
|
||
|
self.encoding = get_encoding(
|
||
|
self.cfg.n_input_dims, self.cfg.pos_encoding_config
|
||
|
)
|
||
|
self.feature_network = get_mlp(
|
||
|
self.encoding.n_output_dims,
|
||
|
self.cfg.n_feature_dims,
|
||
|
self.cfg.mlp_network_config,
|
||
|
)
|
||
|
|
||
|
self.mesh: Optional[Mesh] = None
|
||
|
|
||
|
def initialize_shape(self) -> None:
|
||
|
if self.cfg.shape_init is None and not self.cfg.force_shape_init:
|
||
|
return
|
||
|
|
||
|
# do not initialize shape if weights are provided
|
||
|
if self.cfg.weights is not None and not self.cfg.force_shape_init:
|
||
|
return
|
||
|
|
||
|
get_gt_sdf: Callable[[Float[Tensor, "N 3"]], Float[Tensor, "N 1"]]
|
||
|
assert isinstance(self.cfg.shape_init, str)
|
||
|
if self.cfg.shape_init == "ellipsoid":
|
||
|
assert (
|
||
|
isinstance(self.cfg.shape_init_params, Sized)
|
||
|
and len(self.cfg.shape_init_params) == 3
|
||
|
)
|
||
|
size = torch.as_tensor(self.cfg.shape_init_params).to(self.device)
|
||
|
|
||
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||
|
return ((points_rand / size) ** 2).sum(
|
||
|
dim=-1, keepdim=True
|
||
|
).sqrt() - 1.0 # pseudo signed distance of an ellipsoid
|
||
|
|
||
|
get_gt_sdf = func
|
||
|
elif self.cfg.shape_init == "sphere":
|
||
|
assert isinstance(self.cfg.shape_init_params, float)
|
||
|
radius = self.cfg.shape_init_params
|
||
|
|
||
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||
|
return (points_rand**2).sum(dim=-1, keepdim=True).sqrt() - radius
|
||
|
|
||
|
get_gt_sdf = func
|
||
|
elif self.cfg.shape_init.startswith("mesh:"):
|
||
|
assert isinstance(self.cfg.shape_init_params, float)
|
||
|
mesh_path = self.cfg.shape_init[5:]
|
||
|
if not os.path.exists(mesh_path):
|
||
|
raise ValueError(f"Mesh file {mesh_path} does not exist.")
|
||
|
|
||
|
import trimesh
|
||
|
|
||
|
mesh = trimesh.load(mesh_path)
|
||
|
|
||
|
# move to center
|
||
|
centroid = mesh.vertices.mean(0)
|
||
|
mesh.vertices = mesh.vertices - centroid
|
||
|
|
||
|
# align to up-z and front-x
|
||
|
dirs = ["+x", "+y", "+z", "-x", "-y", "-z"]
|
||
|
dir2vec = {
|
||
|
"+x": np.array([1, 0, 0]),
|
||
|
"+y": np.array([0, 1, 0]),
|
||
|
"+z": np.array([0, 0, 1]),
|
||
|
"-x": np.array([-1, 0, 0]),
|
||
|
"-y": np.array([0, -1, 0]),
|
||
|
"-z": np.array([0, 0, -1]),
|
||
|
}
|
||
|
if (
|
||
|
self.cfg.shape_init_mesh_up not in dirs
|
||
|
or self.cfg.shape_init_mesh_front not in dirs
|
||
|
):
|
||
|
raise ValueError(
|
||
|
f"shape_init_mesh_up and shape_init_mesh_front must be one of {dirs}."
|
||
|
)
|
||
|
if self.cfg.shape_init_mesh_up[1] == self.cfg.shape_init_mesh_front[1]:
|
||
|
raise ValueError(
|
||
|
"shape_init_mesh_up and shape_init_mesh_front must be orthogonal."
|
||
|
)
|
||
|
z_, x_ = (
|
||
|
dir2vec[self.cfg.shape_init_mesh_up],
|
||
|
dir2vec[self.cfg.shape_init_mesh_front],
|
||
|
)
|
||
|
y_ = np.cross(z_, x_)
|
||
|
std2mesh = np.stack([x_, y_, z_], axis=0).T
|
||
|
mesh2std = np.linalg.inv(std2mesh)
|
||
|
|
||
|
# scaling
|
||
|
scale = np.abs(mesh.vertices).max()
|
||
|
mesh.vertices = mesh.vertices / scale * self.cfg.shape_init_params
|
||
|
mesh.vertices = np.dot(mesh2std, mesh.vertices.T).T
|
||
|
|
||
|
from pysdf import SDF
|
||
|
|
||
|
sdf = SDF(mesh.vertices, mesh.faces)
|
||
|
|
||
|
def func(points_rand: Float[Tensor, "N 3"]) -> Float[Tensor, "N 1"]:
|
||
|
# add a negative signed here
|
||
|
# as in pysdf the inside of the shape has positive signed distance
|
||
|
return torch.from_numpy(-sdf(points_rand.cpu().numpy())).to(
|
||
|
points_rand
|
||
|
)[..., None]
|
||
|
|
||
|
get_gt_sdf = func
|
||
|
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unknown shape initialization type: {self.cfg.shape_init}"
|
||
|
)
|
||
|
|
||
|
sdf_gt = get_gt_sdf(
|
||
|
scale_tensor(
|
||
|
self.isosurface_helper.grid_vertices,
|
||
|
self.isosurface_helper.points_range,
|
||
|
self.isosurface_bbox,
|
||
|
)
|
||
|
)
|
||
|
self.sdf.data = sdf_gt
|
||
|
|
||
|
# explicit broadcast to ensure param consistency across ranks
|
||
|
for param in self.parameters():
|
||
|
broadcast(param, src=0)
|
||
|
|
||
|
def isosurface(self) -> Mesh:
|
||
|
# return cached mesh if fix_geometry is True to save computation
|
||
|
if self.cfg.fix_geometry and self.mesh is not None:
|
||
|
return self.mesh
|
||
|
mesh = self.isosurface_helper(self.sdf, self.deformation)
|
||
|
mesh.v_pos = scale_tensor(
|
||
|
mesh.v_pos, self.isosurface_helper.points_range, self.isosurface_bbox
|
||
|
)
|
||
|
if self.cfg.isosurface_remove_outliers:
|
||
|
mesh = mesh.remove_outlier(self.cfg.isosurface_outlier_n_faces_threshold)
|
||
|
self.mesh = mesh
|
||
|
return mesh
|
||
|
|
||
|
def forward(
|
||
|
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
|
||
|
) -> Dict[str, Float[Tensor, "..."]]:
|
||
|
if self.cfg.geometry_only:
|
||
|
return {}
|
||
|
assert (
|
||
|
output_normal == False
|
||
|
), f"Normal output is not supported for {self.__class__.__name__}"
|
||
|
points_unscaled = points # points in the original scale
|
||
|
points = contract_to_unisphere(points, self.bbox) # points normalized to (0, 1)
|
||
|
enc = self.encoding(points.view(-1, self.cfg.n_input_dims))
|
||
|
features = self.feature_network(enc).view(
|
||
|
*points.shape[:-1], self.cfg.n_feature_dims
|
||
|
)
|
||
|
return {"features": features}
|
||
|
|
||
|
@staticmethod
|
||
|
@torch.no_grad()
|
||
|
def create_from(
|
||
|
other: BaseGeometry,
|
||
|
cfg: Optional[Union[dict, DictConfig]] = None,
|
||
|
copy_net: bool = True,
|
||
|
**kwargs,
|
||
|
) -> "TetrahedraSDFGrid":
|
||
|
if isinstance(other, TetrahedraSDFGrid):
|
||
|
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||
|
assert instance.cfg.isosurface_resolution == other.cfg.isosurface_resolution
|
||
|
instance.isosurface_bbox = other.isosurface_bbox.clone()
|
||
|
instance.sdf.data = other.sdf.data.clone()
|
||
|
if (
|
||
|
instance.cfg.isosurface_deformable_grid
|
||
|
and other.cfg.isosurface_deformable_grid
|
||
|
):
|
||
|
assert (
|
||
|
instance.deformation is not None and other.deformation is not None
|
||
|
)
|
||
|
instance.deformation.data = other.deformation.data.clone()
|
||
|
if (
|
||
|
not instance.cfg.geometry_only
|
||
|
and not other.cfg.geometry_only
|
||
|
and copy_net
|
||
|
):
|
||
|
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||
|
instance.feature_network.load_state_dict(
|
||
|
other.feature_network.state_dict()
|
||
|
)
|
||
|
return instance
|
||
|
elif isinstance(other, ImplicitVolume):
|
||
|
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||
|
if other.cfg.isosurface_method != "mt":
|
||
|
other.cfg.isosurface_method = "mt"
|
||
|
threestudio.warn(
|
||
|
f"Override isosurface_method of the source geometry to 'mt'"
|
||
|
)
|
||
|
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
|
||
|
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
|
||
|
threestudio.warn(
|
||
|
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
|
||
|
)
|
||
|
mesh = other.isosurface()
|
||
|
instance.isosurface_bbox = mesh.extras["bbox"]
|
||
|
instance.sdf.data = (
|
||
|
mesh.extras["grid_level"].to(instance.sdf.data).clamp(-1, 1)
|
||
|
)
|
||
|
if not instance.cfg.geometry_only and copy_net:
|
||
|
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||
|
instance.feature_network.load_state_dict(
|
||
|
other.feature_network.state_dict()
|
||
|
)
|
||
|
return instance
|
||
|
elif isinstance(other, ImplicitSDF):
|
||
|
instance = TetrahedraSDFGrid(cfg, **kwargs)
|
||
|
if other.cfg.isosurface_method != "mt":
|
||
|
other.cfg.isosurface_method = "mt"
|
||
|
threestudio.warn(
|
||
|
f"Override isosurface_method of the source geometry to 'mt'"
|
||
|
)
|
||
|
if other.cfg.isosurface_resolution != instance.cfg.isosurface_resolution:
|
||
|
other.cfg.isosurface_resolution = instance.cfg.isosurface_resolution
|
||
|
threestudio.warn(
|
||
|
f"Override isosurface_resolution of the source geometry to {instance.cfg.isosurface_resolution}"
|
||
|
)
|
||
|
mesh = other.isosurface()
|
||
|
instance.isosurface_bbox = mesh.extras["bbox"]
|
||
|
instance.sdf.data = mesh.extras["grid_level"].to(instance.sdf.data)
|
||
|
if (
|
||
|
instance.cfg.isosurface_deformable_grid
|
||
|
and other.cfg.isosurface_deformable_grid
|
||
|
):
|
||
|
assert instance.deformation is not None
|
||
|
instance.deformation.data = mesh.extras["grid_deformation"].to(
|
||
|
instance.deformation.data
|
||
|
)
|
||
|
if not instance.cfg.geometry_only and copy_net:
|
||
|
instance.encoding.load_state_dict(other.encoding.state_dict())
|
||
|
instance.feature_network.load_state_dict(
|
||
|
other.feature_network.state_dict()
|
||
|
)
|
||
|
return instance
|
||
|
else:
|
||
|
raise TypeError(
|
||
|
f"Cannot create {TetrahedraSDFGrid.__name__} from {other.__class__.__name__}"
|
||
|
)
|
||
|
|
||
|
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
|
||
|
out: Dict[str, Any] = {}
|
||
|
if self.cfg.geometry_only or self.cfg.n_feature_dims == 0:
|
||
|
return out
|
||
|
points_unscaled = points
|
||
|
points = contract_to_unisphere(points_unscaled, self.bbox)
|
||
|
enc = self.encoding(points.reshape(-1, self.cfg.n_input_dims))
|
||
|
features = self.feature_network(enc).view(
|
||
|
*points.shape[:-1], self.cfg.n_feature_dims
|
||
|
)
|
||
|
out.update(
|
||
|
{
|
||
|
"features": features,
|
||
|
}
|
||
|
)
|
||
|
return out
|