Source code for esrgan.model.discriminator

import copy
from typing import Optional

from catalyst.registry import MODULE
import torch
from torch import nn

from esrgan import utils


[docs]class VGGConv(nn.Module): """VGG-like neural network for image classification. Args: encoder: Image encoder module, usually used for the extraction of embeddings from input signals. pool: Pooling layer, used to reduce embeddings from the encoder. head: Classification head, usually consists of Fully Connected layers. """ def __init__( self, encoder: nn.Module, pool: nn.Module, head: nn.Module, ) -> None: super().__init__() self.encoder = encoder self.pool = pool self.head = head
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward call. Args: x: Batch of images. Returns: Batch of logits. """ x = self.pool(self.encoder(x)) x = x.view(x.shape[0], -1) x = self.head(x) return x
[docs] @classmethod def get_from_params( cls, encoder_params: Optional[dict] = None, pooling_params: Optional[dict] = None, head_params: Optional[dict] = None, ) -> "VGGConv": """Create model based on it config. Args: encoder_params: Params of encoder module. pooling_params: Params of the pooling layer. head_params: 'Head' module params. Returns: Model. """ encoder: nn.Module = nn.Identity() if (encoder_params_ := copy.deepcopy(encoder_params)) is not None: encoder_fn = MODULE.get(encoder_params_.pop("module")) encoder = encoder_fn(**encoder_params_) pool: nn.Module = nn.Identity() if (pooling_params_ := copy.deepcopy(pooling_params)) is not None: pool_fn = MODULE.get(pooling_params_.pop("module")) pool = pool_fn(**pooling_params_) head: nn.Module = nn.Identity() if (head_params_ := copy.deepcopy(head_params)) is not None: head_fn = MODULE.get(head_params_.pop("module")) head = head_fn(**head_params_) net = cls(encoder=encoder, pool=pool, head=head) utils.net_init_(net) return net
__all__ = ["VGGConv"]