Source code for esrgan.model.module.blocks.upsampling

import torch
from torch import nn
from torch.nn import functional as F

from esrgan import utils
from esrgan.model.module.blocks.misc import Conv2d, LeakyReLU
from esrgan.utils.types import ModuleParams


[docs]class SubPixelConv(nn.Module): """Rearranges elements in a tensor of shape :math:`(B, C \\times r^2, H, W)` to a tensor of shape :math:`(B, C, H \\times r, W \\times r)`. Look at the paper: `Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`_ for more details. Args: num_features: Number of channels in the input tensor. scale_factor: Factor to increase spatial resolution by. conv_fn: Convolution layer params. activation_fn: Activation function to use after sub-pixel convolution. .. _`Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network`: https://arxiv.org/pdf/1609.05158.pdf """ @utils.process_fn_params def __init__( self, num_features: int, scale_factor: int = 2, conv_fn: ModuleParams = Conv2d, activation_fn: ModuleParams = nn.PReLU, ): super().__init__() self.block = nn.Sequential( conv_fn(num_features, num_features * 4), nn.PixelShuffle(upscale_factor=scale_factor), activation_fn(), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Apply conv -> shuffle pixels -> apply nonlinearity. Args: x: Batch of inputs. Returns: Upscaled input. """ output = self.block(x) return output
[docs]class InterpolateConv(nn.Module): """Upsamples a given multi-channel 2D (spatial) data. Args: num_features: Number of channels in the input tensor. scale_factor: Factor to increase spatial resolution by. conv_fn: Convolutional layer params. activation_fn: Activation function to use after convolution. """ @utils.process_fn_params def __init__( self, num_features: int, scale_factor: int = 2, conv_fn: ModuleParams = Conv2d, activation_fn: ModuleParams = LeakyReLU, ) -> None: super().__init__() self.scale_factor = scale_factor self.block = nn.Sequential( conv_fn(num_features, num_features), activation_fn(), )
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Upscale input -> apply conv -> apply nonlinearity. Args: x: Batch of inputs. Returns: Upscaled data. """ x = F.interpolate(x, scale_factor=self.scale_factor, mode="nearest") output = self.block(x) return output
__all__ = ["SubPixelConv", "InterpolateConv"]