Source code for esrgan.nn.modules.container
from typing import Iterable
import torch
from torch import nn
__all__ = ["ConcatInputModule", "ResidualModule"]
[docs]class ResidualModule(nn.Module):
"""Residual wrapper, adds identity connection.
It has been proposed in `Deep Residual Learning for Image Recognition`_.
Args:
module: PyTorch layer to wrap.
scale: Residual connections scaling factor.
requires_grad: If set to ``False``, the layer will not learn
the strength of the residual connection.
.. _`Deep Residual Learning for Image Recognition`:
https://arxiv.org/pdf/1512.03385.pdf
"""
def __init__(
self,
module: nn.Module,
scale: float = 1.0,
requires_grad: bool = False,
) -> None:
super().__init__()
self.module = module
self.scale = nn.Parameter(
torch.tensor(scale), requires_grad=requires_grad
)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass.
Args:
x: Batch of inputs.
Returns:
Processed batch.
"""
return x + self.scale * self.module(x)