DreamCraft3D/threestudio/models/materials/hybrid_rgb_latent_material.py

37 lines
1.1 KiB
Python
Raw Normal View History

2023-12-12 11:17:53 -05:00
import random
from dataclasses import dataclass, field
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.networks import get_encoding, get_mlp
from threestudio.utils.ops import dot, get_activation
from threestudio.utils.typing import *
@threestudio.register("hybrid-rgb-latent-material")
class HybridRGBLatentMaterial(BaseMaterial):
@dataclass
class Config(BaseMaterial.Config):
n_output_dims: int = 3
color_activation: str = "sigmoid"
requires_normal: bool = True
cfg: Config
def configure(self) -> None:
self.requires_normal = self.cfg.requires_normal
def forward(
self, features: Float[Tensor, "B ... Nf"], **kwargs
) -> Float[Tensor, "B ... Nc"]:
assert (
features.shape[-1] == self.cfg.n_output_dims
), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input."
color = features
color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3])
return color