DreamCraft3D/threestudio/utils/GAN/loss.py
2023-12-15 17:44:44 +08:00

35 lines
1.1 KiB
Python

import torch
import torch.nn.functional as F
def generator_loss(discriminator, inputs, reconstructions, cond=None):
if cond is None:
logits_fake = discriminator(reconstructions.contiguous())
else:
logits_fake = discriminator(
torch.cat((reconstructions.contiguous(), cond), dim=1)
)
g_loss = -torch.mean(logits_fake)
return g_loss
def hinge_d_loss(logits_real, logits_fake):
loss_real = torch.mean(F.relu(1.0 - logits_real))
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
d_loss = 0.5 * (loss_real + loss_fake)
return d_loss
def discriminator_loss(discriminator, inputs, reconstructions, cond=None):
if cond is None:
logits_real = discriminator(inputs.contiguous().detach())
logits_fake = discriminator(reconstructions.contiguous().detach())
else:
logits_real = discriminator(
torch.cat((inputs.contiguous().detach(), cond), dim=1)
)
logits_fake = discriminator(
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
)
d_loss = hinge_d_loss(logits_real, logits_fake).mean()
return d_loss