# taken from https://github.com/VainF/pytorch-msssim/blob/master/pytorch_msssim/ssim.py
# Copyright 2020 by Gongfan Fang, Zhejiang University.
# All rights reserved.
# Author: Gongfan Fang
# Licence: MIT
import warnings
from typing import List, Optional, Tuple, Union
import torch
import torch.nn.functional as F
from torch import Tensor
def _fspecial_gauss_1d(size: int, sigma: float) -> Tensor:
r"""Create 1-D gauss kernel
Args:
size (int): the size of gauss kernel
sigma (float): sigma of normal distribution
Returns:
torch.Tensor: 1D kernel (1 x 1 x size)
"""
coords = torch.arange(size, dtype=torch.float)
coords -= size // 2
g = torch.exp(-(coords ** 2) / (2 * sigma ** 2))
g /= g.sum()
return g.unsqueeze(0).unsqueeze(0)
[docs]def gaussian_filter(input: Tensor, win: Tensor) -> Tensor:
r""" Blur input with 1-D kernel
Args:
input (torch.Tensor): a batch of tensors to be blurred
window (torch.Tensor): 1-D gauss kernel
Returns:
torch.Tensor: blurred tensors
"""
assert all([ws == 1 for ws in win.shape[1:-1]]), win.shape
if len(input.shape) == 4:
conv = F.conv2d
elif len(input.shape) == 5:
conv = F.conv3d
else:
raise NotImplementedError(input.shape)
C = input.shape[1]
out = input
for i, s in enumerate(input.shape[2:]):
if s >= win.shape[-1]:
out = conv(out, weight=win.transpose(2 + i, -1), stride=1, padding=0, groups=C)
else:
warnings.warn(
f"Skipping Gaussian Smoothing at dimension 2+{i} for input: {input.shape} and win size: {win.shape[-1]}"
)
return out
def _ssim(
X: Tensor,
Y: Tensor,
data_range: float,
win: Tensor,
size_average: bool = True,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tuple[Tensor, Tensor]:
r""" Calculate ssim index for X and Y
Args:
X (torch.Tensor): images
Y (torch.Tensor): images
data_range (float or int): value range of input images. (usually 1.0 or 255)
win (torch.Tensor): 1-D gauss kernel
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
Returns:
Tuple[torch.Tensor, torch.Tensor]: ssim results.
"""
K1, K2 = K
# batch, channel, [depth,] height, width = X.shape
compensation = 1.0
C1 = (K1 * data_range) ** 2
C2 = (K2 * data_range) ** 2
win = win.to(X.device, dtype=X.dtype)
mu1 = gaussian_filter(X, win)
mu2 = gaussian_filter(Y, win)
mu1_sq = mu1.pow(2)
mu2_sq = mu2.pow(2)
mu1_mu2 = mu1 * mu2
sigma1_sq = compensation * (gaussian_filter(X * X, win) - mu1_sq)
sigma2_sq = compensation * (gaussian_filter(Y * Y, win) - mu2_sq)
sigma12 = compensation * (gaussian_filter(X * Y, win) - mu1_mu2)
cs_map = (2 * sigma12 + C2) / (sigma1_sq + sigma2_sq + C2) # set alpha=beta=gamma=1
ssim_map = ((2 * mu1_mu2 + C1) / (mu1_sq + mu2_sq + C1)) * cs_map
ssim_per_channel = torch.flatten(ssim_map, 2).mean(-1)
cs = torch.flatten(cs_map, 2).mean(-1)
return ssim_per_channel, cs
# not currently used to commenting for coverage
# def ssim(
# X: Tensor,
# Y: Tensor,
# data_range: float = 255,
# size_average: bool = True,
# win_size: int = 11,
# win_sigma: float = 1.5,
# win: Optional[Tensor] = None,
# K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
# nonnegative_ssim: bool = False,
# ) -> Tensor:
# r""" interface of ssim
# Args:
# X (torch.Tensor): a batch of images, (N,C,H,W)
# Y (torch.Tensor): a batch of images, (N,C,H,W)
# data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
# size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
# win_size: (int, optional): the size of gauss kernel
# win_sigma: (float, optional): sigma of normal distribution
# win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
# K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
# nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu
# Returns:
# torch.Tensor: ssim results
# """
# if not X.shape == Y.shape:
# raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
# for d in range(len(X.shape) - 1, 1, -1):
# X = X.squeeze(dim=d)
# Y = Y.squeeze(dim=d)
# if len(X.shape) not in (4, 5):
# raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
# #if not X.type() == Y.type():
# # raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
# if win is not None: # set win_size
# win_size = win.shape[-1]
# if not (win_size % 2 == 1):
# raise ValueError("Window size should be odd.")
# if win is None:
# win = _fspecial_gauss_1d(win_size, win_sigma)
# win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
# ssim_per_channel, cs = _ssim(X, Y, data_range=data_range, win=win, size_average=False, K=K)
# if nonnegative_ssim:
# ssim_per_channel = torch.relu(ssim_per_channel)
# if size_average:
# return ssim_per_channel.mean()
# else:
# return ssim_per_channel.mean(1)
[docs]def ms_ssim(
X: Tensor,
Y: Tensor,
data_range: float = 255,
size_average: bool = True,
win_size: int = 11,
win_sigma: float = 1.5,
win: Optional[Tensor] = None,
weights: Optional[List[float]] = None,
K: Union[Tuple[float, float], List[float]] = (0.01, 0.03)
) -> Tensor:
r""" interface of ms-ssim
Args:
X (torch.Tensor): a batch of images, (N,C,[T,]H,W)
Y (torch.Tensor): a batch of images, (N,C,[T,]H,W)
data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
win_size: (int, optional): the size of gauss kernel
win_sigma: (float, optional): sigma of normal distribution
win (torch.Tensor, optional): 1-D gauss kernel. if None, a new kernel will be created according to win_size and win_sigma
weights (list, optional): weights for different levels
K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
Returns:
torch.Tensor: ms-ssim results
"""
if not X.shape == Y.shape:
raise ValueError(f"Input images should have the same dimensions, but got {X.shape} and {Y.shape}.")
for d in range(len(X.shape) - 1, 1, -1):
X = X.squeeze(dim=d)
Y = Y.squeeze(dim=d)
#if not X.type() == Y.type():
# raise ValueError(f"Input images should have the same dtype, but got {X.type()} and {Y.type()}.")
if len(X.shape) == 4:
avg_pool = F.avg_pool2d
elif len(X.shape) == 5:
avg_pool = F.avg_pool3d
else:
raise ValueError(f"Input images should be 4-d or 5-d tensors, but got {X.shape}")
if win is not None: # set win_size
win_size = win.shape[-1]
if not (win_size % 2 == 1):
raise ValueError("Window size should be odd.")
smaller_side = min(X.shape[-2:])
assert smaller_side > (win_size - 1) * (
2 ** 4
), "Image size should be larger than %d due to the 4 downsamplings in ms-ssim" % ((win_size - 1) * (2 ** 4))
if weights is None:
weights = [0.0448, 0.2856, 0.3001, 0.2363, 0.1333]
weights_tensor = X.new_tensor(weights)
if win is None:
win = _fspecial_gauss_1d(win_size, win_sigma)
win = win.repeat([X.shape[1]] + [1] * (len(X.shape) - 1))
levels = weights_tensor.shape[0]
mcs = []
for i in range(levels):
ssim_per_channel, cs = _ssim(X, Y, win=win, data_range=data_range, size_average=False, K=K)
if i < levels - 1:
mcs.append(torch.relu(cs))
padding = [s % 2 for s in X.shape[2:]]
X = avg_pool(X, kernel_size=2, padding=padding)
Y = avg_pool(Y, kernel_size=2, padding=padding)
ssim_per_channel = torch.relu(ssim_per_channel) # type: ignore # (batch, channel)
mcs_and_ssim = torch.stack(mcs + [ssim_per_channel], dim=0) # (level, batch, channel)
ms_ssim_val = torch.prod(mcs_and_ssim ** weights_tensor.view(-1, 1, 1), dim=0)
if size_average:
return ms_ssim_val.mean()
else:
return ms_ssim_val.mean(1)
# not currently used to commenting for coverage
# class SSIM(torch.nn.Module):
# def __init__(
# self,
# data_range: float = 255,
# size_average: bool = True,
# win_size: int = 11,
# win_sigma: float = 1.5,
# channel: int = 3,
# spatial_dims: int = 2,
# K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
# nonnegative_ssim: bool = False,
# ) -> None:
# r""" class for ssim
# Args:
# data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
# size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
# win_size: (int, optional): the size of gauss kernel
# win_sigma: (float, optional): sigma of normal distribution
# channel (int, optional): input channels (default: 3)
# K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
# nonnegative_ssim (bool, optional): force the ssim response to be nonnegative with relu.
# """
# super(SSIM, self).__init__()
# self.win_size = win_size
# self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
# self.size_average = size_average
# self.data_range = data_range
# self.K = K
# self.nonnegative_ssim = nonnegative_ssim
# def forward(self, X: Tensor, Y: Tensor) -> Tensor:
# return ssim(
# X,
# Y,
# data_range=self.data_range,
# size_average=self.size_average,
# win=self.win,
# K=self.K,
# nonnegative_ssim=self.nonnegative_ssim,
# )
# class MS_SSIM(torch.nn.Module):
# def __init__(
# self,
# data_range: float = 255,
# size_average: bool = True,
# win_size: int = 11,
# win_sigma: float = 1.5,
# channel: int = 3,
# spatial_dims: int = 2,
# weights: Optional[List[float]] = None,
# K: Union[Tuple[float, float], List[float]] = (0.01, 0.03),
# ) -> None:
# r""" class for ms-ssim
# Args:
# data_range (float or int, optional): value range of input images. (usually 1.0 or 255)
# size_average (bool, optional): if size_average=True, ssim of all images will be averaged as a scalar
# win_size: (int, optional): the size of gauss kernel
# win_sigma: (float, optional): sigma of normal distribution
# channel (int, optional): input channels (default: 3)
# weights (list, optional): weights for different levels
# K (list or tuple, optional): scalar constants (K1, K2). Try a larger K2 constant (e.g. 0.4) if you get a negative or NaN results.
# """
# super(MS_SSIM, self).__init__()
# self.win_size = win_size
# self.win = _fspecial_gauss_1d(win_size, win_sigma).repeat([channel, 1] + [1] * spatial_dims)
# self.size_average = size_average
# self.data_range = data_range
# self.weights = weights
# self.K = K
# def forward(self, X: Tensor, Y: Tensor) -> Tensor:
# return ms_ssim(
# X,
# Y,
# data_range=self.data_range,
# size_average=self.size_average,
# win=self.win,
# weights=self.weights,
# K=self.K,
# )