improved datahandling

This commit is contained in:
Falko Victor Habel 2025-01-12 20:45:41 +01:00
parent 530f499efb
commit 6757718569
3 changed files with 16 additions and 19 deletions

View File

@ -7,12 +7,11 @@ import random
import numpy as np
class AIIADataLoader:
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42, task='denoising'):
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
self.data_dir = data_dir
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.task = task
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds for reproducibility
@ -26,30 +25,27 @@ class AIIADataLoader:
self.image_paths = self._load_image_paths()
self.train_paths, self.val_paths = self._split_data(self.image_paths)
# Split train paths into denoising and rotation subsets
num_train = len(self.train_paths)
mid_index = num_train // 2
self.denoise_train_paths = self.train_paths[:mid_index]
self.rotation_train_paths = self.train_paths[mid_index:]
# Define transformations
self.transform = transforms.Compose([
transforms.ToTensor()
])
# Create datasets and dataloaders
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform, task=self.task)
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform, task=self.task)
# Create datasets and dataloaders for denoising and rotation
self.denoise_dataset = AIIADataset(self.denoise_train_paths, transform=self.transform)
self.rotation_dataset = AIIADataset(self.rotation_train_paths, transform=self.transform)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
self.denoise_loader = DataLoader(self.denoise_dataset, batch_size=batch_size, shuffle=True)
self.rotation_loader = DataLoader(self.rotation_dataset, batch_size=batch_size, shuffle=True)
# Validation loader
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
def _load_image_paths(self):
images = []
for filename in os.listdir(self.data_dir):
if any(filename.endswith(ext) for ext in ['.png', '.jpg', '.jpeg']):
images.append(os.path.join(self.data_dir, filename))
return images
def _split_data(self, paths):
n_val = int(len(paths) * self.val_split)
train_paths = paths[n_val:]
val_paths = paths[:n_val]
return train_paths, val_paths
class AIIADataset(Dataset):
def __init__(self, image_paths, transform=None, task='denoising'):

View File

@ -0,0 +1 @@
from .DataLoader import AIIADataLoader

View File