added first dataloader conzept

This commit is contained in:
Falko Victor Habel 2025-01-11 23:37:16 +01:00
parent 7287ba543f
commit 0ba2a3f23c
1 changed files with 68 additions and 0 deletions

68
src/data/DataLoader.py Normal file
View File

@ -0,0 +1,68 @@
import os
from PIL import Image
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
import random
import numpy as np
class AIIADataLoader:
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
self.data_dir = data_dir
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Set random seeds for reproducibility
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# Load and split dataset
self.image_paths = self._load_image_paths()
self.train_paths, self.val_paths = self._split_data(self.image_paths)
# Define transformations
self.transform = transforms.Compose([
transforms.ToTensor()
])
# Create datasets and dataloaders
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform)
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
def _load_image_paths(self):
images = []
for filename in os.listdir(self.data_dir):
if any(filename.endswith(ext) for ext in ['.png', '.jpg', '.jpeg']):
images.append(os.path.join(self.data_dir, filename))
return images
def _split_data(self, paths):
n_val = int(len(paths) * self.val_split)
train_paths = paths[n_val:]
val_paths = paths[:n_val]
return train_paths, val_paths
class AIIADataset(Dataset):
def __init__(self, image_paths, transform=None):
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img_path = self.image_paths[idx]
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
return image