Source code for esrgan.model.module.conv

import collections
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import torch
from torch import nn

from esrgan import utils
from esrgan.model.module import blocks
from esrgan.utils.types import ModuleParams


[docs]class StridedConvEncoder(nn.Module): """Generalized Fully Convolutional encoder. Args: layers: List of feature maps sizes of each block. layer_order: Ordered list of layers applied within each block. For instance, if you don't want to use normalization layer just exclude it from this list. conv_fn: Convolutional layer params. activation_fn: Activation function to use. norm_fn: Normalization layer params, e.g. :py:class:`.nn.BatchNorm2d`. residual_fn: Block wrapper function, e.g. :py:class:`~.blocks.container.ResidualModule` can be used to add residual connections between blocks. """ @utils.process_fn_params def __init__( self, layers: Iterable[int] = (3, 64, 128, 128, 256, 256, 512, 512), layer_order: Iterable[str] = ("conv", "norm", "activation"), conv_fn: ModuleParams = blocks.Conv2d, activation_fn: ModuleParams = blocks.LeakyReLU, norm_fn: Optional[ModuleParams] = nn.BatchNorm2d, residual_fn: Optional[ModuleParams] = None, ): super().__init__() name2fn: Dict[str, Callable[..., nn.Module]] = { "activation": activation_fn, "conv": conv_fn, "norm": norm_fn, } self._layers = list(layers) net: List[Tuple[str, nn.Module]] = [] first_conv = collections.OrderedDict([ ("conv_0", name2fn["conv"](self._layers[0], self._layers[1])), ("act", name2fn["activation"]()), ]) net.append(("block_0", nn.Sequential(first_conv))) channels = utils.pairwise(self._layers[1:]) for i, (in_ch, out_ch) in enumerate(channels, start=1): block_list: List[Tuple[str, nn.Module]] = [] for name in layer_order: # `conv + 2x2 pooling` is equal to `conv with stride=2` kwargs = {"stride": out_ch // in_ch} if name == "conv" else {} module = utils.create_layer( layer_name=name, layer=name2fn[name], in_channels=in_ch, out_channels=out_ch, **kwargs ) block_list.append((name, module)) block = nn.Sequential(collections.OrderedDict(block_list)) # add residual connection, like in resnet blocks if residual_fn is not None and in_ch == out_ch: block = residual_fn(block) net.append((f"block_{i}", block)) self.net = nn.Sequential(collections.OrderedDict(net))
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass.""" output = self.net(x) return output
@property def in_channels(self) -> int: """The number of channels in the feature map of the input.""" return self._layers[0] @property def out_channels(self) -> int: """Number of channels produced by the block.""" return self._layers[-1]
__all__ = ["StridedConvEncoder"]