mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
64 lines
2.2 KiB
Python
64 lines
2.2 KiB
Python
|
import random
|
||
|
from dataclasses import dataclass, field
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
import torch.nn.functional as F
|
||
|
|
||
|
import threestudio
|
||
|
from threestudio.models.materials.base import BaseMaterial
|
||
|
from threestudio.models.networks import get_encoding, get_mlp
|
||
|
from threestudio.utils.ops import dot, get_activation
|
||
|
from threestudio.utils.typing import *
|
||
|
|
||
|
|
||
|
@threestudio.register("no-material")
|
||
|
class NoMaterial(BaseMaterial):
|
||
|
@dataclass
|
||
|
class Config(BaseMaterial.Config):
|
||
|
n_output_dims: int = 3
|
||
|
color_activation: str = "sigmoid"
|
||
|
input_feature_dims: Optional[int] = None
|
||
|
mlp_network_config: Optional[dict] = None
|
||
|
requires_normal: bool = False
|
||
|
|
||
|
cfg: Config
|
||
|
|
||
|
def configure(self) -> None:
|
||
|
self.use_network = False
|
||
|
if (
|
||
|
self.cfg.input_feature_dims is not None
|
||
|
and self.cfg.mlp_network_config is not None
|
||
|
):
|
||
|
self.network = get_mlp(
|
||
|
self.cfg.input_feature_dims,
|
||
|
self.cfg.n_output_dims,
|
||
|
self.cfg.mlp_network_config,
|
||
|
)
|
||
|
self.use_network = True
|
||
|
self.requires_normal = self.cfg.requires_normal
|
||
|
|
||
|
def forward(
|
||
|
self, features: Float[Tensor, "B ... Nf"], **kwargs
|
||
|
) -> Float[Tensor, "B ... Nc"]:
|
||
|
if not self.use_network:
|
||
|
assert (
|
||
|
features.shape[-1] == self.cfg.n_output_dims
|
||
|
), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
|
||
|
color = get_activation(self.cfg.color_activation)(features)
|
||
|
else:
|
||
|
color = self.network(features.view(-1, features.shape[-1])).view(
|
||
|
*features.shape[:-1], self.cfg.n_output_dims
|
||
|
)
|
||
|
color = get_activation(self.cfg.color_activation)(color)
|
||
|
return color
|
||
|
|
||
|
def export(self, features: Float[Tensor, "*N Nf"], **kwargs) -> Dict[str, Any]:
|
||
|
color = self(features, **kwargs).clamp(0, 1)
|
||
|
assert color.shape[-1] >= 3, "Output color must have at least 3 channels"
|
||
|
if color.shape[-1] > 3:
|
||
|
threestudio.warn(
|
||
|
"Output color has >3 channels, treating the first 3 as RGB"
|
||
|
)
|
||
|
return {"albedo": color[..., :3]}
|