mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 14:28:55 -05:00
401 lines
13 KiB
Python
401 lines
13 KiB
Python
|
"""
|
||
|
Copyright (C) 2019 NVIDIA Corporation. Ting-Chun Wang, Ming-Yu Liu, Jun-Yan Zhu.
|
||
|
BSD License. All rights reserved.
|
||
|
Redistribution and use in source and binary forms, with or without
|
||
|
modification, are permitted provided that the following conditions are met:
|
||
|
* Redistributions of source code must retain the above copyright notice, this
|
||
|
list of conditions and the following disclaimer.
|
||
|
* Redistributions in binary form must reproduce the above copyright notice,
|
||
|
this list of conditions and the following disclaimer in the documentation
|
||
|
and/or other materials provided with the distribution.
|
||
|
THE AUTHOR DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE, INCLUDING ALL
|
||
|
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR ANY PARTICULAR PURPOSE.
|
||
|
IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY SPECIAL, INDIRECT OR CONSEQUENTIAL
|
||
|
DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS,
|
||
|
WHETHER IN AN ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING
|
||
|
OUT OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
||
|
"""
|
||
|
import functools
|
||
|
|
||
|
import numpy as np
|
||
|
import torch
|
||
|
import torch.nn as nn
|
||
|
from torch.autograd import Variable
|
||
|
|
||
|
|
||
|
###############################################################################
|
||
|
# Functions
|
||
|
###############################################################################
|
||
|
def weights_init(m):
|
||
|
classname = m.__class__.__name__
|
||
|
if classname.find("Conv") != -1:
|
||
|
m.weight.data.normal_(0.0, 0.02)
|
||
|
elif classname.find("BatchNorm2d") != -1:
|
||
|
m.weight.data.normal_(1.0, 0.02)
|
||
|
m.bias.data.fill_(0)
|
||
|
|
||
|
|
||
|
def get_norm_layer(norm_type="instance"):
|
||
|
if norm_type == "batch":
|
||
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True)
|
||
|
elif norm_type == "instance":
|
||
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False)
|
||
|
else:
|
||
|
raise NotImplementedError("normalization layer [%s] is not found" % norm_type)
|
||
|
return norm_layer
|
||
|
|
||
|
|
||
|
def define_G(
|
||
|
input_nc,
|
||
|
output_nc,
|
||
|
ngf,
|
||
|
netG,
|
||
|
n_downsample_global=3,
|
||
|
n_blocks_global=9,
|
||
|
n_local_enhancers=1,
|
||
|
n_blocks_local=3,
|
||
|
norm="instance",
|
||
|
gpu_ids=[],
|
||
|
last_op=nn.Tanh(),
|
||
|
):
|
||
|
norm_layer = get_norm_layer(norm_type=norm)
|
||
|
if netG == "global":
|
||
|
netG = GlobalGenerator(
|
||
|
input_nc,
|
||
|
output_nc,
|
||
|
ngf,
|
||
|
n_downsample_global,
|
||
|
n_blocks_global,
|
||
|
norm_layer,
|
||
|
last_op=last_op,
|
||
|
)
|
||
|
elif netG == "local":
|
||
|
netG = LocalEnhancer(
|
||
|
input_nc,
|
||
|
output_nc,
|
||
|
ngf,
|
||
|
n_downsample_global,
|
||
|
n_blocks_global,
|
||
|
n_local_enhancers,
|
||
|
n_blocks_local,
|
||
|
norm_layer,
|
||
|
)
|
||
|
elif netG == "encoder":
|
||
|
netG = Encoder(input_nc, output_nc, ngf, n_downsample_global, norm_layer)
|
||
|
else:
|
||
|
raise ("generator not implemented!")
|
||
|
# print(netG)
|
||
|
if len(gpu_ids) > 0:
|
||
|
assert torch.cuda.is_available()
|
||
|
netG.cuda(gpu_ids[0])
|
||
|
netG.apply(weights_init)
|
||
|
return netG
|
||
|
|
||
|
|
||
|
def print_network(net):
|
||
|
if isinstance(net, list):
|
||
|
net = net[0]
|
||
|
num_params = 0
|
||
|
for param in net.parameters():
|
||
|
num_params += param.numel()
|
||
|
print(net)
|
||
|
print("Total number of parameters: %d" % num_params)
|
||
|
|
||
|
|
||
|
##############################################################################
|
||
|
# Generator
|
||
|
##############################################################################
|
||
|
class LocalEnhancer(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
input_nc,
|
||
|
output_nc,
|
||
|
ngf=32,
|
||
|
n_downsample_global=3,
|
||
|
n_blocks_global=9,
|
||
|
n_local_enhancers=1,
|
||
|
n_blocks_local=3,
|
||
|
norm_layer=nn.BatchNorm2d,
|
||
|
padding_type="reflect",
|
||
|
):
|
||
|
super(LocalEnhancer, self).__init__()
|
||
|
self.n_local_enhancers = n_local_enhancers
|
||
|
|
||
|
###### global generator model #####
|
||
|
ngf_global = ngf * (2**n_local_enhancers)
|
||
|
model_global = GlobalGenerator(
|
||
|
input_nc,
|
||
|
output_nc,
|
||
|
ngf_global,
|
||
|
n_downsample_global,
|
||
|
n_blocks_global,
|
||
|
norm_layer,
|
||
|
).model
|
||
|
model_global = [
|
||
|
model_global[i] for i in range(len(model_global) - 3)
|
||
|
] # get rid of final convolution layers
|
||
|
self.model = nn.Sequential(*model_global)
|
||
|
|
||
|
###### local enhancer layers #####
|
||
|
for n in range(1, n_local_enhancers + 1):
|
||
|
### downsample
|
||
|
ngf_global = ngf * (2 ** (n_local_enhancers - n))
|
||
|
model_downsample = [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(input_nc, ngf_global, kernel_size=7, padding=0),
|
||
|
norm_layer(ngf_global),
|
||
|
nn.ReLU(True),
|
||
|
nn.Conv2d(
|
||
|
ngf_global, ngf_global * 2, kernel_size=3, stride=2, padding=1
|
||
|
),
|
||
|
norm_layer(ngf_global * 2),
|
||
|
nn.ReLU(True),
|
||
|
]
|
||
|
### residual blocks
|
||
|
model_upsample = []
|
||
|
for i in range(n_blocks_local):
|
||
|
model_upsample += [
|
||
|
ResnetBlock(
|
||
|
ngf_global * 2, padding_type=padding_type, norm_layer=norm_layer
|
||
|
)
|
||
|
]
|
||
|
|
||
|
### upsample
|
||
|
model_upsample += [
|
||
|
nn.ConvTranspose2d(
|
||
|
ngf_global * 2,
|
||
|
ngf_global,
|
||
|
kernel_size=3,
|
||
|
stride=2,
|
||
|
padding=1,
|
||
|
output_padding=1,
|
||
|
),
|
||
|
norm_layer(ngf_global),
|
||
|
nn.ReLU(True),
|
||
|
]
|
||
|
|
||
|
### final convolution
|
||
|
if n == n_local_enhancers:
|
||
|
model_upsample += [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||
|
nn.Tanh(),
|
||
|
]
|
||
|
|
||
|
setattr(self, "model" + str(n) + "_1", nn.Sequential(*model_downsample))
|
||
|
setattr(self, "model" + str(n) + "_2", nn.Sequential(*model_upsample))
|
||
|
|
||
|
self.downsample = nn.AvgPool2d(
|
||
|
3, stride=2, padding=[1, 1], count_include_pad=False
|
||
|
)
|
||
|
|
||
|
def forward(self, input):
|
||
|
### create input pyramid
|
||
|
input_downsampled = [input]
|
||
|
for i in range(self.n_local_enhancers):
|
||
|
input_downsampled.append(self.downsample(input_downsampled[-1]))
|
||
|
|
||
|
### output at coarest level
|
||
|
output_prev = self.model(input_downsampled[-1])
|
||
|
### build up one layer at a time
|
||
|
for n_local_enhancers in range(1, self.n_local_enhancers + 1):
|
||
|
model_downsample = getattr(self, "model" + str(n_local_enhancers) + "_1")
|
||
|
model_upsample = getattr(self, "model" + str(n_local_enhancers) + "_2")
|
||
|
input_i = input_downsampled[self.n_local_enhancers - n_local_enhancers]
|
||
|
output_prev = model_upsample(model_downsample(input_i) + output_prev)
|
||
|
return output_prev
|
||
|
|
||
|
|
||
|
class NormalNet(nn.Module):
|
||
|
def __init__(
|
||
|
self,
|
||
|
name="normalnet",
|
||
|
input_nc=3,
|
||
|
output_nc=3,
|
||
|
ngf=64,
|
||
|
n_downsampling=4,
|
||
|
n_blocks=9,
|
||
|
norm_layer=nn.BatchNorm2d,
|
||
|
padding_type="reflect",
|
||
|
last_op=nn.Sigmoid(),
|
||
|
):
|
||
|
assert n_blocks >= 0
|
||
|
super(NormalNet, self).__init__()
|
||
|
self.name = name
|
||
|
activation = nn.ReLU(True)
|
||
|
|
||
|
model = [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
||
|
nn.BatchNorm2d(ngf),
|
||
|
activation,
|
||
|
]
|
||
|
### downsample
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2**i
|
||
|
model += [
|
||
|
nn.Conv2d(
|
||
|
ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1
|
||
|
),
|
||
|
nn.BatchNorm2d(ngf * mult * 2),
|
||
|
activation,
|
||
|
]
|
||
|
|
||
|
### resnet blocks
|
||
|
mult = 2**n_downsampling
|
||
|
for i in range(n_blocks):
|
||
|
model += [
|
||
|
ResnetBlock(
|
||
|
ngf * mult,
|
||
|
padding_type=padding_type,
|
||
|
activation=activation,
|
||
|
norm_layer=norm_layer,
|
||
|
)
|
||
|
]
|
||
|
|
||
|
### upsample
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2 ** (n_downsampling - i)
|
||
|
model += [
|
||
|
nn.Upsample(scale_factor=2),
|
||
|
nn.Conv2d(
|
||
|
ngf * mult, int(ngf * mult / 2), kernel_size=3, stride=1, padding=1
|
||
|
),
|
||
|
nn.BatchNorm2d(int(ngf * mult / 2)),
|
||
|
activation,
|
||
|
]
|
||
|
model += [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||
|
]
|
||
|
if last_op is not None:
|
||
|
model += [last_op]
|
||
|
self.model = nn.Sequential(*model)
|
||
|
|
||
|
def forward(self, in_x, label=None):
|
||
|
res_list = []
|
||
|
return self.model(in_x)
|
||
|
|
||
|
|
||
|
# Define a resnet block
|
||
|
class ResnetBlock(nn.Module):
|
||
|
def __init__(
|
||
|
self, dim, padding_type, norm_layer, activation=nn.ReLU(True), use_dropout=False
|
||
|
):
|
||
|
super(ResnetBlock, self).__init__()
|
||
|
self.conv_block = self.build_conv_block(
|
||
|
dim, padding_type, norm_layer, activation, use_dropout
|
||
|
)
|
||
|
|
||
|
def build_conv_block(self, dim, padding_type, norm_layer, activation, use_dropout):
|
||
|
conv_block = []
|
||
|
p = 0
|
||
|
if padding_type == "reflect":
|
||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||
|
elif padding_type == "replicate":
|
||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||
|
elif padding_type == "zero":
|
||
|
p = 1
|
||
|
else:
|
||
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
||
|
|
||
|
conv_block += [
|
||
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
||
|
nn.BatchNorm2d(dim),
|
||
|
activation,
|
||
|
]
|
||
|
if use_dropout:
|
||
|
conv_block += [nn.Dropout(0.5)]
|
||
|
|
||
|
p = 0
|
||
|
if padding_type == "reflect":
|
||
|
conv_block += [nn.ReflectionPad2d(1)]
|
||
|
elif padding_type == "replicate":
|
||
|
conv_block += [nn.ReplicationPad2d(1)]
|
||
|
elif padding_type == "zero":
|
||
|
p = 1
|
||
|
else:
|
||
|
raise NotImplementedError("padding [%s] is not implemented" % padding_type)
|
||
|
conv_block += [
|
||
|
nn.Conv2d(dim, dim, kernel_size=3, padding=p),
|
||
|
nn.BatchNorm2d(dim),
|
||
|
]
|
||
|
|
||
|
return nn.Sequential(*conv_block)
|
||
|
|
||
|
def forward(self, x):
|
||
|
out = x + self.conv_block(x)
|
||
|
return out
|
||
|
|
||
|
|
||
|
class Encoder(nn.Module):
|
||
|
def __init__(
|
||
|
self, input_nc, output_nc, ngf=32, n_downsampling=4, norm_layer=nn.BatchNorm2d
|
||
|
):
|
||
|
super(Encoder, self).__init__()
|
||
|
self.output_nc = output_nc
|
||
|
|
||
|
model = [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0),
|
||
|
norm_layer(ngf),
|
||
|
nn.ReLU(True),
|
||
|
]
|
||
|
### downsample
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2**i
|
||
|
model += [
|
||
|
nn.Conv2d(
|
||
|
ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1
|
||
|
),
|
||
|
norm_layer(ngf * mult * 2),
|
||
|
nn.ReLU(True),
|
||
|
]
|
||
|
|
||
|
### upsample
|
||
|
for i in range(n_downsampling):
|
||
|
mult = 2 ** (n_downsampling - i)
|
||
|
model += [
|
||
|
nn.ConvTranspose2d(
|
||
|
ngf * mult,
|
||
|
int(ngf * mult / 2),
|
||
|
kernel_size=3,
|
||
|
stride=2,
|
||
|
padding=1,
|
||
|
output_padding=1,
|
||
|
),
|
||
|
norm_layer(int(ngf * mult / 2)),
|
||
|
nn.ReLU(True),
|
||
|
]
|
||
|
|
||
|
model += [
|
||
|
nn.ReflectionPad2d(3),
|
||
|
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
|
||
|
nn.Tanh(),
|
||
|
]
|
||
|
self.model = nn.Sequential(*model)
|
||
|
|
||
|
def forward(self, input, inst):
|
||
|
outputs = self.model(input)
|
||
|
|
||
|
# instance-wise average pooling
|
||
|
outputs_mean = outputs.clone()
|
||
|
inst_list = np.unique(inst.cpu().numpy().astype(int))
|
||
|
for i in inst_list:
|
||
|
for b in range(input.size()[0]):
|
||
|
indices = (inst[b : b + 1] == int(i)).nonzero() # n x 4
|
||
|
for j in range(self.output_nc):
|
||
|
output_ins = outputs[
|
||
|
indices[:, 0] + b,
|
||
|
indices[:, 1] + j,
|
||
|
indices[:, 2],
|
||
|
indices[:, 3],
|
||
|
]
|
||
|
mean_feat = torch.mean(output_ins).expand_as(output_ins)
|
||
|
outputs_mean[
|
||
|
indices[:, 0] + b,
|
||
|
indices[:, 1] + j,
|
||
|
indices[:, 2],
|
||
|
indices[:, 3],
|
||
|
] = mean_feat
|
||
|
return outputs_mean
|