Source code for IQM_Vis.data_handlers.data_api

'''
generic image and metric data class constructor
both use the same image for reference and transformed
'''
# Author: Matt Clifford <matt.clifford@bristol.ac.uk>
# License: BSD 3-Clause License

import os
from functools import lru_cache
from collections import namedtuple
import numpy as np
import pandas as pd
import imghdr
import IQM_Vis
from IQM_Vis.data_handlers import base_dataset_loader


# keep a track of all the cached functions so we can clear them easily
CACHED_FUNCTIONS = []
# custom decorator to cache functions and store which are cached
[docs]def cache_tracked(func): cached_func = lru_cache(maxsize=None)(func) CACHED_FUNCTIONS.append(cached_func) return cached_func
[docs]class cache_metric_call: ''' cache metric fucntions that have been calculated already to do this we need to convert numpy arrays to a hashable object (since arrs are mutable) to achieve this we convert to a bytes array (input to __call__) then load this from buffer ''' def __init__(self, metric): self.metric = metric
[docs] @cache_tracked def __call__(self, ref, trans, **kwargs): # expect a hashable bytes array tuple as input with the data type and shape # N.B. we need to copy the array since from buffer gives a read only array since # it is a view of a bytes array (immutable) ref = np.frombuffer(ref.bytes, dtype=ref.dtype).reshape(ref.shape).copy() trans = np.frombuffer(trans.bytes, dtype=trans.dtype).reshape(trans.shape).copy() return self.metric(ref, trans, **kwargs)
[docs]class dataset_holder(base_dataset_loader): '''Stores images and metrics to communicate with the UI via the IQM-Vis data API. Args: image_list (list): list of image file paths metrics (dict): dictionary with keys of the metric names and values of the callable metric function metric_images (dict): Optional dictionary with keys of the metric image names and values of the callable metric image function (Defaults to {}) image_loader (function): Optional function which loads an image from a file path (Defaults to IQM_Vis.utils.load_image) image_post_processing (function): Optional function to apply after image transformations. For example cropping an image after rotation. (Defaults to None) image_list_to_transform (list): list of image file paths for images to transform if they are different to the reference images. If None then will use the same image as the reference image. (Defaults to None) ''' def __init__(self, image_list: list, # list of image file names metrics: dict={}, metric_images: dict={}, image_loader=IQM_Vis.utils.load_image, # function to load image files image_pre_processing=None, # apply a function to the image before transformations (e.g. resize to smaller) image_post_processing=None, # apply a function to the image after transformations (e.g. zoom to help with black boarders on rotation) image_list_to_transform=None, # if you want to use a different image to transform than reference human_exp_csv=None # csv for where the human experiments file is ): self.image_storer = namedtuple('image', ['name', 'data']) self.bytes_arrays = namedtuple('arr', ['bytes', 'dtype', 'shape']) self.image_loader = image_loader self.image_pre_processing = image_pre_processing self.load_image_list(image_list) if image_list_to_transform != None: self.image_list_to_transform = image_list_to_transform self._load_image_data(0) # load the first transform image self.metrics = metrics for m in self.metrics: self.metrics[m] = cache_metric_call(self.metrics[m]) self.metric_images = metric_images self.image_post_processing = image_post_processing self.image_post_processing_hash = None if human_exp_csv is not None: self.human_exp_df = pd.read_csv(human_exp_csv, index_col=0) self._check_inputs()
[docs] def add_metric(self, key, value): if not isinstance(value, cache_metric_call): value = cache_metric_call(value) self.metrics[key] = value
[docs] def add_metric_image(self, key, value): self.metric_images[key] = value
[docs] def get_image_dataset_list(self): # get image file list return self.image_list
[docs] def load_image_list(self, image_list): if len(image_list) == 0: if not hasattr(self, 'image_list'): raise ValueError(f'image_list is empty') else: return # remove any non image file paths just_images = [] for image_file in image_list: try: image_format = imghdr.what(image_file) except FileNotFoundError: image_format = None if image_format != None: just_images.append(image_file) self.image_list = just_images self.image_list_to_transform = just_images self.image_names = [get_image_name(file) for file in self.image_list] self._load_image_data(0)
def _load_image_data(self, i): # reference image self.image_post_processing_hash = None self.current_file = self.image_list[i] image_name_ref = get_image_name(self.current_file) image_data_ref = self._cached_image_loader(self.current_file) self.reference_unprocessed = image_data_ref if self.image_pre_processing is not None: image_data_ref = self.image_pre_processing(image_data_ref) self.image_reference = self.image_storer(image_name_ref, image_data_ref) self.ref_bytes = self.bytes_arrays( image_data_ref.tobytes(), image_data_ref.dtype, image_data_ref.shape) # image to transform if self.current_file == self.image_list_to_transform[i]: self.image_to_transform = self.image_storer(image_name_ref, image_data_ref) else: image_name_trans = get_image_name(self.image_list_to_transform[i]) image_data_trans = self._cached_image_loader( self.image_list_to_transform[i]) if self.image_pre_processing is not None: image_data_trans = self.image_pre_processing(image_data_trans) self.image_reference = self.image_storer(image_name_trans, image_data_trans) # Human experiments if hasattr(self, 'human_scores'): del self.human_scores # delete old scores (incase we dont have ones for new image) if hasattr(self, 'human_exp_df'): if image_name_ref in self.human_exp_df.index: self.human_scores = {'mean': self.human_exp_df.loc[image_name_ref].to_dict()} def __len__(self): return len(self.image_list) def __getitem__(self, i): self._load_image_data(i) @cache_tracked def _cached_image_loader(self, file_name): return self.image_loader(file_name)
[docs] def get_reference_image_by_index(self, index): if index >= len(self.image_list): raise IndexError('Index out of range of the length of the image list') file_name = self.image_list[index] image_data = self._cached_image_loader(file_name) return image_data
[docs] def get_reference_image_name(self): return self.image_reference.name
[docs] def get_reference_unprocessed(self): return self.reference_unprocessed
[docs] def get_reference_image(self): if hash(self.image_post_processing) != self.image_post_processing_hash: # need to post process ref image as either first call or post processing has changed self.image_reference_post_processed = self.image_reference.data.copy() if self.image_post_processing is not None: self.image_reference_post_processed = self.image_post_processing( self.image_reference_post_processed) # cache the hash so we can test if the post processing changes self.image_post_processing_hash = hash(self.image_post_processing) # save the bytes array self.ref_bytes = self.bytes_arrays( self.image_reference_post_processed.tobytes(), self.image_reference_post_processed.dtype, self.image_reference_post_processed.shape) return self.image_reference_post_processed
[docs] def get_image_to_transform_name(self): return self.image_to_transform.name
[docs] def get_image_to_transform(self): return self.image_to_transform.data
[docs] def get_metrics(self, transformed_image, metrics_to_use='all', **kwargs): # convert array to hashable so we can cache already calculated trans_bytes = self.bytes_arrays( transformed_image.tobytes(), transformed_image.dtype, transformed_image.shape) # get metrics results = {} for metric in self.metrics: if metric in metrics_to_use or metrics_to_use == 'all': if self.ref_bytes.shape != trans_bytes.shape: # There has been a change in the data so need to quit this calc ASAP results[metric] = 100 else: # calc as normal results[metric] = self.metrics[metric]( self.ref_bytes, trans_bytes, **kwargs) return results
[docs] def get_metric_images(self, transformed_image, metrics_to_use='all', **kwargs): results = {} for metric in self.metric_images: if metric in metrics_to_use or metrics_to_use == 'all': results[metric] = self.metric_images[metric](self.get_reference_image(), transformed_image, **kwargs) return results
def _check_inputs(self): input_types = [(self.image_reference.name, str), (self.image_reference.data, np.ndarray), (self.metrics, dict), (self.metric_images, dict)] for item in input_types: if type(item[0]) != item[1]: var_name = f'{item[0]=}'.split('=')[0] raise TypeError(f'holder input: {var_name} should be a {item[1]} not {type(item[0])}')
[docs] def clear_all_cache(self): for cached_func in CACHED_FUNCTIONS: cached_func.cache_clear()
[docs]def get_image_name(file_path): # get name of image from filepath return os.path.splitext(os.path.basename(file_path))[0]