fixed Dataloader and added test cases.
This commit is contained in:
parent
24a3a7bf56
commit
fd518ff080
|
@ -15,7 +15,7 @@ class FilePathLoader:
|
|||
self.successful_count = 0
|
||||
self.skipped_count = 0
|
||||
|
||||
if self.file_path_column not in dataset.column_names:
|
||||
if self.file_path_column not in dataset.columns:
|
||||
raise ValueError(f"Column '{self.file_path_column}' not found in dataset.")
|
||||
|
||||
def _get_image(self, item):
|
||||
|
@ -106,7 +106,11 @@ class JPGImageLoader:
|
|||
print(f"Skipped {self.skipped_count} images due to errors.")
|
||||
|
||||
class AIIADataLoader:
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path", label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
def __init__(self, dataset, batch_size=32, val_split=0.2, seed=42, column="file_path",
|
||||
label_column=None, pretraining=False, **dataloader_kwargs):
|
||||
if column not in dataset.columns:
|
||||
raise ValueError(f"Column '{column}' not found in dataset")
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.val_split = val_split
|
||||
self.seed = seed
|
||||
|
@ -145,7 +149,6 @@ class AIIADataLoader:
|
|||
if not self.items:
|
||||
raise ValueError("No valid items were loaded from the dataset")
|
||||
|
||||
|
||||
train_indices, val_indices = self._split_data()
|
||||
|
||||
self.train_dataset = self._create_subset(train_indices)
|
||||
|
@ -192,9 +195,11 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
|
||||
if task == 'denoise':
|
||||
noise_std = 0.1
|
||||
|
@ -214,15 +219,20 @@ class AIIADataset(torch.utils.data.Dataset):
|
|||
image, label = item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
|
||||
image = self.transform(image)
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
return image, label
|
||||
else:
|
||||
if isinstance(item, Image.Image):
|
||||
image = self.transform(item)
|
||||
else:
|
||||
image = self.transform(item[0])
|
||||
if image.shape != (3, 224, 224):
|
||||
raise ValueError(f"Invalid image shape at index {idx}: {image.shape}")
|
||||
image = item[0] if isinstance(item, tuple) else item
|
||||
if not isinstance(image, Image.Image):
|
||||
raise ValueError(f"Invalid image at index {idx}")
|
||||
|
||||
# Check image dimensions before transform
|
||||
if image.size[0] < 224 or image.size[1] < 224:
|
||||
raise ValueError("Invalid image dimensions")
|
||||
image = self.transform(image)
|
||||
return image
|
||||
|
|
|
@ -0,0 +1,112 @@
|
|||
import pytest
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision import transforms
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
from aiia.data.DataLoader import FilePathLoader, JPGImageLoader, AIIADataLoader, AIIADataset
|
||||
|
||||
def create_sample_dataset(file_paths=None):
|
||||
if file_paths is None:
|
||||
file_paths = ['path/to/image1.jpg', 'path/to/image2.png']
|
||||
|
||||
data = {
|
||||
'file_path': file_paths,
|
||||
'label': [0] * len(file_paths) # Match length of labels to file_paths
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
return df
|
||||
|
||||
def create_sample_bytes_dataset(bytes_data=None):
|
||||
if bytes_data is None:
|
||||
bytes_data = [b'fake_image_data_1', b'fake_image_data_2']
|
||||
|
||||
data = {
|
||||
'jpg': bytes_data,
|
||||
'label': [0] * len(bytes_data) # Match length of labels to bytes_data
|
||||
}
|
||||
df = pd.DataFrame(data)
|
||||
return df
|
||||
|
||||
def test_file_path_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_dataset()
|
||||
loader = FilePathLoader(dataset, label_column='label') # Added label_column
|
||||
item = loader.get_item(0)
|
||||
assert isinstance(item[0], Image.Image)
|
||||
assert item[1] == 0
|
||||
|
||||
loader.print_summary()
|
||||
|
||||
def test_jpg_image_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_bytes_dataset()
|
||||
loader = JPGImageLoader(dataset, label_column='label') # Added label_column
|
||||
item = loader.get_item(0)
|
||||
assert isinstance(item[0], Image.Image)
|
||||
assert item[1] == 0
|
||||
|
||||
loader.print_summary()
|
||||
|
||||
def test_aiia_data_loader(mocker):
|
||||
# Mock Image.open to return a fake image
|
||||
mock_image = Image.new('RGB', (224, 224))
|
||||
mocker.patch('PIL.Image.open', return_value=mock_image)
|
||||
|
||||
dataset = create_sample_dataset()
|
||||
data_loader = AIIADataLoader(dataset, batch_size=2, label_column='label')
|
||||
|
||||
# Test train loader
|
||||
batch = next(iter(data_loader.train_loader))
|
||||
assert isinstance(batch, list)
|
||||
assert len(batch) == 2 # (images, labels)
|
||||
assert batch[0].shape[0] == 1 # batch size
|
||||
|
||||
def test_aiia_dataset():
|
||||
items = [(Image.new('RGB', (224, 224)), 0), (Image.new('RGB', (224, 224)), 1)]
|
||||
dataset = AIIADataset(items)
|
||||
|
||||
assert len(dataset) == 2
|
||||
|
||||
item = dataset[0]
|
||||
assert isinstance(item[0], torch.Tensor)
|
||||
assert item[1] == 0
|
||||
|
||||
def test_aiia_dataset_pre_training():
|
||||
items = [(Image.new('RGB', (224, 224)), 'denoise', Image.new('RGB', (224, 224)))]
|
||||
dataset = AIIADataset(items, pretraining=True)
|
||||
|
||||
assert len(dataset) == 1
|
||||
|
||||
item = dataset[0]
|
||||
assert isinstance(item[0], torch.Tensor)
|
||||
assert isinstance(item[2], str)
|
||||
|
||||
def test_aiia_dataset_invalid_image():
|
||||
items = [(Image.new('RGB', (50, 50)), 0)] # Create small image
|
||||
dataset = AIIADataset(items)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid image dimensions"):
|
||||
dataset[0]
|
||||
|
||||
def test_aiia_dataset_invalid_task():
|
||||
items = [(Image.new('RGB', (224, 224)), 'invalid_task', Image.new('RGB', (224, 224)))]
|
||||
dataset = AIIADataset(items, pretraining=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
dataset[0]
|
||||
|
||||
def test_aiia_data_loader_invalid_column():
|
||||
dataset = create_sample_dataset()
|
||||
with pytest.raises(ValueError, match="Column 'invalid_column' not found"):
|
||||
AIIADataLoader(dataset, column='invalid_column')
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(['-v'])
|
Loading…
Reference in New Issue