Source code for esrgan.nn.criterions.adversarial
import torch
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
__all__ = ["AdversarialLoss", "RelativisticAdversarialLoss"]
[docs]class AdversarialLoss(_Loss):
"""GAN Loss function.
Args:
mode: Specifies loss terget: ``'generator'`` or ``'discriminator'``.
``'generator'``: maximize probability that fake data drawn from
real data distribution (it is useful when training generator),
``'discriminator'``: minimize probability that real and generated
distributions are similar.
Raises:
NotImplementedError: If `mode` not ``'generator'``
or ``'discriminator'``.
"""
def __init__(self, mode: str = "discriminator") -> None:
super().__init__()
if mode == "generator":
self.forward = self.forward_generator
elif mode == "discriminator":
self.forward = self.forward_discriminator
else:
raise NotImplementedError()
[docs] def forward_generator(
self, fake_logits: torch.Tensor, real_logits: torch.Tensor
) -> torch.Tensor:
"""Forward pass (generator mode).
Args:
fake_logits: Predictions of discriminator for fake data.
real_logits: Predictions of discriminator for real data.
Returns:
Loss, scalar.
"""
loss = F.binary_cross_entropy_with_logits(
input=fake_logits, target=torch.ones_like(fake_logits)
)
return loss
[docs] def forward_discriminator(
self, fake_logits: torch.Tensor, real_logits: torch.Tensor
) -> torch.Tensor:
"""Forward pass (discriminator mode).
Args:
fake_logits: Predictions of discriminator for fake data.
real_logits: Predictions of discriminator for real data.
Returns:
Loss, scalar.
"""
loss_real = F.binary_cross_entropy_with_logits(
real_logits, torch.ones_like(real_logits), reduction="sum"
)
loss_fake = F.binary_cross_entropy_with_logits(
fake_logits, torch.zeros_like(fake_logits), reduction="sum"
)
num_samples = real_logits.shape[0] + fake_logits.shape[0]
loss = (loss_real + loss_fake) / num_samples # mean loss
return loss
[docs]class RelativisticAdversarialLoss(_Loss):
"""Relativistic average GAN loss function.
It has been proposed in `The relativistic discriminator: a key element
missing from standard GAN`_.
Args:
mode: Specifies loss target: ``'generator'`` or ``'discriminator'``.
``'generator'``: maximize probability that fake data more realistic
than real (it is useful when training generator),
``'discriminator'``: maximize probability that real data more
realistic than fake (useful when training discriminator).
Raises:
NotImplementedError: If `mode` not ``'generator'``
or ``'discriminator'``.
.. _`The relativistic discriminator: a key element missing
from standard GAN`: https://arxiv.org/pdf/1807.00734.pdf
"""
def __init__(self, mode: str = "discriminator") -> None:
super().__init__()
if mode == "generator":
self.rf_labels, self.fr_labels = 0, 1
elif mode == "discriminator":
self.rf_labels, self.fr_labels = 1, 0
else:
raise NotImplementedError()
[docs] def forward(
# self, outputs: torch.Tensor, targets: torch.Tensor
self, fake_logits: torch.Tensor, real_logits: torch.Tensor
) -> torch.Tensor:
"""Forward propagation method for the relativistic adversarial loss.
Args:
fake_logits: Probability that generated samples are not real.
real_logits: Probability that real (ground truth) samples are fake.
Returns:
Loss, scalar.
"""
loss_rf = F.binary_cross_entropy_with_logits(
input=(real_logits - fake_logits.mean()),
target=torch.empty_like(real_logits).fill_(self.rf_labels),
)
loss_fr = F.binary_cross_entropy_with_logits(
input=(fake_logits - real_logits.mean()),
target=torch.empty_like(fake_logits).fill_(self.fr_labels),
)
loss = (loss_fr + loss_rf) / 2
return loss