uupdated DataLoader for new AIIA CNN arch
This commit is contained in:
parent
99c3ec38c7
commit
106539f48a
|
@ -1,97 +1,206 @@
|
|||
import os
|
||||
import io
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
import re
|
||||
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, limit=None):
|
||||
torch.manual_seed(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
class FilePathLoader:
|
||||
def __init__(self, dataset, file_path_column="file_path", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.file_path_column = file_path_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f'Using device: {self.device}')
|
||||
|
||||
image_paths = self._load_image_paths(data_dir, limit=limit)
|
||||
|
||||
train_paths, val_paths = self._split_data(image_paths, val_split=val_split)
|
||||
|
||||
# Create combined dataset for training with both denoise and rotate tasks
|
||||
train_denoise_paths = [(path, 'denoise') for path in train_paths]
|
||||
train_rotate_paths = [(path, 'rotate') for path in train_paths]
|
||||
train_combined = train_denoise_paths + train_rotate_paths
|
||||
|
||||
val_denoise_paths = [(path, 'denoise') for path in val_paths]
|
||||
val_rotate_paths = [(path, 'rotate') for path in val_paths]
|
||||
val_combined = val_denoise_paths + val_rotate_paths
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
self.train_dataset = AIIADataset(train_combined, transform=transform)
|
||||
self.val_dataset = AIIADataset(val_combined, transform=transform)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def _load_image_paths(self, data_dir, limit=None):
|
||||
extensions = ('.png', '.jpeg', '.jpg')
|
||||
image_paths = []
|
||||
for root, dirs, files in os.walk(data_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(extensions):
|
||||
image_paths.append(os.path.join(root, file))
|
||||
image_paths = sorted(list(set(image_paths)))
|
||||
if limit is not None:
|
||||
image_paths = image_paths[:limit]
|
||||
return image_paths
|
||||
|
||||
def _split_data(self, image_paths, val_split=0.2):
|
||||
train_paths, val_paths = train_test_split(
|
||||
image_paths, test_size=val_split, random_state=42
|
||||
)
|
||||
return train_paths, val_paths
|
||||
if self.file_path_column not in dataset.column_names:
|
||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
path = item[self.file_path_column]
|
||||
image = Image.open(path).convert("RGB")
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from {path}: {e}")
|
||||
return None
|
||||
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, data_paths, transform=None, preload=False):
|
||||
self.data_paths = data_paths
|
||||
self.transform = transform
|
||||
self.preload = preload
|
||||
self.loaded_images = {}
|
||||
|
||||
if self.preload:
|
||||
for path, task in self.data_paths:
|
||||
img = Image.open(path).convert('RGB')
|
||||
self.loaded_images[path] = img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
path, task = self.data_paths[idx]
|
||||
if self.preload:
|
||||
img = self.loaded_images[path]
|
||||
def get_item(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
img = Image.open(path).convert('RGB')
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class JPGImageLoader:
|
||||
def __init__(self, dataset, bytes_column="jpg", label_column=None):
|
||||
self.dataset = dataset
|
||||
self.bytes_column = bytes_column
|
||||
self.label_column = label_column
|
||||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = img + torch.randn_like(img) * noise_std
|
||||
target = img
|
||||
return noisy_img, target, task
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(img, angle)
|
||||
target = torch.tensor(angle).long()
|
||||
return rotated_img, target, task
|
||||
if self.bytes_column not in dataset.column_names:
|
||||
raise ValueError(f"Column '{self.bytes_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
try:
|
||||
bytes_data = item[self.bytes_column]
|
||||
img_bytes = io.BytesIO(bytes_data)
|
||||
image = Image.open(img_bytes).convert("RGB")
|
||||
return image
|
||||
except Exception as e:
|
||||
print(f"Error loading image from bytes: {e}")
|
||||
return None
|
||||
|
||||
def get_item(self, idx):
|
||||
item = self.dataset[idx]
|
||||
image = self._get_image(item)
|
||||
if image is not None:
|
||||
self.successful_count += 1
|
||||
if self.label_column is not None:
|
||||
label = item.get(self.label_column)
|
||||
return (image, label)
|
||||
else:
|
||||
return (image,)
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
self.skipped_count += 1
|
||||
return None
|
||||
|
||||
def print_summary(self):
|
||||
print(f"Successfully converted {self.successful_count} images.")
|
||||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
|
||||
class AIIADataLoader(DataLoader):
|
||||
def __init__(self, dataset,
|
||||
batch_size=32,
|
||||
val_split=0.2,
|
||||
seed=42,
|
||||
column="file_path",
|
||||
label_column=None):
|
||||
super().__init__()
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
|
||||
# Determine which loader to use based on the dataset's content
|
||||
# Check if any entry in bytes_column is a bytes or bytestring type
|
||||
is_bytes_or_bytestring = any(
|
||||
isinstance(value, (bytes, memoryview))
|
||||
for value in dataset[column].dropna().head(1).astype(str)
|
||||
)
|
||||
|
||||
if is_bytes_or_bytestring:
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# Check if file_path column contains valid image file paths (at least one entry)
|
||||
sample_paths = dataset[column].dropna().head(1).astype(str)
|
||||
|
||||
# Regex pattern for matching image file paths (adjust as needed)
|
||||
filepath_pattern = r'.*(?:/|\\).*\.([jJ][pP][gG]|png|gif)$'
|
||||
|
||||
if any(
|
||||
re.match(filepath_pattern, path, flags=re.IGNORECASE)
|
||||
for path in sample_paths
|
||||
):
|
||||
self.loader = FilePathLoader(
|
||||
dataset,
|
||||
file_path_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
else:
|
||||
# If neither condition is met, default to JPGImageLoader (assuming bytes are stored as strings)
|
||||
self.loader = JPGImageLoader(
|
||||
dataset,
|
||||
bytes_column=column,
|
||||
label_column=label_column
|
||||
)
|
||||
|
||||
# Get all items
|
||||
self.items = [self.loader.get_item(idx) for idx in range(len(dataset))]
|
||||
|
||||
# Split into train and validation sets
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
# Create datasets for training and validation
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
self.val_dataset = self._create_subset(val_indices)
|
||||
|
||||
def _split_data(self):
|
||||
if len(self.items) == 0:
|
||||
return [], []
|
||||
|
||||
tasks = [item[1] if len(item) > 1 and hasattr(item, '__getitem__') else None for item in self.items]
|
||||
unique_tasks = list(set(tasks)) if tasks.count(None) < len(tasks) else []
|
||||
|
||||
train_indices = []
|
||||
val_indices = []
|
||||
|
||||
for task in unique_tasks:
|
||||
task_indices = [i for i, t in enumerate(tasks) if t == task]
|
||||
n_val = int(len(task_indices) * self.val_split)
|
||||
|
||||
random.shuffle(task_indices)
|
||||
|
||||
val_indices.extend(task_indices[:n_val])
|
||||
train_indices.extend(task_indices[n_val:])
|
||||
|
||||
return train_indices, val_indices
|
||||
|
||||
def _create_subset(self, indices):
|
||||
subset_items = [self.items[i] for i in indices]
|
||||
return AIIADataset(subset_items)
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, items):
|
||||
self.items = items
|
||||
|
||||
def __len__(self):
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.items[idx]
|
||||
if isinstance(item, tuple) and len(item) == 2:
|
||||
image, label = item
|
||||
return (image, label)
|
||||
elif isinstance(item, tuple) and len(item) == 3:
|
||||
image, task, label = item
|
||||
# Handle tasks accordingly (e.g., apply different augmentations)
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = image + torch.randn_like(image) * noise_std
|
||||
target = image
|
||||
return (noisy_img, target, task)
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(image, angle)
|
||||
target = torch.tensor(angle).long()
|
||||
return (rotated_img, target, task)
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
||||
else:
|
||||
# Handle single images without labels or tasks
|
||||
if isinstance(item, Image.Image):
|
||||
return item
|
||||
else:
|
||||
raise ValueError("Invalid item format.")
|
||||
|
|
Loading…
Reference in New Issue