imrpoved Dataloader to actually load the images

This commit is contained in:
Falko Victor Habel 2025-01-13 11:06:30 +01:00
parent b371d747fd
commit cbacd5e03c
1 changed files with 76 additions and 61 deletions

View File

@ -1,82 +1,97 @@
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data import DataLoader
from torchvision import transforms
import random
import numpy as np
from sklearn.model_selection import train_test_split
class AIIADataLoader:
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
self.data_dir = data_dir
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, limit=None):
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
# Load and split dataset
self.image_paths = self._load_image_paths()
self.train_paths, self.val_paths = self._split_data(self.image_paths)
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {self.device}')
image_paths = self._load_image_paths(data_dir, limit=limit)
train_paths, val_paths = self._split_data(image_paths, val_split=val_split)
# Split train paths into denoising and rotation subsets
num_train = len(self.train_paths)
mid_index = num_train // 2
self.denoise_train_paths = self.train_paths[:mid_index]
self.rotation_train_paths = self.train_paths[mid_index:]
# Define transformations
self.transform = transforms.Compose([
# Create combined dataset for training with both denoise and rotate tasks
train_denoise_paths = [(path, 'denoise') for path in train_paths]
train_rotate_paths = [(path, 'rotate') for path in train_paths]
train_combined = train_denoise_paths + train_rotate_paths
val_denoise_paths = [(path, 'denoise') for path in val_paths]
val_rotate_paths = [(path, 'rotate') for path in val_paths]
val_combined = val_denoise_paths + val_rotate_paths
transform = transforms.Compose([
transforms.Resize((256, 256)),
transforms.ToTensor()
])
# Create datasets and dataloaders for denoising and rotation
self.denoise_dataset = AIIADataset(self.denoise_train_paths, transform=self.transform)
self.rotation_dataset = AIIADataset(self.rotation_train_paths, transform=self.transform)
self.denoise_loader = DataLoader(self.denoise_dataset, batch_size=batch_size, shuffle=True)
self.rotation_loader = DataLoader(self.rotation_dataset, batch_size=batch_size, shuffle=True)
# Validation loader
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
self.train_dataset = AIIADataset(train_combined, transform=transform)
self.val_dataset = AIIADataset(val_combined, transform=transform)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
class AIIADataset(Dataset):
def __init__(self, image_paths, transform=None, task='denoising'):
self.image_paths = image_paths
def _load_image_paths(self, data_dir, limit=None):
extensions = ('.png', '.jpeg', '.jpg')
image_paths = []
for root, dirs, files in os.walk(data_dir):
for file in files:
if file.lower().endswith(extensions):
image_paths.append(os.path.join(root, file))
image_paths = sorted(list(set(image_paths)))
if limit is not None:
image_paths = image_paths[:limit]
return image_paths
def _split_data(self, image_paths, val_split=0.2):
train_paths, val_paths = train_test_split(
image_paths, test_size=val_split, random_state=42
)
return train_paths, val_paths
class AIIADataset(torch.utils.data.Dataset):
def __init__(self, data_paths, transform=None, preload=False):
self.data_paths = data_paths
self.transform = transform
self.task = task
if task == 'denoising':
self.noise_transform = transforms.Compose([
lambda x: x + torch.randn_like(x) * 0.1 # Adjust noise level as needed
])
elif task == 'rotation':
self.rotation_angles = [0, 90, 180, 270]
else:
raise ValueError("Unknown task")
self.preload = preload
self.loaded_images = {}
if self.preload:
for path, task in self.data_paths:
img = Image.open(path).convert('RGB')
self.loaded_images[path] = img
def __len__(self):
return len(self.image_paths)
return len(self.data_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
if self.task == 'denoising':
noisy_image = self.noise_transform(image)
return noisy_image, image # input: noisy, target: clean
elif self.task == 'rotation':
angle = random.choice(self.rotation_angles)
rotated_image = transforms.functional.rotate(image, angle)
return rotated_image, angle # input: rotated image, target: angle
path, task = self.data_paths[idx]
if self.preload:
img = self.loaded_images[path]
else:
raise ValueError("Unknown task")
img = Image.open(path).convert('RGB')
if task == 'denoise':
noise_std = 0.1
noisy_img = img + torch.randn_like(img) * noise_std
target = img
return noisy_img, target, task
elif task == 'rotate':
angles = [0, 90, 180, 270]
angle = random.choice(angles)
rotated_img = transforms.functional.rotate(img, angle)
target = torch.tensor(angle).long()
return rotated_img, target, task
else:
raise ValueError(f"Unknown task: {task}")