229 lines
8.8 KiB
Python
229 lines
8.8 KiB
Python
import io
|
|
from PIL import Image
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from torchvision import transforms
|
|
import random
|
|
import re
|
|
import base64
|
|
|
|
class FilePathLoader:
|
|
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
|
self.dataset = dataset
|
|
self.file_path_column = file_path_column
|
|
self.label_column = label_column
|
|
self.successful_count = 0
|
|
self.skipped_count = 0
|
|
|
|
if self.file_path_column not in dataset.column_names:
|
|
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
|
|
|
def _get_image(self, item):
|
|
try:
|
|
path = item[self.file_path_column]
|
|
image = Image.open(path)
|
|
if image.mode == 'RGBA':
|
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
|
background.paste(image, mask=image.split()[3])
|
|
image = background
|
|
elif image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
return image
|
|
except Exception as e:
|
|
print(f"Error loading image from {path}: {e}")
|
|
return None
|
|
|
|
def get_item(self, idx):
|
|
item = self.dataset.iloc[idx]
|
|
image = self._get_image(item)
|
|
if image is not None:
|
|
self.successful_count += 1
|
|
if self.label_column is not None:
|
|
label = item.get(self.label_column)
|
|
return (image, label)
|
|
else:
|
|
return (image,)
|
|
else:
|
|
self.skipped_count += 1
|
|
return None
|
|
|
|
def print_summary(self):
|
|
print(f"Successfully converted {self.successful_count} images.")
|
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
|
|
|
class JPGImageLoader:
|
|
def __init__(self, dataset, bytes_column="jpg", label_column=None):
|
|
self.dataset = dataset
|
|
self.bytes_column = bytes_column
|
|
self.label_column = label_column
|
|
self.successful_count = 0
|
|
self.skipped_count = 0
|
|
|
|
if self.bytes_column not in dataset.columns:
|
|
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
|
|
|
def _get_image(self, item):
|
|
try:
|
|
data = item[self.bytes_column]
|
|
|
|
if isinstance(data, str) and data.startswith("b'"):
|
|
cleaned_data = data[2:-1].encode('latin1').decode('unicode-escape').encode('latin1')
|
|
bytes_data = cleaned_data
|
|
elif isinstance(data, str):
|
|
bytes_data = base64.b64decode(data)
|
|
else:
|
|
bytes_data = data
|
|
|
|
img_bytes = io.BytesIO(bytes_data)
|
|
image = Image.open(img_bytes)
|
|
if image.mode == 'RGBA':
|
|
background = Image.new('RGB', image.size, (0, 0, 0))
|
|
background.paste(image, mask=image.split()[3])
|
|
image = background
|
|
elif image.mode != 'RGB':
|
|
image = image.convert('RGB')
|
|
return image
|
|
except Exception as e:
|
|
print(f"Error loading image from bytes: {e}")
|
|
return None
|
|
|
|
def get_item(self, idx):
|
|
item = self.dataset.iloc[idx]
|
|
image = self._get_image(item)
|
|
if image is not None:
|
|
self.successful_count += 1
|
|
if self.label_column is not None:
|
|
label = item.get(self.label_column)
|
|
return (image, label)
|
|
else:
|
|
return (image,)
|
|
else:
|
|
self.skipped_count += 1
|
|
return None
|
|
|
|
def print_summary(self):
|
|
print(f"Successfully converted {self.successful_count} images.")
|
|
print(f"Skipped {self.skipped_count} images due to errors.")
|
|
|
|
class AIIADataLoader:
|
|
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
|
self.batch_size = batch_size
|
|
self.val_split = val_split
|
|
self.seed = seed
|
|
self.pretraining = pretraining
|
|
random.seed(seed)
|
|
|
|
sample_value = dataset[column].iloc[0]
|
|
is_bytes_or_bytestring = isinstance(sample_value, (bytes, str)) and (
|
|
isinstance(sample_value, bytes) or
|
|
sample_value.startswith("b'") or
|
|
sample_value.startswith(('b"', 'data:image'))
|
|
)
|
|
|
|
if is_bytes_or_bytestring:
|
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
|
else:
|
|
sample_paths = dataset[column].dropna().head(1).astype(str)
|
|
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|[pP][nN][gG]|[gG][iI][fF])$'
|
|
|
|
if any(re.match(filepath_pattern, path, flags=re.IGNORECASE) for path in sample_paths):
|
|
self.loader = FilePathLoader(dataset, file_path_column=column, label_column=label_column)
|
|
else:
|
|
self.loader = JPGImageLoader(dataset, bytes_column=column, label_column=label_column)
|
|
|
|
self.items = []
|
|
for idx in range(len(dataset)):
|
|
item = self.loader.get_item(idx)
|
|
if item is not None: # Only add valid items
|
|
if self.pretraining:
|
|
img = item[0] if isinstance(item, tuple) else item
|
|
self.items.append((img, 'denoise', img))
|
|
self.items.append((img, 'rotate', 0))
|
|
else:
|
|
self.items.append(item)
|
|
|
|
if not self.items:
|
|
raise ValueError("No valid items were loaded from the dataset")
|
|
|
|
|
|
train_indices, val_indices = self._split_data()
|
|
|
|
self.train_dataset = self._create_subset(train_indices)
|
|
self.val_dataset = self._create_subset(val_indices)
|
|
|
|
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True, **dataloader_kwargs)
|
|
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False, **dataloader_kwargs)
|
|
|
|
def _split_data(self):
|
|
if len(self.items) == 0:
|
|
raise ValueError("No items to split")
|
|
|
|
num_samples = len(self.items)
|
|
indices = list(range(num_samples))
|
|
random.shuffle(indices)
|
|
|
|
split_idx = int((1 - self.val_split) * num_samples)
|
|
train_indices = indices[:split_idx]
|
|
val_indices = indices[split_idx:]
|
|
|
|
return train_indices, val_indices
|
|
|
|
def _create_subset(self, indices):
|
|
subset_items = [self.items[i] for i in indices]
|
|
return AIIADataset(subset_items, pretraining=self.pretraining)
|
|
|
|
class AIIADataset(torch.utils.data.Dataset):
|
|
def __init__(self, items, pretraining=False):
|
|
self.items = items
|
|
self.pretraining = pretraining
|
|
self.transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor()
|
|
])
|
|
|
|
def __len__(self):
|
|
return len(self.items)
|
|
|
|
def __getitem__(self, idx):
|
|
item = self.items[idx]
|
|
|
|
if self.pretraining:
|
|
image, task, label = item
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError(f"Invalid image at index {idx}")
|
|
|
|
image = self.transform(image)
|
|
if image.shape != (3, 224, 224):
|
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
|
|
if task == 'denoise':
|
|
noise_std = 0.1
|
|
noisy_img = image + torch.randn_like(image) * noise_std
|
|
target = image.clone()
|
|
return noisy_img, target, task
|
|
elif task == 'rotate':
|
|
angles = [0, 90, 180, 270]
|
|
angle = random.choice(angles)
|
|
rotated_img = transforms.functional.rotate(image, angle)
|
|
target = torch.tensor(angle / 90).long()
|
|
return rotated_img, target, task
|
|
else:
|
|
raise ValueError(f"Invalid task at index {idx}: {task}")
|
|
else:
|
|
if isinstance(item, tuple) and len(item) == 2:
|
|
image, label = item
|
|
if not isinstance(image, Image.Image):
|
|
raise ValueError(f"Invalid image at index {idx}")
|
|
image = self.transform(image)
|
|
if image.shape != (3, 224, 224):
|
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
return image, label
|
|
else:
|
|
if isinstance(item, Image.Image):
|
|
image = self.transform(item)
|
|
else:
|
|
image = self.transform(item[0])
|
|
if image.shape != (3, 224, 224):
|
|
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
|
return image
|