# Author: Matt Clifford <matt.clifford@bristol.ac.uk>
# License: BSD 3-Clause License
import warnings
import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity as lpips_torch
from DISTS_pytorch import DISTS as dists_original
from IQM_Vis.metrics.metric_utils import _check_shapes, _numpy_to_torch_image
[docs]class LPIPS:
    '''Learned Perceptual Image Patch Similarity between two images.
       Images must have the same dimensions.
    Args:
        network (str): Pretrained network to use. Choose between ‘alex’, ‘vgg’
                       or ‘squeeze’. (Defaults to 'alex')
        reduction (str): How to reduce over the batch dimension. Choose between
                         ‘sum’ or ‘mean’. (Defaults to 'mean')
    '''
    def __init__(self, network='alex', reduction='mean'):
        self.initialised = False   # initialse fully on first __call__ to save load up time
        self.network = network
        self.reduction = reduction
        # self.metric = lpips_torch
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.preproccess_function = _numpy_to_torch_image
[docs]    def __call__(self, im_ref, im_comp, **kwargs):
        '''When an instance is called
        Args:
            im_ref (np.array): Reference image
            im_comp (np.array): Comparison image
            **kwargs: Arbitrary keyword arguments
        Returns:
            score (np.array): LPIPS score
        '''
        if self.initialised == False:
            with warnings.catch_warnings():    # we don't care about the warnings these give
                warnings.simplefilter("ignore")
                self.metric = lpips_torch(net_type=self.network,
                                          reduction=self.reduction,
                                          normalize=True)
                self.metric.to(self.device)
            self.initialised = True
        _check_shapes(im_ref, im_comp)
        im_ref = self.preproccess_function(im_ref).to(
            device=self.device, dtype=torch.float)
        im_comp = self.preproccess_function(im_comp).to(
            device=self.device, dtype=torch.float)
        with warnings.catch_warnings():    # warning because we have reset
            warnings.simplefilter("ignore")
            _score = self.metric(im_ref, im_comp)
        score = _score.cpu().detach().numpy()
        self.metric.reset()
        return score  
[docs]class DISTS:
    '''Deep Image Structure and Texture Similarity (DISTS) Metric. Uses the
        code from https://github.com/dingkeyan93/DISTS. Uses the PyTorch backend.
        It is robust to texture variance (e.g., evaluating the images generated
        by GANs) and mild geometric transformations (e.g., evaluating the image
        pairs that are not strictly point-by-point aligned).
    '''
    def __init__(self):
        self.initialised = False   # initialse fully on first __call__ to save load up time
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.preproccess_function = _numpy_to_torch_image
[docs]    def __call__(self, im_ref, im_comp, **kwargs):
        '''When an instance is called
        Args:
            im_ref (np.array): Reference image
            im_comp (np.array): Comparison image
            **kwargs: Arbitrary keyword arguments
        Returns:
            score (np.array): DISTS score
        '''
        # load model on first time called
        if self.initialised == False:
            with warnings.catch_warnings():    # we don't care about the warnings these give
                warnings.simplefilter("ignore")
                self.metric = dists_original()
            self.metric.to(self.device)
            self.initialised = True
        _check_shapes(im_ref, im_comp)
        im_ref = self.preproccess_function(im_ref).to(
            device=self.device, dtype=torch.float)
        im_comp = self.preproccess_function(im_comp).to(
            device=self.device, dtype=torch.float)
        _score = self.metric(im_ref, im_comp)
        score = _score.cpu().detach().numpy()
        return score