added first dataloader conzept
This commit is contained in:
parent
7287ba543f
commit
0ba2a3f23c
|
@ -0,0 +1,68 @@
|
|||
import os
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import random
|
||||
import numpy as np
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, data_dir, batch_size=32, val_split=0.2, seed=42):
|
||||
self.data_dir = data_dir
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Set random seeds for reproducibility
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Load and split dataset
|
||||
self.image_paths = self._load_image_paths()
|
||||
self.train_paths, self.val_paths = self._split_data(self.image_paths)
|
||||
|
||||
# Define transformations
|
||||
self.transform = transforms.Compose([
|
||||
transforms.ToTensor()
|
||||
])
|
||||
|
||||
# Create datasets and dataloaders
|
||||
self.train_dataset = AIIADataset(self.train_paths, transform=self.transform)
|
||||
self.val_dataset = AIIADataset(self.val_paths, transform=self.transform)
|
||||
|
||||
self.train_loader = DataLoader(self.train_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.val_loader = DataLoader(self.val_dataset, batch_size=batch_size, shuffle=False)
|
||||
|
||||
def _load_image_paths(self):
|
||||
images = []
|
||||
for filename in os.listdir(self.data_dir):
|
||||
if any(filename.endswith(ext) for ext in ['.png', '.jpg', '.jpeg']):
|
||||
images.append(os.path.join(self.data_dir, filename))
|
||||
return images
|
||||
|
||||
def _split_data(self, paths):
|
||||
n_val = int(len(paths) * self.val_split)
|
||||
train_paths = paths[n_val:]
|
||||
val_paths = paths[:n_val]
|
||||
return train_paths, val_paths
|
||||
|
||||
class AIIADataset(Dataset):
|
||||
def __init__(self, image_paths, transform=None):
|
||||
self.image_paths = image_paths
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self):
|
||||
return len(self.image_paths)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
img_path = self.image_paths[idx]
|
||||
image = Image.open(img_path).convert('RGB')
|
||||
|
||||
if self.transform:
|
||||
image = self.transform(image)
|
||||
|
||||
return image
|
Loading…
Reference in New Issue