Source code for esrgan.models.generator
import torch
from torch import nn
from esrgan import utils
__all__ = ["EncoderDecoderNet"]
[docs]class EncoderDecoderNet(nn.Module):
"""Generalized Encoder-Decoder network.
Args:
encoder: Encoder module, usually used for the extraction
of embeddings from input signals.
decoder: Decoder module, usually used for embeddings processing
e.g. generation of signal similar to the input one (in GANs).
"""
def __init__(self, encoder: nn.Module, decoder: nn.Module) -> None:
super().__init__()
self.encoder = encoder
self.decoder = decoder
utils.net_init_(self)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Forward pass method.
Args:
x: Batch of input signals e.g. images.
Returns:
Batch of generated signals e.g. images.
"""
x = self.encoder(x)
x = self.decoder(x)
x = torch.clamp(x, min=0.0, max=1.0)
return x