improved datahandling
This commit is contained in:
parent
530f499efb
commit
6757718569
|
@ -7,12 +7,11 @@ import random
|
|||
import numpy as np
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, task='denoising'):
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.task = task
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
|
@ -26,30 +25,27 @@ class AIIADataLoader:
|
|||
self.image_paths = self._load_image_paths()
|
||||
self.train_paths, self.val_paths = self._split_data(self.image_paths)
|
||||
|
||||
# Split train paths into denoising and rotation subsets
|
||||
num_train = len(self.train_paths)
|
||||
mid_index = num_train // 2
|
||||
self.denoise_train_paths = self.train_paths[:mid_index]
|
||||
self.rotation_train_paths = self.train_paths[mid_index:]
|
||||
|
||||
# Define transformations
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Create datasets and dataloaders
|
||||
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform, task=self.task)
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform, task=self.task)
|
||||
# Create datasets and dataloaders for denoising and rotation
|
||||
self.denoise_dataset = AIIADataset(self.denoise_train_paths, transform=self.transform)
|
||||
self.rotation_dataset = AIIADataset(self.rotation_train_paths, transform=self.transform)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.denoise_loader = DataLoader(self.denoise_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.rotation_loader = DataLoader(self.rotation_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
# Validation loader
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def _load_image_paths(self):
|
||||
images = []
|
||||
for filename in os.listdir(self.data_dir):
|
||||
if any(filename.endswith(ext) for ext in ['.png', '.jpg', '.jpeg']):
|
||||
images.append(os.path.join(self.data_dir, filename))
|
||||
return images
|
||||
|
||||
def _split_data(self, paths):
|
||||
n_val = int(len(paths) * self.val_split)
|
||||
train_paths = paths[n_val:]
|
||||
val_paths = paths[:n_val]
|
||||
return train_paths, val_paths
|
||||
|
||||
class AIIADataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None, task='denoising'):
|
|
@ -0,0 +1 @@
|
|||
from .DataLoader import AIIADataLoader
|
Loading…
Reference in New Issue