Source code for esrgan.datasets

import glob
from pathlib import Path
import random
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union

from albumentations.augmentations.crops import functional as F
from catalyst import data
from catalyst.contrib.datasets import misc
import numpy as np
from torch.utils.data import Dataset

from esrgan import utils

__all__ = ["DIV2KDataset", "Flickr2KDataset", "ImageFolderDataset"]


def has_image_extension(uri: Union[str, Path]) -> bool:
    """Checks that file has image extension.

    Args:
        uri: The resource to load the file from.

    Returns:
        ``True`` if file has image extension, ``False`` otherwise.

    """
    ext = Path(uri).suffix
    return ext.lower() in {".bmp", ".png", ".jpeg", ".jpg", ".tif", ".tiff"}


def images_in_dir(*args: Union[str, Path]) -> List[str]:
    """Searches for all images in the directory.

    Args:
        *args: Path to the folder with images.
            Each element of path segments can be either a string
            representing a path segment, an object implementing
            the :py:class:`os.PathLike` interface which returns a string,
            or another path object.

    Returns:
        List of images in the folder or its subfolders.

    """
    # fix path to dir for the `NTIRE 2017` datasets
    path = Path(*args)
    if not path.exists():
        idx = path.name.rfind("_")
        path = path.parent / path.name[:idx] / path.name[idx + 1:]

    files = glob.iglob(f"{path}/**/*", recursive=True)
    images = sorted(filter(has_image_extension, files))

    return images


def paired_random_crop(
    images: Iterable[np.ndarray], crops_sizes: Iterable[Tuple[int, int]],
) -> Iterable[np.ndarray]:
    """Crop a random part of the input images.

    Args:
        images: Sequence of images.
        crops_sizes: Sequence of crop sizes ``(height, width)``.

    Returns:
        List of crops.

    """
    h_start, w_start = random.random(), random.random()

    crops = [
        F.random_crop(image, height, width, h_start, w_start)
        for image, (height, width) in zip(images, crops_sizes)
    ]

    return crops


class _PairedImagesDataset(Dataset):
    """Base Dataset for the Image Super-Resolution task.

    Args:
        train: If True, creates dataset from training set,
            otherwise creates from validation set.
        target_type: Type of target to use, ``'bicubic_X2'``, ``'unknown_X4'``,
            ``'X8'``, ``'mild'``, ...
        patch_size: If ``train == True``, define sizes of patches to produce,
            return full image otherwise. Tuple of height and width.
        transform: A function / transform that takes in dictionary (with low
            and high resolution images) and returns a transformed version.
        low_resolution_image_key: Key to use to store images of low resolution.
        high_resolution_image_key: Key to use to store high resolution images.

    """

    def __init__(
        self,
        train: bool = True,
        target_type: str = "bicubic_X4",
        patch_size: Tuple[int, int] = (96, 96),
        transform: Optional[Callable[[Any], Dict]] = None,
        low_resolution_image_key: str = "lr_image",
        high_resolution_image_key: str = "hr_image",
    ) -> None:
        self.train = train

        self.lr_key = low_resolution_image_key
        self.hr_key = high_resolution_image_key

        self.data: List[Dict[str, str]] = []
        self.open_fn = data.ReaderCompose([
            data.ImageReader(input_key="lr_image", output_key=self.lr_key),
            data.ImageReader(input_key="hr_image", output_key=self.hr_key),
        ])

        _, downscaling = target_type.split("_")
        self.scale = int(downscaling) if downscaling.isdigit() else 4
        height, width = patch_size
        self.target_patch_size = patch_size
        self.input_patch_size = (height // self.scale, width // self.scale)

        self.transform = utils.Augmentor(transform)

    def __getitem__(self, index: int) -> Dict:
        """Gets element of the dataset.

        Args:
            index: Index of the element in the dataset.

        Returns:
            Dict of low and high resolution images.

        """
        record = self.data[index]

        sample_dict = self.open_fn(record)

        if self.train:
            # use random crops during training
            lr_crop, hr_crop = paired_random_crop(
                (sample_dict[self.lr_key], sample_dict[self.hr_key]),
                (self.input_patch_size, self.target_patch_size),
            )
            sample_dict.update({self.lr_key: lr_crop, self.hr_key: hr_crop})

        sample_dict = self.transform(sample_dict)

        return sample_dict

    def __len__(self) -> int:
        """Get length of the dataset.

        Returns:
            Length of the dataset.

        """
        return len(self.data)


class DIV2KDataset(_PairedImagesDataset):
    """`DIV2K <https://data.vision.ee.ethz.ch/cvl/DIV2K>`_ Dataset.

    Args:
        root: Root directory where images are downloaded to.
        train: If True, creates dataset from training set,
            otherwise creates from validation set.
        target_type: Type of target to use, ``'bicubic_X2'``, ``'unknown_X4'``,
            ``'X8'``, ``'mild'``, ...
        patch_size: If ``train == True``, define sizes of patches to produce,
            return full image otherwise. Tuple of height and width.
        transform: A function / transform that takes in dictionary (with low
            and high resolution images) and returns a transformed version.
        low_resolution_image_key: Key to use to store images of low resolution.
        high_resolution_image_key: Key to use to store high resolution images.
        download: If true, downloads the dataset from the internet
            and puts it in root directory. If dataset is already downloaded,
            it is not downloaded again.

    """

    url = "http://data.vision.ee.ethz.ch/cvl/DIV2K/"
    resources = {
        "DIV2K_train_LR_bicubic_X2.zip": "9a637d2ef4db0d0a81182be37fb00692",
        "DIV2K_train_LR_unknown_X2.zip": "1396d023072c9aaeb999c28b81315233",
        "DIV2K_valid_LR_bicubic_X2.zip": "1512c9a3f7bde2a1a21a73044e46b9cb",
        "DIV2K_valid_LR_unknown_X2.zip": "d319bd9033573d21de5395e6454f34f8",
        "DIV2K_train_LR_bicubic_X3.zip": "ad80b9fe40c049a07a8a6c51bfab3b6d",
        "DIV2K_train_LR_unknown_X3.zip": "4e651308aaa54d917fb1264395b7f6fa",
        "DIV2K_valid_LR_bicubic_X3.zip": "18b1d310f9f88c13618c287927b29898",
        "DIV2K_valid_LR_unknown_X3.zip": "05184168e3608b5c539fbfb46bcade4f",
        "DIV2K_train_LR_bicubic_X4.zip": "76c43ec4155851901ebbe8339846d93d",
        "DIV2K_train_LR_unknown_X4.zip": "e3c7febb1b3f78bd30f9ba15fe8e3956",
        "DIV2K_valid_LR_bicubic_X4.zip": "21962de700c8d368c6ff83314480eff0",
        "DIV2K_valid_LR_unknown_X4.zip": "8ac3413102bb3d0adc67012efb8a6c94",
        "DIV2K_train_LR_x8.zip": "613db1b855721b3d2b26f4194a1d22a6",
        "DIV2K_train_LR_mild.zip": "807b3e3a5156f35bd3a86c5bbfb674bc",
        "DIV2K_train_LR_difficult.zip": "5a8f2b9e0c5f5ed0dac271c1293662f4",
        "DIV2K_train_LR_wild.zip": "d00982366bffee7c4739ba7ff1316b3b",
        "DIV2K_valid_LR_x8.zip": "c5aeea2004e297e9ff3abfbe143576a5",
        "DIV2K_valid_LR_mild.zip": "8c433f812ca532eed62c11ec0de08370",
        "DIV2K_valid_LR_difficult.zip": "1620af11bf82996bc94df655cb6490fe",
        "DIV2K_valid_LR_wild.zip": "aacae8db6bec39151ca5bb9c80bf2f6c",
        "DIV2K_train_HR.zip": "bdc2d9338d4e574fe81bf7d158758658",
        "DIV2K_valid_HR.zip": "9fcdda83005c5e5997799b69f955ff88",
    }

    def __init__(
        self,
        root: str,
        train: bool = True,
        target_type: str = "bicubic_X4",
        patch_size: Tuple[int, int] = (96, 96),
        transform: Optional[Callable[[Any], Dict]] = None,
        low_resolution_image_key: str = "lr_image",
        high_resolution_image_key: str = "hr_image",
        download: bool = False,
    ) -> None:
        super().__init__(
            train=train,
            target_type=target_type,
            patch_size=patch_size,
            transform=transform,
            low_resolution_image_key=low_resolution_image_key,
            high_resolution_image_key=high_resolution_image_key,
        )

        mode = "train" if train else "valid"
        filename_hr = f"DIV2K_{mode}_HR.zip"
        filename_lr = f"DIV2K_{mode}_LR_{target_type}.zip"
        if download:
            # download HR (target) images
            misc.download_and_extract_archive(
                f"{self.url}{filename_hr}",
                download_root=root,
                filename=filename_hr,
                md5=self.resources[filename_hr],
            )

            # download lr (input) images
            misc.download_and_extract_archive(
                f"{self.url}{filename_lr}",
                download_root=root,
                filename=filename_lr,
                md5=self.resources[filename_lr],
            )

        # 'index' files
        lr_images = images_in_dir(root, Path(filename_lr).stem)
        hr_images = images_in_dir(root, Path(filename_hr).stem)
        assert len(lr_images) == len(hr_images)

        self.data = [
            {"lr_image": lr_image, "hr_image": hr_image}
            for lr_image, hr_image in zip(lr_images, hr_images)
        ]


[docs]class Flickr2KDataset(_PairedImagesDataset): """`Flickr2K <https://github.com/LimBee/NTIRE2017>`_ Dataset. Args: root: Root directory where images are downloaded to. train: If True, creates dataset from training set, otherwise creates from validation set. target_type: Type of target to use, ``'bicubic_X2'``, ``'unknown_X4'``, ... patch_size: If ``train == True``, define sizes of patches to produce, return full image otherwise. Tuple of height and width. transform: A function / transform that takes in dictionary (with low and high resolution images) and returns a transformed version. low_resolution_image_key: Key to use to store images of low resolution. high_resolution_image_key: Key to use to store high resolution images. download: If true, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again. """ url = "https://cv.snu.ac.kr/research/EDSR/" resources = { "Flickr2K.tar": "5d3f39443d5e9489bff8963f8f26cb03", } def __init__( self, root: str, train: bool = True, target_type: str = "bicubic_X4", patch_size: Tuple[int, int] = (96, 96), transform: Optional[Callable[[Any], Dict]] = None, low_resolution_image_key: str = "lr_image", high_resolution_image_key: str = "hr_image", download: bool = False, ) -> None: super().__init__( train=train, target_type=target_type, patch_size=patch_size, transform=transform, low_resolution_image_key=low_resolution_image_key, high_resolution_image_key=high_resolution_image_key, ) filename = "Flickr2K.tar" if download: # download images misc.download_and_extract_archive( f"{self.url}{filename}", download_root=root, filename=filename, md5=self.resources[filename], ) degradation, downscaling = target_type.split("_") # 'index' files subdir_lr = Path(f"Flickr2K_LR_{degradation}", downscaling) subdir_hr = "Flickr2K_HR" lr_images = images_in_dir(root, Path(filename).stem, subdir_lr) hr_images = images_in_dir(root, Path(filename).stem, subdir_hr) assert len(lr_images) == len(hr_images) self.data = [ {"lr_image": lr_image, "hr_image": hr_image} for lr_image, hr_image in zip(lr_images, hr_images) ]
[docs]class ImageFolderDataset(data.ListDataset): """A generic data loader where the samples are arranged in this way: :: <pathname>/xxx.ext <pathname>/xxy.ext <pathname>/xxz.ext ... <pathname>/123.ext <pathname>/nsdf3.ext <pathname>/asd932_.ext Args: pathname: Root directory of dataset. image_key: Key to use to store image. image_name_key: Key to use to store name of the image. transform: A function / transform that takes in dictionary and returns its transformed version. """ def __init__( self, pathname: str, image_key: str = "image", image_name_key: str = "filename", transform: Optional[Callable[[Dict], Dict]] = None, ) -> None: files = glob.iglob(pathname, recursive=True) images = sorted(filter(has_image_extension, files)) list_data = [{"image": filename} for filename in images] open_fn = data.ReaderCompose([ data.ImageReader(input_key="image", output_key=image_key), data.LambdaReader(input_key="image", output_key=image_name_key), ]) transform = utils.Augmentor(transform) super().__init__( list_data=list_data, open_fn=open_fn, dict_transform=transform )