'''
generic image and metric data class constructor
both use the same image for reference and transformed
'''
# Author: Matt Clifford <matt.clifford@bristol.ac.uk>
# License: BSD 3-Clause License
import os
from functools import lru_cache
from collections import namedtuple
import numpy as np
import pandas as pd
import imghdr
import IQM_Vis
from IQM_Vis.data_handlers import base_dataset_loader
# keep a track of all the cached functions so we can clear them easily
CACHED_FUNCTIONS = []
# custom decorator to cache functions and store which are cached
[docs]def cache_tracked(func):
cached_func = lru_cache(maxsize=None)(func)
CACHED_FUNCTIONS.append(cached_func)
return cached_func
[docs]class cache_metric_call:
''' cache metric fucntions that have been calculated already
to do this we need to convert numpy arrays to a hashable object (since arrs are mutable)
to achieve this we convert to a bytes array (input to __call__) then load this from buffer
'''
def __init__(self, metric):
self.metric = metric
[docs] @cache_tracked
def __call__(self, ref, trans, **kwargs):
# expect a hashable bytes array tuple as input with the data type and shape
# N.B. we need to copy the array since from buffer gives a read only array since
# it is a view of a bytes array (immutable)
ref = np.frombuffer(ref.bytes, dtype=ref.dtype).reshape(ref.shape).copy()
trans = np.frombuffer(trans.bytes, dtype=trans.dtype).reshape(trans.shape).copy()
return self.metric(ref, trans, **kwargs)
[docs]class dataset_holder(base_dataset_loader):
'''Stores images and metrics to communicate with the UI via the IQM-Vis data
API.
Args:
image_list (list): list of image file paths
metrics (dict): dictionary with keys of the metric names and values
of the callable metric function
metric_images (dict): Optional dictionary with keys of the metric
image names and values of the callable metric
image function (Defaults to {})
image_loader (function): Optional function which loads an image from
a file path (Defaults to IQM_Vis.utils.load_image)
image_post_processing (function): Optional function to apply after image
transformations. For example cropping an image after rotation.
(Defaults to None)
image_list_to_transform (list): list of image file paths for images
to transform if they are different to the reference
images. If None then will use the same image as
the reference image. (Defaults to None)
'''
def __init__(self, image_list: list, # list of image file names
metrics: dict={},
metric_images: dict={},
image_loader=IQM_Vis.utils.load_image, # function to load image files
image_pre_processing=None, # apply a function to the image before transformations (e.g. resize to smaller)
image_post_processing=None, # apply a function to the image after transformations (e.g. zoom to help with black boarders on rotation)
image_list_to_transform=None, # if you want to use a different image to transform than reference
human_exp_csv=None # csv for where the human experiments file is
):
self.image_storer = namedtuple('image', ['name', 'data'])
self.bytes_arrays = namedtuple('arr', ['bytes', 'dtype', 'shape'])
self.image_loader = image_loader
self.image_pre_processing = image_pre_processing
self.load_image_list(image_list)
if image_list_to_transform != None:
self.image_list_to_transform = image_list_to_transform
self._load_image_data(0) # load the first transform image
self.metrics = metrics
for m in self.metrics:
self.metrics[m] = cache_metric_call(self.metrics[m])
self.metric_images = metric_images
self.image_post_processing = image_post_processing
self.image_post_processing_hash = None
if human_exp_csv is not None:
self.human_exp_df = pd.read_csv(human_exp_csv, index_col=0)
self._check_inputs()
[docs] def add_metric(self, key, value):
if not isinstance(value, cache_metric_call):
value = cache_metric_call(value)
self.metrics[key] = value
[docs] def add_metric_image(self, key, value):
self.metric_images[key] = value
[docs] def get_image_dataset_list(self):
# get image file list
return self.image_list
[docs] def load_image_list(self, image_list):
if len(image_list) == 0:
if not hasattr(self, 'image_list'):
raise ValueError(f'image_list is empty')
else:
return
# remove any non image file paths
just_images = []
for image_file in image_list:
try:
image_format = imghdr.what(image_file)
except FileNotFoundError:
image_format = None
if image_format != None:
just_images.append(image_file)
self.image_list = just_images
self.image_list_to_transform = just_images
self.image_names = [get_image_name(file) for file in self.image_list]
self._load_image_data(0)
def _load_image_data(self, i):
# reference image
self.image_post_processing_hash = None
self.current_file = self.image_list[i]
image_name_ref = get_image_name(self.current_file)
image_data_ref = self._cached_image_loader(self.current_file)
self.reference_unprocessed = image_data_ref
if self.image_pre_processing is not None:
image_data_ref = self.image_pre_processing(image_data_ref)
self.image_reference = self.image_storer(image_name_ref, image_data_ref)
self.ref_bytes = self.bytes_arrays(
image_data_ref.tobytes(), image_data_ref.dtype, image_data_ref.shape)
# image to transform
if self.current_file == self.image_list_to_transform[i]:
self.image_to_transform = self.image_storer(image_name_ref, image_data_ref)
else:
image_name_trans = get_image_name(self.image_list_to_transform[i])
image_data_trans = self._cached_image_loader(
self.image_list_to_transform[i])
if self.image_pre_processing is not None:
image_data_trans = self.image_pre_processing(image_data_trans)
self.image_reference = self.image_storer(image_name_trans, image_data_trans)
# Human experiments
if hasattr(self, 'human_scores'):
del self.human_scores # delete old scores (incase we dont have ones for new image)
if hasattr(self, 'human_exp_df'):
if image_name_ref in self.human_exp_df.index:
self.human_scores = {'mean': self.human_exp_df.loc[image_name_ref].to_dict()}
def __len__(self):
return len(self.image_list)
def __getitem__(self, i):
self._load_image_data(i)
@cache_tracked
def _cached_image_loader(self, file_name):
return self.image_loader(file_name)
[docs] def get_reference_image_by_index(self, index):
if index >= len(self.image_list):
raise IndexError('Index out of range of the length of the image list')
file_name = self.image_list[index]
image_data = self._cached_image_loader(file_name)
return image_data
[docs] def get_reference_image_name(self):
return self.image_reference.name
[docs] def get_reference_unprocessed(self):
return self.reference_unprocessed
[docs] def get_reference_image(self):
if hash(self.image_post_processing) != self.image_post_processing_hash:
# need to post process ref image as either first call or post processing has changed
self.image_reference_post_processed = self.image_reference.data.copy()
if self.image_post_processing is not None:
self.image_reference_post_processed = self.image_post_processing(
self.image_reference_post_processed)
# cache the hash so we can test if the post processing changes
self.image_post_processing_hash = hash(self.image_post_processing)
# save the bytes array
self.ref_bytes = self.bytes_arrays(
self.image_reference_post_processed.tobytes(),
self.image_reference_post_processed.dtype,
self.image_reference_post_processed.shape)
return self.image_reference_post_processed
[docs] def get_metrics(self, transformed_image, metrics_to_use='all', **kwargs):
# convert array to hashable so we can cache already calculated
trans_bytes = self.bytes_arrays(
transformed_image.tobytes(), transformed_image.dtype, transformed_image.shape)
# get metrics
results = {}
for metric in self.metrics:
if metric in metrics_to_use or metrics_to_use == 'all':
if self.ref_bytes.shape != trans_bytes.shape:
# There has been a change in the data so need to quit this calc ASAP
results[metric] = 100
else:
# calc as normal
results[metric] = self.metrics[metric](
self.ref_bytes, trans_bytes, **kwargs)
return results
[docs] def get_metric_images(self, transformed_image, metrics_to_use='all', **kwargs):
results = {}
for metric in self.metric_images:
if metric in metrics_to_use or metrics_to_use == 'all':
results[metric] = self.metric_images[metric](self.get_reference_image(), transformed_image, **kwargs)
return results
def _check_inputs(self):
input_types = [(self.image_reference.name, str),
(self.image_reference.data, np.ndarray),
(self.metrics, dict),
(self.metric_images, dict)]
for item in input_types:
if type(item[0]) != item[1]:
var_name = f'{item[0]=}'.split('=')[0]
raise TypeError(f'holder input: {var_name} should be a {item[1]} not {type(item[0])}')
[docs] def clear_all_cache(self):
for cached_func in CACHED_FUNCTIONS:
cached_func.cache_clear()
[docs]def get_image_name(file_path):
# get name of image from filepath
return os.path.splitext(os.path.basename(file_path))[0]