NN¶
These are the basic building block for graphs:
Containers¶
ConcatInputModule¶
- class esrgan.nn.ConcatInputModule(module: Iterable[torch.nn.modules.module.Module])[source]¶
Module wrapper, passing outputs of all previous layers into each next layer.
- Parameters
module – PyTorch layer to wrap.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of inputs.
- Returns
Processed batch.
- training: bool¶
ResidualModule¶
- class esrgan.nn.ResidualModule(module: torch.nn.modules.module.Module, scale: float = 1.0, requires_grad: bool = False)[source]¶
Residual wrapper, adds identity connection.
It has been proposed in Deep Residual Learning for Image Recognition.
- Parameters
module – PyTorch layer to wrap.
scale – Residual connections scaling factor.
requires_grad – If set to
False
, the layer will not learn the strength of the residual connection.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of inputs.
- Returns
Processed batch.
- training: bool¶
Residual-in-Residual layers¶
ResidualDenseBlock¶
- class esrgan.nn.ResidualDenseBlock(num_features: int, growth_channels: int, num_blocks: int = 5, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True), residual_scaling: float = 0.2)[source]¶
Basic block of
ResidualInResidualDenseBlock
.- Parameters
num_features – \(C\) from an expected input of size \((N, C, H, W)\).
growth_channels – Number of channels in the latent space.
num_blocks – Number of convolutional blocks to use to form dense block.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use after convolution e.g.,
nn.LeakyReLU
.residual_scaling – Residual connections scaling factor.
- training: bool¶
ResidualInResidualDenseBlock¶
- class esrgan.nn.ResidualInResidualDenseBlock(num_features: int = 64, growth_channels: int = 32, num_dense_blocks: int = 3, residual_scaling: float = 0.2, **kwargs: Any)[source]¶
Residual-in-Residual Dense Block (RRDB).
Look at the paper: ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks for more details.
- Parameters
num_features – \(C\) from an expected input of size \((N, C, H, W)\).
growth_channels – Number of channels in the latent space.
num_dense_blocks – Number of dense blocks to use to form RRDB block.
residual_scaling – Residual connections scaling factor.
**kwargs – Dense block params.
- training: bool¶
UpSampling layers¶
InterpolateConv¶
- class esrgan.nn.InterpolateConv(num_features: int, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True))[source]¶
Upsamples a given multi-channel 2D (spatial) data.
- Parameters
num_features – Number of channels in the input tensor.
scale_factor – Factor to increase spatial resolution by.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use after convolution e.g.,
nn.PReLU
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass. Upscale input -> apply conv -> apply nonlinearity.
- Parameters
x – Batch of inputs.
- Returns
Upscaled data.
- training: bool¶
SubPixelConv¶
- class esrgan.nn.SubPixelConv(num_features: int, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.PReLU'>)[source]¶
Rearranges elements in a tensor of shape \((B, C \times r^2, H, W)\) to a tensor of shape \((B, C, H \times r, W \times r)\).
Look at the paper: Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network for more details.
- Parameters
num_features – Number of channels in the input tensor.
scale_factor – Factor to increase spatial resolution by.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use after sub-pixel convolution e.g.,
nn.PReLU
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass. Apply conv -> shuffle pixels -> apply nonlinearity.
- Parameters
x – Batch of inputs.
- Returns
Upscaled input.
- training: bool¶
Loss functions¶
AdversarialLoss¶
- class esrgan.nn.AdversarialLoss(mode: str = 'discriminator')[source]¶
GAN Loss function.
- Parameters
mode – Specifies loss terget:
'generator'
or'discriminator'
.'generator'
: maximize probability that fake data drawn from real data distribution (it is useful when training generator),'discriminator'
: minimize probability that real and generated distributions are similar.- Raises
NotImplementedError – If mode not
'generator'
or'discriminator'
.
- forward_discriminator(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]¶
Forward pass (discriminator mode).
- Parameters
fake_logits – Predictions of discriminator for fake data.
real_logits – Predictions of discriminator for real data.
- Returns
Loss, scalar.
- forward_generator(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]¶
Forward pass (generator mode).
- Parameters
fake_logits – Predictions of discriminator for fake data.
real_logits – Predictions of discriminator for real data.
- Returns
Loss, scalar.
- reduction: str¶
RelativisticAdversarialLoss¶
- class esrgan.nn.RelativisticAdversarialLoss(mode: str = 'discriminator')[source]¶
Relativistic average GAN loss function.
It has been proposed in The relativistic discriminator: a key element missing from standard GAN.
- Parameters
mode – Specifies loss target:
'generator'
or'discriminator'
.'generator'
: maximize probability that fake data more realistic than real (it is useful when training generator),'discriminator'
: maximize probability that real data more realistic than fake (useful when training discriminator).- Raises
NotImplementedError – If mode not
'generator'
or'discriminator'
.
- forward(fake_logits: torch.Tensor, real_logits: torch.Tensor) → torch.Tensor[source]¶
Forward propagation method for the relativistic adversarial loss.
- Parameters
fake_logits – Probability that generated samples are not real.
real_logits – Probability that real (ground truth) samples are fake.
- Returns
Loss, scalar.
- reduction: str¶
PerceptualLoss¶
- class esrgan.nn.PerceptualLoss(layers: Dict[str, float], model: str = 'vgg19', distance: Union[str, Callable] = 'l1', mean: Iterable[float] = (0.485, 0.456, 0.406), std: Iterable[float] = (0.229, 0.224, 0.225))[source]¶
The Perceptual Loss.
Calculates loss between features of model (usually VGG is used) for input (produced by generator) and target (real) images.
- Parameters
layers – Dict of layers names and weights (to balance different layers).
model – Model to use to extract features.
distance – Method to compute distance between features.
mean – List of float values used for data standartization. If there is no need to normalize data, please use [0., 0., 0.].
std – List of float values used for data standartization. If there is no need to normalize data, please use [1., 1., 1.].
- Raises
NotImplementedError – distance must be one of:
'l1'
,'cityblock'
,'l2'
, or'euclidean'
, raise error otherwise.
- forward(fake_data: torch.Tensor, real_data: torch.Tensor) → torch.Tensor[source]¶
Forward propagation method for the perceptual loss.
- Parameters
fake_data – Batch of input (fake, generated) images.
real_data – Batch of target (real, ground truth) images.
- Returns
Loss, scalar.
- reduction: str¶
Misc¶
- class esrgan.nn.Conv2dSN(in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = (3, 3), stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', n_power_iterations: int = 1)[source]¶
nn.Conv2d
+ spectral normalization.Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size \((N, C_{\text{in}}, H, W)\) and output \((N, C_{\text{out}}, H_{\text{out}}, W_{\text{out}})\) can be precisely described as:
\[\text{out}(N_i, C_{\text{out}_j}) = \text{bias}(C_{\text{out}_j}) + \sum_{k = 0}^{C_{\text{in}} - 1} \text{weight}(C_{\text{out}_j}, k) \star \text{input}(N_i, k)\]where \(\star\) is the valid 2D cross-correlation operator, \(N\) is a batch size, \(C\) denotes a number of channels, \(H\) is a height of input planes in pixels, and \(W\) is width in pixels.
Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm \(\sigma\) of the weight matrix calculated using power iteration method. See Spectral Normalization for Generative Adversarial Networks for details.
- Parameters
in_channels – Number of channels in the input image.
out_channels – Number of channels produced by the convolution.
kernel_size – Size of the convolving kernel.
stride – Stride of the convolution.
padding – Padding added to both sides of the input.
dilation – Spacing between kernel elements.
groups – Number of blocked connections from input channels to output channels.
bias – If
True
, adds a learnable bias to the output.padding_mode –
'zeros'
,'reflect'
,'replicate'
or'circular'
.n_power_iterations – Number of power iterations to calculate spectral norm.
- bias: Optional[torch.Tensor]¶
- dilation: Tuple[int, ...]¶
- groups: int¶
- kernel_size: Tuple[int, ...]¶
- out_channels: int¶
- output_padding: Tuple[int, ...]¶
- padding: Tuple[int, ...]¶
- padding_mode: str¶
- stride: Tuple[int, ...]¶
- transposed: bool¶
- weight: torch.Tensor¶
- class esrgan.nn.LinearSN(in_features: int, out_features: int, bias: bool = True, n_power_iterations: int = 1)[source]¶
nn.Linear
+ spectral normalization.Applies a linear transformation to the incoming data: \(y = xA^T + b\).
Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm \(\sigma\) of the weight matrix calculated using power iteration method. See Spectral Normalization for Generative Adversarial Networks for details.
- Parameters
in_features – Size of each input sample.
out_features – Size of each output sample.
bias – If set to
False
, the layer will not learn an additive bias.n_power_iterations – Number of power iterations to calculate spectral norm.
- in_features: int¶
- out_features: int¶
- weight: torch.Tensor¶