Source code for esrgan.nn.criterions.perceptual

from typing import Callable, Dict, Iterable, Union

import torch
from torch import nn
from torch.nn import functional as F
from torch.nn.modules.loss import _Loss
import torchvision

__all__ = ["PerceptualLoss"]


def _layer2index_vgg16(layer: str) -> int:
    """Map name of VGG layer to corresponding number in torchvision layer.

    Args:
        layer: name of the layer e.g. ``'conv1_1'``

    Returns:
        Number of layer (in network) with name `layer`.

    Examples:
        >>> _layer2index_vgg16('conv1_1')
        0
        >>> _layer2index_vgg16('pool5')
        30

    """
    block1 = ("conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1")
    block2 = ("conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2")
    block3 = ("conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "pool3")  # noqa: E501
    block4 = ("conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "pool4")  # noqa: E501
    block5 = ("conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "pool5")  # noqa: E501
    layers_order = block1 + block2 + block3 + block4 + block5
    vgg16_layers = {n: idx for idx, n in enumerate(layers_order)}

    return vgg16_layers[layer]


def _layer2index_vgg19(layer: str) -> int:
    """Map name of VGG layer to corresponding number in torchvision layer.

    Args:
        layer: name of the layer e.g. ``'conv1_1'``

    Returns:
        Number of layer (in network) with name `layer`.

    Examples:
        >>> _layer2index_vgg16('conv1_1')
        0
        >>> _layer2index_vgg16('pool5')
        36

    """
    block1 = ("conv1_1", "relu1_1", "conv1_2", "relu1_2", "pool1")
    block2 = ("conv2_1", "relu2_1", "conv2_2", "relu2_2", "pool2")
    block3 = ("conv3_1", "relu3_1", "conv3_2", "relu3_2", "conv3_3", "relu3_3", "conv3_4", "relu3_4", "pool3")  # noqa: E501
    block4 = ("conv4_1", "relu4_1", "conv4_2", "relu4_2", "conv4_3", "relu4_3", "conv4_4", "relu4_4", "pool4")  # noqa: E501
    block5 = ("conv5_1", "relu5_1", "conv5_2", "relu5_2", "conv5_3", "relu5_3", "conv5_4", "relu5_4", "pool5")  # noqa: E501
    layers_order = block1 + block2 + block3 + block4 + block5
    vgg19_layers = {n: idx for idx, n in enumerate(layers_order)}

    return vgg19_layers[layer]


[docs]class PerceptualLoss(_Loss): """The Perceptual Loss. Calculates loss between features of `model` (usually VGG is used) for input (produced by generator) and target (real) images. Args: layers: Dict of layers names and weights (to balance different layers). model: Model to use to extract features. distance: Method to compute distance between features. mean: List of float values used for data standartization. If there is no need to normalize data, please use [0., 0., 0.]. std: List of float values used for data standartization. If there is no need to normalize data, please use [1., 1., 1.]. Raises: NotImplementedError: `distance` must be one of: ``'l1'``, ``'cityblock'``, ``'l2'``, or ``'euclidean'``, raise error otherwise. """ def __init__( self, layers: Dict[str, float], model: str = "vgg19", distance: Union[str, Callable] = "l1", mean: Iterable[float] = (0.485, 0.456, 0.406), std: Iterable[float] = (0.229, 0.224, 0.225), ) -> None: super().__init__() model_fn = torchvision.models.__dict__[model] layer2idx = globals()[f"_layer2index_{model}"] w_sum = sum(layers.values()) self.layers = {str(layer2idx(k)): w / w_sum for k, w in layers.items()} last_layer = max(map(layer2idx, layers)) model = model_fn(pretrained=True) network = nn.Sequential(*list(model.features)[:last_layer + 1]).eval() for param in network.parameters(): param.requires_grad = False self.model = network if callable(distance): self.distance = distance elif distance.lower() in {"l1", "cityblock"}: self.distance = F.l1_loss elif distance.lower() in {"l2", "euclidean"}: self.distance = F.mse_loss else: raise NotImplementedError() self.mean = torch.tensor(mean).view(1, -1, 1, 1) self.std = torch.tensor(std).view(1, -1, 1, 1)
[docs] def forward( self, fake_data: torch.Tensor, real_data: torch.Tensor ) -> torch.Tensor: """Forward propagation method for the perceptual loss. Args: fake_data: Batch of input (fake, generated) images. real_data: Batch of target (real, ground truth) images. Returns: Loss, scalar. """ fake_features = self._get_features(fake_data) real_features = self._get_features(real_data) # calculate weighted sum of distances between real and fake features loss = torch.tensor(0.0, requires_grad=True).to(fake_data) for layer, weight in self.layers.items(): layer_loss = F.l1_loss(fake_features[layer], real_features[layer]) loss = loss + weight * layer_loss return loss
def _get_features(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: # normalize input tensor x = (x - self.mean.to(x)) / self.std.to(x) # extract net features features: Dict[str, torch.Tensor] = {} for name, module in self.model._modules.items(): x = module(x) if name in self.layers: features[name] = x return features