Models¶
The models subpackage contains definitions of models for addressing image super-resolution tasks:
Generators¶
EncoderDecoderNet¶
- class esrgan.models.EncoderDecoderNet(encoder: torch.nn.modules.module.Module, decoder: torch.nn.modules.module.Module)[source]¶
Generalized Encoder-Decoder network.
- Parameters
encoder – Encoder module, usually used for the extraction of embeddings from input signals.
decoder – Decoder module, usually used for embeddings processing e.g. generation of signal similar to the input one (in GANs).
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass method.
- Parameters
x – Batch of input signals e.g. images.
- Returns
Batch of generated signals e.g. images.
- training: bool¶
SRGAN¶
- class esrgan.models.SRResNetEncoder(in_channels: int = 3, out_channels: int = 64, num_basic_blocks: int = 16, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), norm: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, activation: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.PReLU'>)[source]¶
‘Encoder’ part of SRResNet network, processing images in LR space.
It has been proposed in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.
- Parameters
in_channels – Number of channels in the input image.
out_channels – Number of channels produced by the encoder.
num_basic_blocks – Depth of the encoder, number of basic blocks to use.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.norm – Class constructor or partial object which when called should return normalization layer e.g.,
nn.BatchNorm2d
.activation – Class constructor or partial object which when called should return activation function to use after BN layers e.g.,
nn.PReLU
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of images.
- Returns
Batch of embeddings.
- training: bool¶
- class esrgan.models.SRResNetDecoder(in_channels: int = 64, out_channels: int = 3, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.activation.PReLU'>)[source]¶
‘Decoder’ part of SRResNet, converting embeddings to output image.
It has been proposed in Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network.
- Parameters
in_channels – Number of channels in the input embedding.
out_channels – Number of channels in the output image.
scale_factor – Ratio between the size of the high-resolution image (output) and its low-resolution counterpart (input). In other words multiplier for spatial size.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use e.g.,
nn.ReLU
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of embeddings.
- Returns
Batch of upscaled images.
- training: bool¶
ESRGAN¶
- class esrgan.models.ESREncoder(in_channels: int = 3, out_channels: int = 64, growth_channels: int = 32, num_basic_blocks: int = 23, num_dense_blocks: int = 3, num_residual_blocks: int = 5, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True), residual_scaling: float = 0.2)[source]¶
‘Encoder’ part of ESRGAN network, processing images in LR space.
It has been proposed in ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
- Parameters
in_channels – Number of channels in the input image.
out_channels – Number of channels produced by the encoder.
growth_channels – Number of channels in the latent space.
num_basic_blocks – Depth of the encoder, number of Residual-in-Residual Dense block (RRDB) to use.
num_dense_blocks – Number of dense blocks to use to form RRDB block.
num_residual_blocks – Number of convolutions to use to form dense block.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use e.g.,
nn.ReLU
.residual_scaling – Residual connections scaling factor.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of images.
- Returns
Batch of embeddings.
- training: bool¶
- class esrgan.models.ESRNetDecoder(in_channels: int = 64, out_channels: int = 3, scale_factor: int = 2, conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True))[source]¶
‘Decoder’ part of ESRGAN, converting embeddings to output image.
It has been proposed in ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks.
- Parameters
in_channels – Number of channels in the input embedding.
out_channels – Number of channels in the output image.
scale_factor – Ratio between the size of the high-resolution image (output) and its low-resolution counterpart (input). In other words multiplier for spatial size.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.activation – Class constructor or partial object which when called should return activation function to use e.g.,
nn.ReLU
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of embeddings.
- Returns
Batch of upscaled images.
- training: bool¶
Discriminators¶
VGGConv¶
- class esrgan.models.VGGConv(encoder: torch.nn.modules.module.Module, pool: torch.nn.modules.module.Module, head: torch.nn.modules.module.Module)[source]¶
VGG-like neural network for image classification.
- Parameters
encoder – Image encoder module, usually used for the extraction of embeddings from input signals.
pool – Pooling layer, used to reduce embeddings from the encoder.
head – Classification head, usually consists of Fully Connected layers.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward call.
- Parameters
x – Batch of images.
- Returns
Batch of logits.
- training: bool¶
StridedConvEncoder¶
- class esrgan.models.StridedConvEncoder(layers: Iterable[int] = (3, 64, 128, 128, 256, 256, 512, 512), layer_order: Iterable[str] = ('conv', 'norm', 'activation'), conv: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.conv.Conv2d'>, kernel_size=(3, 3), padding=1), norm: Optional[Callable[[...], torch.nn.modules.module.Module]] = <class 'torch.nn.modules.batchnorm.BatchNorm2d'>, activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True), residual: Optional[Callable[[...], torch.nn.modules.module.Module]] = None)[source]¶
Generalized Fully Convolutional encoder.
- Parameters
layers – List of feature maps sizes of each block.
layer_order – Ordered list of layers applied within each block. For instance, if you don’t want to use normalization layer just exclude it from this list.
conv – Class constructor or partial object which when called should return convolutional layer e.g.,
nn.Conv2d
.norm – Class constructor or partial object which when called should return normalization layer e.g.,
nn.BatchNorm2d
.activation – Class constructor or partial object which when called should return activation function to use e.g.,
nn.ReLU
.residual – Class constructor or partial object which when called should return block wrapper module e.g.,
esrgan.nn.ResidualModule
can be used to add residual connections between blocks.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of inputs.
- Returns
Batch of embeddings.
- property in_channels: int¶
The number of channels in the feature map of the input.
- Returns
Size of the input feature map.
- property out_channels: int¶
Number of channels produced by the block.
- Returns
Size of the output feature map.
- training: bool¶
LinearHead¶
- class esrgan.models.LinearHead(in_channels: int, out_channels: int, latent_channels: Optional[Iterable[int]] = None, layer_order: Iterable[str] = ('linear', 'activation'), linear: Callable[[...], torch.nn.modules.module.Module] = <class 'torch.nn.modules.linear.Linear'>, activation: Callable[[...], torch.nn.modules.module.Module] = functools.partial(<class 'torch.nn.modules.activation.LeakyReLU'>, negative_slope=0.2, inplace=True), norm: Optional[Callable[[...], torch.nn.modules.module.Module]] = None, dropout: Optional[Callable[[...], torch.nn.modules.module.Module]] = None)[source]¶
Stack of linear layers used for embeddings classification.
- Parameters
in_channels – Size of each input sample.
out_channels – Size of each output sample.
latent_channels – Size of the latent space.
layer_order – Ordered list of layers applied within each block. For instance, if you don’t want to use activation function just exclude it from this list.
linear – Class constructor or partial object which when called should return linear layer e.g.,
nn.Linear
.activation – Class constructor or partial object which when called should return activation function layer e.g.,
nn.ReLU
.norm – Class constructor or partial object which when called should return normalization layer e.g.,
nn.BatchNorm1d
.dropout – Class constructor or partial object which when called should return dropout layer e.g.,
nn.Dropout
.
- forward(x: torch.Tensor) → torch.Tensor[source]¶
Forward pass.
- Parameters
x – Batch of inputs e.g. images.
- Returns
Batch of logits.
- training: bool¶