Source code for esrgan.model.module.esrnet

import collections
from typing import List, Tuple

import torch
from torch import nn

from esrgan import utils
from esrgan.model.module import blocks
from esrgan.utils.types import ModuleParams


[docs]class ESREncoder(nn.Module): """'Encoder' part of ESRGAN network, processing images in LR space. It has been proposed in `ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`_. Args: in_channels: Number of channels in the input image. out_channels: Number of channels produced by the encoder. growth_channels: Number of channels in the latent space. num_basic_blocks: Depth of the encoder, number of Residual-in-Residual Dense block (RRDB) to use. num_dense_blocks: Number of dense blocks to use to form `RRDB` block. num_residual_blocks: Number of convolutions to use to form dense block. conv_fn: Convolutional layers parameters. activation_fn: Activation function to use after BN layers. residual_scaling: Residual connections scaling factor. .. _`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`: https://arxiv.org/pdf/1809.00219.pdf """ @utils.process_fn_params def __init__( self, in_channels: int = 3, out_channels: int = 64, growth_channels: int = 32, num_basic_blocks: int = 23, num_dense_blocks: int = 3, num_residual_blocks: int = 5, conv_fn: ModuleParams = blocks.Conv2d, activation_fn: ModuleParams = blocks.LeakyReLU, residual_scaling: float = 0.2, ) -> None: super().__init__() blocks_list: List[nn.Module] = [] # first conv first_conv = conv_fn(in_channels, out_channels) blocks_list.append(first_conv) # basic blocks - sequence of rrdb layers for _ in range(num_basic_blocks): basic_block = blocks.ResidualInResidualDenseBlock( num_features=out_channels, growth_channels=growth_channels, conv_fn=conv_fn, activation_fn=activation_fn, num_dense_blocks=num_dense_blocks, num_blocks=num_residual_blocks, residual_scaling=residual_scaling, ) blocks_list.append(basic_block) # last conv of the encoder last_conv = conv_fn(out_channels, out_channels) blocks_list.append(last_conv) self.blocks = nn.ModuleList(blocks_list)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: Batch of images. Returns: Batch of embeddings. """ input_ = output = self.blocks[0](x) for module in self.blocks[1:]: output = module(output) return input_ + output
[docs]class ESRNetDecoder(nn.Module): """'Decoder' part of ESRGAN, converting embeddings to output image. It has been proposed in `ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`_. Args: in_channels: Number of channels in the input embedding. out_channels: Number of channels in the output image. scale_factor: Ratio between the size of the high-resolution image (output) and its low-resolution counterpart (input). In other words multiplier for spatial size. conv_fn: Convolutional layers parameters. activation_fn: Activation function to use. .. _`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`: https://arxiv.org/pdf/1809.00219.pdf """ @utils.process_fn_params def __init__( self, in_channels: int = 64, out_channels: int = 3, scale_factor: int = 2, conv_fn: ModuleParams = blocks.Conv2d, activation_fn: ModuleParams = blocks.LeakyReLU, ) -> None: super().__init__() # check params if utils.is_power_of_two(scale_factor): raise NotImplementedError( f"scale_factor should be power of 2, got {scale_factor}" ) blocks_list: List[Tuple[str, nn.Module]] = [] # upsampling for i in range(scale_factor // 2): upsampling_block = blocks.InterpolateConv( num_features=in_channels, conv_fn=conv_fn, activation_fn=activation_fn, ) blocks_list.append((f"upsampling_{i}", upsampling_block)) # highres conv + last conv last_conv = nn.Sequential( conv_fn(in_channels, in_channels), activation_fn(), conv_fn(in_channels, out_channels), ) blocks_list.append(("conv", last_conv)) self.blocks = nn.Sequential(collections.OrderedDict(blocks_list))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x: Batch of embeddings. Returns: Batch of upscaled images. """ output = self.blocks(x) return output
__all__ = ["ESREncoder", "ESRNetDecoder"]