added "denosing" and "rotation" as pretraining options in the DataLoader

This commit is contained in:
Falko Victor Habel 2025-01-11 23:48:31 +01:00
parent 34b4f35e51
commit 530f499efb
1 changed files with 25 additions and 7 deletions

View File

@ -7,11 +7,12 @@ import random
import numpy as np
class AIIADataLoader:
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, task='denoising'):
self.data_dir = data_dir
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.task = task
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds for reproducibility
@ -31,8 +32,8 @@ class AIIADataLoader:
])
# Create datasets and dataloaders
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform)
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform, task=self.task)
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform, task=self.task)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
@ -51,13 +52,22 @@ class AIIADataLoader:
return train_paths, val_paths
class AIIADataset(Dataset):
def __init__(self, image_paths, transform=None):
def __init__(self, image_paths, transform=None, task='denoising'):
self.image_paths = image_paths
self.transform = transform
self.task = task
if task == 'denoising':
self.noise_transform = transforms.Compose([
lambda x: x + torch.randn_like(x) * 0.1 # Adjust noise level as needed
])
elif task == 'rotation':
self.rotation_angles = [0, 90, 180, 270]
else:
raise ValueError("Unknown task")
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
@ -65,4 +75,12 @@ class AIIADataset(Dataset):
if self.transform:
image = self.transform(image)
return image
if self.task == 'denoising':
noisy_image = self.noise_transform(image)
return noisy_image, image # input: noisy, target: clean
elif self.task == 'rotation':
angle = random.choice(self.rotation_angles)
rotated_image = transforms.functional.rotate(image, angle)
return rotated_image, angle # input: rotated image, target: angle
else:
raise ValueError("Unknown task")