Datasets

The models subpackage contains definitions for the following datasets for image super-resolution:

All datasets are subclasses of torch.utils.data.Dataset i.e, they have __getitem__ and __len__ methods implemented. Hence, they can all be passed to a torch.utils.data.DataLoader which can load multiple samples parallelly using torch.multiprocessing workers. For example:

div2k_data = esrgan.datasets.DIV2KDataset('path/to/div2k_root/')
data_loader = torch.utils.data.DataLoader(div2k_data, batch_size=4, shuffle=True)

DIV2K

class esrgan.datasets.DIV2KDataset(root: str, train: bool = True, target_type: str = 'bicubic_X4', patch_size: Tuple[int, int] = (96, 96), transform: Optional[Callable[[Any], Dict]] = None, low_resolution_image_key: str = 'lr_image', high_resolution_image_key: str = 'hr_image', download: bool = False)[source]

DIV2K Dataset.

Parameters
  • root – Root directory where images are downloaded to.

  • train – If True, creates dataset from training set, otherwise creates from validation set.

  • target_type – Type of target to use, 'bicubic_X2', 'unknown_X4', 'X8', 'mild', …

  • patch_size – If train == True, define sizes of patches to produce, return full image otherwise. Tuple of height and width.

  • transform – A function / transform that takes in dictionary (with low and high resolution images) and returns a transformed version.

  • low_resolution_image_key – Key to use to store images of low resolution.

  • high_resolution_image_key – Key to use to store high resolution images.

  • download – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

Flickr2K

class esrgan.datasets.Flickr2KDataset(root: str, train: bool = True, target_type: str = 'bicubic_X4', patch_size: Tuple[int, int] = (96, 96), transform: Optional[Callable[[Any], Dict]] = None, low_resolution_image_key: str = 'lr_image', high_resolution_image_key: str = 'hr_image', download: bool = False)[source]

Flickr2K Dataset.

Parameters
  • root – Root directory where images are downloaded to.

  • train – If True, creates dataset from training set, otherwise creates from validation set.

  • target_type – Type of target to use, 'bicubic_X2', 'unknown_X4', …

  • patch_size – If train == True, define sizes of patches to produce, return full image otherwise. Tuple of height and width.

  • transform – A function / transform that takes in dictionary (with low and high resolution images) and returns a transformed version.

  • low_resolution_image_key – Key to use to store images of low resolution.

  • high_resolution_image_key – Key to use to store high resolution images.

  • download – If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.

Folder of Images

class esrgan.datasets.ImageFolderDataset(pathname: str, image_key: str = 'image', image_name_key: str = 'filename', transform: Optional[Callable[[Dict], Dict]] = None)[source]

A generic data loader where the samples are arranged in this way:

<pathname>/xxx.ext
<pathname>/xxy.ext
<pathname>/xxz.ext
...
<pathname>/123.ext
<pathname>/nsdf3.ext
<pathname>/asd932_.ext
Parameters
  • pathname – Root directory of dataset.

  • image_key – Key to use to store image.

  • image_name_key – Key to use to store name of the image.

  • transform – A function / transform that takes in dictionary and returns its transformed version.