"""
Functions for dataset processing.
"""
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import WeightedRandomSampler
[docs]class StimulusDataset(Dataset):
"""Torch dataset for stimulus images.
Parameters
----------
x_data:
A 3-d numpy array of input stimulus image.
y_data:
A 1-d numpy array of output class labels.
be "train" or "test".
img_transforms:
A dictionary of torchvision transforms.
class2idx:
A mapping between class and it's corresponding integer idx.
Attributes
----------
x_data: np.ndarray
A 3-d numpy array of input stimulus image.
y_data: np.ndarray
A 1-d numpy array of output class labels.
be "train" or "test".
img_transforms: dict
A dictionary of torchvision transforms.
class2idx: dict
A mapping between class and it's corresponding integer idx.
"""
def __init__(
self,
x_data: np.ndarray,
y_data: np.ndarray,
img_transform: dict,
class2idx: dict,
):
self.x_data = x_data
self.y_data = y_data
self.img_transform = img_transform
self.class2idx = class2idx
def __getitem__(self, idx):
x = Image.fromarray(self.x_data[idx].astype(np.uint8)).convert("RGB")
y = torch.tensor(self.class2idx[self.y_data[idx]], dtype=torch.long)
if self.img_transform:
x = self.img_transform(image=np.array(x))["image"]
return x, y
def __len__(self):
assert len(self.x_data) == len(self.y_data)
return len(self.x_data)
[docs]def calculate_mean_std(dataset: Dataset):
"""Calculate dataset mean and standard deviation.
Parameters
----------
dataset:
PyTorch dataset object.
Returns
-------
tuple
A tuple that consists a tuple of mean and std for 3 channels (mean, std).
"""
data_loader = DataLoader(dataset=dataset, shuffle=False, batch_size=1)
img_size = dataset[0][0].shape[1]
pixel_sum = torch.tensor([0.0, 0.0, 0.0])
pixel_sum_squared = torch.tensor([0.0, 0.0, 0.0])
for img, _ in tqdm(data_loader):
pixel_sum += img.sum(axis=[0, 2, 3])
pixel_sum_squared += (img ** 2).sum(axis=[0, 2, 3])
total_num_pixels = img_size * img_size * len(data_loader)
total_mean = pixel_sum / total_num_pixels
total_var = (pixel_sum_squared / total_num_pixels) - (total_mean ** 2)
total_std = torch.sqrt(total_var)
return total_mean, total_std
[docs]class FMRIDataset(Dataset):
"""Torch dataset for fMRI data.
Parameters
----------
x_data:
A 2-d numpy array of input fmri data.
y_data:
A 1-d numpy array of output class labels.
class2idx:
A dictionary that maps class string to integers.
Attributes
----------
x_data:
A 2-d numpy array of input fmri data.
y_data:
A 1-d numpy array of output class labels.
class2idx:
A dictionary that maps class string to integers.
"""
def __init__(self, x_data: np.ndarray, y_data: np.ndarray, class2idx: dict):
self.class2idx = class2idx
self.x_data = x_data
self.y_data = y_data
def __getitem__(self, idx):
x = torch.tensor(self.x_data, dtype=torch.float)[idx]
y = torch.tensor(self.class2idx[self.y_data[idx]], dtype=torch.long)
return x, y
def __len__(self):
assert len(self.x_data) == len(self.y_data)
return len(self.x_data)
[docs]def create_weighted_sampler(y_data: np.ndarray, class2idx: dict):
"""
Create a weighted random sampler.
Parameters
----------
y_data:
The output class labels.
class2idx:
A dictionary to convert class strings to integer.
"""
y_data = np.array([class2idx[t] for t in y_data])
class_count_list = np.array(
[len(np.where(y_data == t)[0]) for t in np.unique(y_data)]
)
class_weight_list = 1.0 / class_count_list
samples_weight_list = [class_weight_list[t] for t in y_data]
class_weight_tensor = torch.tensor(samples_weight_list)
weighted_random_sampler = WeightedRandomSampler(
class_weight_tensor, len(class_weight_tensor)
)
return weighted_random_sampler