Source code for esrgan.utils.module_params
from typing import Any, Callable, Optional
from torch import nn
__all__ = ["create_layer"]
[docs]def create_layer(
layer: Callable[..., nn.Module],
in_channels: Optional[int] = None,
out_channels: Optional[int] = None,
layer_name: Optional[str] = None,
**kwargs: Any,
) -> nn.Module:
"""Helper function to generalize layer creation.
Args:
layer: Layer constructor.
in_channels: Size of the input sample.
out_channels: Size of the output e.g. number of channels
produced by the convolution.
layer_name: Name of the layer e.g. ``'activation'``.
**kwargs: Additional params to pass into `layer` function.
Returns:
Layer.
Examples:
>>> in_channels, out_channels = 10, 5
>>> create_layer(nn.Linear, in_channels, out_channels)
Linear(in_features=10, out_features=5, bias=True)
>>> create_layer(nn.ReLU, in_channels, out_channels, layer_name='act')
ReLU()
"""
module: nn.Module
if layer_name in {"activation", "act", "dropout", "pool", "pooling"}:
module = layer(**kwargs)
elif layer_name in {"normalization", "norm", "bn"}:
module = layer(out_channels, **kwargs)
else:
module = layer(in_channels, out_channels, **kwargs)
return module