Criterions

Adversarial Loss

class esrgan.criterions.adversarial.AdversarialLoss(mode: str = 'discriminator')[source]

GAN Loss function.

Parameters

mode – Specifies loss terget: 'generator' or 'discriminator'. 'generator': maximize probability that fake data drawn from real data distribution (it is useful when training generator), 'discriminator': minimize probability that real and generated distributions are similar.

Raises

NotImplementedError – If mode not 'generator' or 'discriminator'.

forward_discriminator(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]

Forward pass (discriminator mode).

Parameters
  • fake_logits – Predictions of discriminator for fake data.

  • real_logits – Predictions of discriminator for real data.

Returns

Loss, scalar.

forward_generator(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]

Forward pass (generator mode).

Parameters
  • fake_logits – Predictions of discriminator for fake data.

  • real_logits – Predictions of discriminator for real data.

Returns

Loss, scalar.

class esrgan.criterions.adversarial.RelativisticAdversarialLoss(mode: str = 'discriminator')[source]

Relativistic average GAN loss function.

It has been proposed in The relativistic discriminator: a key element missing from standard GAN.

Parameters

mode – Specifies loss target: 'generator' or 'discriminator'. 'generator': maximize probability that fake data more realistic than real (it is useful when training generator), 'discriminator': maximize probability that real data more realistic than fake (useful when training discriminator.

Raises

NotImplementedError – If mode not 'generator' or 'discriminator'.

forward(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]

Forward propagation method for the relativistic adversarial loss.

Parameters
  • fake_logits – Probability that generated samples are not real.

  • real_logits – Probability that real (ground truth) samples are fake.

Returns

Loss, scalar.

Perceptual Loss

class esrgan.criterions.perceptual.PerceptualLoss(layers: Dict[str, float], model: str = 'vgg19', distance: Union[str, Callable] = 'l1', mean: Iterable[float] = (0.485, 0.456, 0.406), std: Iterable[float] = (0.229, 0.224, 0.225))[source]

The Perceptual Loss.

Calculates loss between features of model (usually VGG is used) for input (produced by generator) and target (real) images.

Parameters
  • layers – Dict of layers names and weights (to balance different layers).

  • model – Model to use to extract features.

  • distance – Method to compute distance between features.

  • mean – List of float values used for data standartization. If there is no need to normalize data, please use [0., 0., 0.].

  • std – List of float values used for data standartization. If there is no need to normalize data, please use [1., 1., 1.].

Raises

NotImplementedErrordistance must be one of: 'l1', 'cityblock', 'l2', or 'euclidean', raise error otherwise.

forward(fake_data: torch.Tensor, real_data: torch.Tensor) → torch.Tensor[source]

Forward propagation method for the perceptual loss.

Parameters
  • fake_data – Batch of input (fake, generated) images.

  • real_data – Batch of target (real, ground truth) images.

Returns

Loss, scalar.