Criterions¶
Adversarial Loss¶
-
class
esrgan.criterions.adversarial.
AdversarialLoss
(mode: str = 'discriminator')[source]¶ GAN Loss function.
- Parameters
mode – Specifies loss terget:
'generator'
or'discriminator'
.'generator'
: maximize probability that fake data drawn from real data distribution (it is useful when training generator),'discriminator'
: minimize probability that real and generated distributions are similar.- Raises
NotImplementedError – If mode not
'generator'
or'discriminator'
.
-
class
esrgan.criterions.adversarial.
RelativisticAdversarialLoss
(mode: str = 'discriminator')[source]¶ Relativistic average GAN loss function.
It has been proposed in The relativistic discriminator: a key element missing from standard GAN.
- Parameters
mode – Specifies loss target:
'generator'
or'discriminator'
.'generator'
: maximize probability that fake data more realistic than real (it is useful when training generator),'discriminator'
: maximize probability that real data more realistic than fake (useful when training discriminator.- Raises
NotImplementedError – If mode not
'generator'
or'discriminator'
.
-
forward
(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]¶ Forward propagation method for the relativistic adversarial loss.
- Parameters
fake_logits – Probability that generated samples are not real.
real_logits – Probability that real (ground truth) samples are fake.
- Returns
Loss, scalar.
Perceptual Loss¶
-
class
esrgan.criterions.perceptual.
PerceptualLoss
(layers: Dict[str, float], model: str = 'vgg19', distance: Union[str, Callable] = 'l1', mean: Iterable[float] = (0.485, 0.456, 0.406), std: Iterable[float] = (0.229, 0.224, 0.225))[source]¶ The Perceptual Loss.
Calculates loss between features of model (usually VGG is used) for input (produced by generator) and target (real) images.
- Parameters
layers – Dict of layers names and weights (to balance different layers).
model – Model to use to extract features.
distance – Method to compute distance between features.
mean – List of float values used for data standartization. If there is no need to normalize data, please use [0., 0., 0.].
std – List of float values used for data standartization. If there is no need to normalize data, please use [1., 1., 1.].
- Raises
NotImplementedError – distance must be one of:
'l1'
,'cityblock'
,'l2'
, or'euclidean'
, raise error otherwise.