imrpoved Dataloader to actually load the images
This commit is contained in:
parent
b371d747fd
commit
cbacd5e03c
|
@ -1,82 +1,97 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import numpy as np
|
||||
from sklearn.model_selection import train_test_split
|
||||
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, limit=None):
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
# Load and split dataset
|
||||
self.image_paths = self._load_image_paths()
|
||||
self.train_paths, self.val_paths = self._split_data(self.image_paths)
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
print(f'Using device: {self.device}')
|
||||
|
||||
image_paths = self._load_image_paths(data_dir, limit=limit)
|
||||
|
||||
train_paths, val_paths = self._split_data(image_paths, val_split=val_split)
|
||||
|
||||
# Split train paths into denoising and rotation subsets
|
||||
num_train = len(self.train_paths)
|
||||
mid_index = num_train // 2
|
||||
self.denoise_train_paths = self.train_paths[:mid_index]
|
||||
self.rotation_train_paths = self.train_paths[mid_index:]
|
||||
|
||||
# Define transformations
|
||||
self.transform = transforms.Compose([
|
||||
# Create combined dataset for training with both denoise and rotate tasks
|
||||
train_denoise_paths = [(path, 'denoise') for path in train_paths]
|
||||
train_rotate_paths = [(path, 'rotate') for path in train_paths]
|
||||
train_combined = train_denoise_paths + train_rotate_paths
|
||||
|
||||
val_denoise_paths = [(path, 'denoise') for path in val_paths]
|
||||
val_rotate_paths = [(path, 'rotate') for path in val_paths]
|
||||
val_combined = val_denoise_paths + val_rotate_paths
|
||||
|
||||
transform = transforms.Compose([
|
||||
transforms.Resize((256, 256)),
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Create datasets and dataloaders for denoising and rotation
|
||||
self.denoise_dataset = AIIADataset(self.denoise_train_paths, transform=self.transform)
|
||||
self.rotation_dataset = AIIADataset(self.rotation_train_paths, transform=self.transform)
|
||||
|
||||
self.denoise_loader = DataLoader(self.denoise_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.rotation_loader = DataLoader(self.rotation_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Validation loader
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
|
||||
|
||||
self.train_dataset = AIIADataset(train_combined, transform=transform)
|
||||
self.val_dataset = AIIADataset(val_combined, transform=transform)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
class AIIADataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None, task='denoising'):
|
||||
self.image_paths = image_paths
|
||||
def _load_image_paths(self, data_dir, limit=None):
|
||||
extensions = ('.png', '.jpeg', '.jpg')
|
||||
image_paths = []
|
||||
for root, dirs, files in os.walk(data_dir):
|
||||
for file in files:
|
||||
if file.lower().endswith(extensions):
|
||||
image_paths.append(os.path.join(root, file))
|
||||
image_paths = sorted(list(set(image_paths)))
|
||||
if limit is not None:
|
||||
image_paths = image_paths[:limit]
|
||||
return image_paths
|
||||
|
||||
def _split_data(self, image_paths, val_split=0.2):
|
||||
train_paths, val_paths = train_test_split(
|
||||
image_paths, test_size=val_split, random_state=42
|
||||
)
|
||||
return train_paths, val_paths
|
||||
|
||||
|
||||
class AIIADataset(torch.utils.data.Dataset):
|
||||
def __init__(self, data_paths, transform=None, preload=False):
|
||||
self.data_paths = data_paths
|
||||
self.transform = transform
|
||||
self.task = task
|
||||
if task == 'denoising':
|
||||
self.noise_transform = transforms.Compose([
|
||||
lambda x: x + torch.randn_like(x) * 0.1 # Adjust noise level as needed
|
||||
])
|
||||
elif task == 'rotation':
|
||||
self.rotation_angles = [0, 90, 180, 270]
|
||||
else:
|
||||
raise ValueError("Unknown task")
|
||||
self.preload = preload
|
||||
self.loaded_images = {}
|
||||
|
||||
if self.preload:
|
||||
for path, task in self.data_paths:
|
||||
img = Image.open(path).convert('RGB')
|
||||
self.loaded_images[path] = img
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
return len(self.data_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
if self.task == 'denoising':
|
||||
noisy_image = self.noise_transform(image)
|
||||
return noisy_image, image # input: noisy, target: clean
|
||||
elif self.task == 'rotation':
|
||||
angle = random.choice(self.rotation_angles)
|
||||
rotated_image = transforms.functional.rotate(image, angle)
|
||||
return rotated_image, angle # input: rotated image, target: angle
|
||||
path, task = self.data_paths[idx]
|
||||
if self.preload:
|
||||
img = self.loaded_images[path]
|
||||
else:
|
||||
raise ValueError("Unknown task")
|
||||
img = Image.open(path).convert('RGB')
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
noisy_img = img + torch.randn_like(img) * noise_std
|
||||
target = img
|
||||
return noisy_img, target, task
|
||||
elif task == 'rotate':
|
||||
angles = [0, 90, 180, 270]
|
||||
angle = random.choice(angles)
|
||||
rotated_img = transforms.functional.rotate(img, angle)
|
||||
target = torch.tensor(angle).long()
|
||||
return rotated_img, target, task
|
||||
else:
|
||||
raise ValueError(f"Unknown task: {task}")
|
Loading…
Reference in New Issue