Source code for esrgan.nn.modules.rrdb
import collections
from typing import Any, Callable, List, Tuple
from torch import nn
from esrgan.nn.modules import container
from esrgan.nn.modules.misc import Conv2d, LeakyReLU
__all__ = ["ResidualDenseBlock", "ResidualInResidualDenseBlock"]
[docs]class ResidualDenseBlock(container.ResidualModule):
"""Basic block of :py:class:`ResidualInResidualDenseBlock`.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`.
growth_channels: Number of channels in the latent space.
num_blocks: Number of convolutional blocks to use to form dense block.
conv: Class constructor or partial object which when called
should return convolutional layer e.g., :py:class:`nn.Conv2d`.
activation: Class constructor or partial object which when called
should return activation function to use after convolution
e.g., :py:class:`nn.LeakyReLU`.
residual_scaling: Residual connections scaling factor.
"""
def __init__(
self,
num_features: int,
growth_channels: int,
num_blocks: int = 5,
conv: Callable[..., nn.Module] = Conv2d,
activation: Callable[..., nn.Module] = LeakyReLU,
residual_scaling: float = 0.2,
) -> None:
in_channels = [
num_features + i * growth_channels for i in range(num_blocks)
]
out_channels = [growth_channels] * (num_blocks - 1) + [num_features]
blocks: List[nn.Module] = []
for in_channels_, out_channels_ in zip(in_channels, out_channels):
block = collections.OrderedDict([
("conv", conv(in_channels_, out_channels_)),
("act", activation()),
])
blocks.append(nn.Sequential(block))
super().__init__(
module=container.ConcatInputModule(nn.ModuleList(blocks)),
scale=residual_scaling,
)
[docs]class ResidualInResidualDenseBlock(container.ResidualModule):
"""Residual-in-Residual Dense Block (RRDB).
Look at the paper: `ESRGAN: Enhanced Super-Resolution Generative
Adversarial Networks`_ for more details.
Args:
num_features: :math:`C` from an expected input of size
:math:`(N, C, H, W)`.
growth_channels: Number of channels in the latent space.
num_dense_blocks: Number of dense blocks to use to form `RRDB` block.
residual_scaling: Residual connections scaling factor.
**kwargs: Dense block params.
.. _`ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks`:
https://arxiv.org/pdf/1809.00219.pdf
"""
def __init__(
self,
num_features: int = 64,
growth_channels: int = 32,
num_dense_blocks: int = 3,
residual_scaling: float = 0.2,
**kwargs: Any,
) -> None:
blocks: List[Tuple[str, nn.Module]] = []
for i in range(num_dense_blocks):
block = ResidualDenseBlock(
num_features=num_features,
growth_channels=growth_channels,
residual_scaling=residual_scaling,
**kwargs,
)
blocks.append((f"block_{i}", block))
super().__init__(
module=nn.Sequential(collections.OrderedDict(blocks)),
scale=residual_scaling
)