AIIA/src/aiia/data/DataLoader.py

229 lines
8.8 KiB
Python

import io
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import random
import re
import base64
class FilePathLoader:
def __init__(self, dataset, file_path_column="file_path", label_column=None):
self.dataset = dataset
self.file_path_column = file_path_column
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.file_path_column not in dataset.column_names:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
try:
path = item[self.file_path_column]
image = Image.open(path)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from {path}: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
if self.label_column is not None:
label = item.get(self.label_column)
return (image, label)
else:
return (image,)
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class JPGImageLoader:
def __init__(self, dataset, bytes_column="jpg", label_column=None):
self.dataset = dataset
self.bytes_column = bytes_column
self.label_column = label_column
self.successful_count = 0
self.skipped_count = 0
if self.bytes_column not in dataset.columns:
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
def _get_image(self, item):
try:
data = item[self.bytes_column]
if isinstance(data, str) and data.startswith("b'"):
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
bytes_data = cleaned_data
elif isinstance(data, str):
bytes_data = base64.b64decode(data)
else:
bytes_data = data
img_bytes = io.BytesIO(bytes_data)
image = Image.open(img_bytes)
if image.mode == 'RGBA':
background = Image.new('RGB', image.size, (0, 0, 0))
background.paste(image, mask=image.split()[3])
image = background
elif image.mode != 'RGB':
image = image.convert('RGB')
return image
except Exception as e:
print(f"Error loading image from bytes: {e}")
return None
def get_item(self, idx):
item = self.dataset.iloc[idx]
image = self._get_image(item)
if image is not None:
self.successful_count += 1
if self.label_column is not None:
label = item.get(self.label_column)
return (image, label)
else:
return (image,)
else:
self.skipped_count += 1
return None
def print_summary(self):
print(f"Successfully converted {self.successful_count} images.")
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.pretraining = pretraining
random.seed(seed)
sample_value = dataset[column].iloc[0]
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
isinstance(sample_value, bytes) or
sample_value.startswith("b'") or
sample_value.startswith(('b"', 'data:image'))
)
if is_bytes_or_bytestring:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
else:
sample_paths = dataset[column].dropna().head(1).astype(str)
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
else:
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
self.items = []
for idx in range(len(dataset)):
item = self.loader.get_item(idx)
if item is not None: # Only add valid items
if self.pretraining:
img = item[0] if isinstance(item, tuple) else item
self.items.append((img, 'denoise', img))
self.items.append((img, 'rotate', 0))
else:
self.items.append(item)
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
self.val_dataset = self._create_subset(val_indices)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
def _split_data(self):
if len(self.items) == 0:
raise ValueError("No items to split")
num_samples = len(self.items)
indices = list(range(num_samples))
random.shuffle(indices)
split_idx = int((1 - self.val_split) * num_samples)
train_indices = indices[:split_idx]
val_indices = indices[split_idx:]
return train_indices, val_indices
def _create_subset(self, indices):
subset_items = [self.items[i] for i in indices]
return AIIADataset(subset_items, pretraining=self.pretraining)
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, items, pretraining=False):
self.items = items
self.pretraining = pretraining
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
def __len__(self):
return len(self.items)
def __getitem__(self, idx):
item = self.items[idx]
if self.pretraining:
image, task, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
noisy_img = image + torch.randn_like(image) * noise_std
target = image.clone()
return noisy_img, target, task
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(image, angle)
target = torch.tensor(angle / 90).long()
return rotated_img, target, task
else:
raise ValueError(f"Invalid task at index {idx}: {task}")
else:
if isinstance(item, tuple) and len(item) == 2:
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image