NN

These are the basic building block for graphs:

Containers

ConcatInputModule

class esrgan.nn.ConcatInputModule(module: Iterable[torch.nn.modules.module.Module])[source]

Module wrapper, passing outputs of all previous layers into each next layer.

Parameters

module – PyTorch layer to wrap.

forward(x: torch.Tensor)torch.Tensor[source]

Forward pass.

Parameters

x – Batch of inputs.

Returns

Processed batch.

training: bool

ResidualModule

class esrgan.nn.ResidualModule(module: torch.nn.modules.module.Module, scale: float = 1.0, requires_grad: bool = False)[source]

Residual wrapper, adds identity connection.

It has been proposed in Deep Residual Learning for Image Recognition.

Parameters
  • module – PyTorch layer to wrap.

  • scale – Residual connections scaling factor.

  • requires_grad – If set to False, the layer will not learn the strength of the residual connection.

forward(x: torch.Tensor)torch.Tensor[source]

Forward pass.

Parameters

x – Batch of inputs.

Returns

Processed batch.

training: bool

Residual-in-Residual layers

ResidualDenseBlock

class esrgan.nn.ResidualDenseBlock(num_features: int, growth_channels: int, num_blocks: int = 5, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True), residual_scaling: float = 0.2)[source]

Basic block of ResidualInResidualDenseBlock.

Parameters
  • num_features\(C\) from an expected input of size \((N, C, H, W)\).

  • growth_channels – Number of channels in the latent space.

  • num_blocks – Number of convolutional blocks to use to form dense block.

  • conv – Class constructor or partial object which when called should return convolutional layer e.g., nn.Conv2d.

  • activation – Class constructor or partial object which when called should return activation function to use after convolution e.g., nn.LeakyReLU.

  • residual_scaling – Residual connections scaling factor.

training: bool

ResidualInResidualDenseBlock

class esrgan.nn.ResidualInResidualDenseBlock(num_features: int = 64, growth_channels: int = 32, num_dense_blocks: int = 3, residual_scaling: float = 0.2, **kwargs: Any)[source]

Residual-in-Residual Dense Block (RRDB).

Look at the paper: ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks for more details.

Parameters
  • num_features\(C\) from an expected input of size \((N, C, H, W)\).

  • growth_channels – Number of channels in the latent space.

  • num_dense_blocks – Number of dense blocks to use to form RRDB block.

  • residual_scaling – Residual connections scaling factor.

  • **kwargs – Dense block params.

training: bool

UpSampling layers

InterpolateConv

class esrgan.nn.InterpolateConv(num_features: int, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True))[source]

Upsamples a given multi-channel 2D (spatial) data.

Parameters
  • num_features – Number of channels in the input tensor.

  • scale_factor – Factor to increase spatial resolution by.

  • conv – Class constructor or partial object which when called should return convolutional layer e.g., nn.Conv2d.

  • activation – Class constructor or partial object which when called should return activation function to use after convolution e.g., nn.PReLU.

forward(x: torch.Tensor)torch.Tensor[source]

Forward pass. Upscale input -> apply conv -> apply nonlinearity.

Parameters

x – Batch of inputs.

Returns

Upscaled data.

training: bool

SubPixelConv

class esrgan.nn.SubPixelConv(num_features: int, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.PReLU'>)[source]

Rearranges elements in a tensor of shape \((B, C \times r^2, H, W)\) to a tensor of shape \((B, C, H \times r, W \times r)\).

Look at the paper: Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network for more details.

Parameters
  • num_features – Number of channels in the input tensor.

  • scale_factor – Factor to increase spatial resolution by.

  • conv – Class constructor or partial object which when called should return convolutional layer e.g., nn.Conv2d.

  • activation – Class constructor or partial object which when called should return activation function to use after sub-pixel convolution e.g., nn.PReLU.

forward(x: torch.Tensor)torch.Tensor[source]

Forward pass. Apply conv -> shuffle pixels -> apply nonlinearity.

Parameters

x – Batch of inputs.

Returns

Upscaled input.

training: bool

Loss functions

AdversarialLoss

class esrgan.nn.AdversarialLoss(mode: str = 'discriminator')[source]

GAN Loss function.

Parameters

mode – Specifies loss terget: 'generator' or 'discriminator'. 'generator': maximize probability that fake data drawn from real data distribution (it is useful when training generator), 'discriminator': minimize probability that real and generated distributions are similar.

Raises

NotImplementedError – If mode not 'generator' or 'discriminator'.

forward_discriminator(fake_logits: torch.Tensor, real_logits: torch.Tensor)torch.Tensor[source]

Forward pass (discriminator mode).

Parameters
  • fake_logits – Predictions of discriminator for fake data.

  • real_logits – Predictions of discriminator for real data.

Returns

Loss, scalar.

forward_generator(fake_logits: torch.Tensor, real_logits: torch.Tensor)torch.Tensor[source]

Forward pass (generator mode).

Parameters
  • fake_logits – Predictions of discriminator for fake data.

  • real_logits – Predictions of discriminator for real data.

Returns

Loss, scalar.

reduction: str

RelativisticAdversarialLoss

class esrgan.nn.RelativisticAdversarialLoss(mode: str = 'discriminator')[source]

Relativistic average GAN loss function.

It has been proposed in The relativistic discriminator: a key element missing from standard GAN.

Parameters

mode – Specifies loss target: 'generator' or 'discriminator'. 'generator': maximize probability that fake data more realistic than real (it is useful when training generator), 'discriminator': maximize probability that real data more realistic than fake (useful when training discriminator).

Raises

NotImplementedError – If mode not 'generator' or 'discriminator'.

forward(fake_logits: torch.Tensor, real_logits: torch.Tensor)torch.Tensor[source]

Forward propagation method for the relativistic adversarial loss.

Parameters
  • fake_logits – Probability that generated samples are not real.

  • real_logits – Probability that real (ground truth) samples are fake.

Returns

Loss, scalar.

reduction: str

PerceptualLoss

class esrgan.nn.PerceptualLoss(layers: Dict[str, float], model: str = 'vgg19', distance: Union[str, Callable] = 'l1', mean: Iterable[float] = (0.485, 0.456, 0.406), std: Iterable[float] = (0.229, 0.224, 0.225))[source]

The Perceptual Loss.

Calculates loss between features of model (usually VGG is used) for input (produced by generator) and target (real) images.

Parameters
  • layers – Dict of layers names and weights (to balance different layers).

  • model – Model to use to extract features.

  • distance – Method to compute distance between features.

  • mean – List of float values used for data standartization. If there is no need to normalize data, please use [0., 0., 0.].

  • std – List of float values used for data standartization. If there is no need to normalize data, please use [1., 1., 1.].

Raises

NotImplementedErrordistance must be one of: 'l1', 'cityblock', 'l2', or 'euclidean', raise error otherwise.

forward(fake_data: torch.Tensor, real_data: torch.Tensor)torch.Tensor[source]

Forward propagation method for the perceptual loss.

Parameters
  • fake_data – Batch of input (fake, generated) images.

  • real_data – Batch of target (real, ground truth) images.

Returns

Loss, scalar.

reduction: str

Misc

class esrgan.nn.Conv2dSN(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = (3, 3), stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', n_power_iterations: int = 1)[source]

nn.Conv2d + spectral normalization.

Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, H, W)\) and output \((N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:

\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]

where \(\star\) is the valid 2D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.

Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm \(\sigma\) of the weight matrix calculated using power iteration method. See Spectral Normalization for Generative Adversarial Networks for details.

Parameters
  • in_channels – Number of channels in the input image.

  • out_channels – Number of channels produced by the convolution.

  • kernel_size – Size of the convolving kernel.

  • stride – Stride of the convolution.

  • padding – Padding added to both sides of the input.

  • dilation – Spacing between kernel elements.

  • groups – Number of blocked connections from input channels to output channels.

  • bias – If True, adds a learnable bias to the output.

  • padding_mode'zeros', 'reflect', 'replicate' or 'circular'.

  • n_power_iterations – Number of power iterations to calculate spectral norm.

bias: Optional[torch.Tensor]
dilation: Tuple[int, ...]
groups: int
kernel_size: Tuple[int, ...]
out_channels: int
output_padding: Tuple[int, ...]
padding: Tuple[int, ...]
padding_mode: str
stride: Tuple[int, ...]
transposed: bool
weight: torch.Tensor
class esrgan.nn.LinearSN(in_features: int, out_features: int, bias: bool = True, n_power_iterations: int = 1)[source]

nn.Linear + spectral normalization.

Applies a linear transformation to the incoming data: \(y = xA^T + b\).

Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm \(\sigma\) of the weight matrix calculated using power iteration method. See Spectral Normalization for Generative Adversarial Networks for details.

Parameters
  • in_features – Size of each input sample.

  • out_features – Size of each output sample.

  • bias – If set to False, the layer will not learn an additive bias.

  • n_power_iterations – Number of power iterations to calculate spectral norm.

in_features: int
out_features: int
weight: torch.Tensor