mirror of
https://github.com/deepseek-ai/DreamCraft3D.git
synced 2025-02-23 06:18:56 -05:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
import torch
|
|
import torch.nn.functional as F
|
|
|
|
|
|
def generator_loss(discriminator, inputs, reconstructions, cond=None):
|
|
if cond is None:
|
|
logits_fake = discriminator(reconstructions.contiguous())
|
|
else:
|
|
logits_fake = discriminator(
|
|
torch.cat((reconstructions.contiguous(), cond), dim=1)
|
|
)
|
|
g_loss = -torch.mean(logits_fake)
|
|
return g_loss
|
|
|
|
|
|
def hinge_d_loss(logits_real, logits_fake):
|
|
loss_real = torch.mean(F.relu(1.0 - logits_real))
|
|
loss_fake = torch.mean(F.relu(1.0 + logits_fake))
|
|
d_loss = 0.5 * (loss_real + loss_fake)
|
|
return d_loss
|
|
|
|
|
|
def discriminator_loss(discriminator, inputs, reconstructions, cond=None):
|
|
if cond is None:
|
|
logits_real = discriminator(inputs.contiguous().detach())
|
|
logits_fake = discriminator(reconstructions.contiguous().detach())
|
|
else:
|
|
logits_real = discriminator(
|
|
torch.cat((inputs.contiguous().detach(), cond), dim=1)
|
|
)
|
|
logits_fake = discriminator(
|
|
torch.cat((reconstructions.contiguous().detach(), cond), dim=1)
|
|
)
|
|
d_loss = hinge_d_loss(logits_real, logits_fake).mean()
|
|
return d_loss |