Source code for esrgan.core.runner

from typing import Dict

from catalyst import dl
import torch


[docs]class GANRunner(dl.Runner): """Runner for experiments with supervised / GAN model.""" def _init( self, input_key: str = "image", target_key: str = "real_image", generator_output_key: str = "fake_image", discriminator_real_output_gkey: str = "g_real_logits", discriminator_fake_output_gkey: str = "g_fake_logits", discriminator_real_output_dkey: str = "d_real_logits", discriminator_fake_output_dkey: str = "d_fake_logits", generator_key: str = "generator", discriminator_key: str = "discriminator", ) -> None: """Catalyst-specific helper method for `__init__`. Args: input_key: Key in batch dict mapping for model input. target_key: Key in batch dict mapping for target. generator_output_key: Key in output dict model output of the generator will be stored under. discriminator_real_output_gkey: Key to store predictions of discriminator for real inputs, contain gradients for generator. discriminator_fake_output_gkey: Key to store predictions of discriminator for predictions of generator, contain gradients for generator. discriminator_real_output_dkey: Key to store predictions of discriminator for real inputs, contain gradients for discriminator only. discriminator_fake_output_dkey: Key to store predictions of discriminator for items produced by generator, contain gradients for discriminator only. generator_key: Key in model dict mapping for generator model. discriminator_key: Key in model dict mapping for discriminator model (will be used in gan stages only). """ super()._init() self.generator_key = generator_key self.discriminator_key = discriminator_key self.input_key = input_key self.target_key = target_key self.generator_output_key = generator_output_key self.discriminator_real_output_gkey = discriminator_real_output_gkey self.discriminator_fake_output_gkey = discriminator_fake_output_gkey self.discriminator_real_output_dkey = discriminator_real_output_dkey self.discriminator_fake_output_dkey = discriminator_fake_output_dkey def _prepare_for_stage(self, stage: str) -> None: """Prepare `_handle_batch` method for current stage. Args: stage: Name of current stage. Raises: NotImplementedError: Name of the `stage` should ends with ``'_supervised'``, ``'_gan'`` or should be ``'infer'``, raise error otherwise. """ super()._prepare_for_stage(stage=stage) if stage.endswith("_supervised") or stage == "infer": self._handle_batch = self._handle_batch_supervised elif stage.endswith("_gan"): self._handle_batch = self._handle_batch_gan else: raise NotImplementedError() def _handle_batch(self, batch: Dict[str, torch.Tensor]) -> None: # `_handle_batch` method is @abstractmethod so it must be defined # even if it overwrites in `_prepare_for_stage` raise NotImplementedError() def _handle_batch_supervised(self, batch: Dict[str, torch.Tensor]) -> None: """Process train/valid batch, supervised mode. Args: batch: Input batch (batch of samples). """ model = self.model[self.generator_key] output = model(batch[self.input_key]) self.output = {self.generator_output_key: output} def _handle_batch_gan(self, batch: Dict[str, torch.Tensor]) -> None: """Process train/valid batch, GAN mode. Args: batch: Input batch, should raw samples for generator and ground truth samples for discriminator. """ generator = self.model[self.generator_key] discriminator = self.model[self.discriminator_key] real_image = batch[self.target_key] fake_image = generator(batch[self.input_key]) noise = torch.randn(real_image.shape, device=self.device) real_image = torch.clamp((real_image + 0.05 * noise), min=0.0, max=1.0) # predictions used in calculation of adversarial loss of generator real_logits_g = discriminator(real_image) fake_logits_g = discriminator(fake_image) # predictions used in calculation of adversarial loss of discriminator real_logits_d = discriminator(real_image) fake_logits_d = discriminator(fake_image.detach()) self.output = { self.generator_output_key: fake_image, self.discriminator_real_output_gkey: real_logits_g, self.discriminator_fake_output_gkey: fake_logits_g, self.discriminator_real_output_dkey: real_logits_d, self.discriminator_fake_output_dkey: fake_logits_d, }
[docs] def predict_batch(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: """Generate predictions based on input batch (generator inference). Args: batch: Input batch (batch of samples to adjust e.g. zoom). Returns: Batch of predictions of the generator. """ model = self.model[self.generator_key] output = model(batch[self.input_key].to(self.device)) return output
__all__ = ["GANRunner"]