Source code for esrgan.utils.module_params

import copy
import functools
import re
from typing import Any, Callable, Dict, Optional

from catalyst.registry import MODULE
from torch import nn

from esrgan.utils.types import ModuleParams


[docs]def process_fn_params(function: Callable) -> Callable: """Decorator for `fn_params` processing. Decorator that process all `*_fn` parameters and replaces ``str`` and ``dict`` values with corresponding constructors of `nn` modules. For example for ``act_fn='ReLU'`` and ``act_fn=nn.ReLU`` parameters the result will be ``nn.ReLU`` constructor of ReLU activation function, and for ``act_fn={'act': 'ReLU', 'inplace': True}`` the result will be 'partial' constructor ``nn.ReLU`` in which ``inplace`` argument is set to ``True``. Args: function: Function to wrap. Returns: Wrapped function. """ @functools.wraps(function) def wrapper(*args: Any, **kwargs: Any) -> Any: kwargs_: Dict[str, Any] = {} for key, value in kwargs.items(): if (match := re.match(r"(\w+)_fn", key)) and value: value = _process_fn_params( params=value, key=match.group(1) ) kwargs_[key] = value output = function(*args, **kwargs_) return output return wrapper
def _process_fn_params( params: ModuleParams, key: Optional[str] = None ) -> Callable[..., nn.Module]: module_fn: Callable[..., nn.Module] if callable(params): module_fn = params elif isinstance(params, str): name = params module_fn = MODULE.get(name) elif isinstance(params, dict) and key is not None: params = copy.deepcopy(params) name_or_fn = params.pop(key) module_fn = _process_fn_params(name_or_fn) module_fn = functools.partial(module_fn, **params) else: NotImplementedError() return module_fn
[docs]def create_layer( layer: Callable[..., nn.Module], in_channels: Optional[int] = None, out_channels: Optional[int] = None, layer_name: Optional[str] = None, **kwargs: Any, ) -> nn.Module: """Helper function to generalize layer creation. Args: layer: Layer constructor. in_channels: Size of the input sample. out_channels: Size of the output e.g. number of channels produced by the convolution. layer_name: Name of the layer e.g. ``'activation'``. **kwargs: Additional params to pass into `layer` function. Returns: Layer. Examples: >>> in_channels, out_channels = 10, 5 >>> create_layer(nn.Linear, in_channels, out_channels) Linear(in_features=10, out_features=5, bias=True) >>> create_layer(nn.ReLU, in_channels, out_channels, layer_name='act') ReLU() """ module: nn.Module if layer_name in {"activation", "act", "dropout", "pool", "pooling"}: module = layer(**kwargs) elif layer_name in {"normalization", "norm", "bn"}: module = layer(out_channels, **kwargs) else: module = layer(in_channels, out_channels, **kwargs) return module
__all__ = ["process_fn_params", "create_layer"]