Source code for esrgan.nn.modules.misc

import functools
from typing import Callable, Tuple, Union

from torch import nn
from torch.nn.utils.spectral_norm import SpectralNorm

__all__ = ["Conv2d", "Conv2dSN", "LeakyReLU", "LinearSN"]


Conv2d: Callable[..., nn.Module] = functools.partial(
    nn.Conv2d, kernel_size=(3, 3), padding=1
)
LeakyReLU: Callable[..., nn.Module] = functools.partial(
    nn.LeakyReLU, negative_slope=0.2, inplace=True
)


[docs]class Conv2dSN(nn.Conv2d): """:py:class:`nn.Conv2d` + spectral normalization. Applies a 2D convolution over an input signal composed of several input planes. In the simplest case, the output value of the layer with input size :math:`(N, C_{\\text{in}}, H, W)` and output :math:`(N, C_{\\text{out}}, H_{\\text{out}}, W_{\\text{out}})` can be precisely described as: .. math:: \\text{out}(N_i, C_{\\text{out}_j}) = \\text{bias}(C_{\\text{out}_j}) + \\sum_{k = 0}^{C_{\\text{in}} - 1} \\text{weight}(C_{\\text{out}_j}, k) \\star \\text{input}(N_i, k) where :math:`\\star` is the valid 2D `cross-correlation`_ operator, :math:`N` is a batch size, :math:`C` denotes a number of channels, :math:`H` is a height of input planes in pixels, and :math:`W` is width in pixels. Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm :math:`\\sigma` of the weight matrix calculated using power iteration method. See `Spectral Normalization for Generative Adversarial Networks`_ for details. Args: in_channels: Number of channels in the input image. out_channels: Number of channels produced by the convolution. kernel_size: Size of the convolving kernel. stride: Stride of the convolution. padding: Padding added to both sides of the input. dilation: Spacing between kernel elements. groups: Number of blocked connections from input channels to output channels. bias: If ``True``, adds a learnable bias to the output. padding_mode: ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. n_power_iterations: Number of power iterations to calculate spectral norm. .. _cross-correlation: https://en.wikipedia.org/wiki/Cross-correlation .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 """ def __init__( self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = (3, 3), stride: Union[int, Tuple[int, int]] = 1, padding: Union[int, Tuple[int, int]] = 0, dilation: Union[int, Tuple[int, int]] = 1, groups: int = 1, bias: bool = True, padding_mode: str = "zeros", n_power_iterations: int = 1, ) -> None: super().__init__( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode, ) SpectralNorm.apply( module=self, n_power_iterations=n_power_iterations, name="weight", dim=0, eps=1e-12 )
[docs]class LinearSN(nn.Linear): """:py:class:`nn.Linear` + spectral normalization. Applies a linear transformation to the incoming data: :math:`y = xA^T + b`. Spectral normalization stabilizes the training of discriminators (critics) in Generative Adversarial Networks (GANs) by rescaling the weight tensor with spectral norm :math:`\\sigma` of the weight matrix calculated using power iteration method. See `Spectral Normalization for Generative Adversarial Networks`_ for details. Args: in_features: Size of each input sample. out_features: Size of each output sample. bias: If set to ``False``, the layer will not learn an additive bias. n_power_iterations: Number of power iterations to calculate spectral norm. .. _`Spectral Normalization for Generative Adversarial Networks`: https://arxiv.org/abs/1802.05957 """ def __init__( self, in_features: int, out_features: int, bias: bool = True, n_power_iterations: int = 1, ) -> None: super().__init__( in_features=in_features, out_features=out_features, bias=bias, ) SpectralNorm.apply( module=self, n_power_iterations=n_power_iterations, name="weight", dim=0, eps=1e-12 )