Source code for esrgan.utils.init

import copy
import inspect
import math
from typing import Any, Callable, Optional, Union

import torch
from torch import nn

__all__ = ["kaiming_normal_", "module_init_", "net_init_"]


[docs]def kaiming_normal_( tensor: torch.Tensor, a: float = 0, mode: str = "fan_in", nonlinearity: str = "leaky_relu" ) -> None: """Fills the input `Tensor` with values according to the method described in `Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification`_. Args: tensor: An n-dimensional tensor. a: The slope of the rectifier used after this layer (only used with ``'leaky_relu'`` and ``'prelu'``). mode: Either ``'fan_in'`` or ``'fan_out'``. Choosing ``'fan_in'`` preserves the magnitude of the variance of the weights in the forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the backwards pass. nonlinearity: The non-linear function (`nn.functional` name). .. _`Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification`: https://arxiv.org/pdf/1502.01852.pdf """ base_act = "relu" if nonlinearity == "prely" else nonlinearity nn.init.kaiming_normal_(tensor, a=a, mode=mode, nonlinearity=base_act) if nonlinearity == "prelu": with torch.no_grad(): std_correction = math.sqrt(1 + a ** 2) tensor.div_(std_correction)
[docs]def module_init_( module: nn.Module, nonlinearity: Union[str, nn.Module, None] = None, **kwargs: Any, ) -> None: """Initialize module based on the activation function. Args: module: Module to initialize. nonlinearity: Activation function. If LeakyReLU/PReLU and of type `nn.Module`, then initialization will be adapted by value of slope. **kwargs: Additional params to pass in init function. """ # get name of activation function and extract slope param if possible activation_name: Optional[str] = None init_kwargs = copy.deepcopy(kwargs) if isinstance(nonlinearity, str): activation_name = nonlinearity.lower() elif isinstance(nonlinearity, nn.Module): activation_name = nonlinearity.__class__.__name__.lower() assert isinstance(activation_name, str) if activation_name == "leakyrelu": # leakyrelu == LeakyReLU.lower activation_name = "leaky_relu" init_kwargs["a"] = kwargs.get("a", nonlinearity.negative_slope) elif activation_name == "prelu": init_kwargs["a"] = kwargs.get("a", nonlinearity.weight.data) # select initialization if activation_name in {"sigmoid", "tanh"}: weignt_init_fn: Callable = nn.init.xavier_uniform_ init_kwargs = kwargs elif activation_name in {"relu", "elu", "leaky_relu", "prelu"}: weignt_init_fn = kaiming_normal_ init_kwargs["nonlinearity"] = activation_name else: weignt_init_fn = nn.init.normal_ init_kwargs["std"] = kwargs.get("std", 0.01) # init weights of the module if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d)): weignt_init_fn(module.weight, **init_kwargs) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, (nn.BatchNorm2d, nn.BatchNorm3d, nn.GroupNorm)): nn.init.constant_(module.weight, 1) nn.init.constant_(module.bias, 0)
[docs]def net_init_(net: nn.Module) -> None: """Inplace initialization of weights of neural network. Args: net: Network to initialize. """ # create set of all activation functions (in PyTorch) activations = tuple( m[1] for m in inspect.getmembers(nn.modules.activation, inspect.isclass) if m[1].__module__ == "torch.nn.modules.activation" ) # init of the layer depends on activation used after it, # so iterate from the last layer to the first activation: Optional[nn.Module] = None for m in reversed(list(net.modules())): if isinstance(m, activations): activation = m module_init_(m, nonlinearity=activation)