fixed Dataloader and added test cases.

This commit is contained in:
Falko Victor Habel 2025-03-15 23:42:08 +01:00
parent 24a3a7bf56
commit fd518ff080
2 changed files with 163 additions and 41 deletions

View File

@ -15,7 +15,7 @@ class FilePathLoader:
self.successful_count = 0
self.skipped_count = 0
if self.file_path_column not in dataset.column_names:
if self.file_path_column not in dataset.columns:
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
def _get_image(self, item):
@ -106,7 +106,11 @@ class JPGImageLoader:
print(f"Skipped {self.skipped_count} images due to errors.")
class AIIADataLoader:
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
label_column=None, pretraining=False, **dataloader_kwargs):
if column not in dataset.columns:
raise ValueError(f"Column '{column}' not found in dataset")
self.batch_size = batch_size
self.val_split = val_split
self.seed = seed
@ -145,7 +149,6 @@ class AIIADataLoader:
if not self.items:
raise ValueError("No valid items were loaded from the dataset")
train_indices, val_indices = self._split_data()
self.train_dataset = self._create_subset(train_indices)
@ -192,9 +195,11 @@ class AIIADataset(torch.utils.data.Dataset):
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
if task == 'denoise':
noise_std = 0.1
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
image, label = item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
return image, label
else:
if isinstance(item, Image.Image):
image = self.transform(item)
else:
image = self.transform(item[0])
if image.shape != (3, 224, 224):
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
image = item[0] if isinstance(item, tuple) else item
if not isinstance(image, Image.Image):
raise ValueError(f"Invalid image at index {idx}")
# Check image dimensions before transform
if image.size[0] < 224 or image.size[1] < 224:
raise ValueError("Invalid image dimensions")
image = self.transform(image)
return image

View File

@ -0,0 +1,112 @@
import pytest
from PIL import Image
import torch
from torch.utils.data import DataLoader
from torchvision import transforms
import pandas as pd
import numpy as np
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
def create_sample_dataset(file_paths=None):
if file_paths is None:
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
data = {
'file_path': file_paths,
'label': [0] * len(file_paths) # Match length of labels to file_paths
}
df = pd.DataFrame(data)
return df
def create_sample_bytes_dataset(bytes_data=None):
if bytes_data is None:
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
data = {
'jpg': bytes_data,
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
}
df = pd.DataFrame(data)
return df
def test_file_path_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
loader = FilePathLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_jpg_image_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_bytes_dataset()
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
item = loader.get_item(0)
assert isinstance(item[0], Image.Image)
assert item[1] == 0
loader.print_summary()
def test_aiia_data_loader(mocker):
# Mock Image.open to return a fake image
mock_image = Image.new('RGB', (224, 224))
mocker.patch('PIL.Image.open', return_value=mock_image)
dataset = create_sample_dataset()
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
# Test train loader
batch = next(iter(data_loader.train_loader))
assert isinstance(batch, list)
assert len(batch) == 2 # (images, labels)
assert batch[0].shape[0] == 1 # batch size
def test_aiia_dataset():
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
dataset = AIIADataset(items)
assert len(dataset) == 2
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert item[1] == 0
def test_aiia_dataset_pre_training():
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
assert len(dataset) == 1
item = dataset[0]
assert isinstance(item[0], torch.Tensor)
assert isinstance(item[2], str)
def test_aiia_dataset_invalid_image():
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
dataset = AIIADataset(items)
with pytest.raises(ValueError, match="Invalid image dimensions"):
dataset[0]
def test_aiia_dataset_invalid_task():
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
dataset = AIIADataset(items, pretraining=True)
with pytest.raises(ValueError):
dataset[0]
def test_aiia_data_loader_invalid_column():
dataset = create_sample_dataset()
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
AIIADataLoader(dataset, column='invalid_column')
if __name__ == "__main__":
pytest.main(['-v'])