Utilities¶
Set of utilities that can make life a little bit easier:
Augmentation¶
Model init¶
- esrgan.utils.init.kaiming_normal_(tensor: torch.Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu') → None[source]¶
Fills the input Tensor with values according to the method described in Delving deep into rectifiers: Surpassing human-level performance on ImageNet classification.
- Parameters
tensor – An n-dimensional tensor.
a – The slope of the rectifier used after this layer (only used with
'leaky_relu'
and'prelu'
).mode – Either
'fan_in'
or'fan_out'
. Choosing'fan_in'
preserves the magnitude of the variance of the weights in the forward pass. Choosing'fan_out'
preserves the magnitudes in the backwards pass.nonlinearity – The non-linear function (nn.functional name).
- esrgan.utils.init.module_init_(module: torch.nn.modules.module.Module, nonlinearity: Optional[Union[str, torch.nn.modules.module.Module]] = None, **kwargs: Any) → None[source]¶
Initialize module based on the activation function.
- Parameters
module – Module to initialize.
nonlinearity – Activation function. If LeakyReLU/PReLU and of type nn.Module, then initialization will be adapted by value of slope.
**kwargs – Additional params to pass in init function.
Model params¶
- esrgan.utils.module_params.create_layer(layer: Callable[[...], torch.nn.modules.module.Module], in_channels: Optional[int] = None, out_channels: Optional[int] = None, layer_name: Optional[str] = None, **kwargs: Any) → torch.nn.modules.module.Module[source]¶
Helper function to generalize layer creation.
- Parameters
layer – Layer constructor.
in_channels – Size of the input sample.
out_channels – Size of the output e.g. number of channels produced by the convolution.
layer_name – Name of the layer e.g.
'activation'
.**kwargs – Additional params to pass into layer function.
- Returns
Layer.
Examples
>>> in_channels, out_channels = 10, 5 >>> create_layer(nn.Linear, in_channels, out_channels) Linear(in_features=10, out_features=5, bias=True) >>> create_layer(nn.ReLU, in_channels, out_channels, layer_name='act') ReLU()