mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
217 lines
6.5 KiB
Python
217 lines
6.5 KiB
Python
|
import functools
|
||
|
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
def count_params(model):
|
||
|
total_params = sum(p.numel() for p in model.parameters())
|
||
|
return total_params
|
||
|
|
||
|
|
||
|
class ActNorm(nn.Module):
|
||
|
def __init__(
|
||
|
self, num_features, logdet=False, affine=True, allow_reverse_init=False
|
||
|
):
|
||
|
assert affine
|
||
|
super().__init__()
|
||
|
self.logdet = logdet
|
||
|
self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
|
||
|
self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
|
||
|
self.allow_reverse_init = allow_reverse_init
|
||
|
|
||
|
self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
|
||
|
|
||
|
def initialize(self, input):
|
||
|
with torch.no_grad():
|
||
|
flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
|
||
|
mean = (
|
||
|
flatten.mean(1)
|
||
|
.unsqueeze(1)
|
||
|
.unsqueeze(2)
|
||
|
.unsqueeze(3)
|
||
|
.permute(1, 0, 2, 3)
|
||
|
)
|
||
|
std = (
|
||
|
flatten.std(1)
|
||
|
.unsqueeze(1)
|
||
|
.unsqueeze(2)
|
||
|
.unsqueeze(3)
|
||
|
.permute(1, 0, 2, 3)
|
||
|
)
|
||
|
|
||
|
self.loc.data.copy_(-mean)
|
||
|
self.scale.data.copy_(1 / (std + 1e-6))
|
||
|
|
||
|
def forward(self, input, reverse=False):
|
||
|
if reverse:
|
||
|
return self.reverse(input)
|
||
|
if len(input.shape) == 2:
|
||
|
input = input[:, :, None, None]
|
||
|
squeeze = True
|
||
|
else:
|
||
|
squeeze = False
|
||
|
|
||
|
_, _, height, width = input.shape
|
||
|
|
||
|
if self.training and self.initialized.item() == 0:
|
||
|
self.initialize(input)
|
||
|
self.initialized.fill_(1)
|
||
|
|
||
|
h = self.scale * (input + self.loc)
|
||
|
|
||
|
if squeeze:
|
||
|
h = h.squeeze(-1).squeeze(-1)
|
||
|
|
||
|
if self.logdet:
|
||
|
log_abs = torch.log(torch.abs(self.scale))
|
||
|
logdet = height * width * torch.sum(log_abs)
|
||
|
logdet = logdet * torch.ones(input.shape[0]).to(input)
|
||
|
return h, logdet
|
||
|
|
||
|
return h
|
||
|
|
||
|
def reverse(self, output):
|
||
|
if self.training and self.initialized.item() == 0:
|
||
|
if not self.allow_reverse_init:
|
||
|
raise RuntimeError(
|
||
|
"Initializing ActNorm in reverse direction is "
|
||
|
"disabled by default. Use allow_reverse_init=True to enable."
|
||
|
)
|
||
|
else:
|
||
|
self.initialize(output)
|
||
|
self.initialized.fill_(1)
|
||
|
|
||
|
if len(output.shape) == 2:
|
||
|
output = output[:, :, None, None]
|
||
|
squeeze = True
|
||
|
else:
|
||
|
squeeze = False
|
||
|
|
||
|
h = output / self.scale - self.loc
|
||
|
|
||
|
if squeeze:
|
||
|
h = h.squeeze(-1).squeeze(-1)
|
||
|
return h
|
||
|
|
||
|
|
||
|
class AbstractEncoder(nn.Module):
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
def encode(self, *args, **kwargs):
|
||
|
raise NotImplementedError
|
||
|
|
||
|
|
||
|
class Labelator(AbstractEncoder):
|
||
|
"""Net2Net Interface for Class-Conditional Model"""
|
||
|
|
||
|
def __init__(self, n_classes, quantize_interface=True):
|
||
|
super().__init__()
|
||
|
self.n_classes = n_classes
|
||
|
self.quantize_interface = quantize_interface
|
||
|
|
||
|
def encode(self, c):
|
||
|
c = c[:, None]
|
||
|
if self.quantize_interface:
|
||
|
return c, None, [None, None, c.long()]
|
||
|
return c
|
||
|
|
||
|
|
||
|
class SOSProvider(AbstractEncoder):
|
||
|
# for unconditional training
|
||
|
def __init__(self, sos_token, quantize_interface=True):
|
||
|
super().__init__()
|
||
|
self.sos_token = sos_token
|
||
|
self.quantize_interface = quantize_interface
|
||
|
|
||
|
def encode(self, x):
|
||
|
# get batch size from data and replicate sos_token
|
||
|
c = torch.ones(x.shape[0], 1) * self.sos_token
|
||
|
c = c.long().to(x.device)
|
||
|
if self.quantize_interface:
|
||
|
return c, None, [None, None, c]
|
||
|
return c
|
||
|
|
||
|
|
||
|
def weights_init(m):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find("Conv") != -1:
|
||
|
nn.init.normal_(m.weight.data, 0.0, 0.02)
|
||
|
elif classname.find("BatchNorm") != -1:
|
||
|
nn.init.normal_(m.weight.data, 1.0, 0.02)
|
||
|
nn.init.constant_(m.bias.data, 0)
|
||
|
|
||
|
|
||
|
class NLayerDiscriminator(nn.Module):
|
||
|
"""Defines a PatchGAN discriminator as in Pix2Pix
|
||
|
--> see https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/models/networks.py
|
||
|
"""
|
||
|
|
||
|
def __init__(self, input_nc=3, ndf=64, n_layers=3, use_actnorm=False):
|
||
|
"""Construct a PatchGAN discriminator
|
||
|
Parameters:
|
||
|
input_nc (int) -- the number of channels in input images
|
||
|
ndf (int) -- the number of filters in the last conv layer
|
||
|
n_layers (int) -- the number of conv layers in the discriminator
|
||
|
norm_layer -- normalization layer
|
||
|
"""
|
||
|
super(NLayerDiscriminator, self).__init__()
|
||
|
if not use_actnorm:
|
||
|
norm_layer = nn.BatchNorm2d
|
||
|
else:
|
||
|
norm_layer = ActNorm
|
||
|
if (
|
||
|
type(norm_layer) == functools.partial
|
||
|
): # no need to use bias as BatchNorm2d has affine parameters
|
||
|
use_bias = norm_layer.func != nn.BatchNorm2d
|
||
|
else:
|
||
|
use_bias = norm_layer != nn.BatchNorm2d
|
||
|
|
||
|
kw = 4
|
||
|
padw = 1
|
||
|
sequence = [
|
||
|
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
||
|
nn.LeakyReLU(0.2, True),
|
||
|
]
|
||
|
nf_mult = 1
|
||
|
nf_mult_prev = 1
|
||
|
for n in range(1, n_layers): # gradually increase the number of filters
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(
|
||
|
ndf * nf_mult_prev,
|
||
|
ndf * nf_mult,
|
||
|
kernel_size=kw,
|
||
|
stride=2,
|
||
|
padding=padw,
|
||
|
bias=use_bias,
|
||
|
),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True),
|
||
|
]
|
||
|
|
||
|
nf_mult_prev = nf_mult
|
||
|
nf_mult = min(2**n_layers, 8)
|
||
|
sequence += [
|
||
|
nn.Conv2d(
|
||
|
ndf * nf_mult_prev,
|
||
|
ndf * nf_mult,
|
||
|
kernel_size=kw,
|
||
|
stride=1,
|
||
|
padding=padw,
|
||
|
bias=use_bias,
|
||
|
),
|
||
|
norm_layer(ndf * nf_mult),
|
||
|
nn.LeakyReLU(0.2, True),
|
||
|
]
|
||
|
|
||
|
sequence += [
|
||
|
nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)
|
||
|
] # output 1 channel prediction map
|
||
|
self.main = nn.Sequential(*sequence)
|
||
|
|
||
|
def forward(self, input):
|
||
|
"""Standard forward."""
|
||
|
return self.main(input)
|