added "denosing" and "rotation" as pretraining options in the DataLoader
This commit is contained in:
parent
34b4f35e51
commit
530f499efb
|
@ -7,11 +7,12 @@ import random
|
|||
import numpy as np
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, task='denoising'):
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.task = task
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
|
@ -31,8 +32,8 @@ class AIIADataLoader:
|
|||
])
|
||||
|
||||
# Create datasets and dataloaders
|
||||
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform)
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
|
||||
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform, task=self.task)
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform, task=self.task)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
@ -51,13 +52,22 @@ class AIIADataLoader:
|
|||
return train_paths, val_paths
|
||||
|
||||
class AIIADataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None):
|
||||
def __init__(self, image_paths, transform=None, task='denoising'):
|
||||
self.image_paths = image_paths
|
||||
self.transform = transform
|
||||
|
||||
self.task = task
|
||||
if task == 'denoising':
|
||||
self.noise_transform = transforms.Compose([
|
||||
lambda x: x + torch.randn_like(x) * 0.1 # Adjust noise level as needed
|
||||
])
|
||||
elif task == 'rotation':
|
||||
self.rotation_angles = [0, 90, 180, 270]
|
||||
else:
|
||||
raise ValueError("Unknown task")
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
@ -65,4 +75,12 @@ class AIIADataset(Dataset):
|
|||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image
|
||||
if self.task == 'denoising':
|
||||
noisy_image = self.noise_transform(image)
|
||||
return noisy_image, image # input: noisy, target: clean
|
||||
elif self.task == 'rotation':
|
||||
angle = random.choice(self.rotation_angles)
|
||||
rotated_image = transforms.functional.rotate(image, angle)
|
||||
return rotated_image, angle # input: rotated image, target: angle
|
||||
else:
|
||||
raise ValueError("Unknown task")
|
Loading…
Reference in New Issue