mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import random
|
|
from dataclasses import dataclass, field
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
import threestudio
|
|
from threestudio.models.materials.base import BaseMaterial
|
|
from threestudio.models.networks import get_encoding, get_mlp
|
|
from threestudio.utils.ops import dot, get_activation
|
|
from threestudio.utils.typing import *
|
|
|
|
|
|
@threestudio.register("hybrid-rgb-latent-material")
|
|
class HybridRGBLatentMaterial(BaseMaterial):
|
|
@dataclass
|
|
class Config(BaseMaterial.Config):
|
|
n_output_dims: int = 3
|
|
color_activation: str = "sigmoid"
|
|
requires_normal: bool = True
|
|
|
|
cfg: Config
|
|
|
|
def configure(self) -> None:
|
|
self.requires_normal = self.cfg.requires_normal
|
|
|
|
def forward(
|
|
self, features: Float[Tensor, "B ... Nf"], **kwargs
|
|
) -> Float[Tensor, "B ... Nc"]:
|
|
assert (
|
|
features.shape[-1] == self.cfg.n_output_dims
|
|
), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
|
|
color = features
|
|
color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3])
|
|
return color
|