import collections
from typing import List, Tuple
import torch
from torch import nn
from esrgan import utils
from esrgan.model.module import blocks
from esrgan.utils.types import ModuleParams
[docs]class ESREncoder(nn.Module):
"""'Encoder' part of ESRGAN network, processing images in LR space.
It has been proposed in `ESRGAN: Enhanced Super-Resolution
Generative Adversarial Networks`_.
Args:
in_channels: Number of channels in the input image.
out_channels: Number of channels produced by the encoder.
growth_channels: Number of channels in the latent space.
num_basic_blocks: Depth of the encoder, number of Residual-in-Residual
Dense block (RRDB) to use.
num_dense_blocks: Number of dense blocks to use to form `RRDB` block.
num_residual_blocks: Number of convolutions to use to form dense block.
conv_fn: Convolutional layers parameters.
activation_fn: Activation function to use after BN layers.
residual_scaling: Residual connections scaling factor.
.. _`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`:
https://arxiv.org/pdf/1809.00219.pdf
"""
@utils.process_fn_params
def __init__(
self,
in_channels: int = 3,
out_channels: int = 64,
growth_channels: int = 32,
num_basic_blocks: int = 23,
num_dense_blocks: int = 3,
num_residual_blocks: int = 5,
conv_fn: ModuleParams = blocks.Conv2d,
activation_fn: ModuleParams = blocks.LeakyReLU,
residual_scaling: float = 0.2,
) -> None:
super().__init__()
blocks_list: List[nn.Module] = []
# first conv
first_conv = conv_fn(in_channels, out_channels)
blocks_list.append(first_conv)
# basic blocks - sequence of rrdb layers
for _ in range(num_basic_blocks):
basic_block = blocks.ResidualInResidualDenseBlock(
num_features=out_channels,
growth_channels=growth_channels,
conv_fn=conv_fn,
activation_fn=activation_fn,
num_dense_blocks=num_dense_blocks,
num_blocks=num_residual_blocks,
residual_scaling=residual_scaling,
)
blocks_list.append(basic_block)
# last conv of the encoder
last_conv = conv_fn(out_channels, out_channels)
blocks_list.append(last_conv)
self.blocks = nn.ModuleList(blocks_list)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Batch of images.
Returns:
Batch of embeddings.
"""
input_ = output = self.blocks[0](x)
for module in self.blocks[1:]:
output = module(output)
return input_ + output
[docs]class ESRNetDecoder(nn.Module):
"""'Decoder' part of ESRGAN, converting embeddings to output image.
It has been proposed in `ESRGAN: Enhanced Super-Resolution
Generative Adversarial Networks`_.
Args:
in_channels: Number of channels in the input embedding.
out_channels: Number of channels in the output image.
scale_factor: Ratio between the size of the high-resolution image
(output) and its low-resolution counterpart (input).
In other words multiplier for spatial size.
conv_fn: Convolutional layers parameters.
activation_fn: Activation function to use.
.. _`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`:
https://arxiv.org/pdf/1809.00219.pdf
"""
@utils.process_fn_params
def __init__(
self,
in_channels: int = 64,
out_channels: int = 3,
scale_factor: int = 2,
conv_fn: ModuleParams = blocks.Conv2d,
activation_fn: ModuleParams = blocks.LeakyReLU,
) -> None:
super().__init__()
# check params
if utils.is_power_of_two(scale_factor):
raise NotImplementedError(
f"scale_factor should be power of 2, got {scale_factor}"
)
blocks_list: List[Tuple[str, nn.Module]] = []
# upsampling
for i in range(scale_factor // 2):
upsampling_block = blocks.InterpolateConv(
num_features=in_channels,
conv_fn=conv_fn,
activation_fn=activation_fn,
)
blocks_list.append((f"upsampling_{i}", upsampling_block))
# highres conv + last conv
last_conv = nn.Sequential(
conv_fn(in_channels, in_channels),
activation_fn(),
conv_fn(in_channels, out_channels),
)
blocks_list.append(("conv", last_conv))
self.blocks = nn.Sequential(collections.OrderedDict(blocks_list))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Batch of embeddings.
Returns:
Batch of upscaled images.
"""
output = self.blocks(x)
return output
__all__ = ["ESREncoder", "ESRNetDecoder"]