Source code for IQM_Vis.metrics.SSIM.ssim

# taken from https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py
# Copyright 2020 by Gongfan Fang, Zhejiang University.
# All rights reserved.
# Author: Gongfan Fang
# Licence: MIT

import warnings
from typing import List, Optional, Tuple, Union

import torch
import torch.nn.functional as F
from torch import Tensor


def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor:
    r"""Create 1-D gauss kernel
    Args:
        size (int): the size of gauss kernel
        sigma (float): sigma of normal distribution
    Returns:
        torch.Tensor: 1D kernel (1 x 1 x size)
    """
    coords = torch.arange(size, dtype=torch.float)
    coords -= size // 2

    g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
    g /= g.sum()

    return g.unsqueeze(0).unsqueeze(0)


[docs]def gaussian_filter(input: Tensor, win: Tensor) -> Tensor: r""" Blur input with 1-D kernel Args: input (torch.Tensor): a batch of tensors to be blurred window (torch.Tensor): 1-D gauss kernel Returns: torch.Tensor: blurred tensors """ assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape if len(input.shape) == 4: conv = F.conv2d elif len(input.shape) == 5: conv = F.conv3d else: raise NotImplementedError(input.shape) C = input.shape[1] out = input for i, s in enumerate(input.shape[2:]): if s >= win.shape[-1]: out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C) else: warnings.warn( f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}" ) return out
def _ssim( X: Tensor, Y: Tensor, data_range: float, win: Tensor, size_average: bool = True, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) ) -> Tuple[Tensor, Tensor]: r""" Calculate ssim index for X and Y Args: X (torch.Tensor): images Y (torch.Tensor): images data_range (float or int): value range of input images. (usually 1.0 or 255) win (torch.Tensor): 1-D gauss kernel size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar Returns: Tuple[torch.Tensor, torch.Tensor]: ssim results. """ K1, K2 = K # batch, channel, [depth,] height, width = X.shape compensation = 1.0 C1 = (K1 * data_range) ** 2 C2 = (K2 * data_range) ** 2 win = win.to(X.device, dtype=X.dtype) mu1 = gaussian_filter(X, win) mu2 = gaussian_filter(Y, win) mu1_sq = mu1.pow(2) mu2_sq = mu2.pow(2) mu1_mu2 = mu1 * mu2 sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq) sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq) sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2) cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1 ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1) cs = torch.flatten(cs_map, 2).mean(-1) return ssim_per_channel, cs # not currently used to commenting for coverage # def ssim( # X: Tensor, # Y: Tensor, # data_range: float = 255, # size_average: bool = True, # win_size: int = 11, # win_sigma: float = 1.5, # win: Optional[Tensor] = None, # K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), # nonnegative_ssim: bool = False, # ) -> Tensor: # r""" interface of ssim # Args: # X (torch.Tensor): a batch of images, (N,C,H,W) # Y (torch.Tensor): a batch of images, (N,C,H,W) # data_range (float or int, optional): value range of input images. (usually 1.0 or 255) # size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar # win_size: (int, optional): the size of gauss kernel # win_sigma: (float, optional): sigma of normal distribution # win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma # K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. # nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu # Returns: # torch.Tensor: ssim results # """ # if not X.shape == Y.shape: # raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") # for d in range(len(X.shape) - 1, 1, -1): # X = X.squeeze(dim=d) # Y = Y.squeeze(dim=d) # if len(X.shape) not in (4, 5): # raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") # #if not X.type() == Y.type(): # # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") # if win is not None: # set win_size # win_size = win.shape[-1] # if not (win_size % 2 == 1): # raise ValueError("Window size should be odd.") # if win is None: # win = _fspecial_gauss_1d(win_size, win_sigma) # win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) # ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K) # if nonnegative_ssim: # ssim_per_channel = torch.relu(ssim_per_channel) # if size_average: # return ssim_per_channel.mean() # else: # return ssim_per_channel.mean(1)
[docs]def ms_ssim( X: Tensor, Y: Tensor, data_range: float = 255, size_average: bool = True, win_size: int = 11, win_sigma: float = 1.5, win: Optional[Tensor] = None, weights: Optional[List[float]] = None, K: Union[Tuple[float, float], List[float]] = (0.01, 0.03) ) -> Tensor: r""" interface of ms-ssim Args: X (torch.Tensor): a batch of images, (N,C,[T,]H,W) Y (torch.Tensor): a batch of images, (N,C,[T,]H,W) data_range (float or int, optional): value range of input images. (usually 1.0 or 255) size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar win_size: (int, optional): the size of gauss kernel win_sigma: (float, optional): sigma of normal distribution win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma weights (list, optional): weights for different levels K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. Returns: torch.Tensor: ms-ssim results """ if not X.shape == Y.shape: raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.") for d in range(len(X.shape) - 1, 1, -1): X = X.squeeze(dim=d) Y = Y.squeeze(dim=d) #if not X.type() == Y.type(): # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.") if len(X.shape) == 4: avg_pool = F.avg_pool2d elif len(X.shape) == 5: avg_pool = F.avg_pool3d else: raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}") if win is not None: # set win_size win_size = win.shape[-1] if not (win_size % 2 == 1): raise ValueError("Window size should be odd.") smaller_side = min(X.shape[-2:]) assert smaller_side > (win_size - 1) * ( 2 ** 4 ), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4)) if weights is None: weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333] weights_tensor = X.new_tensor(weights) if win is None: win = _fspecial_gauss_1d(win_size, win_sigma) win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1)) levels = weights_tensor.shape[0] mcs = [] for i in range(levels): ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K) if i < levels - 1: mcs.append(torch.relu(cs)) padding = [s % 2 for s in X.shape[2:]] X = avg_pool(X, kernel_size=2, padding=padding) Y = avg_pool(Y, kernel_size=2, padding=padding) ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel) mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel) ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0) if size_average: return ms_ssim_val.mean() else: return ms_ssim_val.mean(1)
# not currently used to commenting for coverage # class SSIM(torch.nn.Module): # def __init__( # self, # data_range: float = 255, # size_average: bool = True, # win_size: int = 11, # win_sigma: float = 1.5, # channel: int = 3, # spatial_dims: int = 2, # K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), # nonnegative_ssim: bool = False, # ) -> None: # r""" class for ssim # Args: # data_range (float or int, optional): value range of input images. (usually 1.0 or 255) # size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar # win_size: (int, optional): the size of gauss kernel # win_sigma: (float, optional): sigma of normal distribution # channel (int, optional): input channels (default: 3) # K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. # nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu. # """ # super(SSIM, self).__init__() # self.win_size = win_size # self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) # self.size_average = size_average # self.data_range = data_range # self.K = K # self.nonnegative_ssim = nonnegative_ssim # def forward(self, X: Tensor, Y: Tensor) -> Tensor: # return ssim( # X, # Y, # data_range=self.data_range, # size_average=self.size_average, # win=self.win, # K=self.K, # nonnegative_ssim=self.nonnegative_ssim, # ) # class MS_SSIM(torch.nn.Module): # def __init__( # self, # data_range: float = 255, # size_average: bool = True, # win_size: int = 11, # win_sigma: float = 1.5, # channel: int = 3, # spatial_dims: int = 2, # weights: Optional[List[float]] = None, # K: Union[Tuple[float, float], List[float]] = (0.01, 0.03), # ) -> None: # r""" class for ms-ssim # Args: # data_range (float or int, optional): value range of input images. (usually 1.0 or 255) # size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar # win_size: (int, optional): the size of gauss kernel # win_sigma: (float, optional): sigma of normal distribution # channel (int, optional): input channels (default: 3) # weights (list, optional): weights for different levels # K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results. # """ # super(MS_SSIM, self).__init__() # self.win_size = win_size # self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims) # self.size_average = size_average # self.data_range = data_range # self.weights = weights # self.K = K # def forward(self, X: Tensor, Y: Tensor) -> Tensor: # return ms_ssim( # X, # Y, # data_range=self.data_range, # size_average=self.size_average, # win=self.win, # weights=self.weights, # K=self.K, # )